Exemple #1
0
  def _setup_function_captures(self, concrete_function_name, nodes):
    """Setup captures and variables in a restored function."""
    self._restored_concrete_functions.add(concrete_function_name)
    concrete_function = self._concrete_functions[concrete_function_name]
    proto = self._proto.concrete_functions[concrete_function_name]
    bound_inputs = [
        self._get_tensor_from_node(nodes[node_id])
        for node_id in proto.bound_inputs]
    bound_variables = [
        nodes[node_id] for node_id in proto.bound_inputs
        if self._proto.nodes[node_id].WhichOneof("kind") == "variable"
    ]
    # TODO(b/205010575): This is only injecting the captured inputs into the
    # concrete function, note that we did not modify the FuncGraph
    # itself.
    captured_inputs_list = []
    concrete_function._func_graph.variables = bound_variables  # pylint: disable=protected-access
    if bound_inputs:
      for bound_input, internal_capture in zip(
          bound_inputs, concrete_function.inputs[-len(bound_inputs):]):
        if distribute_utils.is_distributed_variable(bound_input):
          concrete_function.graph.capture_distributed_variable(
              bound_input, internal_capture)
          captured_inputs_list.append(bound_input)
        elif distribute_utils.is_distributed_table(bound_input):
          closure, spec = bound_input.resource_handle_call_time_value()
          concrete_function.graph.replace_capture_with_deferred_capture(
              bound_input._coordinator_instance.resource_handle,  # pylint: disable=protected-access
              closure,
              spec,
              default_value=bound_input._coordinator_instance.resource_handle,  # pylint: disable=protected-access
              placeholder=internal_capture)
          captured_inputs_list.append(
              concrete_function.graph.deferred_external_captures[-1])

        else:
          captured_inputs_list.append(bound_input)
          concrete_function.graph.replace_capture(bound_input,
                                                  internal_capture)
          if internal_capture.dtype == dtypes.resource:
            if resource_variable_ops.is_resource_variable(bound_input):
              try:
                handle = bound_input.handle
              except ValueError:
                # For mirrored variables we'll copy handle data for components
                # as they get captured.
                pass
              else:
                handle_data_util.copy_handle_data(handle, internal_capture)
            else:
              handle_data_util.copy_handle_data(bound_input, internal_capture)
          # Setting "captures" first means "capture" won't create a new
          # placeholder for this input.
          concrete_function.graph.capture(bound_input)

    concrete_function.set_external_captures(captured_inputs_list)
Exemple #2
0
 def _get_tensor_from_node(self, node):
   """Resolves a node id into a tensor to be captured for a function."""
   with ops.init_scope():
     if distribute_utils.is_distributed_variable(node):
       return node
     elif distribute_utils.is_distributed_table(node):
       return node
     elif resource_variable_ops.is_resource_variable(node):
       return node.handle
     elif isinstance(node, tracking.Asset):
       return node.asset_path
     elif tensor_util.is_tf_type(node):
       return node
     elif isinstance(node, tracking.CapturableResource):
       # Note: this executes restored functions in the CapturableResource.
       return node.resource_handle
     raise ValueError(f"Cannot convert node {node} to tensor.")
Exemple #3
0
    def _get_tensor_from_node(self, node_id, fn_name):
        """Resolves a node id into a tensor to be captured for a function."""
        if self._node_filters is not None and self._nodes[node_id] is None:
            raise ValueError(
                f"Error when processing nodes_to_load. Function '{fn_name}' requires "
                "inputs/variables that are not loaded when nodes_to_load="
                f"{self._node_filters}.")

        with ops.init_scope():
            obj = self._nodes[node_id]
            if distribute_utils.is_distributed_variable(obj):
                return obj
            elif distribute_utils.is_distributed_table(obj):
                return obj
            elif resource_variable_ops.is_resource_variable(obj):
                return obj.handle
            elif isinstance(obj, tracking.Asset):
                return obj.asset_path
            elif tensor_util.is_tf_type(obj):
                return obj
            elif isinstance(obj, tracking.CapturableResource):
                # Note: this executes restored functions in the CapturableResource.
                return obj.resource_handle
            raise ValueError(f"Cannot convert node {obj} to tensor.")