def test_dataset_with_extra_test(self):
   s = control_flow.for_stmt(
       dataset_ops.Dataset.range(5),
       extra_test=lambda s: s < 3,
       body=lambda i, s: (s + i,),
       init_state=(constant_op.constant(0, dtype=dtypes.int64),))
   self.assertEqual(self.evaluate(s), (3,))
 def test_python(self):
   s = control_flow.for_stmt(
       range(5),
       extra_test=lambda s: True,
       body=lambda i, s: (s + i,),
       init_state=(0,))
   self.assertEqual((10,), s)
 def test_fn():
   itr = iter(dataset_ops.Dataset.range(5))
   return control_flow.for_stmt(
       itr,
       extra_test=None,
       body=lambda i, s: (s * 10 + i,),
       init_state=(constant_op.constant(0, dtype=dtypes.int64),))
 def test_tensor(self):
   with ops.Graph().as_default():
     s = control_flow.for_stmt(
         constant_op.constant([1, 2, 3, 4]),
         extra_test=lambda s: True,
         body=lambda i, s: (s * 10 + i,),
         init_state=(0,))
     self.assertEqual(self.evaluate(s), (1234,))
 def test_tf_dataset(self):
   with ops.Graph().as_default():
     s = control_flow.for_stmt(
         dataset_ops.Dataset.range(5),
         extra_test=None,
         body=lambda i, s: (s * 10 + i,),
         init_state=(constant_op.constant(0, dtype=dtypes.int64),))
     self.assertEqual(self.evaluate(s), (1234,))
 def test_tensor(self):
   s = control_flow.for_stmt(
       constant_op.constant([1, 2, 3, 4]),
       extra_test=lambda s: True,
       body=lambda i, s: (s + i,),
       init_state=(0,))
   with self.cached_session():
     self.assertEqual((10,), self.evaluate(s))
 def test_dataset(self):
   to_int32 = lambda i: math_ops.cast(i, dtypes.int32)
   s = control_flow.for_stmt(
       dataset_ops.Dataset.range(5).map(to_int32),
       extra_test=lambda s: True,
       body=lambda i, s: (s + i,),
       init_state=(0,))
   with self.cached_session():
     self.assertEqual((10,), self.evaluate(s))
 def run_loop():
   return control_flow.for_stmt(
       gen,
       extra_test=lambda s: False,  # Break before loop
       body=lambda i, s: (s * 10 + i,),
       get_state=None,
       set_state=None,
       init_vars=(0,),
       basic_symbol_names=('s',),
       composite_symbol_names=(),
       opts={})
 def run_loop():
   return control_flow.for_stmt(
       gen,
       extra_test=lambda s, c: c == 0,  # Break after first iteration
       body=lambda i, s, c: (s * 10 + i, c + 1),
       get_state=None,
       set_state=None,
       init_vars=(0, 0),
       basic_symbol_names=('s', 'c'),
       composite_symbol_names=(),
       opts={})
Exemple #10
0
 def test_range_tensor_explicit_limit_delta(self):
     s = control_flow.for_stmt(math_ops.range(-17, -3, 5),
                               extra_test=lambda s: True,
                               body=lambda i, s: (s * 100 + i, ),
                               get_state=lambda: (),
                               set_state=lambda _: None,
                               init_vars=(0, ),
                               basic_symbol_names=('s', ),
                               composite_symbol_names=(),
                               opts={})
     self.assertEqual(self.evaluate(s), (-171207, ))
Exemple #11
0
 def test_tensor(self):
     s = control_flow.for_stmt(constant_op.constant([1, 2, 3, 4]),
                               extra_test=lambda s: True,
                               body=lambda i, s: (s * 10 + i, ),
                               get_state=lambda: (),
                               set_state=lambda _: None,
                               init_vars=(0, ),
                               basic_symbol_names=('s', ),
                               composite_symbol_names=(),
                               opts={})
     self.assertEqual(self.evaluate(s), (1234, ))
