Esempio n. 1
0
    def testRaggedRankTwo(self):
        rt = ragged_factory_ops.constant([
            [
                [111, 112, 113, 114],
                [121],
            ],  # row 0
            [],  # row 1
            [[], [321, 322], [331]],  # row 2
            [[411, 412]]  # row 3
        ])  # pyformat: disable
        segment_ids1 = [0, 2, 2, 2]
        segmented1 = ragged_math_ops.segment_sum(rt, segment_ids1, 3)
        expected1 = [
            [[111, 112, 113, 114], [121]],  # row 0
            [],  # row 1
            [[411, 412], [321, 322], [331]]  # row 2
        ]  # pyformat: disable
        self.assertAllEqual(segmented1, expected1)

        segment_ids2 = [1, 2, 1, 1]
        segmented2 = ragged_math_ops.segment_sum(rt, segment_ids2, 3)
        expected2 = [[],
                     [[111 + 411, 112 + 412, 113, 114], [121 + 321, 322],
                      [331]], []]  # pyformat: disable
        self.assertAllEqual(segmented2, expected2)
  def testRaggedRankTwo(self):
    rt = ragged_factory_ops.constant([
        [[111, 112, 113, 114], [121],],  # row 0
        [],                              # row 1
        [[], [321, 322], [331]],         # row 2
        [[411, 412]]                     # row 3
    ])  # pyformat: disable
    segment_ids1 = [0, 2, 2, 2]
    segmented1 = ragged_math_ops.segment_sum(rt, segment_ids1, 3)
    expected1 = [[[111, 112, 113, 114], [121]],     # row 0
                 [],                                # row 1
                 [[411, 412], [321, 322], [331]]    # row 2
                ]  # pyformat: disable
    self.assertRaggedEqual(segmented1, expected1)

    segment_ids2 = [1, 2, 1, 1]
    segmented2 = ragged_math_ops.segment_sum(rt, segment_ids2, 3)
    expected2 = [[],
                 [[111+411, 112+412, 113, 114], [121+321, 322], [331]],
                 []]  # pyformat: disable
    self.assertRaggedEqual(segmented2, expected2)
 def testRaggedSegmentIds(self):
   rt = ragged_factory_ops.constant([
       [[111, 112, 113, 114], [121],],  # row 0
       [],                              # row 1
       [[], [321, 322], [331]],         # row 2
       [[411, 412]]                     # row 3
   ])  # pyformat: disable
   segment_ids = ragged_factory_ops.constant([[1, 2], [], [1, 1, 2], [2]])
   segmented = ragged_math_ops.segment_sum(rt, segment_ids, 3)
   expected = [[],
               [111+321, 112+322, 113, 114],
               [121+331+411, 412]]  # pyformat: disable
   self.assertRaggedEqual(segmented, expected)
Esempio n. 4
0
 def testRaggedSegmentIds(self):
   rt = ragged_factory_ops.constant([
       [[111, 112, 113, 114], [121],],  # row 0
       [],                              # row 1
       [[], [321, 322], [331]],         # row 2
       [[411, 412]]                     # row 3
   ])  # pyformat: disable
   segment_ids = ragged_factory_ops.constant([[1, 2], [], [1, 1, 2], [2]])
   segmented = ragged_math_ops.segment_sum(rt, segment_ids, 3)
   expected = [[],
               [111+321, 112+322, 113, 114],
               [121+331+411, 412]]  # pyformat: disable
   self.assertAllEqual(segmented, expected)
  def testShapeMismatchError2(self):
    rt = ragged_factory_ops.constant([
        [[111, 112, 113, 114], [121]],  # row 0
        [],                             # row 1
        [[], [321, 322], [331]],        # row 2
        [[411, 412]]                    # row 3
    ])  # pyformat: disable
    segment_ids = ragged_factory_ops.constant([[1, 2], [1], [1, 1, 2], [2]])

    # Error is raised at graph-building time if we can detect it then.
    self.assertRaisesRegexp(
        errors.InvalidArgumentError,
        'segment_ids.shape must be a prefix of data.shape.*',
        ragged_math_ops.segment_sum, rt, segment_ids, 3)

    # Otherwise, error is raised when we run the graph.
    segment_ids2 = ragged_tensor.RaggedTensor.from_row_splits(
        array_ops.placeholder_with_default(segment_ids.values, None),
        array_ops.placeholder_with_default(segment_ids.row_splits, None))
    with self.assertRaisesRegexp(
        errors.InvalidArgumentError,
        'segment_ids.shape must be a prefix of data.shape.*'):
      self.evaluate(ragged_math_ops.segment_sum(rt, segment_ids2, 3))
Esempio n. 6
0
  def testShapeMismatchError2(self):
    rt = ragged_factory_ops.constant([
        [[111, 112, 113, 114], [121]],  # row 0
        [],                             # row 1
        [[], [321, 322], [331]],        # row 2
        [[411, 412]]                    # row 3
    ])  # pyformat: disable
    segment_ids = ragged_factory_ops.constant([[1, 2], [1], [1, 1, 2], [2]])

    # Error is raised at graph-building time if we can detect it then.
    self.assertRaisesRegexp(
        errors.InvalidArgumentError,
        'segment_ids.shape must be a prefix of data.shape.*',
        ragged_math_ops.segment_sum, rt, segment_ids, 3)

    # Otherwise, error is raised when we run the graph.
    segment_ids2 = ragged_tensor.RaggedTensor.from_row_splits(
        array_ops.placeholder_with_default(segment_ids.values, None),
        array_ops.placeholder_with_default(segment_ids.row_splits, None))
    with self.assertRaisesRegexp(
        errors.InvalidArgumentError,
        'segment_ids.shape must be a prefix of data.shape.*'):
      self.evaluate(ragged_math_ops.segment_sum(rt, segment_ids2, 3))