示例#1
0
    def testCaptureOrdering(self):
        v1 = resource_variable_ops.ResourceVariable(1.0)
        v2 = resource_variable_ops.ResourceVariable(2.0)
        v3 = resource_variable_ops.ResourceVariable(3.0)

        @def_function.function
        def fn():
            return v1 + v2 + v3

        concrete_fn = fn.get_concrete_function()
        original_captures = concrete_fn.graph.captures
        outputs = concrete_fn.graph.outputs

        for _ in range(100):
            g = func_graph.FuncGraph('lifted')

            lift_to_graph.lift_to_graph(outputs,
                                        g,
                                        add_sources=True,
                                        handle_captures=True)
            lifted_captures = g.captures
            self.assertLen(lifted_captures, 3)
            for original_capture, lifted_capture in zip(
                    original_captures.values(), lifted_captures.values()):
                self.assertEqual(original_capture.name, lifted_capture.name)
示例#2
0
    def prune(self, feeds, fetches, name=None, input_signature=None):
        # TODO(b/129646028): Add support for CompositeTensors.
        name = name or "pruned"
        feeds = nest.map_structure(self.graph.as_graph_element, feeds)
        fetches = nest.map_structure(self.graph.as_graph_element, fetches)
        flat_feeds, flat_fetches = nest.flatten(feeds), nest.flatten(fetches)
        for f in flat_feeds:
            if not isinstance(f, ops.Tensor):
                raise ValueError("Feeds must be tensors.")

        # Ignoring all feeds that are captures allows prune to be called
        # using wrapped_func.inputs even when it uses variables
        internal_captures = self.graph.internal_captures
        flat_feeds = [f for f in flat_feeds if f not in internal_captures]

        operation_fetches = []
        for f in flat_fetches:
            if isinstance(f, ops.Operation):
                operation_fetches.append(f)
            elif not isinstance(f, ops.Tensor):
                raise ValueError("Fetches must be tensors or operations.")
        for f in flat_feeds + flat_fetches:
            if f.graph is not self._func_graph:
                raise ValueError(
                    "Can only prune function whose feeds and fetches "
                    "are from this graph (%s). Tensor %s from graph %s" %
                    (self._func_graph, f, f.graph))
        with self._func_graph.as_default():
            pruned_graph = func_graph.FuncGraph(name)
        lift_map = lift_to_graph.lift_to_graph(flat_fetches,
                                               pruned_graph,
                                               sources=flat_feeds +
                                               internal_captures)
        pruned_graph.outputs.extend(lift_map[x] for x in flat_fetches
                                    if isinstance(x, ops.Tensor))
        pruned_graph.control_outputs.extend(
            [lift_map[operation] for operation in operation_fetches])
        for external_capture, internal_capture in self.graph.captures.items():
            pruned_graph.captures[external_capture] = lift_map[
                internal_capture]
        pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds)
        pruned_graph.inputs.extend(pruned_graph.captures.values())

        pruned_graph.variables = self.graph.variables

        def _structured_output_mapping(fetched):
            lifted = lift_map[fetched]
            if isinstance(lifted, ops.Operation):
                return None
            return lifted

        pruned_graph.structured_outputs = nest.map_structure(
            _structured_output_mapping, fetches)
        pruned_graph.structured_input_signature = input_signature
        pruned_fn = WrappedFunction(pruned_graph,
                                    variable_holder=self._variable_holder)
        pruned_fn._num_positional_args = len(flat_feeds)  # pylint: disable=protected-access
        # TODO(kathywu): Enable keyword arguments if an input signature is specified
        pruned_fn._arg_keywords = [tensor.op.name for tensor in flat_feeds]  # pylint: disable=protected-access
        return pruned_fn
 def prune(self, feeds, fetches):
     flat_feeds, flat_fetches = nest.flatten(feeds), nest.flatten(fetches)
     for f in flat_feeds + flat_fetches:
         if not isinstance(f, ops.Tensor):
             raise ValueError("Feeds and fetches must be tensors.")
         if f.graph is not self._func_graph:
             raise ValueError(
                 "Can only prune function whose feeds and fetches "
                 "are from this graph (%s). Tensor %s from graph %s" %
                 (self._func_graph, f, f.graph))
     with self._func_graph.as_default():
         pruned_graph = func_graph.FuncGraph("pruned")
         sink_tensor = array_ops.identity_n(flat_fetches)[0]
     lift_map = lift_to_graph.lift_to_graph(sink_tensor,
                                            pruned_graph,
                                            sources=flat_feeds +
                                            self.graph.internal_captures)
     pruned_graph.outputs.extend(lift_map[x] for x in flat_fetches)
     for external_capture, internal_capture in self.graph.captures.items():
         pruned_graph.captures[external_capture] = lift_map[
             internal_capture]
     pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds)
     pruned_graph.inputs.extend(pruned_graph.captures.values())
     pruned_graph.structured_outputs = nest.map_structure(
         lambda node: lift_map[node], fetches)
     pruned_fn = WrappedFunction(pruned_graph,
                                 variable_holder=self._variable_holder)
     pruned_fn._num_positional_args = len(flat_feeds)  # pylint: disable=protected-access
     pruned_fn._arg_keywords = []  # pylint: disable=protected-access
     return pruned_fn
