def restore_captures(concrete_function, inputs):
  """Restore captures for the concrete function.

  Used at deserialization time.  For functions that are being deserialized,
  saved model restores objects that tensors were captured from, but functions
  only know about their tensors -- object information is destroyed by tracing.
  This additional logic extracts the tensors which the function originally
  captured.

  Args:
    concrete_function: the concrete function for which to restore captures
    inputs: a list tensors or other Python objects (such as variables) which
      contain tensors that were originally captured by the function
  """
  bound_inputs = [get_tensor_from_node(obj) for obj in inputs]
  bound_variables = [
      obj for obj in inputs
      if isinstance(obj, (variables_lib.Variable,
                          resource_variable_ops.BaseResourceVariable))
  ]
  # TODO(b/205010575): This is only injecting the captured inputs into the
  # concrete function, note that we did not modify the FuncGraph
  # itself.
  captured_inputs_list = []
  concrete_function.set_variables(bound_variables)
  if bound_inputs:
    for bound_input, internal_capture in zip(
        bound_inputs, concrete_function.inputs[-len(bound_inputs):]):
      # Distributed inputs have special logic for capturing, so we call their
      # custom restoration methods
      if hasattr(bound_input, "__tf_experimental_restore_capture__"):
        captured_inputs_list.append(
            bound_input.__tf_experimental_restore_capture__(
                concrete_function, internal_capture))
      else:
        captured_inputs_list.append(bound_input)
        concrete_function.graph.replace_capture(bound_input, internal_capture)
        if internal_capture.dtype == dtypes.resource:
          if resource_variable_ops.is_resource_variable(bound_input):
            try:
              handle = bound_input.handle
            except ValueError:
              # For mirrored variables we'll copy handle data for components
              # as they get captured.
              pass
            else:
              handle_data_util.copy_handle_data(handle, internal_capture)
          else:
            # TODO(b/213451747): Remove need to call copy_handle_data
            handle_data_util.copy_handle_data(bound_input, internal_capture)
        # Setting "captures" first means "capture" won't create a new
        # placeholder for this input.
        concrete_function.graph.capture(bound_input)

  if any([inp is None for inp in captured_inputs_list]):
    warnings.warn("Trying to load ShardedVariables using tf.saved_model.load. "
                  "This won't work if using a tf.distribute.Strategy, and may "
                  "use excess memory if not using a Strategy. Ignore this "
                  "warning if using tf.keras.models.load_model.")
  concrete_function.set_external_captures(captured_inputs_list)
Пример #2
0
  def _setup_function_captures(self, concrete_function_name, nodes):
    """Setup captures and variables in a restored function."""
    self._restored_concrete_functions.add(concrete_function_name)
    concrete_function = self._concrete_functions[concrete_function_name]
    proto = self._proto.concrete_functions[concrete_function_name]
    bound_inputs = [
        self._get_tensor_from_node(nodes[node_id])
        for node_id in proto.bound_inputs]
    bound_variables = [
        nodes[node_id] for node_id in proto.bound_inputs
        if self._proto.nodes[node_id].WhichOneof("kind") == "variable"
    ]
    # TODO(b/205010575): This is only injecting the captured inputs into the
    # concrete function, note that we did not modify the FuncGraph
    # itself.
    captured_inputs_list = []
    concrete_function._func_graph.variables = bound_variables  # pylint: disable=protected-access
    if bound_inputs:
      for bound_input, internal_capture in zip(
          bound_inputs, concrete_function.inputs[-len(bound_inputs):]):
        if distribute_utils.is_distributed_variable(bound_input):
          concrete_function.graph.capture_distributed_variable(
              bound_input, internal_capture)
          captured_inputs_list.append(bound_input)
        elif distribute_utils.is_distributed_table(bound_input):
          closure, spec = bound_input.resource_handle_call_time_value()
          concrete_function.graph.replace_capture_with_deferred_capture(
              bound_input._coordinator_instance.resource_handle,  # pylint: disable=protected-access
              closure,
              spec,
              default_value=bound_input._coordinator_instance.resource_handle,  # pylint: disable=protected-access
              placeholder=internal_capture)
          captured_inputs_list.append(
              concrete_function.graph.deferred_external_captures[-1])

        else:
          captured_inputs_list.append(bound_input)
          concrete_function.graph.replace_capture(bound_input,
                                                  internal_capture)
          if internal_capture.dtype == dtypes.resource:
            if resource_variable_ops.is_resource_variable(bound_input):
              try:
                handle = bound_input.handle
              except ValueError:
                # For mirrored variables we'll copy handle data for components
                # as they get captured.
                pass
              else:
                handle_data_util.copy_handle_data(handle, internal_capture)
            else:
              handle_data_util.copy_handle_data(bound_input, internal_capture)
          # Setting "captures" first means "capture" won't create a new
          # placeholder for this input.
          concrete_function.graph.capture(bound_input)

    concrete_function.set_external_captures(captured_inputs_list)
