コード例 #1
0
ファイル: load.py プロジェクト: wwjiang007/tensorflow
 def get_unused_handle(x):
     return _unused_handle() if distribute_utils.is_distributed_variable(x)   \
         else x
コード例 #2
0
 def get_handle(x):
     return x.handle if distribute_utils.is_distributed_variable(
         x) else x
コード例 #3
0
ファイル: load.py プロジェクト: waterdrops/tensorflow
    def _setup_functions_captures(self):
        """Setup captures and variables in restored functions."""
        concrete_functions = sorted(self._proto.concrete_functions.items())
        for name, proto in concrete_functions:
            concrete_function = self._concrete_functions[name]
            bound_inputs = [
                self._get_tensor_from_node(node_id, name)
                for node_id in proto.bound_inputs
            ]
            bound_variables = [
                self._nodes[node_id] for node_id in proto.bound_inputs
                if self._proto.nodes[node_id].WhichOneof("kind") == "variable"
            ]
            # TODO(b/205010575): This is only injecting the captured inputs into the
            # concrete function, note that we did not modify the FuncGraph
            # itself.
            captured_inputs_list = []
            concrete_function._func_graph.variables = bound_variables  # pylint: disable=protected-access
            if bound_inputs:
                for bound_input, internal_capture in zip(
                        bound_inputs,
                        concrete_function.inputs[-len(bound_inputs):]):
                    if distribute_utils.is_distributed_variable(bound_input):
                        concrete_function.graph.capture_distributed_variable(
                            bound_input, internal_capture)
                        captured_inputs_list.append(bound_input)
                    elif distribute_utils.is_distributed_table(bound_input):
                        closure, spec = bound_input.resource_handle_call_time_value(
                        )
                        concrete_function.graph.replace_capture_with_deferred_capture(
                            bound_input._coordinator_instance.resource_handle,  # pylint: disable=protected-access
                            closure,
                            spec,
                            default_value=bound_input._coordinator_instance.
                            resource_handle,  # pylint: disable=protected-access
                            placeholder=internal_capture)
                        captured_inputs_list.append(
                            concrete_function.graph.
                            deferred_external_captures[-1])

                    else:
                        captured_inputs_list.append(bound_input)
                        concrete_function.graph.replace_capture(
                            bound_input, internal_capture)
                        if internal_capture.dtype == dtypes.resource:
                            if resource_variable_ops.is_resource_variable(
                                    bound_input):
                                try:
                                    handle = bound_input.handle
                                except ValueError:
                                    # For mirrored variables we'll copy handle data for components
                                    # as they get captured.
                                    pass
                                else:
                                    handle_data_util.copy_handle_data(
                                        handle, internal_capture)
                            else:
                                handle_data_util.copy_handle_data(
                                    bound_input, internal_capture)
                        # Setting "captures" first means "capture" won't create a new
                        # placeholder for this input.
                        concrete_function.graph.capture(bound_input)

            concrete_function.set_external_captures(captured_inputs_list)