def wrap_cached_variables(concrete_function):
    """Wraps the concrete function if it uses cached read tensors.

  This function creates a new concrete function that captures variables
  instead of the cached read tensors.

  Args:
    concrete_function: A Concrete function that maybe captures cached read
      tensors.

  Returns:
    A concrete function that wraps the original concrete function, which
    captures variables instead. If the original function did not capture any
    cached values, then the function is not wrapped and the original object is
    returned.
  """
    outer_graph = func_graph_module.FuncGraph("{}_no_cache".format(
        concrete_function.graph.name))
    captures = concrete_function.graph._captures  # pylint: disable=protected-access
    mapped_captures = None
    remapped_captures = {}

    # Update the external captures to use read tensors generated in the outer
    # graph.
    with outer_graph.as_default():
        for capture, placeholder in concrete_function.graph.captures:
            cached_variable = getattr(capture, "_cached_variable", None)
            if cached_variable is None:
                continue
            cached_variable = cached_variable()
            new_cached_value = cached_variable.read_value()
            remapped_captures[id(capture)] = captures[id(capture)]
            captures[id(capture)] = (new_cached_value, placeholder)
            mapped_captures = True

    if not mapped_captures:
        return concrete_function

    inner_concrete = defun.ConcreteFunction(concrete_function.graph)

    def wrap_function(*args):
        return inner_concrete._call_flat(args, inner_concrete.captured_inputs)  # pylint:disable=protected-access

    args = nest.flatten(concrete_function.structured_input_signature,
                        expand_composites=True)
    func_graph_module.func_graph_from_py_func(None,
                                              wrap_function,
                                              args=tuple(args),
                                              kwargs={},
                                              func_graph=outer_graph)
    fn = defun.ConcreteFunction(outer_graph,
                                function_spec=concrete_function._function_spec)  # pylint: disable=protected-access
    fn._arg_keywords = concrete_function._arg_keywords  # pylint: disable=protected-access
    fn._num_positional_args = concrete_function._num_positional_args  # pylint: disable=protected-access

    # Return the captures to their original values
    for key, capture in remapped_captures.items():
        captures[key] = capture
    return fn
 def test_to_placeholder(self, shape, batch_size, ragged_rank):
     inp = layers.Input(shape=shape, batch_size=batch_size, ragged=True)
     self.assertEqual(inp.ragged_rank, ragged_rank)
     self.assertAllEqual(inp.shape, [batch_size] + list(shape))
     with func_graph.FuncGraph('test').as_default():
         placeholder = inp._to_placeholder()
         self.assertEqual(placeholder.ragged_rank, ragged_rank)
         self.assertAllEqual(placeholder.shape, [batch_size] + list(shape))
示例#6
0
def _try_handling_undefineds(body, get_state, set_state, init_vars, nulls,
                             symbol_names):
    """Makes a best-effort attempt to substitute undefineds with placeholders.

  Note: this substitution requires two things to happen:
   1. the types of loop variables could be inferred (usually by staging one
       iteration)
   2. these types could be replaced by placeholders (e.g. zero values, for
       tensors.

  Args:
    body: a function representing the loop body. See while_stmt.
    get_state: state getter for the loop statement. See while_stmt.
    set_state: state getter for the loop statement. See while_stmt.
    init_vars: loop variables before entering the loop. See while_stmt.
    nulls: list of boolean flags indicating whether the corresponding loop
        var is None or undefined.
    symbol_names: list of loop variable names. See while_stmt.
  Returns:
    A tuple (success, new_init_vars). success is a boolean flag indicating
    whether types could be successfully inferred (step 1 above). new_init_vars
    contains the loop vars, with None or undefined values replaced by
    placeholders, where possible (step 2 above).
  """
    state_modified = False

    try:
        # Stage an iteration of the loop body in a temporary graph.
        with func_graph.FuncGraph('tmp').as_default():
            # This call to set_state helps report nicer error messages when symbols
            # are inconsistently used.
            set_state(init_vars)
            state_modified = True

            body()
            first_iter_vars = get_state()
    except (UnboundLocalError, TypeError, ValueError, KeyError):
        # Fall back to the old functionality. It will likely result in an input
        # validation failure.
        first_iter_vars = None
    finally:
        if state_modified:
            set_state(init_vars)

    if first_iter_vars is not None:
        # Note: the actual placeholder value doesn't matter, because as the staging
        # proved, it will be replaced by an actual value before being read.
        init_vars = tuple(
            (_placeholder_value(iv, v) if n else v)
            for v, n, iv in zip(init_vars, nulls, first_iter_vars))
        success = True
    else:
        success = False

    # This check runs regardless, in case we captured non-Tensor inputs.
    _verify_loop_init_vars(init_vars, symbol_names, first_iter_vars)

    return success, init_vars