Пример #3
0
 def _setup_functions_captures(self):
     """Setup captures and variables in restored functions."""
     concrete_functions = sorted(self._proto.concrete_functions.items())
     for name, proto in concrete_functions:
         concrete_function = self._concrete_functions[name]
         bound_inputs = [
             self._get_tensor_from_node(node_id, name)
             for node_id in proto.bound_inputs
         ]
         bound_variables = [
             self._nodes[node_id] for node_id in proto.bound_inputs
             if self._proto.nodes[node_id].WhichOneof("kind") == "variable"
         ]
         # TODO(andresp): This is only injecting the captured inputs into the
         # concrete function, note that we did not modify the FuncGraph
         # itself.
         concrete_function._captured_inputs = bound_inputs  # pylint: disable=protected-access
         concrete_function._func_graph.variables = bound_variables  # pylint: disable=protected-access
         if bound_inputs:
             for bound_input, internal_capture in zip(
                     bound_inputs,
                     concrete_function.inputs[-len(bound_inputs):]):
                 if distribute_utils.is_distributed_variable(bound_input):
                     concrete_function.graph.capture_distributed_variable(
                         bound_input, internal_capture)
                 else:
                     concrete_function.graph.replace_capture(
                         bound_input, internal_capture)
                     if internal_capture.dtype == dtypes.resource:
                         if resource_variable_ops.is_resource_variable(
                                 bound_input):
                             try:
                                 handle = bound_input.handle
                             except ValueError:
                                 # For mirrored variables we'll copy handle data for components
                                 # as they get captured.
                                 pass
                             else:
                                 handle_data_util.copy_handle_data(
                                     handle, internal_capture)
                         else:
                             handle_data_util.copy_handle_data(
                                 bound_input, internal_capture)
                     # Setting "captures" first means "capture" won't create a new
                     # placeholder for this input.
                     concrete_function.graph.capture(bound_input)
Пример #4
0
def tensor_list_set_item(input_handle,
                         index,
                         item,
                         resize_if_index_out_of_bounds=False,
                         name=None):
    """Sets `item` at `index` in input list."""
    if resize_if_index_out_of_bounds:
        input_list_size = gen_list_ops.tensor_list_length(input_handle)
        # TODO(srbs): This could cause some slowdown. Consider fusing resize
        # functionality in the SetItem op.
        input_handle = control_flow_ops.cond(
            index >= input_list_size,
            lambda: gen_list_ops.tensor_list_resize(  # pylint: disable=g-long-lambda
                input_handle, index + 1),
            lambda: input_handle)
    output_handle = gen_list_ops.tensor_list_set_item(
        input_handle=input_handle, index=index, item=item, name=name)
    handle_data_util.copy_handle_data(input_handle, output_handle)
    return output_handle
Пример #5
0
def tensor_list_scatter(tensor,
                        indices,
                        element_shape=None,
                        input_handle=None,
                        name=None):
    """Returns a TensorList created or updated by scattering `tensor`."""
    tensor = ops.convert_to_tensor(tensor)
    if input_handle is not None:
        output_handle = gen_list_ops.tensor_list_scatter_into_existing_list(
            input_handle=input_handle,
            tensor=tensor,
            indices=indices,
            name=name)
        handle_data_util.copy_handle_data(input_handle, output_handle)
        return output_handle
    else:
        output_handle = gen_list_ops.tensor_list_scatter_v2(
            tensor=tensor,
            indices=indices,
            element_shape=_build_element_shape(element_shape),
            num_elements=-1,
            name=name)
        _set_handle_data(output_handle, element_shape, tensor.dtype)
        return output_handle
