#coding:utf-8"""mnist分类,单层神经网络保存模型,加载模型,预测图像"""import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("/Users/ming/Downloads/zhangming/pytorch_demo/data/mnist", one_hot=True)import pylabtf.reset_default_graph()x = tf.placeholder(tf.float32, [None, 784])y = tf.placeholder(tf.float32, [None, 10])w = tf.Variable(initial_value=tf.random_normal([784, 10]), name="weight")b = tf.Variable(initial_value=tf.zeros([10]), name="bias")model = tf.matmul(x, w) + bpredict = tf.nn.softmax(model)cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(predict), reduction_indices=1))optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(cost)saver = tf.train.Saver(max_to_keep=2)batch_size = 32epochs = 20display_step = 2save_model = "mnist_model/mnist.cpkt"with tf.Session() as sess:sess.run(tf.global_variables_initializer())total_batch = int(mnist.train.num_examples/batch_size)for epoch in range(epochs):epoch_loss = 0for i in range(total_batch):batch_x, batch_y = mnist.train.next_batch(batch_size)_, loss = sess.run([optimizer, cost], feed_dict={x:batch_x, y:batch_y})epoch_loss += lossepoch_loss = epoch_loss/total_batchsaver.save(sess, save_model, global_step=epoch)if epoch % display_step == 0:print("epoch %d , loss %.2f" %(epoch, epoch_loss))print("done...")# 预测模型with tf.Session() as sess2:sess2.run(tf.global_variables_initializer())saver.restore(sess2, save_model)x_test, y = mnist.test.next_batch(2)correct_pred = tf.equal(tf.argmax(predict, 1), tf.argmax(y, 1))acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))predict, accuracy = sess2.run([predict, acc], feed_dict={x:x_test})print("predict:", predict)print("acc: %.4f" % accuracy)img1 = x_test[0].reshape([-1,28])pylab.imshow(img1)pylab.show()img2 = x_test[1].reshape([-1,28])pylab.imshow(img2)pylab.show()
输出:
/usr/local/bin/python2.7 /Users/ming/Downloads/zhangming/tf_demo/3.tf_mnist_1_layer.py
WARNING:tensorflow:From /Users/ming/Downloads/zhangming/tf_demo/3.tf_mnist_1_layer.py:7: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From /usr/local/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From /usr/local/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting /Users/ming/Downloads/zhangming/pytorch_demo/data/mnist/train-images-idx3-ubyte.gz
WARNING:tensorflow:From /usr/local/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting /Users/ming/Downloads/zhangming/pytorch_demo/data/mnist/train-labels-idx1-ubyte.gz
WARNING:tensorflow:From /usr/local/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting /Users/ming/Downloads/zhangming/pytorch_demo/data/mnist/t10k-images-idx3-ubyte.gz
Extracting /Users/ming/Downloads/zhangming/pytorch_demo/data/mnist/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From /usr/local/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: __init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
-11-17 23:44:02.858203: I tensorflow/core/platform/:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
epoch 0 , loss 5.37
epoch 2 , loss 1.45
epoch 4 , loss 1.07
epoch 6 , loss 0.91
epoch 8 , loss 0.82
epoch 10 , loss 0.76
epoch 12 , loss 0.71
epoch 14 , loss 0.67
epoch 16 , loss 0.64
epoch 18 , loss 0.62
done...
('predict:', array([[6.7180810e-12, 6.8759458e-14, 8.5211534e-12, 2.8013984e-09,
9.9993575e-01, 1.3801176e-08, 9.9942827e-09, 8.1797918e-10,
2.9035888e-05, 3.5194662e-05],
[3.0729464e-06, 4.6555795e-09, 1.1277327e-07, 1.6401351e-06,
1.5443068e-06, 1.3794046e-07, 9.7556288e-08, 6.1844635e-01,
1.1378842e-04, 3.8143331e-01]], dtype=float32))
acc: 1.0000
Process finished with exit code 0