Ejemplo n.º 1
0
def optimize_concrete_function(
    concrete_function: function.ConcreteFunction,
    strip_control_dependencies: bool) -> wrap_function.WrappedFunction:
  """Returns optimized function with same signature as `concrete_function`."""
  wrapped_fn = wrap_function.WrappedFunction(
      concrete_function.graph,
      variable_holder=wrap_function.VariableHolder(share_variables=True))
  fetches = concrete_function.structured_outputs
  if strip_control_dependencies:
    flat_outputs, _ = tf2_utils.strip_and_get_tensors_and_control_dependencies(
        tf.nest.flatten(fetches, expand_composites=True))
    fetches = tf.nest.pack_sequence_as(
        concrete_function.structured_outputs,
        flat_outputs,
        expand_composites=True)
  result = wrapped_fn.prune(
      feeds=concrete_function.inputs,
      fetches=fetches,
      input_signature=concrete_function.structured_input_signature)
  # TODO(b/163329414): Remove once `prune` retains shape information for all
  # components.
  for original_out, pruned_out in zip(concrete_function.outputs,
                                      result.outputs):
    pruned_out.set_shape(original_out.get_shape())
  return result
Ejemplo n.º 2
0
def _optimize_concrete_function(
    concrete_function: function.ConcreteFunction
) -> wrap_function.WrappedFunction:
    """Return an optimized function with the same signature as `concrete_function`."""
    wrapped_fn = wrap_function.WrappedFunction(
        concrete_function.graph,
        variable_holder=wrap_function.VariableHolder(share_variables=True))
    result = wrapped_fn.prune(
        feeds=concrete_function.inputs,
        fetches=concrete_function.structured_outputs,
        input_signature=concrete_function.structured_input_signature)
    # TODO(b/178837353): Remove once `prune` retains shape information for all
    # components.
    for original_out, pruned_out in zip(concrete_function.outputs,
                                        result.outputs):
        pruned_out.set_shape(original_out.get_shape())
    return result