1. ホーム
  2. python

[解決済み] TensorFlowチュートリアルのbatch_xs, batch_ys = mnist.train.next_batch(100) のnext_batchはどこから来ているのでしょうか?

2022-02-17 23:28:34

質問

TensorFlowのチュートリアルを試しているのですが、この行のnext_batchがどこから来ているのか理解できません。

 batch_xs, batch_ys = mnist.train.next_batch(100)

を見てみました。

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

そして、そこにもnext_batchは見当たりませんでした。

自分のコードで next_batch を試してみると、次のようになります。

AttributeError: 'numpy.ndarray' object has no attribute 'next_batch'

そこで、next_batchがどこから来るのかを理解したいと思います。

どのように解決するのですか?

next_batch はメソッド DataSet クラス(参照 https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/learn/python/learn/datasets/mnist.py をご覧ください(授業の内容はこちら)。

mnistデータを読み込んで変数に代入すると mnist を使っています。

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

のクラスを見る。 mnist.train . 入力することで見ることができます。

print mnist.train.__class__

以下のように表示されます。

<class 'tensorflow.contrib.learn.python.learn.datasets.mnist.Dataset'>

なぜなら mnist.train はクラス DataSet を使用すると、このクラスの関数 next_batch . 授業の詳細は、以下をご覧ください。 ドキュメンテーション .