def testBroadcastDimension(self, axis, row_length, original_dim_sizes, broadcast_dim_sizes): """Tests for the broadcast_dimension method. Verifies that: * `original.broadcast_dimension(axis, row_length) == broadcast` * `broadcast.broadcast_dimension(axis, row_length) == broadcast` * `broadcast.broadcast_dimension(axis, 1) == broadcast` Args: axis: The axis to broadcast row_length: The slice lengths to broadcast to. original_dim_sizes: The dimension sizes before broadcasting. original_dim_sizes[axis] should be equal to `1` or `row_length`. broadcast_dim_sizes: THe dimension sizes after broadcasting. """ original_shape = RaggedTensorDynamicShape.from_dim_sizes( original_dim_sizes) bcast_shape = RaggedTensorDynamicShape.from_dim_sizes( broadcast_dim_sizes) self.assertEqual(original_shape.rank, bcast_shape.rank) # shape[axis].value == 1 and row_length > 1: bcast1 = original_shape.broadcast_dimension(axis, row_length) # shape[axis].value > 1 and row_length == shape[axis].value: bcast2 = bcast_shape.broadcast_dimension(axis, row_length) # shape[axis].value > 1 and row_length == 1: bcast3 = bcast_shape.broadcast_dimension(axis, 1) self.assertShapeEq(bcast1, bcast_shape) self.assertShapeEq(bcast2, bcast_shape) self.assertShapeEq(bcast3, bcast_shape)
def testBroadcastDimension(self, axis, row_length, original_dim_sizes, broadcast_dim_sizes): """Tests for the broadcast_dimension method. Verifies that: * `original.broadcast_dimension(axis, row_length) == broadcast` * `broadcast.broadcast_dimension(axis, row_length) == broadcast` * `broadcast.broadcast_dimension(axis, 1) == broadcast` Args: axis: The axis to broadcast row_length: The slice lengths to broadcast to. original_dim_sizes: The dimension sizes before broadcasting. original_dim_sizes[axis] should be equal to `1` or `row_length`. broadcast_dim_sizes: THe dimension sizes after broadcasting. """ original_shape = RaggedTensorDynamicShape.from_dim_sizes(original_dim_sizes) bcast_shape = RaggedTensorDynamicShape.from_dim_sizes(broadcast_dim_sizes) self.assertEqual(original_shape.rank, bcast_shape.rank) # shape[axis].value == 1 and row_length > 1: bcast1 = original_shape.broadcast_dimension(axis, row_length) # shape[axis].value > 1 and row_length == shape[axis].value: bcast2 = bcast_shape.broadcast_dimension(axis, row_length) # shape[axis].value > 1 and row_length == 1: bcast3 = bcast_shape.broadcast_dimension(axis, 1) self.assertShapeEq(bcast1, bcast_shape) self.assertShapeEq(bcast2, bcast_shape) self.assertShapeEq(bcast3, bcast_shape)
def testBroadcastDynamicShape(self, x_dims, y_dims, expected_dims): x_shape = RaggedTensorDynamicShape.from_dim_sizes(x_dims) y_shape = RaggedTensorDynamicShape.from_dim_sizes(y_dims) expected = RaggedTensorDynamicShape.from_dim_sizes(expected_dims) result1 = ragged_tensor_shape.broadcast_dynamic_shape(x_shape, y_shape) result2 = ragged_tensor_shape.broadcast_dynamic_shape(y_shape, x_shape) self.assertShapeEq(expected, result1) self.assertShapeEq(expected, result2)
def testBroadcastDynamicShape(self, x_dims, y_dims, expected_dims): x_shape = RaggedTensorDynamicShape.from_dim_sizes(x_dims) y_shape = RaggedTensorDynamicShape.from_dim_sizes(y_dims) expected = RaggedTensorDynamicShape.from_dim_sizes(expected_dims) result1 = ragged_tensor_shape.broadcast_dynamic_shape(x_shape, y_shape) result2 = ragged_tensor_shape.broadcast_dynamic_shape(y_shape, x_shape) self.assertShapeEq(expected, result1) self.assertShapeEq(expected, result2)
def testRepr(self): shape = RaggedTensorDynamicShape.from_dim_sizes([2, (2, 1), 2, 1]) self.assertRegexpMatches( repr(shape), r'RaggedTensorDynamicShape\(' r'partitioned_dim_sizes=\(<[^>]+>, <[^>]+>\), ' r'inner_dim_sizes=<[^>]+>\)')
def testRepr(self): shape = RaggedTensorDynamicShape.from_dim_sizes([2, (2, 1), 2, 1]) self.assertRegexpMatches( repr(shape), r'RaggedTensorDynamicShape\(' r'partitioned_dim_sizes=\(<[^>]+>, <[^>]+>\), ' r'inner_dim_sizes=<[^>]+>\)')
def testBroadcastToRank(self, dim_sizes, rank, expected_dim_sizes): shape = RaggedTensorDynamicShape.from_dim_sizes(dim_sizes) expected = RaggedTensorDynamicShape.from_dim_sizes(expected_dim_sizes) broadcasted_shape = shape.broadcast_to_rank(rank) self.assertShapeEq(broadcasted_shape, expected) self.assertEqual(broadcasted_shape.rank, rank)
def testFromTensor(self, value, expected_dim_sizes): shape = RaggedTensorDynamicShape.from_tensor(value) expected = RaggedTensorDynamicShape.from_dim_sizes(expected_dim_sizes) self.assertShapeEq(shape, expected)
def testRaggedBroadcastTo(self, x, dim_sizes, expected): shape = RaggedTensorDynamicShape.from_dim_sizes(dim_sizes) result = ragged_tensor_shape.broadcast_to(x, shape) self.assertEqual(getattr(result, 'ragged_rank', 0), getattr(expected, 'ragged_rank', 0)) self.assertAllEqual(result, expected)
def testFromTensor(self, value, expected_dim_sizes): shape = RaggedTensorDynamicShape.from_tensor(value) expected = RaggedTensorDynamicShape.from_dim_sizes(expected_dim_sizes) self.assertShapeEq(shape, expected)
def testRaggedBroadcastTo(self, x, dim_sizes, expected): shape = RaggedTensorDynamicShape.from_dim_sizes(dim_sizes) result = ragged_tensor_shape.broadcast_to(x, shape) self.assertEqual( getattr(result, 'ragged_rank', 0), getattr(expected, 'ragged_rank', 0)) self.assertRaggedEqual(result, expected)
def testBroadcastToRank(self, dim_sizes, rank, expected_dim_sizes): shape = RaggedTensorDynamicShape.from_dim_sizes(dim_sizes) expected = RaggedTensorDynamicShape.from_dim_sizes(expected_dim_sizes) broadcasted_shape = shape.broadcast_to_rank(rank) self.assertShapeEq(broadcasted_shape, expected) self.assertEqual(broadcasted_shape.rank, rank)