示例#1
0
def test_batch_gather_invalid_index(graph):
    with graph.as_default():
        params = tf.convert_to_tensor(np.arange(12).reshape(3, 4))
        indices = tf.convert_to_tensor([[0, 4], [3, 1], [2, 2]])
        gathered = tfu.batch_gather(params, indices)

        with pytest.raises(tf.errors.InvalidArgumentError):
            with tf.Session() as session:
                session.run(gathered)
示例#2
0
def test_batch_gather_list_inputs(graph):
    with graph.as_default():
        params = list(range(5))
        indices = [3, 1, 2]
        gathered = tfu.batch_gather(params, indices)
        assert gathered.shape == (3, )

        with tf.Session() as session:
            gathered_value = session.run(gathered)
        assert np.array_equal(gathered_value, np.array([3, 1, 2]))
示例#3
0
def test_batch_gather_1d(graph):
    with graph.as_default():
        params = tf.range(5)
        indices = tf.convert_to_tensor([3, 1, 2])
        gathered = tfu.batch_gather(params, indices)
        assert gathered.shape == (3, )

        with tf.Session() as session:
            gathered_value = session.run(gathered)
        assert np.array_equal(gathered_value, np.array([3, 1, 2]))
示例#4
0
def test_batch_gather_2d(graph):
    with graph.as_default():
        params = tf.convert_to_tensor(np.arange(12).reshape(3, 4))
        indices = tf.convert_to_tensor([[0, 1], [3, 1], [2, 2]])
        gathered = tfu.batch_gather(params, indices)
        assert gathered.shape == (3, 2)

        with tf.Session() as session:
            gathered_value = session.run(gathered)
        # (0) [1]  2    3
        #  4  [5]  6   (7)
        #  8   9 [(10)] 11
        assert np.array_equal(gathered_value,
                              np.array([[0, 1], [7, 5], [10, 10]]))
示例#5
0
def test_batch_gather_2d_slice(graph):
    with graph.as_default():
        params = tf.convert_to_tensor(np.arange(24).reshape(3, 4, 2))
        indices = tf.convert_to_tensor([[0, 1], [3, 1], [2, 2]])
        gathered = tfu.batch_gather(params, indices)
        assert gathered.shape == (3, 2, 2)

        with tf.Session() as session:
            gathered_value = session.run(gathered)
        # (0)  [2]    4     6
        #  8   [10]   12   (14)
        #  16   18  [(20)]  22
        assert np.array_equal(
            gathered_value,
            np.array([[[0, 1], [2, 3]], [[14, 15], [10, 11]],
                      [[20, 21], [20, 21]]]),
        )
示例#6
0
def test_batch_gather_none_axis(graph):
    with graph.as_default():
        params = tf.placeholder(dtype=tf.int32, shape=(3, None))
        indices = tf.placeholder(dtype=tf.int32, shape=(3, None))
        gathered = tfu.batch_gather(params, indices)
        assert gathered.shape.as_list() == [3, None]

        with tf.Session() as session:
            gathered_value = session.run(
                gathered,
                feed_dict={
                    params: np.arange(12).reshape(3, 4),
                    indices: np.array([[0, 1], [3, 1], [2, 2]]),
                },
            )

        assert np.array_equal(gathered_value,
                              np.array([[0, 1], [7, 5], [10, 10]]))