示例#1
0
def optimize_concrete_function(
    concrete_function: function.ConcreteFunction,
    strip_control_dependencies: bool) -> wrap_function.WrappedFunction:
  """Returns optimized function with same signature as `concrete_function`."""
  wrapped_fn = wrap_function.WrappedFunction(
      concrete_function.graph,
      variable_holder=wrap_function.VariableHolder(share_variables=True))
  fetches = concrete_function.structured_outputs
  if strip_control_dependencies:
    flat_outputs, _ = tf2_utils.strip_and_get_tensors_and_control_dependencies(
        tf.nest.flatten(fetches, expand_composites=True))
    fetches = tf.nest.pack_sequence_as(
        concrete_function.structured_outputs,
        flat_outputs,
        expand_composites=True)
  result = wrapped_fn.prune(
      feeds=concrete_function.inputs,
      fetches=fetches,
      input_signature=concrete_function.structured_input_signature)
  # TODO(b/163329414): Remove once `prune` retains shape information for all
  # components.
  for original_out, pruned_out in zip(concrete_function.outputs,
                                      result.outputs):
    pruned_out.set_shape(original_out.get_shape())
  return result
示例#2
0
    def __init__(self, config=None, params=None):
        self._config = config
        self._params = params
        self._functions = {}

        self._variable_holder = wrap_function.VariableHolder(
            share_variables=True)

        # Add reference to the variable holder's mapping of variables, which is a
        # trackable object.
        self._variables_by_name = self._variable_holder.variables
示例#3
0
    def testShareVariablesDifferentGraphs(self):
        def add_v1(x):
            v = variables.Variable(3, name='v')
            return v + x

        def subtract_v1(x):
            v = variables.Variable(4, name='v')
            return v - x

        def different_variable_fn_v1(x):
            with ops.name_scope('different_scope'):
                v = variables.Variable(5, name='v')
            return v * x

        def increment_variable_v1(x):
            v = variables.Variable(6, name='v')
            return v.assign_add(x)

        signature = [tensor_spec.TensorSpec([], dtypes.int32)]
        vh = wrap_function.VariableHolder(share_variables=True)
        new_graph = lambda: wrap_function.WrappedGraph(variable_holder=vh)

        add = new_graph().wrap_function(add_v1, signature)
        subtract = new_graph().wrap_function(subtract_v1, signature)
        different_variable_fn = new_graph().wrap_function(
            different_variable_fn_v1, signature)
        increment_variable = new_graph().wrap_function(increment_variable_v1,
                                                       signature)

        self.assertEqual(10, add(constant_op.constant(7)).numpy())
        self.assertEqual(
            35,
            different_variable_fn(constant_op.constant(7)).numpy())

        # Because the variable in add_v1 was created first, its starting value is 3
        # instead of the values defined in subtract_v1 or increment_variable_v1.
        self.assertEqual(-4, subtract(constant_op.constant(7)).numpy())
        self.assertEqual(10,
                         increment_variable(constant_op.constant(7)).numpy())

        # Check that variable updates
        self.assertEqual(17, add(constant_op.constant(7)).numpy())
        self.assertEqual(3, subtract(constant_op.constant(7)).numpy())

        # Sanity check - result from this function shouldn't change.
        self.assertEqual(
            35,
            different_variable_fn(constant_op.constant(7)).numpy())

        self.assertAllEqual({'v', 'different_scope/v'},
                            set(vh.variables.keys()))
示例#4
0
def _optimize_concrete_function(
    concrete_function: function.ConcreteFunction
) -> wrap_function.WrappedFunction:
    """Return an optimized function with the same signature as `concrete_function`."""
    wrapped_fn = wrap_function.WrappedFunction(
        concrete_function.graph,
        variable_holder=wrap_function.VariableHolder(share_variables=True))
    result = wrapped_fn.prune(
        feeds=concrete_function.inputs,
        fetches=concrete_function.structured_outputs,
        input_signature=concrete_function.structured_input_signature)
    # TODO(b/178837353): Remove once `prune` retains shape information for all
    # components.
    for original_out, pruned_out in zip(concrete_function.outputs,
                                        result.outputs):
        pruned_out.set_shape(original_out.get_shape())
    return result