示例#7
0
    def prune(self, feeds, fetches):
        flat_feeds, flat_fetches = nest.flatten(feeds), nest.flatten(fetches)
        for f in flat_feeds:
            if not isinstance(f, ops.Tensor):
                raise ValueError("Feeds must be tensors.")
        tensor_fetches = []
        operation_fetches = []
        for f in flat_fetches:
            if isinstance(f, ops.Tensor):
                tensor_fetches.append(f)
            elif isinstance(f, ops.Operation):
                operation_fetches.append(f)
            else:
                raise ValueError("Fetches must be tensors or operations.")
        for f in flat_feeds + flat_fetches:
            if f.graph is not self._func_graph:
                raise ValueError(
                    "Can only prune function whose feeds and fetches "
                    "are from this graph (%s). Tensor %s from graph %s" %
                    (self._func_graph, f, f.graph))
        with self._func_graph.as_default():
            pruned_graph = func_graph.FuncGraph("pruned")
            with ops.control_dependencies(operation_fetches):
                if tensor_fetches:
                    identity_fetches = array_ops.identity_n(tensor_fetches)
                    sink_tensor = identity_fetches[0]
                else:
                    identity_fetches = []
                    sink_tensor = control_flow_ops.no_op()
        lift_map = lift_to_graph.lift_to_graph(sink_tensor,
                                               pruned_graph,
                                               sources=flat_feeds +
                                               self.graph.internal_captures)
        for original_fetch, identity_fetch in zip(tensor_fetches,
                                                  identity_fetches):
            lift_map[original_fetch] = lift_map[identity_fetch]
        pruned_graph.outputs.extend(lift_map[x] for x in flat_fetches
                                    if isinstance(x, ops.Tensor))
        for external_capture, internal_capture in self.graph.captures.items():
            pruned_graph.captures[external_capture] = lift_map[
                internal_capture]
        pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds)
        pruned_graph.inputs.extend(pruned_graph.captures.values())

        def _structured_output_mapping(fetched):
            lifted = lift_map[fetched]
            if isinstance(lifted, ops.Operation):
                return None
            return lifted

        pruned_graph.structured_outputs = nest.map_structure(
            _structured_output_mapping, fetches)
        pruned_fn = WrappedFunction(pruned_graph,
                                    variable_holder=self._variable_holder)
        pruned_fn._num_positional_args = len(flat_feeds)  # pylint: disable=protected-access
        pruned_fn._arg_keywords = []  # pylint: disable=protected-access
        return pruned_fn
def _construct_concrete_function(input_func, graph_def, converted_handles):
    """Creates a ConcreteFunction from the input function and frozen graph.

  Args:
    input_func: ConcreteFunction.
    graph_def: TensorFlow GraphDef.
    converted_handles: a set of handles of the varialbes in input_func that were
      converted to constant in `graph_def`.

  Returns:
    ConcreteFunction containing the graph_def.
  """
    captured_inputs = input_func.inputs[-len(input_func.captured_inputs):]
    captured_input_indices = [
        input_func.captured_inputs.index(handle)
        for handle in converted_handles
    ]
    converted_inputs = set(
        [captured_inputs[index] for index in captured_input_indices])
    not_converted_inputs = set(input_func.inputs).difference(converted_inputs)

    output_graph = func_graph.FuncGraph(input_func.graph.name)
    with output_graph.as_default():
        importer.import_graph_def(graph_def, name="")
        output_graph.inputs = _get_tensors_from_graph(output_graph,
                                                      not_converted_inputs)
        output_graph.outputs = _get_tensors_from_graph(output_graph,
                                                       input_func.outputs)

    output_graph.structured_outputs = input_func.graph.structured_outputs
    output_graph.structured_input_signature = (
        input_func.graph.structured_input_signature)

    # pylint: disable=protected-access
    # Create the ConcreteFunction and add it to the global context.
    output_func = function.ConcreteFunction(output_graph,
                                            attrs=input_func._attrs,
                                            signature=input_func._signature)
    output_func.add_to_graph()

    # Inject the captured inputs into the ConcreteFunction.
    output_func._captured_inputs = [
        handle for handle in input_func.captured_inputs
        if handle not in converted_handles
    ]
    output_func.graph.variables = [
        var for var in input_func.graph.variables
        if var.handle not in converted_handles
    ]
    output_func._arg_keywords = input_func._arg_keywords
    output_func._num_positional_args = input_func._num_positional_args
    # pylint: enable=protected-access

    # Register the gradients in the current root context.
    with ops.init_scope():
        output_func._register_gradient()  # pylint: disable=protected-access
    return output_func
def _try_convert_to_tensor_spec(arg, dtype_hint):
  """Returns None or TensorSpec obtained if `arg` is converted to tensor."""
  try:
    # Note: try conversion in a FuncGraph to avoid polluting current context.
    with func_graph_lib.FuncGraph(name="guess_conversion").as_default():
      result = ops.convert_to_tensor(arg, dtype_hint=dtype_hint)
      return tensor_spec.TensorSpec(shape=result.shape, dtype=result.dtype)
  except (TypeError, ValueError):
    return None
 def add_op_to_graph(num_ops):
     with func_graph.FuncGraph("resource").as_default():
         handle = resource_variable_ops.var_handle_op(
             dtype=dtypes.int32, shape=[])
         resource_variable_ops.assign_variable_op(
             handle, constant_op.constant(1, dtype=dtypes.int32))
         for _ in range(num_ops):
             gen_resource_variable_ops.read_variable_op(
                 handle, dtype=dtypes.int32)
