def test_gather_vectorized(self): values = tf.constant([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) index = IndexMap(indices=tf.convert_to_tensor([[0, 1], [1, 0]]), num_segments=2, batch_dims=1) result = gather(values, index) # We use np.testing.assert_array_equal rather than Tensorflow's assertAllEqual np.testing.assert_array_equal(result.numpy(), [[[1, 2], [3, 4]], [[7, 8], [5, 6]]])
def test_gather(self): values, row_index, col_index = self._prepare_tables() cell_index = ProductIndexMap(row_index, col_index) # Compute sums and then gather. The result should have the same shape as # the original table and each element should contain the sum the values in # its cell. sums, _ = reduce_sum(values, cell_index) cell_sum = gather(sums, cell_index) assert cell_sum.shape == values.shape # We use np.testing.assert_array_equal rather than Tensorflow's assertAllEqual np.testing.assert_allclose( cell_sum.numpy(), [[[3.0, 3.0, 3.0], [2.0, 2.0, 1.0], [4.0, 4.0, 4.0]], [[1.0, 2.0, 3.0], [2.0, 0.0, 1.0], [1.0, 3.0, 4.0]]], )