1. ホーム
  2. python

[解決済み] PyTorchのmodel.train()は何をするのですか?

2022-05-13 14:12:42

質問

これは forward()nn.Module ? モデルを呼び出すときに思ったのですが forward メソッドが使われると思っていました。 なぜtrain()を指定する必要があるのでしょうか?

どのように解決するのですか?

model.train() は、モデルに学習中であることを伝えます。そのため、dropoutやbatchnormなど、訓練時とテスト時で挙動が異なるレイヤーは、何が起こっているのかを知り、それに応じて挙動を変えることができるのです。

より詳細です。 モードをtrainに設定します。 (参照 ソースコード ). のどちらかを呼び出すことができます。 model.eval() または model.train(mode=False) でテスト中であることがわかります。 直感的に理解できるのは train 関数がモデルを学習させることを期待するのは直感的ですが、これはそうではありません。モードを設定するだけです。