示例#11
0
def while_stmt(
    test,
    body,
    get_state,
    set_state,
    init_vars,
    basic_symbol_names=None,
    composite_symbol_names=None,
    opts=None,
):
    """Functional form of a while statement.

  The loop operates on a so-called state, which includes all symbols that are
  variant across loop iterations. In what follows we refer to state as either
  a tuple of entities that represent an actual state, or a list of arguments
  of the corresponding types.

  Args:
    test: Callable with the state as arguments, and boolean return type. The
      loop condition.
    body: Callable with the state as arguments, and state as return type. The
      actual loop body.
    get_state: Additional callable which can capture additional state (such as
      the values of composite symbols). This is only useful when staging the
      loop.
    set_state: Additional callable which save values captured by get_state back
      into the Python environment. This is only useful when staging the loop.
    init_vars: Tuple containing the initial state.
    basic_symbol_names: Tuple containing basic loop var names.
    composite_symbol_names: Tuple containing composite loop var names.
    opts: Optional dict of extra loop parameters.

  Returns:
    Tuple containing the final state.
  """

    # Evaluate the initial test once in order to do the dispatch. The evaluation
    # is isolated to minimize unwanted side effects.
    # TODO(mdan): Do a full iteration - some state types might lower to Tensor.
    with func_graph.FuncGraph('tmp').as_default():
        init_test = test(*init_vars)

    # TensorFlow: Multiple evaluations are acceptable in this case, so we're fine
    # with the re-evaluation of `test` that `_tf_while_stmt` will make.
    if tensors.is_dense_tensor(init_test):
        return _tf_while_stmt(test, body, get_state, set_state, init_vars,
                              basic_symbol_names, composite_symbol_names, opts)

    # Normal Python: We already consumed one evaluation of `test`; consistently,
    # unroll one iteration before dispatching to a normal loop.
    # TODO(mdan): Push the "init_test" value via opts into _py_while_stmt?
    if not init_test:
        return init_vars
    init_vars = body(*init_vars)

    return _py_while_stmt(test, body, get_state, set_state, init_vars, opts)
 def _basic_cond(self, true_value, false_value):
     # Eager cond had different semantics, we don't test those here.
     with func_graph.FuncGraph('tmp').as_default():
         return control_flow.if_stmt(cond=constant_op.constant(True),
                                     body=true_value,
                                     orelse=false_value,
                                     get_state=lambda: (),
                                     set_state=lambda _: None,
                                     basic_symbol_names=('s', ),
                                     composite_symbol_names=())
    def testVariableInFuncGraph(self, distribution):
        def model_fn():
            v = variable_scope.variable(2.0, name="bar")
            ds_context.get_replica_context().merge_call(lambda _: _)
            return v

        with func_graph.FuncGraph("fg").as_default(), distribution.scope():
            v1 = variable_scope.variable(1.0, name="foo")
            v2 = distribution.extended.call_for_each_replica(model_fn)

        self._test_mv_properties(v1, "foo:0", distribution)
        self._test_mv_properties(v2, "bar:0", distribution)
示例#14
0
  def __init__(self, variable_holder=None, **kwargs):
    self._variable_holder = (
        variable_holder or VariableHolder(share_variables=True))

    name = kwargs.pop("name", "wrapped_function_graph")
    # Always start with empty collections, unless otherwise specified. Setting
    # `collections=None` will copy the collections from the outer graph.
    collections = kwargs.pop("collections", {})
    self.graph = func_graph.FuncGraph(name, collections=collections, **kwargs)

    self._wrapped_function = WrappedFunction(self.graph, self._variable_holder)
    self._functions = {}
示例#15
0
    def _convert_saved_model_v2(self):
        """Convert the input SavedModel in 2.0 format."""
        self._saved_model = load.load(self._input_saved_model_dir,
                                      self._input_saved_model_tags)
        func = self._saved_model.signatures[
            self._input_saved_model_signature_key]
        frozen_func = convert_to_constants.convert_variables_to_constants_v2(
            func)
        self._grappler_meta_graph_def = saver.export_meta_graph(
            graph_def=frozen_func.graph.as_graph_def(),
            graph=frozen_func.graph)

        # Add a collection 'train_op' so that Grappler knows the outputs.
        fetch_collection = meta_graph_pb2.CollectionDef()
        for array in func.inputs + func.outputs:
            fetch_collection.node_list.value.append(array.name)
        self._grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
            fetch_collection)

        # Run TRT optimizer in Grappler to convert the graph.
        self._run_conversion()

        def _get_tensor(graph, tensors):
            new_tensors = []
            for tensor in tensors:
                new_tensor = graph.get_tensor_by_name(tensor.name)
                new_tensor.set_shape(tensor.shape)
                new_tensors.append(new_tensor)
            return new_tensors

        # TODO(laigd): do we need to use different name e.g. "trt_func_graph"?
        converted_graph = func_graph.FuncGraph(func.graph.name)
        with converted_graph.as_default():
            importer.import_graph_def(self._converted_graph_def, name="")

        converted_graph.inputs = _get_tensor(converted_graph,
                                             func.graph.inputs)
        converted_graph.outputs = _get_tensor(converted_graph,
                                              func.graph.outputs)
        converted_graph.structured_outputs = func.graph.structured_outputs
        converted_graph.structured_input_signature = (
            func.graph.structured_input_signature)

        # pylint: disable=protected-access
        # TODO(laigd): should we set up the signature as well?
        self._converted_func = function.ConcreteFunction(converted_graph,
                                                         attrs=None,
                                                         signature=None)
        self._converted_func.add_to_graph()
        self._converted_func._arg_keywords = func._arg_keywords
        self._converted_func._num_positional_args = func._num_positional_args
        self._converted_func._captured_inputs = func._captured_inputs
        self._converted_func.graph.variables = func.graph.variables
