Example #1
0
    def testBroadcastDimension(self, axis, row_length, original_dim_sizes,
                               broadcast_dim_sizes):
        """Tests for the broadcast_dimension method.

    Verifies that:

    * `original.broadcast_dimension(axis, row_length) == broadcast`
    * `broadcast.broadcast_dimension(axis, row_length) == broadcast`
    * `broadcast.broadcast_dimension(axis, 1) == broadcast`

    Args:
      axis: The axis to broadcast
      row_length: The slice lengths to broadcast to.
      original_dim_sizes: The dimension sizes before broadcasting.
        original_dim_sizes[axis] should be equal to `1` or `row_length`.
      broadcast_dim_sizes: THe dimension sizes after broadcasting.
    """
        original_shape = RaggedTensorDynamicShape.from_dim_sizes(
            original_dim_sizes)
        bcast_shape = RaggedTensorDynamicShape.from_dim_sizes(
            broadcast_dim_sizes)
        self.assertEqual(original_shape.rank, bcast_shape.rank)
        # shape[axis].value == 1 and row_length > 1:
        bcast1 = original_shape.broadcast_dimension(axis, row_length)
        # shape[axis].value > 1 and row_length == shape[axis].value:
        bcast2 = bcast_shape.broadcast_dimension(axis, row_length)
        # shape[axis].value > 1 and row_length == 1:
        bcast3 = bcast_shape.broadcast_dimension(axis, 1)

        self.assertShapeEq(bcast1, bcast_shape)
        self.assertShapeEq(bcast2, bcast_shape)
        self.assertShapeEq(bcast3, bcast_shape)
  def testBroadcastDimension(self, axis, row_length, original_dim_sizes,
                             broadcast_dim_sizes):
    """Tests for the broadcast_dimension method.

    Verifies that:

    * `original.broadcast_dimension(axis, row_length) == broadcast`
    * `broadcast.broadcast_dimension(axis, row_length) == broadcast`
    * `broadcast.broadcast_dimension(axis, 1) == broadcast`

    Args:
      axis: The axis to broadcast
      row_length: The slice lengths to broadcast to.
      original_dim_sizes: The dimension sizes before broadcasting.
        original_dim_sizes[axis] should be equal to `1` or `row_length`.
      broadcast_dim_sizes: THe dimension sizes after broadcasting.
    """
    original_shape = RaggedTensorDynamicShape.from_dim_sizes(original_dim_sizes)
    bcast_shape = RaggedTensorDynamicShape.from_dim_sizes(broadcast_dim_sizes)
    self.assertEqual(original_shape.rank, bcast_shape.rank)
    # shape[axis].value == 1 and row_length > 1:
    bcast1 = original_shape.broadcast_dimension(axis, row_length)
    # shape[axis].value > 1 and row_length == shape[axis].value:
    bcast2 = bcast_shape.broadcast_dimension(axis, row_length)
    # shape[axis].value > 1 and row_length == 1:
    bcast3 = bcast_shape.broadcast_dimension(axis, 1)

    self.assertShapeEq(bcast1, bcast_shape)
    self.assertShapeEq(bcast2, bcast_shape)
    self.assertShapeEq(bcast3, bcast_shape)
Example #3
0
 def testBroadcastDynamicShape(self, x_dims, y_dims, expected_dims):
     x_shape = RaggedTensorDynamicShape.from_dim_sizes(x_dims)
     y_shape = RaggedTensorDynamicShape.from_dim_sizes(y_dims)
     expected = RaggedTensorDynamicShape.from_dim_sizes(expected_dims)
     result1 = ragged_tensor_shape.broadcast_dynamic_shape(x_shape, y_shape)
     result2 = ragged_tensor_shape.broadcast_dynamic_shape(y_shape, x_shape)
     self.assertShapeEq(expected, result1)
     self.assertShapeEq(expected, result2)
 def testBroadcastDynamicShape(self, x_dims, y_dims, expected_dims):
   x_shape = RaggedTensorDynamicShape.from_dim_sizes(x_dims)
   y_shape = RaggedTensorDynamicShape.from_dim_sizes(y_dims)
   expected = RaggedTensorDynamicShape.from_dim_sizes(expected_dims)
   result1 = ragged_tensor_shape.broadcast_dynamic_shape(x_shape, y_shape)
   result2 = ragged_tensor_shape.broadcast_dynamic_shape(y_shape, x_shape)
   self.assertShapeEq(expected, result1)
   self.assertShapeEq(expected, result2)
 def testRepr(self):
   shape = RaggedTensorDynamicShape.from_dim_sizes([2, (2, 1), 2, 1])
   self.assertRegexpMatches(
       repr(shape),
       r'RaggedTensorDynamicShape\('
       r'partitioned_dim_sizes=\(<[^>]+>, <[^>]+>\), '
       r'inner_dim_sizes=<[^>]+>\)')
 def testRepr(self):
   shape = RaggedTensorDynamicShape.from_dim_sizes([2, (2, 1), 2, 1])
   self.assertRegexpMatches(
       repr(shape),
       r'RaggedTensorDynamicShape\('
       r'partitioned_dim_sizes=\(<[^>]+>, <[^>]+>\), '
       r'inner_dim_sizes=<[^>]+>\)')
Example #7
0
 def testBroadcastToRank(self, dim_sizes, rank, expected_dim_sizes):
     shape = RaggedTensorDynamicShape.from_dim_sizes(dim_sizes)
     expected = RaggedTensorDynamicShape.from_dim_sizes(expected_dim_sizes)
     broadcasted_shape = shape.broadcast_to_rank(rank)
     self.assertShapeEq(broadcasted_shape, expected)
     self.assertEqual(broadcasted_shape.rank, rank)
Example #8
0
 def testFromTensor(self, value, expected_dim_sizes):
     shape = RaggedTensorDynamicShape.from_tensor(value)
     expected = RaggedTensorDynamicShape.from_dim_sizes(expected_dim_sizes)
     self.assertShapeEq(shape, expected)
Example #9
0
 def testRaggedBroadcastTo(self, x, dim_sizes, expected):
     shape = RaggedTensorDynamicShape.from_dim_sizes(dim_sizes)
     result = ragged_tensor_shape.broadcast_to(x, shape)
     self.assertEqual(getattr(result, 'ragged_rank', 0),
                      getattr(expected, 'ragged_rank', 0))
     self.assertAllEqual(result, expected)
 def testFromTensor(self, value, expected_dim_sizes):
   shape = RaggedTensorDynamicShape.from_tensor(value)
   expected = RaggedTensorDynamicShape.from_dim_sizes(expected_dim_sizes)
   self.assertShapeEq(shape, expected)
 def testRaggedBroadcastTo(self, x, dim_sizes, expected):
   shape = RaggedTensorDynamicShape.from_dim_sizes(dim_sizes)
   result = ragged_tensor_shape.broadcast_to(x, shape)
   self.assertEqual(
       getattr(result, 'ragged_rank', 0), getattr(expected, 'ragged_rank', 0))
   self.assertRaggedEqual(result, expected)
 def testBroadcastToRank(self, dim_sizes, rank, expected_dim_sizes):
   shape = RaggedTensorDynamicShape.from_dim_sizes(dim_sizes)
   expected = RaggedTensorDynamicShape.from_dim_sizes(expected_dim_sizes)
   broadcasted_shape = shape.broadcast_to_rank(rank)
   self.assertShapeEq(broadcasted_shape, expected)
   self.assertEqual(broadcasted_shape.rank, rank)