def test_function_with_default_bool_input(self, cycles): def func(x, training=False): if training: return 2 * x else: return 7 root = tracking.AutoTrackable() root.f = def_function.function(func) self.assertEqual(20, root.f(constant_op.constant(10), True).numpy()) self.assertEqual(7, root.f(constant_op.constant(1)).numpy()) self.assertEqual(2, root.f(constant_op.constant(1), True).numpy()) imported = self.cycle(root, cycles) self.assertEqual(4, imported.f(constant_op.constant(2), True).numpy()) self.assertEqual(7, imported.f(constant_op.constant(2)).numpy())
def test_prefer_specific_trace(self, cycles): @def_function.function(autograph=False) def func(a): if isinstance(a, int): return a else: return a + 1 self.assertAllEqual(2, func(2).numpy()) self.assertAllEqual(3, func(constant_op.constant(2)).numpy()) root = tracking.AutoTrackable() root.f = func root = self.cycle(root, cycles) self.assertAllEqual(2, root.f(2).numpy()) self.assertAllEqual(4, root.f(3).numpy()) self.assertAllEqual(3, root.f(constant_op.constant(2)).numpy()) self.assertAllEqual(4, root.f(constant_op.constant(3)).numpy())
def _create_unsupported_saved_model(self): root = tracking.AutoTrackable() root.w = variables.Variable(tf.random.uniform([2, 2])) @def_function.function def exported_function(x): root.x = constant_op.constant([[37.0, -23.0], [1.0, 4.0]]) root.y = tf.matmul(root.x, root.w) # unsupported op: linalg.diag root.z = tf.linalg.diag(root.y) return root.z * x root.f = exported_function to_save = root.f.get_concrete_function( tensor_spec.TensorSpec([], dtypes.float32)) save_dir = os.path.join(self._tmp_dir, SAVED_MODEL_DIR) save(root, save_dir, to_save)
def test_non_strict_predicate(self): class NonStrictPredicateClass(tracking.AutoTrackable): pass registration.register_checkpoint_saver( name="NonStrictPredicate", predicate=lambda x: isinstance(x, NonStrictPredicateClass), save_fn=lambda **kwargs: [], restore_fn=lambda **kwargs: None, strict_predicate_restore=False) root = NonStrictPredicateClass() ckpt_path = os.path.join(self.get_temp_dir(), "ckpt") util.Checkpoint(root).write(ckpt_path) root2 = tracking.AutoTrackable() # This should run without throwing an error. util.Checkpoint(root2).read(ckpt_path)
def test_captures_unreachable_variable(self): root = tracking.AutoTrackable() unreachable_variable = variables.Variable([5.0, 2.0]) root.reachable_variable = variables.Variable([1.0, 3.0]) @def_function.function def increase_variable(x): return 2 * unreachable_variable * x + root.reachable_variable root.f = increase_variable self.assertAllEqual([101.0, 83.0], root.f(constant_op.constant([10.0, 20.0])).numpy()) save_dir = os.path.join(self.get_temp_dir(), "saved_model") with self.assertRaisesRegexp(KeyError, "not reachable from root"): save.save(root, save_dir)
def test_strict_predicate(self): class StrictPredicateClass(tracking.AutoTrackable): pass registration.register_checkpoint_saver( name="StrictPredicate", predicate=lambda x: isinstance(x, StrictPredicateClass), save_fn=lambda **kwargs: [], restore_fn=lambda **kwargs: None, strict_predicate_restore=True) root = StrictPredicateClass() ckpt_path = os.path.join(self.get_temp_dir(), "ckpt") util.Checkpoint(root).write(ckpt_path) root2 = tracking.AutoTrackable() with self.assertRaisesRegex(ValueError, "saver cannot be used"): util.Checkpoint(root2).read(ckpt_path)
def test_partial_with_passed_fn_as_default(self, cycles): # TODO(b/124441704): Figure out the story for FunctionSpec vs partial. self.skipTest("Partial does not work for serialization.") def f(x, y): return x(3) + y def my_func(a): return 2 * a func = def_function.function(functools.partial(f, my_func)) root = tracking.AutoTrackable() root.f = func self.assertEqual(root.f(constant_op.constant(3)).numpy(), 9) root = self.cycle(root, cycles) self.assertEqual(root.f(constant_op.constant(3)).numpy(), 9)
def test_method_save_list_func(self): root = tracking.AutoTrackable() @def_function.function def case_fn(x): branch_index = constant_op.constant(1) branches = [lambda: x, lambda: x + 1] case_out = control_flow_ops.switch_case(branch_index, branches) return case_out root.f = def_function.function( lambda x: 2. * case_fn(x), input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) root.f(constant_op.constant(1.)) save_dir = os.path.join(self.get_temp_dir(), "saved_model") save.save(root, save_dir, root.f) self.assertEqual({"output_0": 4.}, _import_and_infer(save_dir, {"x": 1.}))
def testConstModel(self): """Test a basic model with functions to make sure functions are inlined.""" input_data = constant_op.constant(1., shape=[1]) root = tracking.AutoTrackable() root.f = def_function.function(lambda x: 2. * x) to_save = root.f.get_concrete_function(input_data) save_dir = os.path.join(self.get_temp_dir(), 'saved_model') save(root, save_dir, to_save) # Convert model and ensure model is not None. converter = lite.TFLiteConverterV2.from_saved_model(save_dir) tflite_model = converter.convert() # Check values from converted model. expected_value = root.f(input_data) actual_value = self._evaluateTFLiteModel(tflite_model, [input_data]) self.assertEqual(expected_value.numpy(), actual_value)
def test_initialize_with_root_object_and_kwargs(self): model = self._create_trackable() model.v.assign(3.) separate_variable = variables_lib.Variable(5.) with self.assertRaisesRegex(ValueError, "root.v already exists"): trackable_utils.Checkpoint(model, v=separate_variable) checkpoint = trackable_utils.Checkpoint( model, separate_variable=separate_variable) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") save_path = checkpoint.save(checkpoint_prefix) # Case 1: Loading checkpoint with same configuration. new_model = self._create_trackable() separate_variable = variables_lib.Variable(1.) load_checkpoint = trackable_utils.Checkpoint( new_model, separate_variable=separate_variable) load_checkpoint.restore(save_path).assert_consumed() self.assertEqual(self.evaluate(new_model.v), 3) self.assertEqual(self.evaluate(separate_variable), 5) self.assertEqual(self.evaluate(load_checkpoint.save_counter), 1) # Case 2: Loading checkpoint where v and separate_variable are swapped: # v is not attached to the root, while separate variable is attached to root new_model = tracking.AutoTrackable() new_model.separate_variable = variables_lib.Variable(200.) v = variables_lib.Variable(100.) load_checkpoint = trackable_utils.Checkpoint(new_model, v=v) load_checkpoint.restore(save_path).assert_consumed() self.assertEqual(self.evaluate(v), 3) self.assertEqual(self.evaluate(new_model.separate_variable), 5) self.assertEqual(self.evaluate(load_checkpoint.save_counter), 1) # Case 3: Loading checkpoint where no root object is specified separate_variable = variables_lib.Variable(200.) v = variables_lib.Variable(100.) load_checkpoint = trackable_utils.Checkpoint( v=v, separate_variable=separate_variable) load_checkpoint.restore(save_path).assert_consumed() self.assertEqual(self.evaluate(v), 3) self.assertEqual(self.evaluate(new_model.separate_variable), 5) self.assertEqual(self.evaluate(load_checkpoint.save_counter), 1)
def test_control_outputs(self, cycles): exported = tracking.AutoTrackable() exported.v = variables.Variable(1.) exported.f = def_function.function( lambda: exported.v.assign(2., name="should_be_control_output")) exported_graph = exported.f.get_concrete_function().graph self.assertIn( exported_graph.get_operation_by_name("should_be_control_output"), exported_graph.control_outputs) imported = self.cycle(exported, cycles) # Calling get_concrete_function wraps in a second call operation; we want to # inspect the original function body for the control output; digging into # graph.as_graph_def() and its FunctionDefLibrary is another option. imported_concrete, = imported.f._concrete_functions imported_graph = imported_concrete.graph self.assertIn( imported_graph.get_operation_by_name("should_be_control_output"), imported_graph.control_outputs)
def testTxtSignatureDefs(self): with tempfile.TemporaryDirectory() as tmp_dir: @tf.function(input_signature=[ tf.TensorSpec(shape=None, dtype=tf.float32), tf.TensorSpec(shape=None, dtype=tf.float32) ]) def add(a, b): return {'add_result': tf.add(a, b)} @tf.function(input_signature=[ tf.TensorSpec(shape=None, dtype=tf.float32), tf.TensorSpec(shape=None, dtype=tf.float32) ]) def sub(x, y): return {'sub_result': tf.subtract(x, y)} root = tracking.AutoTrackable() root.f1 = add.get_concrete_function() root.f2 = sub.get_concrete_function() tf.saved_model.save(root, tmp_dir, signatures={ 'add': root.f1, 'sub': root.f2 }) converter = tf.lite.TFLiteConverter.from_saved_model(tmp_dir) fb_model = converter.convert() mock_stdout = io.StringIO() with test.mock.patch.object(sys, 'stdout', mock_stdout): analyzer.ModelAnalyzer.analyze(model_content=fb_model) txt = mock_stdout.getvalue() self.assertIn('Your TFLite model has ‘2’ signature_def(s).', txt) self.assertIn("Signature#0 key: 'add'", txt) self.assertIn(" 'a' : T#1", txt) self.assertIn(" 'b' : T#0", txt) self.assertIn(" 'add_result' : T#2", txt) self.assertIn("Signature#1 key: 'sub'", txt) self.assertIn(" 'x' : T#1_1", txt) self.assertIn(" 'y' : T#1_0", txt) self.assertIn(" 'sub_result' : T#1_2", txt)
def test_revived_concrete_function_tensorspec_kwargs(self, cycles): @def_function.function def func(*args): x, y = args return x * (y + 1.) root = tracking.AutoTrackable() root.f = func.get_concrete_function( tensor_spec.TensorSpec([], dtypes.float32, name="x"), tensor_spec.TensorSpec([], dtypes.float32, name="y")) self.assertEqual( 8., root.f(y=constant_op.constant(3.), x=constant_op.constant(2.)).numpy()) imported = self.cycle(root, cycles, signatures={}) self.assertEqual( 8., imported.f(y=constant_op.constant(3.), x=constant_op.constant(2.)).numpy())
def test_partial(self, cycles): # TODO(vbardiovsky): Figure out the story for FunctionSpec vs partial vs # input_signature. self.skipTest("Partial does not work for serialization.") def f(x, y): return x + y func = def_function.function( functools.partial(f, x=array_ops.zeros([1]), y=array_ops.zeros([1]))) root = tracking.AutoTrackable() root.f = func self.assertAllEqual(root.f(), [0.0]) root = self.cycle(root, cycles) self.assertAllEqual(root.f(), [0.0])
def _create_saved_model_v2_complex64(save_dir): """Test a TF V2 model with complex dtype. Args: save_dir: directory name of where the saved model will be stored. """ input_data = constant_op.constant(1., shape=[1]) root = tracking.AutoTrackable() root.v1 = variables.Variable(3 + 1j, dtype=tf.complex64) root.f = def_function.function(lambda x: tf.complex(x, x) + root.v1) to_save = root.f.get_concrete_function(input_data) save(root, save_dir, to_save) return { "async": False, "inputs": { "x": {"value": [1], "shape": [1], "dtype": 'float32'}}, "outputs": { "Identity:0": {"value": [4, 2], "shape": [1], "dtype": "complex64"}}}
def testCheckpointStateChangingVarList(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") obj = tracking.AutoTrackable() obj.var = variable_scope.get_variable(name="v", initializer=0.) self.evaluate(trackable_utils.gather_initializers(obj)) checkpoint = trackable_utils.Checkpoint(obj=obj) looped_variables = [] for iteration in range(10): new_variable = resource_variable_ops.ResourceVariable(iteration) self.evaluate(new_variable.initializer) setattr(checkpoint, "var_%d" % iteration, new_variable) checkpoint.save(checkpoint_prefix) looped_variables.append(new_variable) expected_filenames = ["checkpoint"] # We've copied the saver each time, but checkpoint management should still # be consistent. Nothing gets deleted. for checkpoint_number in range(1, 11): expected_filenames.append("ckpt-%d.index" % (checkpoint_number, )) self.assertEmpty( set(expected_filenames) - set(os.listdir(checkpoint_directory))) self.assertEqual( checkpoint_prefix + "-10", checkpoint_management.latest_checkpoint(checkpoint_directory)) # The checkpoint list only contains the most recent checkpoint, but they're # all on disk. This means we won't eventually run into proto size limits. self.assertEqual([checkpoint_prefix + "-10"], (checkpoint_management.get_checkpoint_state( checkpoint_directory).all_model_checkpoint_paths)) for v in looped_variables: self.evaluate(v.assign(314)) checkpoint.restore(checkpoint_prefix + "-6").run_restore_ops() self.assertEqual(314, self.evaluate(checkpoint.var_9)) self.assertEqual(314, self.evaluate(checkpoint.var_8)) self.assertEqual(314, self.evaluate(checkpoint.var_6)) self.assertEqual(5, self.evaluate(checkpoint.var_5)) self.assertEqual(1, self.evaluate(checkpoint.var_1)) self.assertEqual(0, self.evaluate(checkpoint.var_0)) checkpoint.restore(checkpoint_prefix + "-10").run_restore_ops() self.assertEqual(9, self.evaluate(checkpoint.var_9)) self.assertEqual(8, self.evaluate(checkpoint.var_8)) self.assertEqual(1, self.evaluate(checkpoint.var_1)) self.assertEqual(0, self.evaluate(checkpoint.var_0))
def testConstSavedModel(self): """Test a basic model with functions to make sure functions are inlined.""" input_data = {"x": constant_op.constant(1., shape=[1])} root = tracking.AutoTrackable() root.f = def_function.function(lambda x: 2. * x) to_save = root.f.get_concrete_function(input_data["x"]) save_dir = os.path.join(self.get_temp_dir(), "saved_model") save(root, save_dir, to_save) saved_model = load(save_dir) input_func = saved_model.signatures["serving_default"] variable_graph_def = input_func.graph.as_graph_def() self.assertEqual(0, self._getNumVariables(variable_graph_def)) self.assertTrue(variable_graph_def.library.function) output_func = convert_to_constants.convert_variables_to_constants_v2( input_func) self._testConvertedFunction(root, root.f, output_func, input_data)
def test_weakref_root(self): root = tracking.AutoTrackable() root.v = variables_lib.Variable(1) ref = root.v.ref() ckpt = trackable_utils.Checkpoint(root=weakref.ref(root)) save_path = ckpt.save(os.path.join(self.get_temp_dir(), "ckpt")) root.v.assign(2) ckpt.restore(save_path) self.assertEqual(root.v.numpy(), 1) del root # Verifying if the variable is only referenced from `ref`. # We expect the reference counter to be 1, but `sys.getrefcount` reports # one higher reference counter because a temporary is created when we call # sys.getrefcount(). Hence check if the number returned is 2. # https://docs.python.org/3/library/sys.html#sys.getrefcount self.assertEqual(sys.getrefcount(ref.deref()), 2)
def test_shapes_available(self, cycles): @def_function.function(input_signature=[ tensor_spec.TensorSpec([None, 3], dtypes.int32), tensor_spec.TensorSpec([None, 2], dtypes.int32) ]) def func(x, y): return array_ops.concat([x, y], axis=1) root = tracking.AutoTrackable() root.f = func root = self.cycle(root, cycles) imported_graph = root.f.get_concrete_function().graph input_x, input_y = imported_graph.inputs self.assertEqual([None, 3], input_x.shape.as_list()) self.assertEqual([None, 2], input_y.shape.as_list()) output, = imported_graph.outputs self.assertEqual([None, 5], output.shape.as_list())
def test_revived_concrete_function_kwargs(self, cycles): @def_function.function def func(x, y): return x * (y + 1.) root = tracking.AutoTrackable() root.f = func.get_concrete_function( tensor_spec.TensorSpec([], dtypes.float32), tensor_spec.TensorSpec([], dtypes.float32)) self.assertEqual( 8., root.f(y=constant_op.constant(3.), x=constant_op.constant(2.)).numpy()) # TODO(andresp): Fix exporting of loaded concrete functions as signatures. imported = self.cycle(root, cycles, signatures={}) self.assertEqual( 8., imported.f(y=constant_op.constant(3.), x=constant_op.constant(2.)).numpy())
def test_save_variable_devices(self, save_devices, meta_graph_only): context._reset_context() cpus = context.context().list_physical_devices("CPU") if len(cpus) == 1: context.context().set_logical_device_configuration( cpus[0], [ context.LogicalDeviceConfiguration(), context.LogicalDeviceConfiguration() ]) context.ensure_initialized() root = tracking.AutoTrackable() with ops.device("CPU:0"): root.v0 = variables.Variable(1., name="v0") with ops.device("CPU:1"): root.v1 = variables.Variable(1., name="v1") options = save_options.SaveOptions( experimental_variable_policy=save_devices) file_name = os.path.join(self.get_temp_dir(), "saved_model") if meta_graph_only: save.export_meta_graph(obj=root, filename=file_name, options=options) else: save.save(obj=root, export_dir=file_name, options=options) graph_def = None if meta_graph_only: graph_def = meta_graph.read_meta_graph_file(file_name).graph_def else: graph_def = loader_impl.parse_saved_model( file_name).meta_graphs[0].graph_def v0 = next((n for n in graph_def.node if n.name == "v0"), None) v1 = next((n for n in graph_def.node if n.name == "v1"), None) self.assertIsNotNone(v0) self.assertIsNotNone(v1) if save_devices == save_options.VariablePolicy.SAVE_VARIABLE_DEVICES: self.assertIn("CPU:0", v0.device) self.assertIn("CPU:1", v1.device) else: self.assertEmpty(v0.device) self.assertEmpty(v1.device)
def _create_saved_model_with_debug_ops(self): root = tracking.AutoTrackable() root.w = variables.Variable(tf.random.uniform([2, 2])) @def_function.function def exported_function(x): root.x = constant_op.constant([[37.0, -23.0], [1.0, 4.0]]) root.y = tf.matmul(root.x, root.w) tf.compat.v1.Print(root.x, [root.x]) tf.compat.v1.Assert(tf.greater(tf.reduce_max(root.x), 0), [root.x]) tf.compat.v1.check_numerics(root.x, 'NaN found') return root.y * x root.f = exported_function to_save = root.f.get_concrete_function( tensor_spec.TensorSpec([], dtypes.float32)) save_dir = os.path.join(self._tmp_dir, SAVED_MODEL_DIR) save(root, save_dir, to_save)
def test_save_composite_tensor_signature(self): @def_function.function( input_signature=[ragged_tensor.RaggedTensorSpec(ragged_rank=2)]) def f(x): return {"output_key": x} root = tracking.AutoTrackable() path = os.path.join(self.get_temp_dir(), "saved_model") inp = ragged_factory_ops.constant([[[1.0, 2.0], [3.0]], [[5.]]]) flat_inp = { "x": constant_op.constant([1., 2., 3., 5]), "x_1": constant_op.constant([0, 2, 3], dtype=dtypes.int64), "x_2": constant_op.constant([0, 2, 3, 4], dtype=dtypes.int64) } save.save(root, path, signatures={"key": f.get_concrete_function()}) # Test that the ragged signature can be loaded back into Python with V2 APIs imported = load.load(path) self.assertAllEqual(inp, imported.signatures["key"](**flat_inp)["output_key"]) graph = ops.Graph() # Try running the signature with V1 APIs. with graph.as_default(), session_lib.Session() as session: meta_graph_def = loader.load(session, [tag_constants.SERVING], path) signature = meta_graph_def.signature_def["key"] feed_dict = {} for arg_name in flat_inp: input_tensor = session.graph.get_tensor_by_name( signature.inputs[arg_name].name) feed_dict[input_tensor] = flat_inp[arg_name].numpy() # Get composite tensor components output_components = ( signature.outputs["output_key"].composite_tensor.components) fetches = {} components_keys = ["x", "x_1", "x_2"] for k, output_tensor_info in zip(components_keys, output_components): fetches[k] = session.graph.get_tensor_by_name(output_tensor_info.name) outputs = session.run(fetches, feed_dict) self.assertAllClose(flat_inp, outputs)
def test_concrete_function_backprop(self, cycles): @def_function.function( input_signature=[tensor_spec.TensorSpec([None], dtypes.float32)]) def func(x): return x ** 2. root = tracking.AutoTrackable() root.f = func.get_concrete_function() def _compute_gradient(function): with backprop.GradientTape() as tape: inp = constant_op.constant(1.) tape.watch(inp) output = function(inp) return tape.gradient(output, inp) self.assertEqual(2., _compute_gradient(root.f).numpy()) # TODO(andresp): Fix exporting of loaded concrete functions as signatures. imported = self.cycle(root, cycles, signatures={}) self.assertEqual(2., _compute_gradient(imported.f).numpy())
def test_captured_constant(self, cycles): const = array_ops.zeros([100]) root = tracking.AutoTrackable() root.f = def_function.function(lambda: const + 1.) root.g = def_function.function(lambda: const + 2.) self.assertAllClose(array_ops.ones([100]), root.f()) self.assertAllClose(2. * array_ops.ones([100]), root.g()) imported = self.cycle(root, cycles) self.assertAllClose(array_ops.ones([100]), imported.f()) self.assertAllClose(2. * array_ops.ones([100]), imported.g()) # TODO(b/123408994): Use the public get_concrete_function. f_concrete = imported.f._list_all_concrete_functions_for_serialization()[0] g_concrete = imported.g._list_all_concrete_functions_for_serialization()[0] self.assertLen(f_concrete.captured_inputs, 1) self.assertLen(g_concrete.captured_inputs, 1) # We should be using the same captured EagerTensor in both functions, not # duplicating the constant. self.assertIs(f_concrete.captured_inputs[0], g_concrete.captured_inputs[0])
def test_named_tuple(self, cycles): class NamedTupleType(collections.namedtuple("NamedTupleType", ["a", "b"])): pass @def_function.function def f(x): return x.a + x.b f.get_concrete_function( NamedTupleType( a=tensor_spec.TensorSpec(None, dtypes.float32, name="a"), b=tensor_spec.TensorSpec(None, dtypes.float32, name="b"))) obj = tracking.AutoTrackable() obj.__call__ = f imported = self.cycle(obj, cycles) self.assertAllClose(3., imported(NamedTupleType(a=constant_op.constant(1.), b=constant_op.constant(2.))))
def test_functions_list(self, cycles): root = tracking.AutoTrackable() v1 = variables.Variable(1.) root.losses = [def_function.function(lambda: math_ops.reduce_sum(v1 ** 2))] root.variables = [v1] @def_function.function def _v2_loss(): if len(root.variables) == 1: v2 = variables.Variable(2.) root.variables.append(v2) return math_ops.reduce_sum(root.variables[1] ** 2) root.losses.append(_v2_loss) self.assertAllClose([1., 4.], [loss() for loss in root.losses]) imported = self.cycle(root, cycles) self.assertAllClose([1., 4.], [loss() for loss in imported.losses]) imported.variables[0].assign(3.) imported.variables[1].assign(4.) self.assertAllClose([9., 16.], [loss() for loss in imported.losses])
def test_concrete_function(self, cycles): @def_function.function( input_signature=[tensor_spec.TensorSpec([None], dtypes.int32)]) def func(x): return 2 * x root = tracking.AutoTrackable() root.f = func.get_concrete_function() self.assertAllEqual([2], root.f(constant_op.constant([1])).numpy()) self.assertAllEqual([2, 4], root.f(constant_op.constant([1, 2])).numpy()) # TODO(andresp): Fix exporting of loaded concrete functions as signatures. imported = self.cycle(root, cycles, signatures={}) self.assertAllEqual([2, 4, 6, 8], imported.f(constant_op.constant([1, 2, 3, 4])).numpy()) self.assertAllEqual([2, 4, 6], imported.f(constant_op.constant([1, 2, 3])).numpy())
def test_load_in_graph_mode(self, cycles): root = tracking.AutoTrackable() root.v1 = variables.Variable(1.) root.v2 = variables.Variable(2.) root.f = def_function.function( lambda x: root.v2 * x, input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) if cycles > 1: root = self.cycle(root, cycles - 1) path = tempfile.mkdtemp(prefix=self.get_temp_dir()) save.save(root, path) with ops.Graph().as_default(): imported = load.load(path) var_v1 = imported.v1 output = imported.f(constant_op.constant(2.)) with monitored_session.MonitoredSession() as sess: self.assertEqual(1.0, sess.run(var_v1)) self.assertEqual(4.0, sess.run(output))
def testScalarModel(self): """Test a basic model with Variables.""" input_data = {"x": constant_op.constant(1., shape=[])} root = tracking.AutoTrackable() root.v1 = variables.Variable(3.) root.v2 = variables.Variable(2.) root.f = def_function.function(lambda x: root.v1 * root.v2 * x) input_func = root.f.get_concrete_function(input_data["x"]) variable_graph_def = input_func.graph.as_graph_def() self.assertEqual(2, self._getNumVariables(variable_graph_def)) output_func = convert_to_constants.convert_variables_to_constants_v2( input_func) constant_graph_def = output_func.graph.as_graph_def() self.assertEqual(0, self._getNumVariables(constant_graph_def)) self.assertFalse( self._hasStatefulPartitionedCallOp(constant_graph_def)) self._testConvertedFunction(root, root.f, output_func, input_data)