示例#16
0
    def test_unique_graph_func_graph(self):
        """Test for get_unique_graph with FuncGraph."""
        outer = ops_lib.Graph()
        with outer.as_default():
            k1 = constant_op.constant(1)
            inner = func_graph.FuncGraph("inner")
            inner._graph_key = outer._graph_key
            with inner.as_default():
                k2 = constant_op.constant(2)

        unique_graph = op_selector.get_unique_graph([k1, k2])
        self.assertEqual(unique_graph._graph_key, inner._graph_key)
示例#17
0
    def _initialize(self, args, kwds, add_initializers_to=None):
        """Initializes, on the first call.

    Creates two `Function`s, one that will allow creation of variables
    and one that won't.

    Additionally runs a trace for the `Function` that allows creation
    of variables.

    Args:
      args: Arguments to the underlying python callable.
      kwds: Keyword arguments to the python callable.
      add_initializers_to: Where to collect variable initializers, if not None.
    """

        created_variables = []
        lifted_initializer_graph = func_graph_module.FuncGraph("initializer")
        lifted_all_initializers = [True]
        lifted_placeholders = []

        def variable_capturing_scope(unused_next_creator, **kwds):
            """Creates UnliftedInitializerVariables and saves references to them."""
            v = UnliftedInitializerVariable(
                add_initializers_to=add_initializers_to,
                lifted_initializer_graph=lifted_initializer_graph,
                lifted_all_initializers=lifted_all_initializers,
                lifted_placeholders=lifted_placeholders,
                **kwds)
            created_variables.append(weakref.ref(v))
            return v

        self._created_variables = created_variables
        self._stateful_fn = self._defun_with_scope(variable_capturing_scope)
        self._stateful_fn._name = self._name  # pylint: disable=protected-access
        # Force the definition of the function for these arguments
        self._lifted_initializer_graph = lifted_initializer_graph
        self._graph_deleter = FunctionDeleter(self._lifted_initializer_graph)
        self._lifted_placeholders = lifted_placeholders
        self._concrete_stateful_fn = (
            self._stateful_fn.
            _get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
                *args, **kwds))
        self._lifted_all_initializers = lifted_all_initializers[0]

        def invalid_creator_scope(*unused_args, **unused_kwds):
            """Disables variable creation."""
            raise ValueError("tf.function-decorated function tried to create "
                             "variables on non-first call.")

        self._stateless_fn = self._defun_with_scope(invalid_creator_scope)
        self._stateless_fn._name = self._name  # pylint: disable=protected-access
示例#18
0
  def testExecutingEagerlyOutsideFunction(self, distribution):
    """Verify we preserve the value of executing_eagerly_outside_functions()."""
    def model_fn():
      return ops.executing_eagerly_outside_functions()

    originally = ops.executing_eagerly_outside_functions()
    with distribution.scope():
      in_scope = ops.executing_eagerly_outside_functions()
      in_model_fn = distribution.extended.call_for_each_replica(model_fn)
      unwrapped = distribution.experimental_local_results(in_model_fn)
      self.assertEqual(in_scope, unwrapped[0])
      self.assertEqual(in_scope, originally)

    # Verify this all again, but this time in a FuncGraph.
    with func_graph.FuncGraph("fg").as_default(), distribution.scope():
      in_scope = ops.executing_eagerly_outside_functions()
      in_model_fn = distribution.extended.call_for_each_replica(model_fn)
      unwrapped = distribution.experimental_local_results(in_model_fn)
      self.assertEqual(in_scope, unwrapped[0])
      self.assertEqual(in_scope, originally)
