def test_no_padding(self): """Test padding using string values.""" data = tf.constant([['1', '1', '1'], ['2', '2', '2']]) axis = 1 expected_result = data padded_result = text.pad_along_dimension( data=data, axis=axis, left_pad=None, right_pad=None) self.assertAllEqual(expected_result, padded_result)
def test_float_partial_no_padding(self): """Test padding using float values.""" data = tf.constant([[1.0, 1.0, 1.0]]) axis = 1 right_pad_value = [3.5, 3.5, 3.5] expected_result = tf.constant([[1.0, 1.0, 1.0, 3.5, 3.5, 3.5]]) padded_result = text.pad_along_dimension( data=data, axis=axis, right_pad=right_pad_value) self.assertAllEqual(expected_result, padded_result)
def test_no_right_padding(self): """Test that not specifying a right pad means no right padding.""" data = tf.constant([[1, 1, 1], [2, 2, 1], [3, 3, 1]]) axis = 1 left_pad_value = [0] expected_result = tf.constant([[0, 1, 1, 1], [0, 2, 2, 1], [0, 3, 3, 1]]) padded_result = text.pad_along_dimension( data=data, axis=axis, left_pad=left_pad_value) self.assertAllEqual(expected_result, padded_result)
def test_invalid_axis(self): data = tf.constant([[1, 1, 1], [2, 2, 1], [3, 3, 1]]) axis = -4 left_pad_value = [0, 0] right_pad_value = [9, 9, 9] error_msg = 'axis must be between -k <= axis <= -1 OR 0 <= axis < k' with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, error_msg): _ = text.pad_along_dimension( data=data, axis=axis, left_pad=left_pad_value, right_pad=right_pad_value) error_msg = 'axis must be an int' with self.assertRaisesRegexp(TypeError, error_msg): _ = text.pad_along_dimension( data=data, axis=tf.constant(0), left_pad=left_pad_value, right_pad=right_pad_value)
def test_string_partial_no_padding(self): """Test padding using string values but without one padding value.""" data = tf.constant([['1', '1', '1'], ['2', '2', '2']]) axis = 1 left_pad_value = ['0', '0'] expected_result = tf.constant([['0', '0', '1', '1', '1'], ['0', '0', '2', '2', '2']]) padded_result = text.pad_along_dimension( data=data, axis=axis, left_pad=left_pad_value) self.assertAllEqual(expected_result, padded_result)
def test_pads_along_negative_outer_dimension(self): """Test padding along the outer dimension with a negative axis integer.""" data = tf.constant([[1, 1, 1], [2, 2, 1], [3, 3, 1]]) axis = -2 left_pad_value = [[0, 0, 0]] right_pad_value = [[9, 9, 9]] expected_result = tf.constant([[0, 0, 0], [1, 1, 1], [2, 2, 1], [3, 3, 1], [9, 9, 9]]) padded_result = text.pad_along_dimension( data=data, axis=axis, left_pad=left_pad_value, right_pad=right_pad_value) self.assertAllEqual(expected_result, padded_result)
def test_pads_along_positive_inner_dimension(self): """Test padding along the inner dimension with a positive axis integer.""" data = tf.constant([[1, 1, 1], [2, 2, 1], [3, 3, 1]]) axis = 1 left_pad_value = [0] right_pad_value = [9] expected_result = tf.constant([[0, 1, 1, 1, 9], [0, 2, 2, 1, 9], [0, 3, 3, 1, 9]]) padded_result = text.pad_along_dimension( data=data, axis=axis, left_pad=left_pad_value, right_pad=right_pad_value) self.assertAllEqual(expected_result, padded_result)
def testRaggedPadDimension(self, descr, data, axis, expected, left_pad=None, right_pad=None, ragged_rank=None): data = self._convert_ragged(data, ragged_rank) positive_axis = axis if axis >= 0 else axis + data.shape.ndims assert positive_axis >= 0 left_pad = self._convert_ragged(left_pad, data.ragged_rank - positive_axis) right_pad = self._convert_ragged(right_pad, data.ragged_rank - positive_axis) padded = text.pad_along_dimension(data, axis, left_pad, right_pad) self.assertRaggedEqual(padded, expected)
def test_padding_tensor_of_unknown_shape(self): """Test padding a tensor whose shape is not known at graph building time.""" data = tf.placeholder_with_default( tf.constant([[1, 1, 1], [2, 2, 1], [3, 3, 1]]), shape=None) axis = 1 left_pad_value = [0] right_pad_value = [9] expected_result = tf.constant([[0, 1, 1, 1, 9], [0, 2, 2, 1, 9], [0, 3, 3, 1, 9]]) padded_result = text.pad_along_dimension( data=data, axis=axis, left_pad=left_pad_value, right_pad=right_pad_value) self.assertAllEqual(expected_result, padded_result)