def test_auto_cotrol_deps(self): v = variables.Variable(1) with function_wrappers.FunctionScope(False, None, True) as scope: v.assign(2) op = scope.mark_return_value(constant_op.constant(1)) self.evaluate(op) self.assertEqual(self.evaluate(v.read_value()), 2)
def test_name_scope(self): if context.executing_eagerly(): self.skipTest('Tensor names are disabled in eager') with function_wrappers.FunctionScope(True, 'test_name', False): t = constant_op.constant(1) self.assertIn('test_name', t.name)
def test_auto_control_deps(self): v = variables.Variable(1) with function_wrappers.FunctionScope( '_', None, converter.ConversionOptions(optional_features=converter.Feature .AUTO_CONTROL_DEPS)) as scope: v.assign(2) op = scope.ret(constant_op.constant(1), True) self.evaluate(op) self.assertEqual(self.evaluate(v.read_value()), 2)
def test_name_scope(self): if context.executing_eagerly(): self.skipTest('Tensor names are disabled in eager') with function_wrappers.FunctionScope( 'test_name', None, converter.ConversionOptions( optional_features=converter.Feature.NAME_SCOPES)): t = constant_op.constant(1) self.assertIn('test_name', t.name)
def test_basic(self): def test_fn(a): assert a, 'testmsg' with ops.Graph().as_default(): with self.converted(test_fn, (function_scopes, asserts), {}) as result: with function_wrappers.FunctionScope( False, None, use_auto_deps=True) as scope: result.test_fn(constant_op.constant(False)) op = scope.mark_return_value(constant_op.constant(1)) with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 'testmsg'): self.evaluate(op)
def _basic_function_scope(self): return function_wrappers.FunctionScope( 'test_function_name', 'test_scope', # Note: this must match the name in the `with` statement. converter.ConversionOptions())
def test_all_disabled(self): with function_wrappers.FunctionScope(None, None, converter.STANDARD_OPTIONS): t = constant_op.constant(1) self.assertEqual(self.evaluate(t), 1)
def test_all_disabled(self): with function_wrappers.FunctionScope(False, None, False): t = constant_op.constant(1) self.assertEqual(self.evaluate(t), 1)