Beispiel #1
0
 def _prepare_tables(self):
     """Prepares two tables, both with three distinct rows.
     The first table has two columns:
     1.0, 2.0 | 3.0
     2.0, 0.0 | 1.0
     1.0, 3.0 | 4.0
     The second table has three columns:
     1.0 | 2.0 | 3.0
     2.0 | 0.0 | 1.0
     1.0 | 3.0 | 4.0
     Returns:
     SegmentedTensors with the tables.
     """
     values = torch.tensor([
         [[1.0, 2.0, 3.0], [2.0, 0.0, 1.0], [1.0, 3.0, 4.0]],
         [[1.0, 2.0, 3.0], [2.0, 0.0, 1.0], [1.0, 3.0, 4.0]],
     ])
     row_index = segmented_tensor.IndexMap(indices=[
         [[0, 0, 0], [1, 1, 1], [2, 2, 2]],
         [[0, 0, 0], [1, 1, 1], [2, 2, 2]],
     ],
                                           num_segments=3,
                                           batch_dims=1)
     col_index = segmented_tensor.IndexMap(indices=[
         [[0, 0, 1], [0, 0, 1], [0, 0, 1]],
         [[0, 1, 2], [0, 1, 2], [0, 1, 2]],
     ],
                                           num_segments=3,
                                           batch_dims=1)
     return values, row_index, col_index
Beispiel #2
0
    def test_reduce_max(self):
        values = torch.as_tensor([2., 1., 0., 3.])
        index = segmented_tensor.IndexMap(indices=torch.as_tensor([0, 1, 0,
                                                                   1]),
                                          num_segments=2)
        maximum, _ = segmented_tensor.reduce_max(values, index)

        # We use np.testing.assert_array_equal rather than Tensorflow's assertAllEqual
        np.testing.assert_array_equal(maximum.numpy(), [2, 3])
Beispiel #3
0
    def test_gather_vectorized(self):
        values = torch.as_tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
        index = segmented_tensor.IndexMap(indices=torch.as_tensor([[0, 1],
                                                                   [1, 0]]),
                                          num_segments=2,
                                          batch_dims=1)
        result = segmented_tensor.gather(values, index)

        # We use np.testing.assert_array_equal rather than Tensorflow's assertAllEqual
        np.testing.assert_array_equal(result.numpy(),
                                      [[[1, 2], [3, 4]], [[7, 8], [5, 6]]])
Beispiel #4
0
    def test_reduce_sum_vectorized(self):
        values = torch.as_tensor([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0],
                                  [3.0, 4.0, 5.0]])
        index = segmented_tensor.IndexMap(indices=torch.as_tensor([0, 0, 1]),
                                          num_segments=2,
                                          batch_dims=0)
        sums, new_index = segmented_tensor.reduce_sum(values, index)

        # We use np.testing.assert_allclose rather than Tensorflow's assertAllClose
        np.testing.assert_allclose(sums.numpy(),
                                   [[3.0, 5.0, 7.0], [3.0, 4.0, 5.0]])
        # We use np.testing.assert_array_equal rather than Tensorflow's assertAllEqual
        np.testing.assert_array_equal(new_index.indices.numpy(), [0, 1])
        np.testing.assert_array_equal(new_index.num_segments.numpy(), 2)
        np.testing.assert_array_equal(new_index.batch_dims, 0)
Beispiel #5
0
    def test_flatten(self):
        _, row_index, col_index = self._prepare_tables()
        row_index_flat = segmented_tensor.flatten(row_index)
        col_index_flat = segmented_tensor.flatten(col_index)

        shape = [3, 4, 5]
        batched_index = segmented_tensor.IndexMap(
            indices=torch.zeros(shape).type(torch.LongTensor),
            num_segments=1,
            batch_dims=3)
        batched_index_flat = segmented_tensor.flatten(batched_index)

        # We use np.testing.assert_array_equal rather than Tensorflow's assertAllEqual
        np.testing.assert_array_equal(
            row_index_flat.indices.numpy(),
            [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5])
        np.testing.assert_array_equal(
            col_index_flat.indices.numpy(),
            [0, 0, 1, 0, 0, 1, 0, 0, 1, 3, 4, 5, 3, 4, 5, 3, 4, 5])
        self.assertEqual(batched_index_flat.num_segments.numpy(),
                         np.prod(shape))
        np.testing.assert_array_equal(batched_index_flat.indices.numpy(),
                                      range(np.prod(shape)))