示例#19
0
def _construct_concrete_function(input_func, graph_def):
    """Creates a ConcreteFunction from the input function and frozen graph.

  Args:
    input_func: ConcreteFunction.
    graph_def: TensorFlow GraphDef.

  Returns:
    ConcreteFunction containing the graph_def.
  """
    output_graph = func_graph.FuncGraph(input_func.graph.name)
    with output_graph.as_default():
        importer.import_graph_def(graph_def, name="")
        output_graph.inputs = _get_tensors_from_graph(output_graph,
                                                      input_func.inputs)
        output_graph.outputs = _get_tensors_from_graph(output_graph,
                                                       input_func.outputs)

    output_graph.structured_outputs = input_func.graph.structured_outputs
    output_graph.structured_input_signature = (
        input_func.graph.structured_input_signature)

    # pylint: disable=protected-access
    # Create the ConcreteFunction and add it to the global context.
    output_func = function.ConcreteFunction(output_graph,
                                            attrs=input_func._attrs,
                                            signature=input_func._signature)
    output_func.add_to_graph()

    # Inject the captured inputs into the ConcreteFunction.
    output_func._captured_inputs = input_func.captured_inputs
    output_func.graph.variables = input_func.graph.variables
    output_func._arg_keywords = input_func._arg_keywords
    output_func._num_positional_args = input_func._num_positional_args
    # pylint: enable=protected-access

    # Register the gradients in the current root context.
    with ops.init_scope():
        output_func._register_gradient()  # pylint: disable=protected-access
    return output_func
示例#20
0
    def testClassAttrsRemoved(self):
        """Tests that _class attrs (from colocate_with()) are removed."""
        @def_function.function
        def fn():
            two = constant_op.constant(2.0, name='two')
            ten = constant_op.constant(10.0, name='ten')
            twenty = math_ops.multiply(two, ten, name='twenty')
            three = constant_op.constant(3.0, name='three')
            with framework_ops.colocate_with(twenty):
                thirty = math_ops.multiply(three, ten, name='thirty')
            return ten, twenty, thirty

        concrete_fn = fn.get_concrete_function()
        self.assertItemsEqual(  # Before lifting, 'fn' has colocation attrs.
            concrete_fn.graph.get_operation_by_name(
                'thirty').colocation_groups(),
            [compat.as_bytes('loc:@twenty')])
        thirty_out = concrete_fn.graph.outputs[2]

        g = func_graph.FuncGraph('lifted')
        lift_to_graph.lift_to_graph([thirty_out], g)

        # After lifting, colocation attrs are gone.
        ops = g.get_operations()
        self.assertItemsEqual(
            [op.name for op in ops],
            [
                'three',
                'ten',
                'thirty',  # Lifted from `fn` body.
                thirty_out.op.name
            ])  # Wrapper for output.
        for op in ops:
            with self.assertRaises(ValueError):
                class_attr = op.get_attr('_class')  # Expected not to exist.
                print('Unexpected class_attr', class_attr, 'on', op.name)
            self.assertItemsEqual(
                op.colocation_groups(),  # Expect default self-ref.
                [compat.as_bytes('loc:@%s' % op.name)])
示例#21
0
    def _basic_cond(self, body_fn, else_fn):
        def body():
            nonlocal x
            x = body_fn()

        def orelse():
            nonlocal x
            x = else_fn()

        def set_state(cond_vars):
            nonlocal x
            x, = cond_vars

        x = 0
        # Eager cond had different semantics, we don't test those here.
        with func_graph.FuncGraph('tmp').as_default():
            control_flow.if_stmt(cond=constant_op.constant(True),
                                 body=body,
                                 orelse=orelse,
                                 get_state=lambda: (x, ),
                                 set_state=set_state,
                                 symbol_names=('x', ),
                                 nouts=1)
        return x
