Esempio n. 1
0
 def test2DRaggedTensorWithOneRaggedDimension(self):
   values = ['a', 'b', 'c', 'd', 'e', 'f', 'g']
   rt1 = ragged.from_row_splits(values, [0, 2, 5, 6, 6, 7])
   rt2 = ragged.from_row_splits(values, [0, 7])
   rt3 = ragged.from_row_splits(values, [0, 0, 7, 7])
   self.assertEqual(self.evaluate(ragged.bounding_shape(rt1)).tolist(), [5, 3])
   self.assertEqual(self.evaluate(ragged.bounding_shape(rt2)).tolist(), [1, 7])
   self.assertEqual(self.evaluate(ragged.bounding_shape(rt3)).tolist(), [3, 7])
 def test3DRaggedTensorWithOneRaggedDimension(self):
   values = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]]
   rt1 = ragged.from_row_splits(values, [0, 2, 5, 6, 6, 7])
   rt2 = ragged.from_row_splits(values, [0, 7])
   rt3 = ragged.from_row_splits(values, [0, 0, 7, 7])
   with self.test_session():
     self.assertEqual(ragged.bounding_shape(rt1).eval().tolist(), [5, 3, 2])
     self.assertEqual(ragged.bounding_shape(rt2).eval().tolist(), [1, 7, 2])
     self.assertEqual(ragged.bounding_shape(rt3).eval().tolist(), [3, 7, 2])
Esempio n. 3
0
 def test3DRaggedTensorWithOneRaggedDimension(self):
   values = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]]
   rt1 = ragged.from_row_splits(values, [0, 2, 5, 6, 6, 7])
   rt2 = ragged.from_row_splits(values, [0, 7])
   rt3 = ragged.from_row_splits(values, [0, 0, 7, 7])
   self.assertEqual(
       self.evaluate(ragged.bounding_shape(rt1)).tolist(), [5, 3, 2])
   self.assertEqual(
       self.evaluate(ragged.bounding_shape(rt2)).tolist(), [1, 7, 2])
   self.assertEqual(
       self.evaluate(ragged.bounding_shape(rt3)).tolist(), [3, 7, 2])
 def testRaggedTensorSplitsMismatchErrorAtRuntime(self):
   splits1 = array_ops.placeholder_with_default(
       constant_op.constant([0, 3, 3, 5], dtypes.int64), None)
   splits2 = array_ops.placeholder_with_default(
       constant_op.constant([0, 1, 3, 5], dtypes.int64), None)
   x = ragged.from_row_splits([3, 1, 4, 1, 5], splits1)
   y = ragged.from_row_splits([1, 2, 3, 4, 5], splits2)
   result = ragged.map_inner_values(math_ops.add, x, y)
   with self.test_session():
     self.assertRaisesRegexp(
         errors.InvalidArgumentError,
         r'\[Inputs must have identical ragged splits\] '
         r'\[Condition x == y did not hold element-wise:\].*', result.eval)
 def testRaggedTensorSplitsMismatchErrorAtRuntime(self):
     splits1 = array_ops.placeholder_with_default(
         constant_op.constant([0, 3, 3, 5], dtypes.int64), None)
     splits2 = array_ops.placeholder_with_default(
         constant_op.constant([0, 1, 3, 5], dtypes.int64), None)
     x = ragged.from_row_splits([3, 1, 4, 1, 5], splits1)
     y = ragged.from_row_splits([1, 2, 3, 4, 5], splits2)
     result = ragged.map_inner_values(math_ops.add, x, y)
     with self.test_session():
         self.assertRaisesRegexp(
             errors.InvalidArgumentError,
             r'\[Inputs must have identical ragged splits\] '
             r'\[Condition x == y did not hold element-wise:\].*',
             result.eval)
    def testShapeMismatchError2(self):
        rt = ragged.constant([
            [[111, 112, 113, 114], [121]],  # row 0
            [],  # row 1
            [[], [321, 322], [331]],  # row 2
            [[411, 412]]  # row 3
        ])  # pyformat: disable
        segment_ids = ragged.constant([[1, 2], [1], [1, 1, 2], [2]])

        # Error is raised at graph-building time if we can detect it then.
        self.assertRaisesRegexp(
            errors.InvalidArgumentError,
            'segment_ids.shape must be a prefix of data.shape.*',
            ragged.segment_sum, rt, segment_ids, 3)

        # Otherwise, error is raised when we run the graph.
        segment_ids2 = ragged.from_row_splits(
            array_ops.placeholder_with_default(segment_ids.values, None),
            array_ops.placeholder_with_default(segment_ids.row_splits, None))
        segmented2 = ragged.segment_sum(rt, segment_ids2, 3)
        with self.test_session():
            self.assertRaisesRegexp(
                errors.InvalidArgumentError,
                'segment_ids.shape must be a prefix of data.shape.*',
                segmented2.eval)
