python周りの環境構築
これが良かった。
Python3なTensorFlow環境構築 (Macとpyenv virtualenv) - Qiita
- pythonは2,3のバージョンがある
- TensorFlowはどちらのバージョンでも動くが、githubレポジトリのinput_data.pyはpython3向けに書かれている
Fix dataset encoding in MNIST example for Python 3 · tensorflow/tensorflow@b44abb8 · GitHub
pyenvでバージョン管理
virtualenvでtensorFlow環境を作っても、python3にupdateされずにうまくいかなかった。pyenv-virtualenvで環境を構築して、python3向けのtensorFlowをinstallするとうまくいった。
学習のためのコード
https://www.tensorflow.org/versions/master/tutorials/mnist/beginners/index.html#the-mnist-data
英語なので日本語ソースを探す。
これを見た。
TensorFlow Tutorial MNIST For ML Beginners やった - Qiita
# -*- coding: utf-8 -*- import input_data import tensorflow as tf def main(): mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) # 重みと閾値 W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) # 特徴ベクトルを入れる x = tf.placeholder("float", [None, 784]) # softmax y = tf.nn.softmax(tf.matmul(x, W) + b) # 真のラベル y_ = tf.placeholder("float", [None, 10]) # 損失関数 cross_entropy = -tf.reduce_sum(y_ * tf.log(y)) # 学習の仕方を定義 train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) # セッションを準備 sess = tf.Session() # 変数を初期化 init = tf.initialize_all_variables() sess.run(init) for i in range(1000): batch_xs, batch_ys = mnist.train.next_batch(100) # 勾配を用いた更新 sess.run(train_step, feed_dict={x :batch_xs, y_: batch_ys}) # 正答率を返す関数を定義 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) # accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # 結果を眺める print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})) if __name__ == "__main__": main()
Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes. Extracting MNIST_data/train-images-idx3-ubyte.gz Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes. Extracting MNIST_data/train-labels-idx1-ubyte.gz Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes. Extracting MNIST_data/t10k-images-idx3-ubyte.gz Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes. Extracting MNIST_data/t10k-labels-idx1-ubyte.gz 0.92
92%ということだろうか。動いて感動した。