Exemple #12
0
 def test_python(self):
     s = control_flow.for_stmt(range(5),
                               extra_test=lambda s: True,
                               body=lambda i, s: (s * 10 + i, ),
                               get_state=None,
                               set_state=None,
                               init_vars=(0, ),
                               basic_symbol_names=('s', ),
                               composite_symbol_names=(),
                               opts={})
     self.assertEqual(s, (1234, ))
    def test_dataset_with_extra_test(self):
        def body(i):
            nonlocal s
            s = s * 10 + i

        def set_state(loop_vars):
            nonlocal s
            s, = loop_vars

        s = constant_op.constant(0, dtype=dtypes.int64)
        control_flow.for_stmt(dataset_ops.Dataset.range(5),
                              extra_test=lambda: s < 3,
                              body=body,
                              get_state=lambda: (s, ),
                              set_state=set_state,
                              symbol_names=('s', ),
                              opts={})

        self.assertEqual(s, (12, ))
        self.assertOpCreated('ScanDataset')
 def test_dataset_no_state(self):
   v = variables.Variable(0, dtype=dtypes.int64)
   def stateless_with_side_effects(i):
     v.assign(v.read_value() + i)
   s = control_flow.for_stmt(
       dataset_ops.Dataset.range(5),
       extra_test=None,
       body=stateless_with_side_effects,
       init_state=())
   self.evaluate(s)
   self.assertEqual(self.evaluate(v.read_value()), 10)
    def test_tf_iterator(self):
        def body(i):
            nonlocal s
            s = s * 10 + i

        def set_state(loop_vars):
            nonlocal s
            s, = loop_vars

        s = constant_op.constant(0, dtype=dtypes.int64)
        control_flow.for_stmt(iter(dataset_ops.Dataset.range(5)),
                              extra_test=None,
                              body=body,
                              get_state=lambda: (s, ),
                              set_state=set_state,
                              symbol_names=('s', ),
                              opts={})

        self.assertEqual(s, 1234)
        self.assertOpCreated('IteratorGetNextAsOptional')
Exemple #16
0
    def test_python(self):
        def body(i):
            nonlocal s
            s = s * 10 + i

        def set_state(loop_vars):
            nonlocal s
            s, = loop_vars

        s = 0
        control_flow.for_stmt(range(5),
                              extra_test=lambda: True,
                              body=body,
                              get_state=lambda: (s, ),
                              set_state=set_state,
                              symbol_names=('s', ),
                              opts={})

        self.assertEqual(s, 1234)
        self.assertNoOpsCreated()