Esempio n. 7
0
 def testUnknownRankError(self):
   x = ragged.constant([[1, 2], [3]])
   y = ragged.from_row_splits(
       array_ops.placeholder_with_default([1, 2, 3], shape=None), x.row_splits)
   with self.assertRaisesRegexp(
       ValueError, r'Unable to broadcast: unknown rank'):
     ragged.add(x, y)
 def testUnknownRankError(self):
   x = ragged.constant([[1, 2], [3]])
   y = ragged.from_row_splits(
       array_ops.placeholder_with_default([1, 2, 3], shape=None), x.row_splits)
   with self.assertRaisesRegexp(
       ValueError, r'Unable to broadcast: unknown rank'):
     ragged.add(x, y)
 def testExplicitAxisOptimizations(self):
   rt = ragged.from_row_splits(b'a b c d e f g'.split(), [0, 2, 5, 6, 6, 7])
   with self.test_session():
     self.assertEqual(ragged.bounding_shape(rt, 0).eval().tolist(), 5)
     self.assertEqual(ragged.bounding_shape(rt, 1).eval().tolist(), 3)
     self.assertEqual(
         ragged.bounding_shape(rt, [1, 0]).eval().tolist(), [3, 5])
 def testUnknownRankError(self):
   x = ragged.constant([[1, 2], [3]])
   y = ragged.from_row_splits(
       array_ops.placeholder_with_default([1, 2, 3], shape=None), x.row_splits)
   with self.assertRaisesRegexp(
       ValueError, r'Ragged elementwise ops require that rank \(number '
       r'of dimensions\) be statically known.'):
     ragged.add(x, y)
 def testExplicitAxisOptimizations(self):
     rt = ragged.from_row_splits(b'a b c d e f g'.split(),
                                 [0, 2, 5, 6, 6, 7])
     with self.test_session():
         self.assertEqual(ragged.bounding_shape(rt, 0).eval().tolist(), 5)
         self.assertEqual(ragged.bounding_shape(rt, 1).eval().tolist(), 3)
         self.assertEqual(
             ragged.bounding_shape(rt, [1, 0]).eval().tolist(), [3, 5])
Esempio n. 12
0
  def testRaggedBatchGatherUnknownRankError(self):
    params = [['a', 'b'], ['c', 'd']]
    indices = array_ops.placeholder(dtypes.int32, shape=None)
    ragged_indices = ragged.from_row_splits(indices, [0, 2, 4])

    with self.assertRaisesRegexp(
        ValueError, 'batch_gather does not allow indices with unknown shape.'):
      ragged.batch_gather(params, indices)

    with self.assertRaisesRegexp(
        ValueError, 'batch_gather does not allow indices with unknown shape.'):
      ragged.batch_gather(params, ragged_indices)
  def testRaggedBatchGatherUnknownRankError(self):
    params = [['a', 'b'], ['c', 'd']]
    indices = array_ops.placeholder(dtypes.int32, shape=None)
    ragged_indices = ragged.from_row_splits(indices, [0, 2, 4])

    with self.assertRaisesRegexp(
        ValueError, 'batch_gather does not allow indices with unknown shape.'):
      ragged.batch_gather(params, indices)

    with self.assertRaisesRegexp(
        ValueError, 'batch_gather does not allow indices with unknown shape.'):
      ragged.batch_gather(params, ragged_indices)