示例#22
0
    def prune(self, feeds, fetches, name=None, input_signature=None):
        """Extract a subgraph of this function's underlying graph.

    Wraps the subgraph in a new `WrappedFunction` object.

    Args:
      feeds: Input tensors to the subgraph to extract, as `Tensor` objects.
      fetches: Possibly-nested Python data structure containing information
        about outputs of the target subgraph. Each entry can either be a
        `Tensor` object (for data outputs), an `Operation` object (for control
        outputs), or a `TensorInfo` proto. Any additional shape/dtype
        information provided in a `TensorInfo` and not present in the original
        graph will be added to the returned subgraph.
      name: (optional) Name to give to the underlying `FuncGraph` of the
        returned object. If no name is provided, the graph's name will be
        `"pruned"`.
      input_signature: (optional) possibly-nested Python data structure
        containing `TensorSpec` objects, with which to populate the returned
        functions's `FuncGraph`'s `structured_input_signature` field.

    Returns:
      A new `WrappedFunction` object containing a copy of the portion of this
        object's graph that goes from `feeds` to `fetches`.
    """
        # TODO(b/129646028): Add support for CompositeTensors.
        name = name or "pruned"
        flat_feeds = nest.flatten(feeds, expand_composites=True)
        flat_feeds = [self.graph.as_graph_element(t) for t in flat_feeds]
        for f in flat_feeds:
            if not isinstance(f, ops.Tensor):
                raise ValueError("Feeds must be tensors.")

        # Ignoring all feeds that are captures allows prune to be called
        # using wrapped_func.inputs even when it uses variables
        internal_captures = object_identity.ObjectIdentitySet(
            self.graph.internal_captures)
        flat_feeds = [f for f in flat_feeds if f not in internal_captures]

        operation_fetches = []
        tensor_fetches = []
        tensor_infos = []

        def _fetch_preprocesing_callback(fetch):
            """Extract out lists of ops, tensors, and tensor type info.

      Turns TensorInfos into Tensors in the original `fetches` structure.
      Also extracts ops from `fetches`.

      Args:
        fetch: The fetch to preprocess: Tensor, TensorInfo, or Operation, or
          string identifying a Tensor or Operation.

      Returns:
        `fetch` converted to a Tensor.
      """
            if isinstance(fetch, ops.Operation):
                operation_fetches.append(fetch)
                return fetch
            elif isinstance(fetch, meta_graph_pb2.TensorInfo):
                tensor_infos.append(fetch)
                decoded = _get_element_from_tensor_info(
                    fetch, self._func_graph)
                if (tensor_util.is_tensor(decoded) or isinstance(
                        decoded, composite_tensor.CompositeTensor)):
                    tensor_fetches.append(decoded)
                else:
                    operation_fetches.append(decoded)
                return decoded
            elif isinstance(fetch,
                            (ops.Tensor, composite_tensor.CompositeTensor)):
                tensor_fetches.append(fetch)
                return fetch
            else:
                graph_element = self.graph.as_graph_element(fetch)
                return _fetch_preprocesing_callback(graph_element)

        fetches = nest.map_structure(_fetch_preprocesing_callback, fetches)

        # Expand composite tensors into their component dense Tensors.
        tensor_fetches = nest.flatten(tensor_fetches, expand_composites=True)

        for f in (flat_feeds + tensor_fetches + operation_fetches):
            if f.graph is not self._func_graph:
                raise ValueError(
                    "Can only prune function whose feeds and fetches "
                    "are from this graph (%s). Input %s is from graph %s" %
                    (self._func_graph, f, f.graph))
        with self._func_graph.as_default():
            pruned_graph = func_graph.FuncGraph(name)
        lift_map = lift_to_graph.lift_to_graph(
            operation_fetches + tensor_fetches,
            pruned_graph,
            sources=flat_feeds + self.graph.internal_captures)

        # Note that we add the component tensors of any composite tensors to the
        # returned function's outputs list; the list must contain these component
        # tensors, or the function's sparse outputs won't work properly.
        pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches)
        pruned_graph.control_outputs.extend(
            [lift_map[operation] for operation in operation_fetches])
        pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds)
        for external_capture, internal_capture in self.graph.captures:
            pruned_graph.add_capture(external_capture,
                                     lift_map[internal_capture])
        for ti in tensor_infos:
            if ti.WhichOneof("encoding") == "name":  # Dense tensors only
                t = pruned_graph.as_graph_element(ti.name)
                if tensor_util.is_tensor(t):
                    t.set_shape(tensor_shape.TensorShape(ti.tensor_shape))
        # pylint: disable=protected-access
        for f in self.graph._functions.values():
            pruned_graph._add_function(f)
        # pylint: enable=protected-access

        pruned_graph.variables = self.graph.variables

        def _structured_output_mapping(fetched):
            """callback for `nest.map_structure()`"""
            lifted = lift_map[fetched]
            if isinstance(lifted, ops.Operation):
                return None
            return lifted

        # expand_composites=True here causes composite tensors to be expanded
        # into their component dense Tensors, mapped to the new graph, and then
        # reconstituted into their original composite form.
        pruned_graph.structured_outputs = nest.map_structure(
            _structured_output_mapping, fetches, expand_composites=True)
        pruned_graph.structured_input_signature = input_signature
        pruned_fn = WrappedFunction(pruned_graph,
                                    variable_holder=self._variable_holder)
        pruned_fn._num_positional_args = len(flat_feeds)  # pylint: disable=protected-access
        # TODO(kathywu): Enable keyword arguments if an input signature is specified
        pruned_fn._arg_keywords = [tensor.op.name for tensor in flat_feeds]  # pylint: disable=protected-access
        return pruned_fn
