def _prepare_tables(self): """Prepares two tables, both with three distinct rows. The first table has two columns: 1.0, 2.0 | 3.0 2.0, 0.0 | 1.0 1.0, 3.0 | 4.0 The second table has three columns: 1.0 | 2.0 | 3.0 2.0 | 0.0 | 1.0 1.0 | 3.0 | 4.0 Returns: SegmentedTensors with the tables. """ values = torch.tensor([ [[1.0, 2.0, 3.0], [2.0, 0.0, 1.0], [1.0, 3.0, 4.0]], [[1.0, 2.0, 3.0], [2.0, 0.0, 1.0], [1.0, 3.0, 4.0]], ]) row_index = IndexMap( indices=torch.tensor([ [[0, 0, 0], [1, 1, 1], [2, 2, 2]], [[0, 0, 0], [1, 1, 1], [2, 2, 2]], ]), num_segments=3, batch_dims=1, ) col_index = IndexMap( indices=torch.tensor([ [[0, 0, 1], [0, 0, 1], [0, 0, 1]], [[0, 1, 2], [0, 1, 2], [0, 1, 2]], ]), num_segments=3, batch_dims=1, ) return values, row_index, col_index
def test_reduce_max(self): values = torch.as_tensor([2.0, 1.0, 0.0, 3.0]) index = IndexMap(indices=torch.as_tensor([0, 1, 0, 1]), num_segments=2) maximum, _ = reduce_max(values, index) # We use np.testing.assert_array_equal rather than Tensorflow's assertAllEqual np.testing.assert_array_equal(maximum.numpy(), [2, 3])
def test_gather_vectorized(self): values = torch.as_tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) index = IndexMap(indices=torch.as_tensor([[0, 1], [1, 0]]), num_segments=2, batch_dims=1) result = gather(values, index) # We use np.testing.assert_array_equal rather than Tensorflow's assertAllEqual np.testing.assert_array_equal(result.numpy(), [[[1, 2], [3, 4]], [[7, 8], [5, 6]]])
def test_reduce_sum_vectorized(self): values = torch.as_tensor([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0], [3.0, 4.0, 5.0]]) index = IndexMap(indices=torch.as_tensor([0, 0, 1]), num_segments=2, batch_dims=0) sums, new_index = reduce_sum(values, index) # We use np.testing.assert_allclose rather than Tensorflow's assertAllClose np.testing.assert_allclose(sums.numpy(), [[3.0, 5.0, 7.0], [3.0, 4.0, 5.0]]) # We use np.testing.assert_array_equal rather than Tensorflow's assertAllEqual np.testing.assert_array_equal(new_index.indices.numpy(), [0, 1]) np.testing.assert_array_equal(new_index.num_segments.numpy(), 2) np.testing.assert_array_equal(new_index.batch_dims, 0)
def test_flatten(self): _, row_index, col_index = self._prepare_tables() row_index_flat = flatten(row_index) col_index_flat = flatten(col_index) shape = [3, 4, 5] batched_index = IndexMap(indices=torch.zeros(shape).type(torch.LongTensor), num_segments=1, batch_dims=3) batched_index_flat = flatten(batched_index) # We use np.testing.assert_array_equal rather than Tensorflow's assertAllEqual np.testing.assert_array_equal( row_index_flat.indices.numpy(), [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5] ) np.testing.assert_array_equal( col_index_flat.indices.numpy(), [0, 0, 1, 0, 0, 1, 0, 0, 1, 3, 4, 5, 3, 4, 5, 3, 4, 5] ) self.assertEqual(batched_index_flat.num_segments.numpy(), np.prod(shape)) np.testing.assert_array_equal(batched_index_flat.indices.numpy(), range(np.prod(shape)))