1. ホーム
  2. python

[解決済み] torch.nn.Parameterを理解する

2022-02-01 17:36:13

質問事項

私はpytorchの初心者ですが、どのように理解するのが難しいですか? torch.nn.Parameter() が動作します。

のドキュメントに目を通しました。 https://pytorch.org/docs/stable/nn.html が、ほとんど意味をなしていない。

誰か助けてください。

私が作業しているコードスニペット。

def __init__(self, weight):
    super(Net, self).__init__()
    # initializes the weights of the convolutional layer to be the weights of the 4 defined filters
    k_height, k_width = weight.shape[2:]
    # assumes there are 4 grayscale filters
    self.conv = nn.Conv2d(1, 4, kernel_size=(k_height, k_width), bias=False)
    self.conv.weight = torch.nn.Parameter(weight)

解決方法は?

分解して説明します。テンソルとは、ご存知のように多次元の行列のことです。パラメータはそのままの形ではテンソル、すなわち多次元行列である。これは Variable クラスのサブクラスである。

VariableとParameterの違いは、モジュールと関連付けるときに出てきます。Parameterがモデル属性としてモジュールに関連付けられると、自動的にパラメータリストに追加され、'parameters' iteratorを使用してアクセスすることができます。

当初、Torchでは変数(例えば、中間状態)も割り当て時にモデルのパラメータとして追加されました。その後、変数をパラメータリストに追加するのではなく、キャッシュする必要性があることがユースケースとして認識されるようになりました。

RNNでは、最後の隠された状態を保存して、何度も渡す必要がありません。変数が自動的にモデルのパラメータとして登録されるのではなく、キャッシュする必要があるため、モデルにパラメータを登録する明示的な方法、つまり nn.Parameter クラスが用意されています。

例えば、以下のコードを実行してみてください。

import torch
import torch.nn as nn
from torch.optim import Adam

class NN_Network(nn.Module):
    def __init__(self,in_dim,hid,out_dim):
        super(NN_Network, self).__init__()
        self.linear1 = nn.Linear(in_dim,hid)
        self.linear2 = nn.Linear(hid,out_dim)
        self.linear1.weight = torch.nn.Parameter(torch.zeros(in_dim,hid))
        self.linear1.bias = torch.nn.Parameter(torch.ones(hid))
        self.linear2.weight = torch.nn.Parameter(torch.zeros(in_dim,hid))
        self.linear2.bias = torch.nn.Parameter(torch.ones(hid))

    def forward(self, input_array):
        h = self.linear1(input_array)
        y_pred = self.linear2(h)
        return y_pred

in_d = 5
hidn = 2
out_d = 3
net = NN_Network(in_d, hidn, out_d)

さて、このモデルに関連するパラメータリストを確認します -。

for param in net.parameters():
    print(type(param.data), param.size())

""" Output
<class 'torch.FloatTensor'> torch.Size([5, 2])
<class 'torch.FloatTensor'> torch.Size([2])
<class 'torch.FloatTensor'> torch.Size([5, 2])
<class 'torch.FloatTensor'> torch.Size([2])
"""

または試してみてください。

list(net.parameters())

これは簡単にオプティマイザーに供給することができます。

opt = Adam(net.parameters(), learning_rate=0.001)

また、パラメータはデフォルトでrequire_gradが設定されていることに注意してください。