def test_batch_gather_invalid_index(graph): with graph.as_default(): params = tf.convert_to_tensor(np.arange(12).reshape(3, 4)) indices = tf.convert_to_tensor([[0, 4], [3, 1], [2, 2]]) gathered = tfu.batch_gather(params, indices) with pytest.raises(tf.errors.InvalidArgumentError): with tf.Session() as session: session.run(gathered)
def test_batch_gather_list_inputs(graph): with graph.as_default(): params = list(range(5)) indices = [3, 1, 2] gathered = tfu.batch_gather(params, indices) assert gathered.shape == (3, ) with tf.Session() as session: gathered_value = session.run(gathered) assert np.array_equal(gathered_value, np.array([3, 1, 2]))
def test_batch_gather_1d(graph): with graph.as_default(): params = tf.range(5) indices = tf.convert_to_tensor([3, 1, 2]) gathered = tfu.batch_gather(params, indices) assert gathered.shape == (3, ) with tf.Session() as session: gathered_value = session.run(gathered) assert np.array_equal(gathered_value, np.array([3, 1, 2]))
def test_batch_gather_2d(graph): with graph.as_default(): params = tf.convert_to_tensor(np.arange(12).reshape(3, 4)) indices = tf.convert_to_tensor([[0, 1], [3, 1], [2, 2]]) gathered = tfu.batch_gather(params, indices) assert gathered.shape == (3, 2) with tf.Session() as session: gathered_value = session.run(gathered) # (0) [1] 2 3 # 4 [5] 6 (7) # 8 9 [(10)] 11 assert np.array_equal(gathered_value, np.array([[0, 1], [7, 5], [10, 10]]))
def test_batch_gather_2d_slice(graph): with graph.as_default(): params = tf.convert_to_tensor(np.arange(24).reshape(3, 4, 2)) indices = tf.convert_to_tensor([[0, 1], [3, 1], [2, 2]]) gathered = tfu.batch_gather(params, indices) assert gathered.shape == (3, 2, 2) with tf.Session() as session: gathered_value = session.run(gathered) # (0) [2] 4 6 # 8 [10] 12 (14) # 16 18 [(20)] 22 assert np.array_equal( gathered_value, np.array([[[0, 1], [2, 3]], [[14, 15], [10, 11]], [[20, 21], [20, 21]]]), )
def test_batch_gather_none_axis(graph): with graph.as_default(): params = tf.placeholder(dtype=tf.int32, shape=(3, None)) indices = tf.placeholder(dtype=tf.int32, shape=(3, None)) gathered = tfu.batch_gather(params, indices) assert gathered.shape.as_list() == [3, None] with tf.Session() as session: gathered_value = session.run( gathered, feed_dict={ params: np.arange(12).reshape(3, 4), indices: np.array([[0, 1], [3, 1], [2, 2]]), }, ) assert np.array_equal(gathered_value, np.array([[0, 1], [7, 5], [10, 10]]))