Esempio n. 14
0
    def testKernelErrors(self):
        # An empty vector, defined using a placeholder to ensure that we can't
        # determine that it's invalid at graph-construction time.
        empty_vector = array_ops.placeholder_with_default(array_ops.zeros(
            [0], dtypes.int64),
                                                          shape=None)

        bad_rt1 = ragged.from_row_splits(row_splits=[2, 3], values=[1, 2, 3])
        with self.test_session():
            bad_split0_error = r'First value of ragged splits must be 0.*'
            self.assertRaisesRegexp(errors.InvalidArgumentError,
                                    bad_split0_error,
                                    ragged.to_sparse(bad_rt1).eval)

        bad_rt2 = ragged.from_row_splits(row_splits=[0, 5],
                                         values=empty_vector)
        bad_rt3 = ragged.from_row_splits(row_splits=[0, 1],
                                         values=ragged.from_row_splits(
                                             row_splits=[0, 5],
                                             values=empty_vector))
        with self.test_session():
            split_mismatch1_error = r'Final value of ragged splits must match.*'
            for rt in [bad_rt2, bad_rt3]:
                self.assertRaisesRegexp(errors.InvalidArgumentError,
                                        split_mismatch1_error,
                                        ragged.to_sparse(rt).eval)

        bad_rt4 = ragged.from_row_splits(row_splits=[0, 5],
                                         values=ragged.from_row_splits(
                                             row_splits=[0],
                                             values=empty_vector))
        with self.test_session():
            split_mismatch2_error = r'Final value of ragged splits must match.*'
            self.assertRaisesRegexp(errors.InvalidArgumentError,
                                    split_mismatch2_error,
                                    ragged.to_sparse(bad_rt4).eval)

        bad_rt5 = ragged.from_row_splits(row_splits=empty_vector, values=[])
        with self.test_session():
            empty_splits_error = (r'ragged splits may not be empty.*')
            self.assertRaisesRegexp(errors.InvalidArgumentError,
                                    empty_splits_error,
                                    ragged.to_sparse(bad_rt5).eval)
  def testShapeMismatchError2(self):
    rt = ragged.constant([
        [[111, 112, 113, 114], [121]],  # row 0
        [],                             # row 1
        [[], [321, 322], [331]],        # row 2
        [[411, 412]]                    # row 3
    ])  # pyformat: disable
    segment_ids = ragged.constant([[1, 2], [1], [1, 1, 2], [2]])

    # Error is raised at graph-building time if we can detect it then.
    self.assertRaisesRegexp(
        errors.InvalidArgumentError,
        'segment_ids.shape must be a prefix of data.shape.*',
        ragged.segment_sum, rt, segment_ids, 3)

    # Otherwise, error is raised when we run the graph.
    segment_ids2 = ragged.from_row_splits(
        array_ops.placeholder_with_default(segment_ids.values, None),
        array_ops.placeholder_with_default(segment_ids.row_splits, None))
    segmented2 = ragged.segment_sum(rt, segment_ids2, 3)
    with self.test_session():
      self.assertRaisesRegexp(
          errors.InvalidArgumentError,
          'segment_ids.shape must be a prefix of data.shape.*', segmented2.eval)
  def testKernelErrors(self):
    # An empty vector, defined using a placeholder to ensure that we can't
    # determine that it's invalid at graph-construction time.
    empty_vector = array_ops.placeholder_with_default(
        array_ops.zeros([0], dtypes.int64), shape=None)

    bad_rt1 = ragged.from_row_splits(row_splits=[2, 3], values=[1, 2, 3])
    with self.test_session():
      bad_split0_error = r'First value of ragged splits must be 0.*'
      self.assertRaisesRegexp(errors.InvalidArgumentError, bad_split0_error,
                              ragged.to_sparse(bad_rt1).eval)

    bad_rt2 = ragged.from_row_splits(row_splits=[0, 5], values=empty_vector)
    bad_rt3 = ragged.from_row_splits(
        row_splits=[0, 1],
        values=ragged.from_row_splits(row_splits=[0, 5], values=empty_vector))
    with self.test_session():
      split_mismatch1_error = r'Final value of ragged splits must match.*'
      for rt in [bad_rt2, bad_rt3]:
        self.assertRaisesRegexp(errors.InvalidArgumentError,
                                split_mismatch1_error,
                                ragged.to_sparse(rt).eval)

    bad_rt4 = ragged.from_row_splits(
        row_splits=[0, 5],
        values=ragged.from_row_splits(row_splits=[0], values=empty_vector))
    with self.test_session():
      split_mismatch2_error = r'Final value of ragged splits must match.*'
      self.assertRaisesRegexp(errors.InvalidArgumentError,
                              split_mismatch2_error,
                              ragged.to_sparse(bad_rt4).eval)

    bad_rt5 = ragged.from_row_splits(row_splits=empty_vector, values=[])
    with self.test_session():
      empty_splits_error = (r'ragged splits may not be empty.*')
      self.assertRaisesRegexp(errors.InvalidArgumentError, empty_splits_error,
                              ragged.to_sparse(bad_rt5).eval)
Esempio n. 17
0
 def testExplicitAxisOptimizations(self):
   rt = ragged.from_row_splits(b'a b c d e f g'.split(), [0, 2, 5, 6, 6, 7])
   self.assertEqual(self.evaluate(ragged.bounding_shape(rt, 0)).tolist(), 5)
   self.assertEqual(self.evaluate(ragged.bounding_shape(rt, 1)).tolist(), 3)
   self.assertEqual(
       self.evaluate(ragged.bounding_shape(rt, [1, 0])).tolist(), [3, 5])