示例#23
0
    def prune(self, feeds, fetches, name=None):
        name = name or "pruned"
        flat_feeds, flat_fetches = nest.flatten(feeds), nest.flatten(fetches)
        for f in flat_feeds:
            if not isinstance(f, ops.Tensor):
                raise ValueError("Feeds must be tensors.")

        # Ignoring all feeds that are captures allows prune to be called
        # using wrapped_func.inputs even when it uses variables
        internal_captures = self.graph.internal_captures
        flat_feeds = [f for f in flat_feeds if f not in internal_captures]

        tensor_fetches = []
        operation_fetches = []
        for f in flat_fetches:
            if isinstance(f, ops.Tensor):
                tensor_fetches.append(f)
            elif isinstance(f, ops.Operation):
                operation_fetches.append(f)
            else:
                raise ValueError("Fetches must be tensors or operations.")
        for f in flat_feeds + flat_fetches:
            if f.graph is not self._func_graph:
                raise ValueError(
                    "Can only prune function whose feeds and fetches "
                    "are from this graph (%s). Tensor %s from graph %s" %
                    (self._func_graph, f, f.graph))
        with self._func_graph.as_default():
            pruned_graph = func_graph.FuncGraph(name)
            with ops.control_dependencies(operation_fetches):
                if tensor_fetches:
                    identity_fetches = array_ops.identity_n(tensor_fetches)
                    sink_tensor = identity_fetches[0]
                else:
                    identity_fetches = []
                    sink_tensor = array_ops.zeros([])
        lift_map = lift_to_graph.lift_to_graph([sink_tensor],
                                               pruned_graph,
                                               sources=flat_feeds +
                                               internal_captures)
        for original_fetch, identity_fetch in zip(tensor_fetches,
                                                  identity_fetches):
            lift_map[original_fetch] = lift_map[identity_fetch]
        pruned_graph.outputs.extend(lift_map[x] for x in flat_fetches
                                    if isinstance(x, ops.Tensor))
        pruned_graph.control_outputs.extend(
            [lift_map[operation] for operation in operation_fetches])
        if not tensor_fetches:
            pruned_graph.outputs.append(lift_map[sink_tensor])
        for external_capture, internal_capture in self.graph.captures.items():
            pruned_graph.captures[external_capture] = lift_map[
                internal_capture]
        pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds)
        pruned_graph.inputs.extend(pruned_graph.captures.values())

        pruned_graph.variables = self.graph.variables

        def _structured_output_mapping(fetched):
            lifted = lift_map[fetched]
            if isinstance(lifted, ops.Operation):
                return None
            return lifted

        pruned_graph.structured_outputs = nest.map_structure(
            _structured_output_mapping, fetches)
        pruned_fn = WrappedFunction(pruned_graph,
                                    variable_holder=self._variable_holder)
        pruned_fn._num_positional_args = len(flat_feeds)  # pylint: disable=protected-access
        pruned_fn._arg_keywords = []  # pylint: disable=protected-access
        return pruned_fn
 def add_op_to_graph(num_ops):
     with func_graph.FuncGraph("add").as_default():
         a = gen_array_ops.placeholder(dtypes.float32)
         b = gen_array_ops.placeholder(dtypes.float32)
         for _ in range(num_ops):
             gen_math_ops.add(a, b)
示例#25
0
def _try_handling_undefineds(body, get_state, set_state, init_vars, nulls,
                             shape_invariants, symbol_names):
  """Makes a best-effort attempt to substitute undefineds with placeholders.

  Note: this substitution requires two things to happen:
   1. the types of loop variables could be inferred (usually by staging one
       iteration)
   2. these types could be replaced by placeholders (e.g. zero values, for
       tensors.

  Args:
    body: a function representing the loop body. See while_stmt.
    get_state: state getter for the loop statement. See while_stmt.
    set_state: state getter for the loop statement. See while_stmt.
    init_vars: loop variables before entering the loop. See while_stmt.
    nulls: list of boolean flags indicating whether the corresponding loop var
      is None or undefined.
    shape_invariants: user-specified shape invariant for each loop variable.
    symbol_names: list of loop variable names. See while_stmt.

  Returns:
    A tuple (success, new_init_vars, extra_shape_invariants, failure_message):
     * success is a boolean flag indicating
       whether types could be successfully inferred (step 1 above)
     * new_init_vars contains the loop vars, with None or undefined values
       replaced by default values, where possible (step 2 above)
     * extra_shape_invariants contains shape invariants that would be needed
       by while_stmt, for instance if the placeholder values had a shape
       different from the corresponding loop outputs
  """
  state_modified = False
  failure_message = None

  try:
    # Stage an iteration of the loop body in a temporary graph.
    with func_graph.FuncGraph('tmp').as_default():
      # This call to set_state helps report nicer error messages when symbols
      # are inconsistently used.
      # Another complication is that non_tensor values will be autocast to
      # Tensor by while_loop, and their static value lost. So we need to account
      # that here.
      def autocast_to_tensor(v):
        if isinstance(
            v, (int, float, bool, str, list, tuple, np.ndarray, np.generic)):
          init_val = ops.convert_to_tensor_v2(v)
          return array_ops.placeholder(init_val.dtype, init_val.shape)
        return v
      autocast_init_vars = nest.map_structure(autocast_to_tensor, init_vars)
      set_state(autocast_init_vars)
      state_modified = True

      body()
      first_iter_vars = get_state()
  except (UnboundLocalError, TypeError, ValueError, KeyError):
    ag_logging.log(1, 'Caught error while staging loop body', exc_info=True)
    # Fall back to the old functionality. It will likely result in an input
    # validation failure.
    failure_message = ('Note: AutoGraph tried to determine initial values, but '
                       'ran into an error and gave up:\n\t' +
                       '\t'.join(traceback.format_exception(*sys.exc_info())))
    first_iter_vars = None
  finally:
    if state_modified:
      set_state(init_vars)

  if first_iter_vars is not None:
    # Note: the actual placeholder value doesn't matter, because as the staging
    # proved, it will be replaced by an actual value before being read.
    inits_and_invariants = tuple(
        (_placeholder_value(iv, i, v) if n else (v, None))
        for v, n, iv, i in zip(init_vars, nulls, first_iter_vars,
                               shape_invariants))
    init_vars, extra_shape_invariants = zip(*inits_and_invariants)
    success = True
  else:
    success = False

  # This check runs regardless, in case we captured non-Tensor inputs.
  _verify_loop_init_vars(
      init_vars, symbol_names, first_iter_vars, extra_message=failure_message)

  return success, init_vars, extra_shape_invariants