Пример #6
0
def _graph_mode_decorator(f, args, kwargs):
    """Implement custom gradient decorator for graph mode."""
    # TODO(rsepassi): Add support for kwargs
    if kwargs:
        raise ValueError(
            "The custom_gradient decorator currently supports keywords "
            "arguments only when eager execution is enabled.")
    name = generate_name()
    args = nest.map_structure(ops.convert_to_tensor, args)

    # Checking global and local variables attempts to ensure that no non-resource
    # Variables are added to the graph.
    current_var_scope = variable_scope.get_variable_scope()
    before_vars = set([
        v.ref() for v in current_var_scope.global_variables() +
        current_var_scope.local_variables()
    ])
    with tape_lib.VariableWatcher() as variable_watcher:
        result, grad_fn = f(*args)

    args = nest.flatten(args)
    flat_result = nest.flatten(result)
    flat_result_len = len(flat_result)

    after_vars = set([
        v.ref() for v in current_var_scope.global_variables() +
        current_var_scope.local_variables()
    ])
    new_vars = after_vars - before_vars
    new_vars_list = [v.deref() for v in new_vars]
    for v in new_vars_list:
        if not resource_variable_ops.is_resource_variable(v):
            raise TypeError(
                "All variables used by a function wrapped with @custom_gradient must "
                "be `ResourceVariable`s. Ensure that no `variable_scope` is created "
                "with `use_resource=False`.")

    # The variables that grad_fn needs to return gradients for are the set of
    # variables used that are *not* part of the inputs.
    variables_in_tape = frozenset(
        [v.ref() for v in variable_watcher.watched_variables()])

    graphs = {getattr(o, "graph", None) for o in flat_result}
    # Not all results may be tensors. However, we want to ensure all tensor
    # outputs are from the same graph and get a list of captured inputs for
    # variable search
    graphs.discard(None)  # Discard non-graph outputs
    if graphs:
        if len(graphs) > 1:
            raise ValueError(
                "All custom_gradient outputs should be from the same graph")
        output_graph = graphs.pop()
        filtered_input_tensors = []
        for i in args:
            if i.graph == output_graph:
                filtered_input_tensors.append(i)
    else:
        filtered_input_tensors = args

    variables_in_subgraph = frozenset([
        v.ref()
        for v in _get_dependent_variables(input_ops=filtered_input_tensors,
                                          output_ops=flat_result)
    ])
    variables = sorted(
        [v.deref() for v in variables_in_subgraph.union(variables_in_tape)],
        key=lambda v: v.name)

    grad_argspec = tf_inspect.getfullargspec(grad_fn)
    variables_in_signature = ("variables" in grad_argspec.args
                              or "variables" in grad_argspec.kwonlyargs
                              or grad_argspec.varkw)
    if variables and not variables_in_signature:
        raise TypeError(
            "@tf.custom_gradient grad_fn must accept keyword argument 'variables', "
            "since function uses variables: {}".format(variables))
    if variables_in_signature and not variables:
        # User seems to intend to use variables but none were captured.
        logging.warning(
            "@custom_gradient grad_fn has 'variables' in signature, but "
            "no ResourceVariables were used on the forward pass.")

    all_tensors = flat_result + args + variables

    def tape_grad_fn(*result_grads):
        """Custom grad fn wrapper."""
        result_grads = result_grads[:flat_result_len]
        if variables:
            input_grads, variable_grads = grad_fn(*result_grads,
                                                  variables=variables)
            if len(variable_grads) != len(variables):
                raise ValueError("Must return gradient for each variable from "
                                 "@custom_gradient grad_fn.")
        else:
            input_grads = grad_fn(*result_grads)
            variable_grads = []

        # Need to return one value per input to the IdentityN, so pad the
        # gradients of the inputs of the custom_gradient function with the
        # gradients of the outputs as well.
        input_grads = nest.flatten(input_grads)
        return ([None] * flat_result_len) + input_grads + variable_grads

    @ops.RegisterGradient(name)
    def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
        """Custom grad fn wrapper."""
        return tape_grad_fn(*result_grads)

    original_tensors = all_tensors
    with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
        all_tensors = array_ops.identity_n(all_tensors)

    original_tensors = [ops.convert_to_tensor(x) for x in original_tensors]

    # Propagate handle data for happier shape inference for resource variables.
    for i, t in enumerate(original_tensors):
        if t.dtype == dtypes.resource and hasattr(t, "_handle_data"):
            all_tensors[i]._handle_data = t._handle_data  # pylint: disable=protected-access
    tape_lib.record_operation(f.__name__, all_tensors, original_tensors,
                              tape_grad_fn)
    for ot, t in zip(original_tensors, all_tensors):
        handle_data_util.copy_handle_data(ot, t)
    return nest.pack_sequence_as(structure=result,
                                 flat_sequence=all_tensors[:flat_result_len])