def test_create_variables_with_same_name(self): def f(): v1 = variables.Variable(0, name='v') v2 = variables.Variable(1, name='v') return v1, v2 f_wrapped = wrap_function.wrap_function(f, []) self.assertDictEqual( { 'v:0': 0, 'v_1:0': 1 }, # assert that variable names are uniquified { v.name: v.numpy() for v in f_wrapped._variable_holder.variables.values() }) # Uniquification should reset in separate calls to wrap_function. def f2(): v1 = variables.Variable(3, name='v') v2 = variables.Variable(4, name='v') return v1, v2 f_wrapped_2 = wrap_function.wrap_function(f2, []) self.assertDictEqual({ 'v:0': 3, 'v_1:0': 4 }, { v.name: v.numpy() for v in f_wrapped_2._variable_holder.variables.values() })
def test_to_proto(self): v1 = resource_variable_ops.ResourceVariable(2.) saver = functional_saver.MultiDeviceSaver( saveable_object_util.saveable_objects_for_op(v1, "x")) prefix = os.path.join(self.get_temp_dir(), "ckpt") proto_accumulator = [] wrapped = wrap_function.wrap_function( lambda: proto_accumulator.append(saver.to_proto()), signature=()) self.assertEqual(1, len(proto_accumulator)) proto = proto_accumulator[0] save = wrapped.prune( feeds=wrapped.graph.get_tensor_by_name(proto.filename_tensor_name), fetches=wrapped.graph.get_tensor_by_name(proto.save_tensor_name)) restore = wrapped.prune( feeds=wrapped.graph.get_tensor_by_name(proto.filename_tensor_name), fetches=wrapped.graph.get_operation_by_name(proto.restore_op_name)) save_path = save(constant_op.constant(prefix)) v1.assign(1.) restore(constant_op.constant(save_path)) self.assertEqual(2., self.evaluate(v1)) v2 = resource_variable_ops.ResourceVariable(3.) second_saver = functional_saver.MultiDeviceSaver( saveable_object_util.saveable_objects_for_op(v2, "x")) second_saver.restore(save_path) self.assertEqual(2., self.evaluate(v2))
def load(self, tags): """Creates an object from the MetaGraph identified by `tags`.""" meta_graph_def = self.get_meta_graph_def_from_tags(tags) for node in meta_graph_def.graph_def.node: if node.op == "VariableV2": raise NotImplementedError( "Importing a SavedModel which contains RefVariables. This is not " "currently supported. Running tf.enable_resource_variables() " "before creating exported variables will fix this.") load_graph_returns = [None] wrapped = wrap_function.wrap_function( functools.partial(self.load_graph, load_graph_returns, meta_graph_def), signature=[]) saver, = load_graph_returns self.restore_variables(wrapped, saver) with wrapped.graph.as_default(): init_op = loader_impl.get_init_op(meta_graph_def) if init_op is not None: # TODO(allenl): Deal with assets wrapped.prune(feeds=[], fetches=[wrapped.graph.as_graph_element(init_op)])() signature_functions = self._extract_signatures(wrapped, meta_graph_def) root = tracking.AutoCheckpointable() root.signatures = signature_serialization.create_signature_map( signature_functions) root.variables = list(wrapped.graph.variables) return root
def testImportedFunctionsRegistered(self): if test.is_built_with_gpu_support(): self.skipTest( "Disabling this new test due to errors with cuda and rocm") with ops.Graph().as_default() as graph: x = array_ops.placeholder(dtypes.variant, shape=[], name='foo') ds = dataset_ops.from_variant(x, structure=(structure.TensorStructure( dtypes.int32, []))) y = ds.reduce(array_ops.zeros([], dtype=dtypes.int32), lambda p, q: p + q) graph_def = graph.as_graph_def() def fn_to_wrap(a): returned_elements = graph_def_importer.import_graph_def( graph_def, input_map={x.name: a}, return_elements=[y.name]) return returned_elements[0] wrapped_fn = wrap_function.wrap_function( fn_to_wrap, [tensor_spec.TensorSpec((), dtypes.variant)]) ds = dataset_ops.Dataset.from_tensor_slices([10, 20]) v = dataset_ops.to_variant(ds) self.evaluate(wrapped_fn(v))
def testNoArguments(self): def f(): return constant_op.constant(1.) f_wrapped = wrap_function.wrap_function(f, []) self.assertAllEqual(1.0, f_wrapped())
def load(self, tags): """Creates an object from the MetaGraph identified by `tags`.""" meta_graph_def = self.get_meta_graph_def_from_tags(tags) load_graph_returns = [None] wrapped = wrap_function.wrap_function( functools.partial(self.load_graph, load_graph_returns, meta_graph_def), signature=[]) saver, = load_graph_returns self.restore_variables(wrapped, saver) with wrapped.graph.as_default(): init_op = loader_impl.get_init_op(meta_graph_def) root = tracking.AutoTrackable() if init_op is not None: asset_feed_tensors = [] asset_paths = [] for tensor_name, value in loader_impl.get_asset_tensors( self._export_dir, meta_graph_def).items(): asset_feed_tensors.append(wrapped.graph.as_graph_element(tensor_name)) asset_paths.append(tracking.TrackableAsset(value)) init_fn = wrapped.prune( feeds=asset_feed_tensors, fetches=[wrapped.graph.as_graph_element(init_op)]) initializer = _Initializer(init_fn, asset_paths) initializer._initialize() # pylint: disable=protected-access root.initializer = initializer root.asset_paths = asset_paths else: root.asset_paths = [] signature_functions = self._extract_signatures(wrapped, meta_graph_def) root.signatures = signature_serialization.create_signature_map( signature_functions) root.variables = list(wrapped.graph.variables) return root
def testGradientsOfPrune(self): v1 = variables.Variable(2.) v2_holder = [] def f(z): v2 = variables.Variable(3.) v2_holder.append(v2) return array_ops.identity(v1 * v2 * z, 'fetch') f_wrapped = wrap_function.wrap_function( f, [tensor_spec.TensorSpec((), dtype=dtypes.float32)]) x = constant_op.constant(1.) with backprop.GradientTape() as tape: tape.watch(x) out = f_wrapped(x) grads = tape.gradient(out, [x, v1, v2_holder[0]]) self.assertAllEqual(6.0, out) self.assertAllEqual([6.0, 3.0, 2.0], grads) pruned = f_wrapped.prune( feeds=f_wrapped.inputs, fetches=f_wrapped.graph.get_tensor_by_name('fetch:0')) x = constant_op.constant(1.) with backprop.GradientTape() as tape: tape.watch(x) out = pruned(x) grads = tape.gradient(out, [x, v1, v2_holder[0]]) self.assertAllEqual(6.0, out) self.assertAllEqual([6.0, 3.0, 2.0], grads)
def load(self, tags): """Creates an object from the MetaGraph identified by `tags`.""" meta_graph_def = self.get_meta_graph_def_from_tags(tags) load_graph_returns = [None] wrapped = wrap_function.wrap_function( functools.partial(self.load_graph, load_graph_returns, meta_graph_def), signature=[]) saver, = load_graph_returns self.restore_variables(wrapped, saver) with wrapped.graph.as_default(): init_op = loader_impl.get_init_op(meta_graph_def) root = tracking.AutoTrackable() if init_op is not None: asset_feed_tensors = [] asset_paths = [] for tensor_name, value in loader_impl.get_asset_tensors( self._export_dir, meta_graph_def).items(): asset_feed_tensors.append(wrapped.graph.as_graph_element(tensor_name)) asset_paths.append(tracking.TrackableAsset(value)) init_fn = wrapped.prune( feeds=asset_feed_tensors, fetches=[wrapped.graph.as_graph_element(init_op)]) initializer = _Initializer(init_fn, asset_paths) initializer.initialize() root.initializer = initializer root.asset_paths = asset_paths else: root.asset_paths = [] signature_functions = self._extract_signatures(wrapped, meta_graph_def) root.signatures = signature_serialization.create_signature_map( signature_functions) root.variables = list(wrapped.graph.variables) return root
def wrap_and_execute(self): if context.executing_eagerly(): wrapped = wrap_function.wrap_function(graph_function, [self]) # use the wrapped graph function wrapped() else: # use the original function graph_function(self)
def testCollectionsIsolation(self): v1 = variables.Variable(2.) v2_holder = [] def f(): v2 = variables.Variable(3.) v2_holder.append(v2) ops.add_to_collection(ops.GraphKeys.LOSSES, v2 * constant_op.constant(3.)) return array_ops.identity(v1 * v2 * constant_op.constant(1.), 'fetch') f_wrapped = wrap_function.wrap_function(f, []) self.assertAllEqual(6.0, f_wrapped()) self.assertEqual( len(f_wrapped.graph.get_collection(ops.GraphKeys.LOSSES)), 1) f_var_collection = f_wrapped.graph.get_collection( ops.GraphKeys.TRAINABLE_VARIABLES) self.assertEqual(len(f_var_collection), 1) self.assertIs(f_var_collection[0], v2_holder[0]) v3_holder = [] def g(): v3 = variables.Variable(4.) v3_holder.append(v3) ops.add_to_collection(ops.GraphKeys.LOSSES, v3 * constant_op.constant(3.)) return array_ops.identity(v1 * v3 * constant_op.constant(1.), 'fetch') g_wrapped = wrap_function.wrap_function(g, []) self.assertAllEqual(8.0, g_wrapped()) self.assertEqual( len(g_wrapped.graph.get_collection(ops.GraphKeys.LOSSES)), 1) g_var_collection = g_wrapped.graph.get_collection( ops.GraphKeys.TRAINABLE_VARIABLES) self.assertEqual(len(g_var_collection), 1) self.assertIs(g_var_collection[0], v3_holder[0]) # Both have only one value, and their values aren't equal. So no sharing. self.assertNotEqual( g_wrapped.graph.get_collection(ops.GraphKeys.LOSSES), f_wrapped.graph.get_collection(ops.GraphKeys.LOSSES))
def load(self, tags): """Creates an object from the MetaGraph identified by `tags`.""" meta_graph_def = self.get_meta_graph_def_from_tags(tags) load_graph_returns = [None] wrapped = wrap_function.wrap_function(functools.partial( self.load_graph, load_graph_returns, meta_graph_def), signature=[]) saver, = load_graph_returns self.restore_variables(wrapped, saver) with wrapped.graph.as_default(): init_op = loader_impl.get_init_op( meta_graph_def ) or monitored_session.Scaffold.default_local_init_op() # Add a dummy Tensor we know we can fetch to add control dependencies to. init_anchor = constant_op.constant(0., name="dummy_fetch") root = tracking.AutoTrackable() asset_feed_tensors = [] asset_paths = [] for tensor_name, value in loader_impl.get_asset_tensors( self._export_dir, meta_graph_def).items(): asset_feed_tensors.append( wrapped.graph.as_graph_element(tensor_name)) asset_paths.append(tracking.TrackableAsset(value)) init_fn = wrapped.prune( feeds=asset_feed_tensors, fetches=[init_anchor, wrapped.graph.as_graph_element(init_op)]) initializer = _Initializer(init_fn, asset_paths) # pylint: disable=protected-access local_init_op, _ = initializer._initialize() # pylint: enable=protected-access with ops.init_scope(): if not context.executing_eagerly(): ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, local_init_op) for variable in wrapped.graph.get_collection_ref( ops.GraphKeys.LOCAL_VARIABLES): # pylint: disable=protected-access variable._initializer_op = local_init_op # pylint: enable=protected-access root.initializer = initializer root.asset_paths = asset_paths signature_functions = self._extract_signatures(wrapped, meta_graph_def) root.signatures = signature_serialization.create_signature_map( signature_functions) root.variables = list(wrapped.graph.variables) root.tensorflow_version = ( meta_graph_def.meta_info_def.tensorflow_version) root.tensorflow_git_version = ( meta_graph_def.meta_info_def.tensorflow_git_version) root.graph = wrapped.graph root.prune = wrapped.prune return root
def test_updates_in_wrap_function(self): def my_func(): layer = normalization.BatchNormalization() x = array_ops.ones((10, 1)) y = layer(x, training=True) # Updates should be tracked in a `wrap_function`. self.assertLen(layer.updates, 2) return y wrapped_fn = wrap_function.wrap_function(my_func, []) wrapped_fn()
def testPruneOperations(self): v = variables.Variable(0) def f(): v.assign_add(1, name='increment', read_value=False) f_wrapped = wrap_function.wrap_function(f, []) pruned = f_wrapped.prune( feeds=(), fetches=(f_wrapped.graph.get_operation_by_name('increment'), )) self.assertEqual((None, ), pruned()) self.assertEqual(1, self.evaluate(v)) del f, f_wrapped def f1(): v.assign_add(array_ops.placeholder(shape=[], dtype=dtypes.int32, name='step'), name='increment', read_value=False) return constant_op.constant(1, name='other') f_wrapped = wrap_function.wrap_function(f1, []) increments = f_wrapped.prune( feeds=(f_wrapped.graph.get_tensor_by_name('step:0')), fetches=(f_wrapped.graph.get_operation_by_name('increment'), f_wrapped.graph.get_tensor_by_name('other:0'))) first_output, second_output = increments(constant_op.constant(2)) self.assertEqual(['Placeholder:0', 'Placeholder_1:0'], [t.name for t in increments.inputs]) self.assertIs(None, first_output) self.assertEqual(1, second_output.numpy()) self.assertEqual(3, v.numpy()) does_not_increment = f_wrapped.prune( feeds=(f_wrapped.graph.get_tensor_by_name('step:0')), fetches=f_wrapped.graph.get_tensor_by_name('other:0')) self.assertEqual(1, does_not_increment(constant_op.constant(3)).numpy()) self.assertEqual(3, v.numpy())
def testVariableLifting(self): save_prefix = os.path.join(self.get_temp_dir(), 'meta_graph_test') export_graph = ops.Graph() with export_graph.as_default(): v = variables.Variable(1.) array_ops.identity(v + 1., name='output') saver = saver_lib.Saver([v]) with self.test_session() as session: session.run(v.initializer) saver.save(session, save_prefix) def importer(): saver_lib.import_meta_graph(save_prefix + '.meta') return ops.get_default_graph().as_graph_element('output:0') wrapped = wrap_function.wrap_function(importer, []) lifted_variables = list(wrapped.graph.variables) self.assertLen(lifted_variables, 1) initializer = wrapped.prune([], wrapped.graph.as_graph_element( v.initializer.name)) self.assertEqual(lifted_variables, list(initializer.graph.variables)) self.assertEqual(initializer.graph.external_captures, wrapped.graph.external_captures) @def_function.function def wraps_initializer(): initializer() wraps_initializer() self.assertEqual(1., lifted_variables[0].numpy()) wrapped_initializer_graphdef = ( wraps_initializer.get_concrete_function().graph.as_graph_def()) self._assert_single_captured_variable_argument( wrapped_initializer_graphdef) @def_function.function def wraps_wrapped(): return wrapped() # Verify that the original graph also has the correct signature. wrapped_wrapped_graphdef = ( wraps_wrapped.get_concrete_function().graph.as_graph_def()) self._assert_single_captured_variable_argument( wrapped_wrapped_graphdef) # Now check that the graph runs wrapped, from eager, and when pruned. self.assertAllEqual(wraps_wrapped().numpy(), lifted_variables[0].numpy() + 1.) self.assertAllEqual(wrapped().numpy(), lifted_variables[0].numpy() + 1.) pruned = wrapped.prune([], wrapped.graph.as_graph_element('output:0')) self.assertAllEqual(wrapped().numpy(), pruned().numpy())
def my_function_from_graph_def(graph_def, inputs, outputs, ref_captures): def _imports_graph_def(): importer.import_graph_def(graph_def, name="") wrapped_import = wrap_function.wrap_function(_imports_graph_def, []) import_graph = wrapped_import.graph wrapped_import.graph.reset_captures([ (tensor, import_graph.get_tensor_by_name(placeholder.name)) for tensor, placeholder in ref_captures ]) return wrapped_import.prune( nest.map_structure(import_graph.as_graph_element, inputs), nest.map_structure(import_graph.as_graph_element, outputs))
def test_updates_in_wrap_function(self): with context.eager_mode(): layer = keras.layers.BatchNormalization() def my_func(): x = array_ops.ones((10, 1)) return layer(x, training=True) wrapped_fn = wrap_function.wrap_function(my_func, []) wrapped_fn() # Updates should be tracked in a `wrap_function`. self.assertLen(layer.updates, 2)
def load(self, tags): """Creates an object from the MetaGraph identified by `tags`.""" meta_graph_def = self.get_meta_graph_def_from_tags(tags) load_graph_returns = [None] wrapped = wrap_function.wrap_function( functools.partial(self.load_graph, load_graph_returns, meta_graph_def), signature=[]) saver, = load_graph_returns self.restore_variables(wrapped, saver) with wrapped.graph.as_default(): init_op = loader_impl.get_init_op( meta_graph_def) or monitored_session.Scaffold.default_local_init_op() # Add a dummy Tensor we know we can fetch to add control dependencies to. init_anchor = constant_op.constant(0., name="dummy_fetch") root = tracking.AutoTrackable() asset_feed_tensors = [] asset_paths = [] for tensor_name, value in loader_impl.get_asset_tensors( self._export_dir, meta_graph_def).items(): asset_feed_tensors.append(wrapped.graph.as_graph_element(tensor_name)) asset_paths.append(tracking.TrackableAsset(value)) init_fn = wrapped.prune( feeds=asset_feed_tensors, fetches=[init_anchor, wrapped.graph.as_graph_element(init_op)]) initializer = _Initializer(init_fn, asset_paths) # pylint: disable=protected-access local_init_op, _ = initializer._initialize() # pylint: enable=protected-access with ops.init_scope(): if not context.executing_eagerly(): ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, local_init_op) for variable in wrapped.graph.get_collection_ref( ops.GraphKeys.LOCAL_VARIABLES): # pylint: disable=protected-access variable._initializer_op = local_init_op # pylint: enable=protected-access root.initializer = initializer root.asset_paths = asset_paths signature_functions = self._extract_signatures(wrapped, meta_graph_def) root.signatures = signature_serialization.create_signature_map( signature_functions) root.variables = list(wrapped.graph.variables) root.tensorflow_version = ( meta_graph_def.meta_info_def.tensorflow_version) root.tensorflow_git_version = ( meta_graph_def.meta_info_def.tensorflow_git_version) root.graph = wrapped.graph root.prune = wrapped.prune return root
def testCollectionsIsolation(self): v1 = variables.Variable(2.) v2_holder = [] def f(): v2 = variables.Variable(3.) v2_holder.append(v2) ops.add_to_collection(ops.GraphKeys.LOSSES, v2 * constant_op.constant(3.)) return array_ops.identity(v1 * v2 * constant_op.constant(1.), 'fetch') f_wrapped = wrap_function.wrap_function(f, []) self.assertAllEqual(6.0, f_wrapped()) self.assertEqual( len(f_wrapped.graph.get_collection(ops.GraphKeys.LOSSES)), 1) f_var_collection = f_wrapped.graph.get_collection( ops.GraphKeys.TRAINABLE_VARIABLES) self.assertEqual(len(f_var_collection), 1) self.assertIs(f_var_collection[0], v2_holder[0]) v3_holder = [] def g(): v3 = variables.Variable(4.) v3_holder.append(v3) ops.add_to_collection(ops.GraphKeys.LOSSES, v3 * constant_op.constant(3.)) return array_ops.identity(v1 * v3 * constant_op.constant(1.), 'fetch') g_wrapped = wrap_function.wrap_function(g, []) self.assertAllEqual(8.0, g_wrapped()) self.assertEqual( len(g_wrapped.graph.get_collection(ops.GraphKeys.LOSSES)), 1) g_var_collection = g_wrapped.graph.get_collection( ops.GraphKeys.TRAINABLE_VARIABLES) self.assertEqual(len(g_var_collection), 1) self.assertIs(g_var_collection[0], v3_holder[0]) # Both have only one value, and their values aren't equal. So no sharing. self.assertNotEqual(g_wrapped.graph.get_collection(ops.GraphKeys.LOSSES), f_wrapped.graph.get_collection(ops.GraphKeys.LOSSES))
def testPruneOperations(self): v = variables.Variable(0) def f(): v.assign_add(1, name='increment', read_value=False) f_wrapped = wrap_function.wrap_function(f, []) pruned = f_wrapped.prune( feeds=(), fetches=(f_wrapped.graph.get_operation_by_name('increment'),)) self.assertEqual((None,), pruned()) self.assertEqual(1, self.evaluate(v)) del f, f_wrapped def f1(): v.assign_add( array_ops.placeholder(shape=[], dtype=dtypes.int32, name='step'), name='increment', read_value=False) return constant_op.constant(1, name='other') f_wrapped = wrap_function.wrap_function(f1, []) increments = f_wrapped.prune( feeds=(f_wrapped.graph.get_tensor_by_name('step:0')), fetches=(f_wrapped.graph.get_operation_by_name('increment'), f_wrapped.graph.get_tensor_by_name('other:0'))) first_output, second_output = increments(constant_op.constant(2)) self.assertEqual(['step:0', 'increment/resource:0'], [t.name for t in increments.inputs]) self.assertIs(None, first_output) self.assertEqual(1, second_output.numpy()) self.assertEqual(3, v.numpy()) does_not_increment = f_wrapped.prune( feeds=(f_wrapped.graph.get_tensor_by_name('step:0')), fetches=f_wrapped.graph.get_tensor_by_name('other:0')) self.assertEqual(1, does_not_increment(constant_op.constant(3)).numpy()) self.assertEqual(3, v.numpy())
def testCaptures(self): v1 = variables.Variable(2.) def f(): v2 = variables.Variable(3.) return array_ops.identity(v1 * v2 * constant_op.constant(1.), 'fetch') f_wrapped = wrap_function.wrap_function(f, []) self.assertAllEqual(6.0, f_wrapped()) pruned = f_wrapped.prune( feeds=(), fetches=f_wrapped.graph.get_tensor_by_name('fetch:0')) self.assertAllEqual(6.0, pruned())
def test_create_variables_with_same_name(self): def f(): v1 = variables.Variable(0, name='v') v2 = variables.Variable(1, name='v') return v1, v2 f_wrapped = wrap_function.wrap_function(f, []) self.assertDictEqual( {'v:0': 0, 'v_1:0': 1}, # assert that variable names are uniquified {v.name: v.numpy() for v in f_wrapped._variable_holder.variables.values()}) # Uniquification should reset in separate calls to wrap_function. def f2(): v1 = variables.Variable(3, name='v') v2 = variables.Variable(4, name='v') return v1, v2 f_wrapped_2 = wrap_function.wrap_function(f2, []) self.assertDictEqual( {'v:0': 3, 'v_1:0': 4}, {v.name: v.numpy() for v in f_wrapped_2._variable_holder.variables.values()})
def testDocString(self): def f(x, do_add): v = variables.Variable(5.0) if do_add: op = v.assign_add(x) else: op = v.assign_sub(x) with ops.control_dependencies([op]): return v.read_value() f_add = wrap_function.wrap_function( f, [tensor_spec.TensorSpec((), dtypes.float32), True]) self.assertAllEqual(f_add(1.0), 6.0) self.assertAllEqual(f_add(1.0), 7.0) # Can call tf.compat.v1.wrap_function again to get a new trace, a new set # of variables, and possibly different non-template arguments. f_sub = wrap_function.wrap_function( f, [tensor_spec.TensorSpec((), dtypes.float32), False]) self.assertAllEqual(f_sub(1.0), 4.0) self.assertAllEqual(f_sub(1.0), 3.0)
def testVariableLifting(self): save_prefix = os.path.join(self.get_temp_dir(), 'meta_graph_test') export_graph = ops.Graph() with export_graph.as_default(): v = variables.Variable(1.) array_ops.identity(v + 1., name='output') saver = saver_lib.Saver([v]) with self.test_session() as session: session.run(v.initializer) saver.save(session, save_prefix) def importer(): saver_lib.import_meta_graph(save_prefix + '.meta') return ops.get_default_graph().as_graph_element('output:0') wrapped = wrap_function.wrap_function(importer, []) lifted_variables = list(wrapped.graph.variables) self.assertLen(lifted_variables, 1) initializer = wrapped.prune( [], wrapped.graph.as_graph_element(v.initializer.name)) self.assertEqual(lifted_variables, list(initializer.graph.variables)) self.assertEqual(initializer.graph.external_captures, wrapped.graph.external_captures) @def_function.function def wraps_initializer(): initializer() wraps_initializer() self.assertEqual(1., lifted_variables[0].numpy()) wrapped_initializer_graphdef = ( wraps_initializer.get_concrete_function().graph.as_graph_def()) self._assert_single_captured_variable_argument(wrapped_initializer_graphdef) @def_function.function def wraps_wrapped(): return wrapped() # Verify that the original graph also has the correct signature. wrapped_wrapped_graphdef = ( wraps_wrapped.get_concrete_function().graph.as_graph_def()) self._assert_single_captured_variable_argument(wrapped_wrapped_graphdef) # Now check that the graph runs wrapped, from eager, and when pruned. self.assertAllEqual(wraps_wrapped().numpy(), lifted_variables[0].numpy() + 1.) self.assertAllEqual(wrapped().numpy(), lifted_variables[0].numpy() + 1.) pruned = wrapped.prune([], wrapped.graph.as_graph_element('output:0')) self.assertAllEqual(wrapped().numpy(), pruned().numpy())
def testPruneStatefulOpsFromWrappedFunc(self): v0 = variables.Variable(0) v1 = variables.Variable(0) # When we wrap a function, we expect it to be executed with 'tf.Graph` # rules: it's allowed to prune all ops that are not in transitive fanin of # the fetches. def f(x): v0.assign_add(1, name='increment_v0') v1.assign_add(1, name='increment_v1') return x f_wrapped = wrap_function.wrap_function(f, [1]) self.assertEqual(1, f_wrapped().numpy()) self.assertEqual(0, v0.numpy()) self.assertEqual(0, v1.numpy()) f_wrapped_with_name = wrap_function.wrap_function(f, [2], name='func') self.assertEqual(2, f_wrapped_with_name().numpy()) self.assertEqual(0, v0.numpy()) self.assertEqual(0, v1.numpy())
def testPrune(self): x_in = [] x_out = [] def f(x, y): x_in.append(x) xx = x * x x_out.append(xx) return xx, 2 * y*y f_wrapped = wrap_function.wrap_function( f, [tensor_spec.TensorSpec((), dtypes.float32)] * 2) f_pruned = f_wrapped.prune(x_in[0], [x_out[0]]) self.assertAllEqual(f_pruned(ops.convert_to_tensor(2.0)), [4.0])
def test_operation_returned(self): v = variables.Variable(0) def f(): v.assign(1, read_value=False, name='assign_to_v') f_wrapped = wrap_function.wrap_function(f, []) operation_to_fetch = f_wrapped.graph.get_operation_by_name( 'assign_to_v') f_pruned = f_wrapped.prune([], operation_to_fetch) self.assertEqual( ['assign_to_v'], [operation.name for operation in f_pruned.graph.control_outputs]) self.assertEqual(0, v.numpy()) f_pruned() self.assertEqual(1, v.numpy())
def test_operation_returned(self): v = variables.Variable(0) def f(): v.assign(1, read_value=False, name='assign_to_v') f_wrapped = wrap_function.wrap_function(f, []) operation_to_fetch = f_wrapped.graph.get_operation_by_name('assign_to_v') f_pruned = f_wrapped.prune( [], operation_to_fetch) self.assertEqual( ['assign_to_v'], [operation.name for operation in f_pruned.graph.control_outputs]) self.assertEqual(0, v.numpy()) f_pruned() self.assertEqual(1, v.numpy())
def load(self, tags): """Creates an object from the MetaGraph identified by `tags`.""" meta_graph_def = self.get_meta_graph_def_from_tags(tags) load_graph_returns = [None] wrapped = wrap_function.wrap_function(functools.partial( self.load_graph, load_graph_returns, meta_graph_def), signature=[]) saver, = load_graph_returns self.restore_variables(wrapped, saver) with wrapped.graph.as_default(): init_op = loader_impl.get_init_op(meta_graph_def) if init_op is not None: # TODO(allenl): Deal with assets wrapped.prune(feeds=[], fetches=[wrapped.graph.as_graph_element(init_op)])() signature_functions = self._extract_signatures(wrapped, meta_graph_def) root = tracking.AutoCheckpointable() root.signatures = signature_serialization.create_signature_map( signature_functions) root.variables = list(wrapped.graph.variables) return root
def testPruneCaptures(self): v1 = variables.Variable(2.) def f(): v2 = variables.Variable(3.) return array_ops.identity(v1 * v2 * constant_op.constant(1.), 'fetch') f_wrapped = wrap_function.wrap_function(f, []) self.assertAllEqual(6.0, f_wrapped()) # Test pruning directly on the inputs pruned = f_wrapped.prune( feeds=f_wrapped.inputs, fetches=f_wrapped.graph.get_tensor_by_name('fetch:0')) self.assertAllEqual(6.0, pruned()) # Test pruning with no inputs pruned = f_wrapped.prune( feeds=(), fetches=f_wrapped.graph.get_tensor_by_name('fetch:0')) self.assertAllEqual(6.0, pruned())
def testImportedFunctionsRegistered(self): if test_util.is_gpu_available(): self.skipTest('not a GPU test') with ops.Graph().as_default() as graph: x = array_ops.placeholder(dtypes.variant, shape=[], name='foo') ds = dataset_ops.from_variant(x, structure=(tensor_spec.TensorSpec( [], dtypes.int32))) y = ds.reduce(array_ops.zeros([], dtype=dtypes.int32), lambda p, q: p + q) graph_def = graph.as_graph_def() def fn_to_wrap(a): returned_elements = graph_def_importer.import_graph_def( graph_def, input_map={x.name: a}, return_elements=[y.name]) return returned_elements[0] wrapped_fn = wrap_function.wrap_function( fn_to_wrap, [tensor_spec.TensorSpec((), dtypes.variant)]) ds = dataset_ops.Dataset.from_tensor_slices([10, 20]) v = dataset_ops.to_variant(ds) self.evaluate(wrapped_fn(v))
def testWrapFuncDatasetDevice(self, device_type, dataset_reduce_fn): devices = config.list_logical_devices(device_type=device_type) if not devices: self.skipTest( 'Skip when {} is not detected by TF'.format(device_type)) @def_function.function def comp(): return dataset_reduce_fn(dataset_ops.Dataset.range(10)) graph = comp.get_concrete_function().graph def function_to_wrap(): with ops.device(devices[0].name): return graph_def_importer.import_graph_def( graph.as_graph_def()) with ops.device(devices[0].name): wrapped_noarg_fn = wrap_function.wrap_function(function_to_wrap, signature=[]) wrapped_noarg_fn()
def testPruneRagged(self): x_in = [] x_out = [] def f(x, y): x_in.append(x) xx = x * x x_out.append(xx) return xx, y * y x_spec = ragged_tensor.RaggedTensorSpec([None, None], dtypes.float32) y_spec = tensor_spec.TensorSpec((), dtypes.float32) f_wrapped = wrap_function.wrap_function(f, [x_spec, y_spec]) f_pruned = f_wrapped.prune(x_in[0], x_out[0]) rt = ragged_factory_ops.constant([[1.0, 2.0], [3.0]]) expected = ragged_factory_ops.constant_value([[1.0, 4.0], [9.0]]) # Note: when we call f_pruned, we must pass the RaggedTensor in using # its components, since that's the current convention for how concrete # functions handle structured inputs. self.assertAllEqual(f_pruned(rt.values, rt.row_splits), expected)
def load(self, tags): """Creates an object from the MetaGraph identified by `tags`.""" meta_graph_def = self.get_meta_graph_def_from_tags(tags) load_shared_name_suffix = "_load_{}".format(ops.uid()) functions = function_deserialization.load_function_def_library( meta_graph_def.graph_def.library, load_shared_name_suffix=load_shared_name_suffix) # Replace existing functions in the MetaGraphDef with renamed functions so # we don't have duplicates or name collisions. meta_graph_def.graph_def.library.Clear() for function in functions.values(): meta_graph_def.graph_def.library.function.add().CopyFrom( function.function_def) # We've renamed functions and shared names. We need the same operation on # the GraphDef itself for consistency. for node_def in meta_graph_def.graph_def.node: function_deserialization.fix_node_def( node_def, functions, load_shared_name_suffix, debug_name="MetaGraph import") load_graph_returns = [None] wrapped = wrap_function.wrap_function(functools.partial( self.load_graph, load_graph_returns, meta_graph_def), signature=[]) saver, = load_graph_returns self.restore_variables(wrapped, saver) with wrapped.graph.as_default(): init_op = loader_impl.get_init_op( meta_graph_def ) or monitored_session.Scaffold.default_local_init_op() # Add a dummy Tensor we know we can fetch to add control dependencies to. init_anchor = constant_op.constant(0., name="dummy_fetch") root = tracking.AutoTrackable() asset_feed_tensors = [] asset_paths = [] for tensor_name, value in loader_impl.get_asset_tensors( self._export_dir, meta_graph_def).items(): asset_feed_tensors.append( wrapped.graph.as_graph_element(tensor_name)) asset_paths.append(tracking.Asset(value)) init_fn = wrapped.prune( feeds=asset_feed_tensors, fetches=[init_anchor, wrapped.graph.as_graph_element(init_op)]) initializer = _Initializer(init_fn, asset_paths) # pylint: disable=protected-access local_init_op, _ = initializer._initialize() # pylint: enable=protected-access with ops.init_scope(): if not context.executing_eagerly(): ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, local_init_op) for variable in wrapped.graph.get_collection_ref( ops.GraphKeys.LOCAL_VARIABLES): # pylint: disable=protected-access variable._initializer_op = local_init_op # pylint: enable=protected-access root.initializer = initializer root.asset_paths = asset_paths signature_functions = self._extract_signatures(wrapped, meta_graph_def) root.signatures = signature_serialization.create_signature_map( signature_functions) root.variables = list(wrapped.graph.variables) root.tensorflow_version = ( meta_graph_def.meta_info_def.tensorflow_version) root.tensorflow_git_version = ( meta_graph_def.meta_info_def.tensorflow_git_version) root.graph = wrapped.graph root.prune = wrapped.prune return root