1. ホーム
  2. python

[解決済み] なぜtorch.catでPyTorchのテンソルを追加できないのですか?

2022-02-14 10:37:41

質問

あります。

import torch

input_sliced = torch.rand(180, 161)
output_sliced = torch.rand(180,)

batched_inputs = torch.Tensor()
batched_outputs = torch.Tensor()

print('input_sliced.size', input_sliced.size())
print('output_sliced.size', output_sliced.size())

batched_inputs = torch.cat((batched_inputs, input_sliced))
batched_outputs = torch.cat((batched_outputs, output_sliced))

print('batched_inputs.size', batched_inputs.size())
print('batched_outputs.size', batched_outputs.size())


これが出力されます。

input_sliced.size torch.Size([180, 161])
output_sliced.size torch.Size([180])

batched_inputs.size torch.Size([180, 161])
batched_outputs.size torch.Size([180])

を必要とします。 batched が追加されますが torch.cat がうまくいきません。何が間違っているのでしょうか?

解決方法を教えてください。

ループでやっているとして、こんな感じでやるのがいいんじゃないでしょうか。

import torch

batch_input, batch_output = [], []
for i in range(10):  # assuming batch_size=10
    batch_input.append(torch.rand(180, 161))
    batch_output.append(torch.rand(180,))

batch_input = torch.stack(batch_input)
batch_output = torch.stack(batch_output)

print(batch_input.shape)   # output: torch.Size([10, 180, 161])
print(batch_output.shape)  # output: torch.Size([10, 180])

もし、結果の batch_* 形状 アプリオリ の場合、最終的な Tensor で、各サンプルをバッチ内の対応する位置に割り当てるだけです。その方がメモリ効率が良いのです。