示例#1
0
 def test_tensor_undefined_output(self):
   with self.assertRaisesRegex(
       ValueError, "'x' must also be initialized in the main branch"):
     self._basic_cond(lambda: variable_operators.Undefined('x'), lambda: 1)
   with self.assertRaisesRegex(
       ValueError, "'x' must also be initialized in the else branch"):
     self._basic_cond(lambda: 1, lambda: variable_operators.Undefined('s'))
示例#2
0
    def test_undefined(self):
        undefined_symbol = variables.Undefined('name')
        undefined_symbol2 = variables.Undefined('name')

        self.assertEqual(undefined_symbol.symbol_name, 'name')
        self.assertEqual(undefined_symbol2.symbol_name, 'name')
        self.assertNotEqual(undefined_symbol, undefined_symbol2)
示例#3
0
    def test_undefined_operations(self):
        undefined_symbol = variables.Undefined('name')

        self.assertIsInstance(undefined_symbol.foo, variables.Undefined)
        self.assertIsInstance(undefined_symbol[0], variables.Undefined)
        self.assertNotIsInstance(undefined_symbol.__class__,
                                 variables.Undefined)
示例#4
0
    def test_tensor_creating_variable(self):
        def body():
            nonlocal i, s
            i = constant_op.constant(2)
            s = i**5

        def set_state(loop_vars):
            nonlocal i, s
            i, s = loop_vars

        i = variable_operators.Undefined('i')
        s = constant_op.constant(0)
        control_flow.while_stmt(test=lambda: math_ops.equal(s, 0),
                                body=body,
                                get_state=lambda: (i, s),
                                set_state=set_state,
                                symbol_names=('i', 's'),
                                opts={})

        self.assertEqual(i, 2)
        self.assertEqual(s, 32)
        self.assertOpCreated('StatelessWhile')
        # Check that the temporary staging of the body did not create extra ops.
        # Node naming is inconsistent between V1 and V2.
        self.assertGraphContains(r'(while/)?pow$', 1)
示例#5
0
  def test_tensor_creating_variable_of_dynamic_shape(self):

    def body():
      nonlocal i, s
      i = array_ops.ones(
          [random_ops.random_uniform(minval=1, maxval=4, shape=()), 7])
      s = math_ops.reduce_sum(i)

    def set_state(loop_vars):
      nonlocal i, s
      i, s = loop_vars

    i = variable_operators.Undefined('i')
    s = constant_op.constant(0.0)
    control_flow.while_stmt(
        test=lambda: math_ops.equal(s, 0),
        body=body,
        get_state=lambda: (i, s),
        set_state=set_state,
        symbol_names=('i', 's'),
        opts={})

    self.assertEqual(i[0][0], 1)
    self.assertGreaterEqual(s, 7)
    self.assertOpCreated('While')  # Not stateless because of the random op.
示例#6
0
    def test_tensor_failing_to_determine_placeholder(self):
        class UserType:
            pass

        def body():
            nonlocal v
            v = UserType()

        def set_state(loop_vars):
            nonlocal v
            v, = loop_vars

        v = variable_operators.Undefined('v')

        with self.assertRaisesRegex(
                ValueError,
                re.compile(
                    'must be defined.*tried to define.*unsupported type',
                    re.DOTALL)):
            control_flow.while_stmt(test=lambda: constant_op.constant(True),
                                    body=body,
                                    get_state=lambda: (v, ),
                                    set_state=set_state,
                                    symbol_names=('v', ),
                                    opts={})
示例#7
0
    def test_tensor_creating_dynamic_shape_variable(self):
        def body():
            nonlocal i, y
            i += 1
            y = random_ops.random_uniform([i])

        def set_state(loop_vars):
            nonlocal i, y
            i, y = loop_vars

        i = constant_op.constant(0)
        y = variable_operators.Undefined('y')
        control_flow.while_stmt(test=lambda: math_ops.less(i, 3),
                                body=body,
                                get_state=lambda: (i, y),
                                set_state=set_state,
                                symbol_names=('i', 'y'),
                                opts={})

        self.assertEqual(i, 3)
        self.assertLess(y[0], 3)
示例#8
0
  def test_tensor_creating_dynamic_shape_variable_preserves_shape_invar(self):

    def body():
      nonlocal i, y
      i += 1
      y = array_ops.zeros([1])

    def set_state(loop_vars):
      nonlocal i, y
      i, y = loop_vars

    i = constant_op.constant(0)
    y = variable_operators.Undefined('y')
    control_flow.while_stmt(
        test=lambda: math_ops.less(i, 3),
        body=body,
        get_state=lambda: (i, y),
        set_state=set_state,
        symbol_names=('i', 'y'),
        opts={'shape_invariants': ((y, tensor_shape.TensorShape([1])),)})

    self.evaluate(y)
示例#9
0
    def test_tensor_failing_to_stage_loop_body(self):
        def body():
            nonlocal i, s
            i = constant_op.constant(2)
            raise ValueError('testing')
            s = i**5  # pylint: disable=unreachable

        def set_state(loop_vars):
            nonlocal i, s
            i, s = loop_vars

        i = variable_operators.Undefined('i')
        s = constant_op.constant(0)

        with self.assertRaisesRegex(
                ValueError,
                re.compile('must be defined.*tried to define.*testing',
                           re.DOTALL)):
            control_flow.while_stmt(test=lambda: math_ops.equal(s, 0),
                                    body=body,
                                    get_state=lambda: (i, s),
                                    set_state=set_state,
                                    symbol_names=('i', 's'),
                                    opts={})
示例#10
0
 def test_read_undefined(self):
     with self.assertRaisesRegex(UnboundLocalError,
                                 'used before assignment'):
         variables.ld(variables.Undefined('a'))
示例#11
0
 def test_tensor_illegal_input(self):
     with self.assertRaisesRegex(ValueError, "'s' may not be None"):
         self._basic_loop(None, lambda i, s: s)
     with self.assertRaisesRegex(ValueError, "'s' must be defined"):
         self._basic_loop(variable_operators.Undefined(''), lambda i, s: s)