Ejemplo n.º 1
0
    def testConstruction(self):
        tensor_values = constant_op.constant(
            ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'])
        values = WrappedTensor(tensor_values)

        row_splits = constant_op.constant([0, 2, 2, 5, 6, 8], dtypes.int64)
        rt = RaggedTensor.from_row_splits(values, row_splits)
        self.assertIsInstance(rt.values, WrappedTensor)
        self.assertAllEqual(rt.values.value, tensor_values)
        self.assertAllEqual(rt.row_splits, row_splits)

        row_starts = constant_op.constant([0, 2, 2, 5, 6], dtypes.int64)
        rt = RaggedTensor.from_row_starts(values, row_starts)
        self.assertIsInstance(rt.values, WrappedTensor)
        self.assertAllEqual(rt.values.value, tensor_values)
        self.assertAllEqual(rt.row_starts(), row_starts)

        row_limits = constant_op.constant([2, 2, 5, 6, 8], dtypes.int64)
        rt = RaggedTensor.from_row_limits(values, row_limits)
        self.assertIsInstance(rt.values, WrappedTensor)
        self.assertAllEqual(rt.values.value, tensor_values)
        self.assertAllEqual(rt.row_limits(), row_limits)

        row_lengths = constant_op.constant([2, 0, 3, 1, 2], dtypes.int64)
        rt = RaggedTensor.from_row_lengths(values, row_lengths)
        self.assertIsInstance(rt.values, WrappedTensor)
        self.assertAllEqual(rt.values.value, tensor_values)
        self.assertAllEqual(rt.row_lengths(), row_lengths)

        rt = RaggedTensor.from_uniform_row_length(values, 4)
        self.assertIsInstance(rt.values, WrappedTensor)
        self.assertAllEqual(rt.values.value, tensor_values)
        self.assertAllEqual(rt.uniform_row_length, 4)
Ejemplo n.º 2
0
 def testErrorsWithUniformRowLength(self, slice_spec, expected, message):
   """Test that rt.__getitem__(slice_spec) == expected."""
   rt = RaggedTensor.from_uniform_row_length(
       RaggedTensor.from_row_splits(EXAMPLE_RAGGED_TENSOR_3D_VALUES,
                                    EXAMPLE_RAGGED_TENSOR_3D_SPLITS),
       EXAMPLE_RAGGED_TENSOR_3D_ROWLEN)
   self.assertAllEqual(rt, EXAMPLE_RAGGED_TENSOR_3D)
   self._TestGetItemException(rt, slice_spec, expected, message)
Ejemplo n.º 3
0
  def testWithUniformRowLength(self, slice_spec, expected, expected_shape):
    """Test that rt.__getitem__(slice_spec) == expected."""
    rt = RaggedTensor.from_uniform_row_length(
        RaggedTensor.from_row_splits(EXAMPLE_RAGGED_TENSOR_3D_VALUES,
                                     EXAMPLE_RAGGED_TENSOR_3D_SPLITS),
        EXAMPLE_RAGGED_TENSOR_3D_ROWLEN)
    self.assertAllEqual(rt, EXAMPLE_RAGGED_TENSOR_3D)
    self.assertIsNot(rt.uniform_row_length, None)
    self._TestGetItem(rt, slice_spec, expected, expected_shape)

    # If the result is 3D, then check that it still has a uniform row length:
    actual = rt.__getitem__(slice_spec)
    if actual.shape.rank == 3:
      self.assertIsNot(actual.uniform_row_length, None)
      self.assertAllEqual(actual.uniform_row_length, expected_shape[1])