Exemple #17
0
    def test_range_tensor_explicit_limit_negative_delta(self):
        def body(i):
            nonlocal s
            s = s * 100 + i

        def set_state(loop_vars):
            nonlocal s
            s, = loop_vars

        s = 0
        control_flow.for_stmt(math_ops.range(17, 3, -5),
                              extra_test=lambda: True,
                              body=body,
                              get_state=lambda: (s, ),
                              set_state=set_state,
                              symbol_names=('s', ),
                              opts={'iterate_names': 'i'})

        self.assertEqual(s, (171207, ))
        self.assertOpCreated('StatelessWhile')
    def test_dataset_with_extra_test_iteration_limiting(self):
        def body(it):
            nonlocal i
            with ops.control_dependencies(
                (control_flow_ops.Assert(i < 3, (i, )), )):
                i = it

        def set_state(loop_vars):
            nonlocal i
            i, = loop_vars

        i = constant_op.constant(0, dtype=dtypes.int64)
        control_flow.for_stmt(dataset_ops.Dataset.range(5),
                              extra_test=lambda: i < 3,
                              body=body,
                              get_state=lambda: (i, ),
                              set_state=set_state,
                              symbol_names=('i', ),
                              opts={})
        self.assertEqual(self.evaluate(i), (3, ))
    def test_dataset_with_extra_test_collection_vars(self):
        def body(i):
            nonlocal s
            l[0] += i
            s += i

        def set_state(loop_vars):
            nonlocal s
            l[0], s = loop_vars

        s = constant_op.constant(0, dtype=dtypes.int64)
        l = [constant_op.constant(0, dtype=dtypes.int64)]
        control_flow.for_stmt(dataset_ops.Dataset.range(5),
                              extra_test=lambda: s < 3,
                              body=body,
                              get_state=lambda: (l[0], s),
                              set_state=set_state,
                              symbol_names=('l[0]', 's'),
                              opts={})
        self.assertEqual(self.evaluate((l[0], s)), (3, 3))
    def test_tf_ragged_tensor(self):
        def body(i):
            nonlocal s
            s = s * 10 + i[0]

        def set_state(loop_vars):
            nonlocal s
            s, = loop_vars

        s = 0
        control_flow.for_stmt(ragged_factory_ops.constant([[1], [2, 4], [3]]),
                              extra_test=None,
                              body=body,
                              get_state=lambda: (s, ),
                              set_state=set_state,
                              symbol_names=('s', ),
                              opts={})

        self.assertEqual(s, (123, ))
        self.assertOpCreated('StatelessWhile')
        def test_fn():
            def body(i):
                nonlocal s
                s = array_ops.concat([s, [i]], 0)

            def set_state(loop_vars):
                nonlocal s
                s, = loop_vars

            s = constant_op.constant([], dtype=dtypes.int64)
            control_flow.for_stmt(iter(dataset_ops.Dataset.range(5)),
                                  extra_test=None,
                                  body=body,
                                  get_state=lambda: (s, ),
                                  set_state=set_state,
                                  symbol_names=('s', ),
                                  opts={
                                      'shape_invariants':
                                      [(s, tensor_shape.TensorShape([None]))]
                                  })
            return s
Exemple #22
0
 def test_dataset_with_extra_test(self):
   s = control_flow.for_stmt(
       dataset_ops.Dataset.range(5),
       extra_test=lambda s: s < 3,
       body=lambda i, s: (s + i,),
       get_state=lambda: (),
       set_state=lambda _: None,
       init_vars=(constant_op.constant(0, dtype=dtypes.int64),),
       basic_symbol_names=('s',),
       composite_symbol_names=(),
       opts={})
   self.assertEqual(self.evaluate(s), (3,))
Exemple #23
0
 def test_range_tensor_random_negative_delta(self):
     random_neg_five = random_ops.random_uniform((),
                                                 -5,
                                                 -4,
                                                 dtype=dtypes.int32)
     s = control_flow.for_stmt(math_ops.range(17, 3, random_neg_five),
                               extra_test=lambda s: True,
                               body=lambda i, s: (s * 100 + i, ),
                               get_state=lambda: (),
                               set_state=lambda _: None,
                               init_vars=(0, ))
     self.assertEqual(self.evaluate(s), (171207, ))
    def test_tf_iterator_shape_invariants(self):
        def body(i):
            nonlocal s
            s = array_ops.concat([s, [i]], 0)

        def set_state(loop_vars):
            nonlocal s
            s, = loop_vars

        s = constant_op.constant([], dtype=dtypes.int64)
        control_flow.for_stmt(
            iter(dataset_ops.Dataset.range(5)),
            extra_test=None,
            body=body,
            get_state=lambda: (s, ),
            set_state=set_state,
            symbol_names=('s', ),
            opts={'shape_invariants': [(s, tensor_shape.TensorShape([None]))]})

        self.assertAllEqual(s, [0, 1, 2, 3, 4])
        self.assertOpCreated('IteratorGetNextAsOptional')
    def test_range_tensor_random_delta(self):
        def body(i):
            nonlocal s
            s = s * 10 + i

        def set_state(loop_vars):
            nonlocal s
            s, = loop_vars

        s = 0
        random_one = random_ops.random_uniform((), 1, 2, dtype=dtypes.int32)
        control_flow.for_stmt(math_ops.range(0, 5, random_one),
                              extra_test=lambda: True,
                              body=body,
                              get_state=lambda: (s, ),
                              set_state=set_state,
                              symbol_names=('s', ),
                              opts={'iterate_names': 'i'})

        self.assertEqual(s, (1234, ))
        self.assertOpCreated('StatelessWhile')
