Mnist之RNN实现方式二_第1页
Mnist之RNN实现方式二_第2页
Mnist之RNN实现方式二_第3页
全文预览已结束

下载本文档

版权说明:本文档由用户提供并上传,收益归属内容提供方,若内容存在侵权,请进行举报或认领

文档简介

1、Mnist 之 RNN 实现方式二# -*- coding: utf-8 -*-import tensorflow as tf import numpy as np# 导入模型数据 from tensorflow.examples.tutorials.mnist import input_data mnist=input_data.read_data_sets(Mnist_Data,one_hot=T rue)# 输入图片是 28*28 n_inputs = 28 # 输入一行,一行有 28 个数据 max_time = 28 # 一共 28 行 lstm_size = 100 # 隐层单元

2、hiddenn_classes = 10 #10 个分类 batch_size = 10 # 每个批次 100 个样本 #总的训练次数 training_iters=20000# 这里的 none 表示第一个维度可以是 任意的长度x = tf.placeholder(tf.float32, None, 784)y = tf.placeholder(tf.float32, None, 10)# 初始化权值 lstm_size 隐层单元weights = tf.Variable(tf.truncated_normal(lstm_size, n_classes, stddev=0.1) #100,1

3、0# 初始化偏置值biases = tf.Variable(tf.constant(0.1, shape=n_classes)# 定义 RNN 网络 def RNN(X, weights, biases):#input=batch_size, max_time, n_inputs#输入变成10, 28, 28inputs = tf.reshape(X, -1, max_time, n_inputs)#定义 LSTM 基本 CELL# lstm_cell = tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(lstm_size)lstm_cell = tf.n

4、n.rnn_cell.BasicLSTMCell(lstm_size, forget_bias=0.1, state_is_tuple=True)#final_state0 是 cell state#final_state1 是 hidden_state outputs,final_state = tf.nn.dynamic_rnn(lstm_cell, inputs, dtype=tf.float32)print(outputs: + str(outputs) print(final_state: + str(final_state)results = tf.nn.softmax(tf.ma

5、tmul(final_state1, weights)+biases)return results# 计算 RNN 的返回结果 prediction = RNN(x, weights, biases) #损失代价函数 cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(lo gits=prediction, labels=y)#使用 AdamOptimizer 进行优化 train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) #结

6、果存放在一个布尔列表中correct_prediction = tf.equal(tf.argmax(prediction, 1), tf.argmax(y, 1) #argmax 返回一维张量中最大的值所在的 位置#求准确率accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)# 开始训练with tf.Session() as sess: sess.run(tf.global_variables_initializer()for step inrange(training_iters):batch_xs, batch_ys = mnist.train.next_batch(batch_size)sess.run(train_step, feed_dict=x: batch_xs, y: batch_ys) if step % 10 = 0:acc = sess.run(accuracy, feed_

温馨提示

  • 1. 本站所有资源如无特殊说明,都需要本地电脑安装OFFICE2007和PDF阅读器。图纸软件为CAD,CAXA,PROE,UG,SolidWorks等.压缩文件请下载最新的WinRAR软件解压。
  • 2. 本站的文档不包含任何第三方提供的附件图纸等,如果需要附件,请联系上传者。文件的所有权益归上传用户所有。
  • 3. 本站RAR压缩包中若带图纸,网页内容里面会有图纸预览,若没有图纸预览就没有图纸。
  • 4. 未经权益所有人同意不得将文件中的内容挪作商业或盈利用途。
  • 5. 人人文库网仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对用户上传分享的文档内容本身不做任何修改或编辑,并不能对任何下载内容负责。
  • 6. 下载文件中如有侵权或不适当内容,请与我们联系,我们立即纠正。
  • 7. 本站不保证下载资源的准确性、安全性和完整性, 同时也不承担用户因使用这些下载资源对自己和他人造成任何形式的伤害或损失。

评论

0/150

提交评论