1. ホーム
  2. python

[解決済み] TensorFlowでNumpyのwhere indexを実装する方法とは?

2022-02-12 13:08:41

質問

以下のような操作をしています。 numpy.where :

    mat = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
    index = np.array([[1,0,0],[0,1,0],[0,0,1]])
    mat[np.where(index>0)] = 100
    print(mat)

TensorFlowで同等のものを実装するには?

mat = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
index = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
tf_mat = tf.constant(mat)
tf_index = tf.constant(index)
indi = tf.where(tf_index>0)
tf_mat[indi] = -1   <===== not allowed 

解決方法は?

変数の更新ではなく、いくつかの要素を置き換えた新しいテンソルを作りたいのだと仮定すると、以下のようなことができる。

import numpy as np
import tensorflow as tf

mat = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
index = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
tf_mat = tf.constant(mat)
tf_index = tf.constant(index)
tf_mat = tf.where(tf_index > 0, -tf.ones_like(tf_mat), tf_mat)
with tf.Session() as sess:
    print(sess.run(tf_mat))

出力します。

[[-1  2  3]
 [ 4 -1  6]
 [ 7  8 -1]]