Exemple #26
0
    def test_dataset_no_state(self):
        v = variables.Variable(0, dtype=dtypes.int64)

        def stateless_with_side_effects(i):
            v.assign(v.read_value() + i)

        s = control_flow.for_stmt(dataset_ops.Dataset.range(5),
                                  extra_test=None,
                                  body=stateless_with_side_effects,
                                  init_state=())
        self.evaluate(s)
        self.assertEqual(self.evaluate(v.read_value()), 10)
  def test_dataset_with_extra_test_no_extra_iterations(self):

    def guarded_body(i, s):
      with ops.control_dependencies((control_flow_ops.Assert(i < 3, (i,)),)):
        return s + i,

    s = control_flow.for_stmt(
        dataset_ops.Dataset.range(5),
        extra_test=lambda s: s < 3,
        body=guarded_body,
        init_state=(constant_op.constant(0, dtype=dtypes.int64),))
    self.assertEqual(self.evaluate(s), (3,))
Exemple #28
0
 def test_range_tensor_random_delta(self):
     random_one = random_ops.random_uniform((), 1, 2, dtype=dtypes.int32)
     s = control_flow.for_stmt(math_ops.range(0, 5, random_one),
                               extra_test=lambda s: True,
                               body=lambda i, s: (s * 10 + i, ),
                               get_state=lambda: (),
                               set_state=lambda _: None,
                               init_vars=(0, ),
                               basic_symbol_names=('s', ),
                               composite_symbol_names=(),
                               opts={})
     self.assertEqual(self.evaluate(s), (1234, ))
Exemple #29
0
 def test_tf_ragged_tensor(self):
   s = control_flow.for_stmt(
       ragged_factory_ops.constant([[1], [2, 4], [3]]),
       extra_test=lambda s: True,
       body=lambda i, s: (s * 10 + i[0],),
       get_state=lambda: (),
       set_state=lambda _: None,
       init_vars=(0,),
       basic_symbol_names=('s',),
       composite_symbol_names=(),
       opts={})
   self.assertEqual(self.evaluate(s), (123,))
    def test_dataset_with_extra_test_no_extra_iterations(self):
        def guarded_body(i, s):
            with ops.control_dependencies(
                (control_flow_ops.Assert(i < 3, (i, )), )):
                return s + i,

        s = control_flow.for_stmt(dataset_ops.Dataset.range(5),
                                  extra_test=lambda s: s < 3,
                                  body=guarded_body,
                                  init_state=(constant_op.constant(
                                      0, dtype=dtypes.int64), ))
        self.assertEqual(self.evaluate(s), (3, ))
Exemple #31
0
 def test_fn():
   itr = iter(dataset_ops.Dataset.range(5))
   return control_flow.for_stmt(
       itr,
       extra_test=None,
       body=lambda i, s: (s * 10 + i,),
       get_state=lambda: (),
       set_state=lambda _: None,
       init_vars=(constant_op.constant(0, dtype=dtypes.int64),),
       basic_symbol_names=('s',),
       composite_symbol_names=(),
       opts={})
