1200字范文,内容丰富有趣,写作的好帮手!
1200字范文 > 3.tensorflow单层神经网络mnist数字识别:训练 加载模型 预测图像

3.tensorflow单层神经网络mnist数字识别:训练 加载模型 预测图像

时间:2021-02-07 09:48:51

相关推荐

3.tensorflow单层神经网络mnist数字识别:训练 加载模型 预测图像

#coding:utf-8"""mnist分类,单层神经网络保存模型,加载模型,预测图像"""import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("/Users/ming/Downloads/zhangming/pytorch_demo/data/mnist", one_hot=True)import pylabtf.reset_default_graph()x = tf.placeholder(tf.float32, [None, 784])y = tf.placeholder(tf.float32, [None, 10])w = tf.Variable(initial_value=tf.random_normal([784, 10]), name="weight")b = tf.Variable(initial_value=tf.zeros([10]), name="bias")model = tf.matmul(x, w) + bpredict = tf.nn.softmax(model)cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(predict), reduction_indices=1))optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(cost)saver = tf.train.Saver(max_to_keep=2)batch_size = 32epochs = 20display_step = 2save_model = "mnist_model/mnist.cpkt"with tf.Session() as sess:sess.run(tf.global_variables_initializer())total_batch = int(mnist.train.num_examples/batch_size)for epoch in range(epochs):epoch_loss = 0for i in range(total_batch):batch_x, batch_y = mnist.train.next_batch(batch_size)_, loss = sess.run([optimizer, cost], feed_dict={x:batch_x, y:batch_y})epoch_loss += lossepoch_loss = epoch_loss/total_batchsaver.save(sess, save_model, global_step=epoch)if epoch % display_step == 0:print("epoch %d , loss %.2f" %(epoch, epoch_loss))print("done...")# 预测模型with tf.Session() as sess2:sess2.run(tf.global_variables_initializer())saver.restore(sess2, save_model)x_test, y = mnist.test.next_batch(2)correct_pred = tf.equal(tf.argmax(predict, 1), tf.argmax(y, 1))acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))predict, accuracy = sess2.run([predict, acc], feed_dict={x:x_test})print("predict:", predict)print("acc: %.4f" % accuracy)img1 = x_test[0].reshape([-1,28])pylab.imshow(img1)pylab.show()img2 = x_test[1].reshape([-1,28])pylab.imshow(img2)pylab.show()

输出:

/usr/local/bin/python2.7 /Users/ming/Downloads/zhangming/tf_demo/3.tf_mnist_1_layer.py

WARNING:tensorflow:From /Users/ming/Downloads/zhangming/tf_demo/3.tf_mnist_1_layer.py:7: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.

Instructions for updating:

Please use alternatives such as official/mnist/dataset.py from tensorflow/models.

WARNING:tensorflow:From /usr/local/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.

Instructions for updating:

Please write your own downloading logic.

WARNING:tensorflow:From /usr/local/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.

Instructions for updating:

Please use tf.data to implement this functionality.

Extracting /Users/ming/Downloads/zhangming/pytorch_demo/data/mnist/train-images-idx3-ubyte.gz

WARNING:tensorflow:From /usr/local/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.

Instructions for updating:

Please use tf.data to implement this functionality.

Extracting /Users/ming/Downloads/zhangming/pytorch_demo/data/mnist/train-labels-idx1-ubyte.gz

WARNING:tensorflow:From /usr/local/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.

Instructions for updating:

Please use tf.one_hot on tensors.

Extracting /Users/ming/Downloads/zhangming/pytorch_demo/data/mnist/t10k-images-idx3-ubyte.gz

Extracting /Users/ming/Downloads/zhangming/pytorch_demo/data/mnist/t10k-labels-idx1-ubyte.gz

WARNING:tensorflow:From /usr/local/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: __init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.

Instructions for updating:

Please use alternatives such as official/mnist/dataset.py from tensorflow/models.

-11-17 23:44:02.858203: I tensorflow/core/platform/:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA

epoch 0 , loss 5.37

epoch 2 , loss 1.45

epoch 4 , loss 1.07

epoch 6 , loss 0.91

epoch 8 , loss 0.82

epoch 10 , loss 0.76

epoch 12 , loss 0.71

epoch 14 , loss 0.67

epoch 16 , loss 0.64

epoch 18 , loss 0.62

done...

('predict:', array([[6.7180810e-12, 6.8759458e-14, 8.5211534e-12, 2.8013984e-09,

9.9993575e-01, 1.3801176e-08, 9.9942827e-09, 8.1797918e-10,

2.9035888e-05, 3.5194662e-05],

[3.0729464e-06, 4.6555795e-09, 1.1277327e-07, 1.6401351e-06,

1.5443068e-06, 1.3794046e-07, 9.7556288e-08, 6.1844635e-01,

1.1378842e-04, 3.8143331e-01]], dtype=float32))

acc: 1.0000

Process finished with exit code 0

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。