Beispiel #1
0
    def testRaggedMap(
        self,
        fn,
        elems,
        expected_output,
        expected_ragged_rank=None,
        result_ragged_rank=None,
        elems_ragged_rank=None,
        dtype=dtypes.int64,
        result_dtype=None,
        infer_shape=False,
    ):
        elems = ragged.constant(elems, dtype, elems_ragged_rank)
        output = ragged.map_fn(fn=fn,
                               elems=elems,
                               dtype=result_dtype,
                               infer_shape=infer_shape)

        expected_rt = ragged.constant(expected_output,
                                      ragged_rank=expected_ragged_rank)
        with self.test_session():
            if ragged.is_ragged(expected_output):
                self.assertEqual(output.ragged_rank, expected_rt.ragged_rank)
            output_values = self.evaluate(output)
            self.assertAllEqual(expected_output, output_values.tolist())
  def testRaggedMapOnStructure_RaggedOutputs(self):
    batman = ragged.constant([[1, 2, 3], [4], [5, 6, 7]])
    # [[10, 20, 30], [40], [50, 60, 70]]
    robin = ragged.map_inner_values(mo.multiply, batman, 10)

    features = {'batman': batman, 'robin': robin}

    def _increment(f):
      return {
          'batman': ragged.add(f['batman'], 1),
          'robin': ragged.add(f['robin'], 1),
      }

    output = ragged.map_fn(
        fn=_increment,
        elems=features,
        infer_shape=False,
        dtype={
            'batman':
                ragged.RaggedTensorType(dtype=dtypes.int32, ragged_rank=1),
            'robin':
                ragged.RaggedTensorType(dtype=dtypes.int32, ragged_rank=1)
        },
    )

    with self.test_session():
      self.assertAllEqual(output['batman'].eval().tolist(),
                          [[2, 3, 4], [5], [6, 7, 8]])
      self.assertAllEqual(output['robin'].eval().tolist(),
                          [[11, 21, 31], [41], [51, 61, 71]])
 def testMismatchRaggedRank2(self):
   elems = ragged.constant([[1, 2, 3], [4, 5], [6, 7]])
   fn = lambda x: ragged.from_row_starts(x, [0])
   with self.assertRaisesWithLiteralMatch(
       ValueError, r'The declared ragged rank (10) mismatches the result (1)'):
     _ = ragged.map_fn(
         fn,
         elems,
         dtype=ragged.RaggedTensorType(dtype=dtypes.int64, ragged_rank=10))
 def testMismatchRaggedRank(self):
   elems = ragged.constant([[[1, 2, 3]], [[4, 5], [6, 7]]])
   fn = lambda x: ragged.reduce_sum(x, axis=0)
   with self.assertRaisesWithLiteralMatch(
       ValueError, r'The declared ragged rank (23) mismatches the result (1)'):
     _ = ragged.map_fn(
         fn,
         elems,
         dtype=ragged.RaggedTensorType(dtype=dtypes.int64, ragged_rank=23))
 def testMapOnSparseTensor(self):
     s = sparse_tensor.SparseTensor(
         indices=[[0, 0], [0, 1], [1, 0], [1, 1]],
         values=[0, 5, 0, 4],
         dense_shape=[2, 2],
     )
     t2 = ragged.RaggedTensor.from_sparse(s)
     id_t2 = ragged.map_fn(
         lambda x: x,
         t2,
     )
     self.assertRaggedEqual(id_t2, [[0, 5], [0, 4]])
    def testRaggedMapOnStructure(self):
        batman = ragged.constant([[1, 2, 3], [4], [5, 6, 7]])
        # [[10, 20, 30], [40], [50, 60, 70]]
        robin = ragged.map_flat_values(mo.multiply, batman, 10)

        features = {'batman': batman, 'robin': robin}

        def _reduce_sum_from_all(f):
            return mo.reduce_sum(f['batman']) + mo.reduce_sum(f['robin'])

        output = ragged.map_fn(
            fn=_reduce_sum_from_all,
            elems=features,
            dtype=dtypes.int32,
        )

        self.assertRaggedEqual(output, [66, 44, 198])
  def testRaggedMapOnStructure(self):
    batman = ragged.constant([[1, 2, 3], [4], [5, 6, 7]])
    # [[10, 20, 30], [40], [50, 60, 70]]
    robin = ragged.map_inner_values(mo.multiply, batman, 10)

    features = {'batman': batman, 'robin': robin}

    def _reduce_sum_from_all(f):
      return mo.reduce_sum(f['batman']) + mo.reduce_sum(f['robin'])

    output = ragged.map_fn(
        fn=_reduce_sum_from_all,
        elems=features,
        dtype=dtypes.int32,
    )

    with self.test_session():
      self.assertAllEqual(output.eval().tolist(), [66, 44, 198])
    def testBatchGather(self):
        tokens = ragged.constant([['hello', '.', 'there'], ['merhaba'],
                                  ['bonjour', '.', 'ca va', '?']])
        indices = ragged.constant([[0, 2], [0], [0, 2]])

        def gather(x):
            tokens_val, indices_val = x
            return array_ops.gather(tokens_val, indices_val)

        data = tokens, indices
        out = ragged.map_fn(gather,
                            data,
                            dtype=ragged.RaggedTensorType(dtype=dtypes.string,
                                                          ragged_rank=1),
                            infer_shape=False)

        self.assertRaggedEqual(
            out, [[b'hello', b'there'], [b'merhaba'], [b'bonjour', b'ca va']])
  def testBatchGather(self):
    tokens = ragged.constant([['hello', '.', 'there'], ['merhaba'],
                              ['bonjour', '.', 'ca va', '?']])
    indices = ragged.constant([[0, 2], [0], [0, 2]])

    def gather(x):
      tokens_val, indices_val = x
      return array_ops.gather(tokens_val, indices_val)

    data = tokens, indices
    out = ragged.map_fn(
        gather,
        data,
        dtype=ragged.RaggedTensorType(dtype=dtypes.string, ragged_rank=1),
        infer_shape=False)

    with self.test_session():
      self.assertAllEqual(
          self.evaluate(out).tolist(),
          [[b'hello', b'there'], [b'merhaba'], [b'bonjour', b'ca va']])
    def testZip(self):
        x = ragged.constant(
            [[10, 20], [30, 40], [50, 60], [70], [80, 90, 100]], dtypes.int64)
        y = array_ops.expand_dims(mo.range(x.nrows(), dtype=dtypes.int64),
                                  axis=1)

        def _zip(foo):
            y_val, x_val = foo
            bar = backend.tile(y_val, array_ops.shape(x_val))
            return array_ops.stack([bar, x_val], axis=1)

        output = ragged.map_fn(_zip, (y, x),
                               dtype=ragged.RaggedTensorType(
                                   dtype=dtypes.int64, ragged_rank=1),
                               infer_shape=False)

        self.assertRaggedEqual(
            output,
            [[[0, 10], [0, 20]], [[1, 30], [1, 40]], [[2, 50], [2, 60]],
             [[3, 70]], [[4, 80], [4, 90], [4, 100]]])
  def testZip(self):
    x = ragged.constant([[10, 20], [30, 40], [50, 60], [70], [80, 90, 100]],
                        dtypes.int64)
    y = array_ops.expand_dims(
        mo.range(ragged.nrows(x), dtype=dtypes.int64), axis=1)

    def _zip(foo):
      y_val, x_val = foo
      bar = backend.tile(y_val, array_ops.shape(x_val))
      return array_ops.stack([bar, x_val], axis=1)

    output = ragged.map_fn(
        _zip, (y, x),
        dtype=ragged.RaggedTensorType(dtype=dtypes.int64, ragged_rank=1),
        infer_shape=False)

    with self.test_session():
      result = self.evaluate(output).tolist()
      self.assertAllEqual(
          result, [[[0, 10], [0, 20]], [[1, 30], [1, 40]], [[2, 50], [2, 60]],
                   [[3, 70]], [[4, 80], [4, 90], [4, 100]]])
    def testRaggedMap(
        self,
        fn,
        elems,
        expected_output,
        expected_ragged_rank=None,
        result_ragged_rank=None,
        elems_ragged_rank=None,
        dtype=dtypes.int64,
        result_dtype=None,
        infer_shape=False,
    ):
        elems = ragged.constant(elems, dtype, elems_ragged_rank)
        output = ragged.map_fn(fn=fn,
                               elems=elems,
                               dtype=result_dtype,
                               infer_shape=infer_shape)

        expected_rt = ragged.constant(expected_output,
                                      ragged_rank=expected_ragged_rank)
        self.assertRaggedEqual(expected_rt, output)
  def testRaggedMap(
      self,
      fn,
      elems,
      expected_output,
      expected_ragged_rank=None,
      result_ragged_rank=None,
      elems_ragged_rank=None,
      dtype=dtypes.int64,
      result_dtype=None,
      infer_shape=False,
  ):
    elems = ragged.constant(elems, dtype, elems_ragged_rank)
    output = ragged.map_fn(
        fn=fn, elems=elems, dtype=result_dtype, infer_shape=infer_shape)

    expected_rt = ragged.constant(
        expected_output, ragged_rank=expected_ragged_rank)
    with self.test_session():
      if ragged.is_ragged(expected_output):
        self.assertEqual(output.ragged_rank, expected_rt.ragged_rank)
      output_values = self.evaluate(output)
      self.assertAllEqual(expected_output, output_values.tolist())