Exemple #32
0
    def test_range_tensor_random_negative_delta(self):
        def body(i):
            nonlocal s
            s = s * 100 + i

        def set_state(loop_vars):
            nonlocal s
            s, = loop_vars

        s = 0
        random_neg_five = random_ops.random_uniform((),
                                                    -5,
                                                    -4,
                                                    dtype=dtypes.int32)
        control_flow.for_stmt(math_ops.range(17, 3, random_neg_five),
                              extra_test=lambda: True,
                              body=body,
                              get_state=lambda: (s, ),
                              set_state=set_state,
                              symbol_names=('s', ),
                              opts={'iterate_names': 'i'})
        self.assertEqual(s, (171207, ))
    def test_tf_ragged_tensor_higher_dimensional(self):
        def body(i):
            nonlocal s
            s = s * 10 + i[0][0]

        def set_state(loop_vars):
            nonlocal s
            s, = loop_vars

        s = 0
        ragged_3d = [
            [[1], [1, 1], [1]],
            [[2], [2]],
        ]
        control_flow.for_stmt(ragged_factory_ops.constant(ragged_3d),
                              extra_test=None,
                              body=body,
                              get_state=lambda: (s, ),
                              set_state=set_state,
                              symbol_names=('s', ),
                              opts={})
        self.assertEqual(self.evaluate(s), (12, ))
    def test_range_tensor_random_delta(self):

        with ops.Graph().as_default():
            random_one = random_ops.random_uniform((),
                                                   1,
                                                   2,
                                                   dtype=dtypes.int32)
            s = control_flow.for_stmt(math_ops.range(0, 5, random_one),
                                      extra_test=lambda s: True,
                                      body=lambda i, s: (s * 10 + i, ),
                                      get_state=lambda: (),
                                      set_state=lambda _: None,
                                      init_vars=(0, ))
            self.assertEqual(self.evaluate(s), (1234, ))
Exemple #35
0
 def test_tf_ragged_tensor_higher_dimensional(self):
     ragged_3d = [
         [[1], [1, 1], [1]],
         [[2], [2]],
     ]
     s = control_flow.for_stmt(ragged_factory_ops.constant(ragged_3d),
                               extra_test=lambda s: True,
                               body=lambda i, s: (s * 10 + i[0][0], ),
                               get_state=lambda: (),
                               set_state=lambda _: None,
                               init_vars=(0, ),
                               basic_symbol_names=('s', ),
                               composite_symbol_names=(),
                               opts={})
     self.assertEqual(self.evaluate(s), (12, ))
Exemple #36
0
  def test_dataset_with_extra_test_no_extra_iterations(self):

    def guarded_body(i, s):
      with ops.control_dependencies((control_flow_ops.Assert(i < 3, (i,)),)):
        return s + i,

    s = control_flow.for_stmt(
        dataset_ops.Dataset.range(5),
        extra_test=lambda s: s < 3,
        body=guarded_body,
        get_state=lambda: (),
        set_state=lambda _: None,
        init_vars=(constant_op.constant(0, dtype=dtypes.int64),),
        basic_symbol_names=('s',),
        composite_symbol_names=(),
        opts={})
    self.assertEqual(self.evaluate(s), (3,))
    def test_dataset_with_extra_test_and_state(self):
        state = [constant_op.constant(0, dtype=dtypes.int64)]

        def get_state():
            return (state[0], )

        def set_state(new_state):
            state[0], = new_state

        def body(i, s):
            state[0] += i
            return (s + i, )

        s = control_flow.for_stmt(dataset_ops.Dataset.range(5),
                                  extra_test=lambda s: s < 3,
                                  body=body,
                                  get_state=get_state,
                                  set_state=set_state,
                                  init_vars=(constant_op.constant(
                                      0, dtype=dtypes.int64), ))
        self.assertEqual(self.evaluate(s), (3, ))
        self.assertEqual(self.evaluate(state[0]), (3, ))
 def test_fn():
     control_flow.for_stmt(dataset_ops.Dataset.range(5),
                           extra_test=None,
                           body=stateless_with_side_effects,
                           init_state=())
 def test_python(self):
     s = control_flow.for_stmt(range(5),
                               extra_test=lambda s: True,
                               body=lambda i, s: (s + i, ),
                               init_state=(0, ))
     self.assertEqual((10, ), s)
 def test_fn():
   control_flow.for_stmt(
       dataset_ops.Dataset.range(5),
       extra_test=None,
       body=stateless_with_side_effects,
       init_state=())