def optimize_concrete_function( concrete_function: function.ConcreteFunction, strip_control_dependencies: bool) -> wrap_function.WrappedFunction: """Returns optimized function with same signature as `concrete_function`.""" wrapped_fn = wrap_function.WrappedFunction( concrete_function.graph, variable_holder=wrap_function.VariableHolder(share_variables=True)) fetches = concrete_function.structured_outputs if strip_control_dependencies: flat_outputs, _ = tf2_utils.strip_and_get_tensors_and_control_dependencies( tf.nest.flatten(fetches, expand_composites=True)) fetches = tf.nest.pack_sequence_as( concrete_function.structured_outputs, flat_outputs, expand_composites=True) result = wrapped_fn.prune( feeds=concrete_function.inputs, fetches=fetches, input_signature=concrete_function.structured_input_signature) # TODO(b/163329414): Remove once `prune` retains shape information for all # components. for original_out, pruned_out in zip(concrete_function.outputs, result.outputs): pruned_out.set_shape(original_out.get_shape()) return result
def __init__(self, config=None, params=None): self._config = config self._params = params self._functions = {} self._variable_holder = wrap_function.VariableHolder( share_variables=True) # Add reference to the variable holder's mapping of variables, which is a # trackable object. self._variables_by_name = self._variable_holder.variables
def testShareVariablesDifferentGraphs(self): def add_v1(x): v = variables.Variable(3, name='v') return v + x def subtract_v1(x): v = variables.Variable(4, name='v') return v - x def different_variable_fn_v1(x): with ops.name_scope('different_scope'): v = variables.Variable(5, name='v') return v * x def increment_variable_v1(x): v = variables.Variable(6, name='v') return v.assign_add(x) signature = [tensor_spec.TensorSpec([], dtypes.int32)] vh = wrap_function.VariableHolder(share_variables=True) new_graph = lambda: wrap_function.WrappedGraph(variable_holder=vh) add = new_graph().wrap_function(add_v1, signature) subtract = new_graph().wrap_function(subtract_v1, signature) different_variable_fn = new_graph().wrap_function( different_variable_fn_v1, signature) increment_variable = new_graph().wrap_function(increment_variable_v1, signature) self.assertEqual(10, add(constant_op.constant(7)).numpy()) self.assertEqual( 35, different_variable_fn(constant_op.constant(7)).numpy()) # Because the variable in add_v1 was created first, its starting value is 3 # instead of the values defined in subtract_v1 or increment_variable_v1. self.assertEqual(-4, subtract(constant_op.constant(7)).numpy()) self.assertEqual(10, increment_variable(constant_op.constant(7)).numpy()) # Check that variable updates self.assertEqual(17, add(constant_op.constant(7)).numpy()) self.assertEqual(3, subtract(constant_op.constant(7)).numpy()) # Sanity check - result from this function shouldn't change. self.assertEqual( 35, different_variable_fn(constant_op.constant(7)).numpy()) self.assertAllEqual({'v', 'different_scope/v'}, set(vh.variables.keys()))
def _optimize_concrete_function( concrete_function: function.ConcreteFunction ) -> wrap_function.WrappedFunction: """Return an optimized function with the same signature as `concrete_function`.""" wrapped_fn = wrap_function.WrappedFunction( concrete_function.graph, variable_holder=wrap_function.VariableHolder(share_variables=True)) result = wrapped_fn.prune( feeds=concrete_function.inputs, fetches=concrete_function.structured_outputs, input_signature=concrete_function.structured_input_signature) # TODO(b/178837353): Remove once `prune` retains shape information for all # components. for original_out, pruned_out in zip(concrete_function.outputs, result.outputs): pruned_out.set_shape(original_out.get_shape()) return result