Ejemplo n.º 1
0
    def testRaggedRankTwo(self):
        rt = ragged.constant([
            [
                [111, 112, 113, 114],
                [121],
            ],  # row 0
            [],  # row 1
            [[], [321, 322], [331]],  # row 2
            [[411, 412]]  # row 3
        ])  # pyformat: disable
        segment_ids1 = [0, 2, 2, 2]
        segmented1 = ragged.segment_sum(rt, segment_ids1, 3)
        expected1 = [
            [[111, 112, 113, 114], [121]],  # row 0
            [],  # row 1
            [[411, 412], [321, 322], [331]]  # row 2
        ]  # pyformat: disable
        self.assertRaggedEqual(segmented1, expected1)

        segment_ids2 = [1, 2, 1, 1]
        segmented2 = ragged.segment_sum(rt, segment_ids2, 3)
        expected2 = [[],
                     [[111 + 411, 112 + 412, 113, 114], [121 + 321, 322],
                      [331]], []]  # pyformat: disable
        self.assertRaggedEqual(segmented2, expected2)
    def testShapeMismatchError2(self):
        rt = ragged.constant([
            [[111, 112, 113, 114], [121]],  # row 0
            [],  # row 1
            [[], [321, 322], [331]],  # row 2
            [[411, 412]]  # row 3
        ])  # pyformat: disable
        segment_ids = ragged.constant([[1, 2], [1], [1, 1, 2], [2]])

        # Error is raised at graph-building time if we can detect it then.
        self.assertRaisesRegexp(
            errors.InvalidArgumentError,
            'segment_ids.shape must be a prefix of data.shape.*',
            ragged.segment_sum, rt, segment_ids, 3)

        # Otherwise, error is raised when we run the graph.
        segment_ids2 = ragged.from_row_splits(
            array_ops.placeholder_with_default(segment_ids.values, None),
            array_ops.placeholder_with_default(segment_ids.row_splits, None))
        segmented2 = ragged.segment_sum(rt, segment_ids2, 3)
        with self.test_session():
            self.assertRaisesRegexp(
                errors.InvalidArgumentError,
                'segment_ids.shape must be a prefix of data.shape.*',
                segmented2.eval)
Ejemplo n.º 3
0
  def testRaggedRankTwo(self):
    rt = ragged.constant([
        [[111, 112, 113, 114], [121],],  # row 0
        [],                              # row 1
        [[], [321, 322], [331]],         # row 2
        [[411, 412]]                     # row 3
    ])  # pyformat: disable
    segment_ids1 = [0, 2, 2, 2]
    segmented1 = ragged.segment_sum(rt, segment_ids1, 3)
    expected1 = [[[111, 112, 113, 114], [121]],     # row 0
                 [],                                # row 1
                 [[411, 412], [321, 322], [331]]    # row 2
                ]  # pyformat: disable
    self.assertEqual(self.evaluate(segmented1).tolist(), expected1)

    segment_ids2 = [1, 2, 1, 1]
    segmented2 = ragged.segment_sum(rt, segment_ids2, 3)
    expected2 = [[],
                 [[111+411, 112+412, 113, 114], [121+321, 322], [331]],
                 []]  # pyformat: disable
    self.assertEqual(self.evaluate(segmented2).tolist(), expected2)
 def testRaggedSegmentIds(self):
   rt = ragged.constant([
       [[111, 112, 113, 114], [121],],  # row 0
       [],                              # row 1
       [[], [321, 322], [331]],         # row 2
       [[411, 412]]                     # row 3
   ])  # pyformat: disable
   segment_ids = ragged.constant([[1, 2], [], [1, 1, 2], [2]])
   segmented = ragged.segment_sum(rt, segment_ids, 3)
   expected = [[],
               [111+321, 112+322, 113, 114],
               [121+331+411, 412]]  # pyformat: disable
   self.assertEqual(self.evaluate(segmented).tolist(), expected)
Ejemplo n.º 5
0
 def testRaggedSegmentIds(self):
   rt = ragged.constant([
       [[111, 112, 113, 114], [121],],  # row 0
       [],                              # row 1
       [[], [321, 322], [331]],         # row 2
       [[411, 412]]                     # row 3
   ])  # pyformat: disable
   segment_ids = ragged.constant([[1, 2], [], [1, 1, 2], [2]])
   segmented = ragged.segment_sum(rt, segment_ids, 3)
   expected = [[],
               [111+321, 112+322, 113, 114],
               [121+331+411, 412]]  # pyformat: disable
   self.assertEqual(self.evaluate(segmented).tolist(), expected)
  def testShapeMismatchError2(self):
    rt = ragged.constant([
        [[111, 112, 113, 114], [121]],  # row 0
        [],                             # row 1
        [[], [321, 322], [331]],        # row 2
        [[411, 412]]                    # row 3
    ])  # pyformat: disable
    segment_ids = ragged.constant([[1, 2], [1], [1, 1, 2], [2]])

    # Error is raised at graph-building time if we can detect it then.
    self.assertRaisesRegexp(
        errors.InvalidArgumentError,
        'segment_ids.shape must be a prefix of data.shape.*',
        ragged.segment_sum, rt, segment_ids, 3)

    # Otherwise, error is raised when we run the graph.
    segment_ids2 = ragged.from_row_splits(
        array_ops.placeholder_with_default(segment_ids.values, None),
        array_ops.placeholder_with_default(segment_ids.row_splits, None))
    segmented2 = ragged.segment_sum(rt, segment_ids2, 3)
    with self.test_session():
      self.assertRaisesRegexp(
          errors.InvalidArgumentError,
          'segment_ids.shape must be a prefix of data.shape.*', segmented2.eval)