def test_passing_text(self):
   rt = ragged_factory_ops.constant([[[[[[[['H']], [['e']], [['l']], [['l']],
                                          [['o']]],
                                         [[['W']], [['o']], [['r']], [['l']],
                                          [['d']], [['!']]]]],
                                       [[[[['T']], [['h']], [['i']], [['s']]],
                                         [[['i']], [['s']]],
                                         [[['M']], [['e']], [['h']], [['r']],
                                          [['d']], [['a']], [['d']]],
                                         [[['.']]]]]]]])
   output_list = [[['H', 'e', 'l', 'l', 'o'], ['W', 'o', 'r', 'l', 'd', '!']],
                  [['T', 'h', 'i', 's'], ['i', 's'],
                   ['M', 'e', 'h', 'r', 'd', 'a', 'd'], ['.']]]
   ref = ragged_factory_ops.constant(output_list)
   rt_s = ragged_squeeze_op.squeeze(rt, [0, 1, 3, 6, 7])
   self.assertRaggedEqual(rt_s, ref)
 def test_passing_text(self):
     rt = ragged_factory_ops.constant([[[[[[[['H']], [['e']], [['l']],
                                            [['l']], [['o']]],
                                           [[['W']], [['o']], [['r']],
                                            [['l']], [['d']], [['!']]]]],
                                         [[[[['T']], [['h']], [['i']],
                                            [['s']]], [[['i']], [['s']]],
                                           [[['M']], [['e']], [['h']],
                                            [['r']], [['d']], [['a']],
                                            [['d']]], [[['.']]]]]]]])
     output_list = [[['H', 'e', 'l', 'l', 'o'],
                     ['W', 'o', 'r', 'l', 'd', '!']],
                    [['T', 'h', 'i', 's'], ['i', 's'],
                     ['M', 'e', 'h', 'r', 'd', 'a', 'd'], ['.']]]
     ref = ragged_factory_ops.constant(output_list)
     rt_s = ragged_squeeze_op.squeeze(rt, [0, 1, 3, 6, 7])
     self.assertRaggedEqual(rt_s, ref)
示例#3
0
def _ragged_squeeze_v1(input, axis=None, name=None, squeeze_dims=None):  # pylint: disable=redefined-builtin
    axis = deprecation.deprecated_argument_lookup('axis', axis, 'squeeze_dims',
                                                  squeeze_dims)
    return ragged_squeeze_op.squeeze(input, axis, name)
 def test_passing_empty(self, input_list, squeeze_ranks=None):
     rt = ragged_squeeze_op.squeeze(ragged_factory_ops.constant(input_list),
                                    squeeze_ranks)
     dt = array_ops.squeeze(constant_op.constant(input_list), squeeze_ranks)
     self.assertRaggedEqual(ragged_conversion_ops.to_tensor(rt), dt)
 def test_failing_axis_is_not_a_list(self, input_list, squeeze_ranks):
     with self.assertRaises(TypeError):
         tensor_ranks = constant_op.constant(squeeze_ranks)
         ragged_squeeze_op.squeeze(ragged_factory_ops.constant(input_list),
                                   tensor_ranks)
 def test_failing_no_squeeze_dim_specified(self, input_list):
     with self.assertRaises(ValueError):
         ragged_squeeze_op.squeeze(ragged_factory_ops.constant(input_list))
 def test_failing_InvalidArgumentError(self, input_list, squeeze_ranks):
     with self.assertRaises(errors.InvalidArgumentError):
         self.evaluate(
             ragged_squeeze_op.squeeze(
                 ragged_factory_ops.constant(input_list), squeeze_ranks))
 def test_passing_ragged(self, input_list, output_list, squeeze_ranks=None):
     rt = ragged_factory_ops.constant(input_list)
     rt_s = ragged_squeeze_op.squeeze(rt, squeeze_ranks)
     ref = ragged_factory_ops.constant(output_list)
     self.assertRaggedEqual(rt_s, ref)
 def test_passing_simple_from_dense(self, input_list, squeeze_ranks=None):
     dt = constant_op.constant(input_list)
     rt = ragged_conversion_ops.from_tensor(dt)
     rt_s = ragged_squeeze_op.squeeze(rt, squeeze_ranks)
     dt_s = array_ops.squeeze(dt, squeeze_ranks)
     self.assertRaggedEqual(ragged_conversion_ops.to_tensor(rt_s), dt_s)
示例#10
0
def _ragged_squeeze_v1(input, axis=None, name=None, squeeze_dims=None):  # pylint: disable=redefined-builtin
  axis = deprecation.deprecated_argument_lookup('axis', axis, 'squeeze_dims',
                                                squeeze_dims)
  return ragged_squeeze_op.squeeze(input, axis, name)
 def test_passing_empty(self, input_list, squeeze_ranks=None):
   rt = ragged_squeeze_op.squeeze(
       ragged_factory_ops.constant(input_list), squeeze_ranks)
   dt = array_ops.squeeze(constant_op.constant(input_list), squeeze_ranks)
   self.assertRaggedEqual(ragged_conversion_ops.to_tensor(rt), dt)
 def test_failing_axis_is_not_a_list(self, input_list, squeeze_ranks):
   with self.assertRaises(TypeError):
     tensor_ranks = constant_op.constant(squeeze_ranks)
     ragged_squeeze_op.squeeze(
         ragged_factory_ops.constant(input_list), tensor_ranks)
 def test_failing_no_squeeze_dim_specified(self, input_list):
   with self.assertRaises(ValueError):
     ragged_squeeze_op.squeeze(ragged_factory_ops.constant(input_list))
 def test_failing_InvalidArgumentError(self, input_list, squeeze_ranks):
   with self.assertRaises(errors.InvalidArgumentError):
     self.evaluate(
         ragged_squeeze_op.squeeze(
             ragged_factory_ops.constant(input_list), squeeze_ranks))
 def test_passing_ragged(self, input_list, output_list, squeeze_ranks=None):
   rt = ragged_factory_ops.constant(input_list)
   rt_s = ragged_squeeze_op.squeeze(rt, squeeze_ranks)
   ref = ragged_factory_ops.constant(output_list)
   self.assertRaggedEqual(rt_s, ref)
 def test_passing_simple_from_dense(self, input_list, squeeze_ranks=None):
   dt = constant_op.constant(input_list)
   rt = ragged_conversion_ops.from_tensor(dt)
   rt_s = ragged_squeeze_op.squeeze(rt, squeeze_ranks)
   dt_s = array_ops.squeeze(dt, squeeze_ranks)
   self.assertRaggedEqual(ragged_conversion_ops.to_tensor(rt_s), dt_s)