def get_unused_handle(x): return _unused_handle() if distribute_utils.is_distributed_variable(x) \ else x
def get_handle(x): return x.handle if distribute_utils.is_distributed_variable( x) else x
def _setup_functions_captures(self): """Setup captures and variables in restored functions.""" concrete_functions = sorted(self._proto.concrete_functions.items()) for name, proto in concrete_functions: concrete_function = self._concrete_functions[name] bound_inputs = [ self._get_tensor_from_node(node_id, name) for node_id in proto.bound_inputs ] bound_variables = [ self._nodes[node_id] for node_id in proto.bound_inputs if self._proto.nodes[node_id].WhichOneof("kind") == "variable" ] # TODO(b/205010575): This is only injecting the captured inputs into the # concrete function, note that we did not modify the FuncGraph # itself. captured_inputs_list = [] concrete_function._func_graph.variables = bound_variables # pylint: disable=protected-access if bound_inputs: for bound_input, internal_capture in zip( bound_inputs, concrete_function.inputs[-len(bound_inputs):]): if distribute_utils.is_distributed_variable(bound_input): concrete_function.graph.capture_distributed_variable( bound_input, internal_capture) captured_inputs_list.append(bound_input) elif distribute_utils.is_distributed_table(bound_input): closure, spec = bound_input.resource_handle_call_time_value( ) concrete_function.graph.replace_capture_with_deferred_capture( bound_input._coordinator_instance.resource_handle, # pylint: disable=protected-access closure, spec, default_value=bound_input._coordinator_instance. resource_handle, # pylint: disable=protected-access placeholder=internal_capture) captured_inputs_list.append( concrete_function.graph. deferred_external_captures[-1]) else: captured_inputs_list.append(bound_input) concrete_function.graph.replace_capture( bound_input, internal_capture) if internal_capture.dtype == dtypes.resource: if resource_variable_ops.is_resource_variable( bound_input): try: handle = bound_input.handle except ValueError: # For mirrored variables we'll copy handle data for components # as they get captured. pass else: handle_data_util.copy_handle_data( handle, internal_capture) else: handle_data_util.copy_handle_data( bound_input, internal_capture) # Setting "captures" first means "capture" won't create a new # placeholder for this input. concrete_function.graph.capture(bound_input) concrete_function.set_external_captures(captured_inputs_list)