def test_tensor_undefined_output(self): with self.assertRaisesRegex( ValueError, "'x' must also be initialized in the main branch"): self._basic_cond(lambda: variable_operators.Undefined('x'), lambda: 1) with self.assertRaisesRegex( ValueError, "'x' must also be initialized in the else branch"): self._basic_cond(lambda: 1, lambda: variable_operators.Undefined('s'))
def test_undefined(self): undefined_symbol = variables.Undefined('name') undefined_symbol2 = variables.Undefined('name') self.assertEqual(undefined_symbol.symbol_name, 'name') self.assertEqual(undefined_symbol2.symbol_name, 'name') self.assertNotEqual(undefined_symbol, undefined_symbol2)
def test_undefined_operations(self): undefined_symbol = variables.Undefined('name') self.assertIsInstance(undefined_symbol.foo, variables.Undefined) self.assertIsInstance(undefined_symbol[0], variables.Undefined) self.assertNotIsInstance(undefined_symbol.__class__, variables.Undefined)
def test_tensor_creating_variable(self): def body(): nonlocal i, s i = constant_op.constant(2) s = i**5 def set_state(loop_vars): nonlocal i, s i, s = loop_vars i = variable_operators.Undefined('i') s = constant_op.constant(0) control_flow.while_stmt(test=lambda: math_ops.equal(s, 0), body=body, get_state=lambda: (i, s), set_state=set_state, symbol_names=('i', 's'), opts={}) self.assertEqual(i, 2) self.assertEqual(s, 32) self.assertOpCreated('StatelessWhile') # Check that the temporary staging of the body did not create extra ops. # Node naming is inconsistent between V1 and V2. self.assertGraphContains(r'(while/)?pow$', 1)
def test_tensor_creating_variable_of_dynamic_shape(self): def body(): nonlocal i, s i = array_ops.ones( [random_ops.random_uniform(minval=1, maxval=4, shape=()), 7]) s = math_ops.reduce_sum(i) def set_state(loop_vars): nonlocal i, s i, s = loop_vars i = variable_operators.Undefined('i') s = constant_op.constant(0.0) control_flow.while_stmt( test=lambda: math_ops.equal(s, 0), body=body, get_state=lambda: (i, s), set_state=set_state, symbol_names=('i', 's'), opts={}) self.assertEqual(i[0][0], 1) self.assertGreaterEqual(s, 7) self.assertOpCreated('While') # Not stateless because of the random op.
def test_tensor_failing_to_determine_placeholder(self): class UserType: pass def body(): nonlocal v v = UserType() def set_state(loop_vars): nonlocal v v, = loop_vars v = variable_operators.Undefined('v') with self.assertRaisesRegex( ValueError, re.compile( 'must be defined.*tried to define.*unsupported type', re.DOTALL)): control_flow.while_stmt(test=lambda: constant_op.constant(True), body=body, get_state=lambda: (v, ), set_state=set_state, symbol_names=('v', ), opts={})
def test_tensor_creating_dynamic_shape_variable(self): def body(): nonlocal i, y i += 1 y = random_ops.random_uniform([i]) def set_state(loop_vars): nonlocal i, y i, y = loop_vars i = constant_op.constant(0) y = variable_operators.Undefined('y') control_flow.while_stmt(test=lambda: math_ops.less(i, 3), body=body, get_state=lambda: (i, y), set_state=set_state, symbol_names=('i', 'y'), opts={}) self.assertEqual(i, 3) self.assertLess(y[0], 3)
def test_tensor_creating_dynamic_shape_variable_preserves_shape_invar(self): def body(): nonlocal i, y i += 1 y = array_ops.zeros([1]) def set_state(loop_vars): nonlocal i, y i, y = loop_vars i = constant_op.constant(0) y = variable_operators.Undefined('y') control_flow.while_stmt( test=lambda: math_ops.less(i, 3), body=body, get_state=lambda: (i, y), set_state=set_state, symbol_names=('i', 'y'), opts={'shape_invariants': ((y, tensor_shape.TensorShape([1])),)}) self.evaluate(y)
def test_tensor_failing_to_stage_loop_body(self): def body(): nonlocal i, s i = constant_op.constant(2) raise ValueError('testing') s = i**5 # pylint: disable=unreachable def set_state(loop_vars): nonlocal i, s i, s = loop_vars i = variable_operators.Undefined('i') s = constant_op.constant(0) with self.assertRaisesRegex( ValueError, re.compile('must be defined.*tried to define.*testing', re.DOTALL)): control_flow.while_stmt(test=lambda: math_ops.equal(s, 0), body=body, get_state=lambda: (i, s), set_state=set_state, symbol_names=('i', 's'), opts={})
def test_read_undefined(self): with self.assertRaisesRegex(UnboundLocalError, 'used before assignment'): variables.ld(variables.Undefined('a'))
def test_tensor_illegal_input(self): with self.assertRaisesRegex(ValueError, "'s' may not be None"): self._basic_loop(None, lambda i, s: s) with self.assertRaisesRegex(ValueError, "'s' must be defined"): self._basic_loop(variable_operators.Undefined(''), lambda i, s: s)