def _create_substitute_placeholder(value, name=None, dtype=None): """Creates a placeholder for `value` and propagates shape info to it.""" # Note: setting ops.control_dependencies(None) ensures we always put # capturing placeholders outside of any control flow context. with ops.control_dependencies(None): placeholder = graph_placeholder( dtype=dtype or value.dtype, shape=value.shape, name=name) custom_gradient.copy_handle_data(value, placeholder) return placeholder
def _create_substitute_placeholder(value, name=None, dtype=None): """Creates a placeholder for `value` and propagates shape info to it.""" # Note: setting ops.control_dependencies(None) ensures we always put # capturing placeholders outside of any control flow context. with ops.control_dependencies(None): placeholder = graph_placeholder( dtype=dtype or value.dtype, shape=value.shape, name=name) custom_gradient.copy_handle_data(value, placeholder) return placeholder
def _get_func_graph_for_branch(name_attr_list): """Generates and returns a FuncGraph for the given branch.""" inputs = op.inputs[1:] # First input is pred. input_shapes = [t.shape for t in inputs] func_graph = util.get_func_graph(op, input_shapes, name_attr_list.name) for external_t, internal_t in zip(inputs, func_graph.inputs): custom_gradient.copy_handle_data(external_t, internal_t) func_graph.reset_captures(zip(inputs, func_graph.inputs)) # Link the op so that the gradient code can use it. func_graph._forward_cond = op return func_graph
def _setup_functions_captures(self): """Setup captures and variables in restored functions.""" concrete_functions = sorted(self._proto.concrete_functions.items()) for name, proto in concrete_functions: concrete_function = self._concrete_functions[name] bound_inputs = [ self._get_tensor_from_node(node_id) for node_id in proto.bound_inputs ] bound_variables = [ self._nodes[node_id] for node_id in proto.bound_inputs if self._proto.nodes[node_id].WhichOneof("kind") == "variable" ] # TODO(andresp): This is only injecting the captured inputs into the # concrete function, note that we did not modify the FuncGraph # itself. concrete_function._captured_inputs = bound_inputs # pylint: disable=protected-access concrete_function._func_graph.variables = bound_variables # pylint: disable=protected-access if bound_inputs: for bound_input, internal_capture in zip( bound_inputs, concrete_function.inputs[-len(bound_inputs):]): if ds_values.is_distributed_variable(bound_input): concrete_function.graph.capture_distributed_variable( bound_input, internal_capture) else: concrete_function.graph._captures[ops.tensor_id( bound_input)] = ( # pylint: disable=protected-access bound_input, internal_capture) if internal_capture.dtype == dtypes.resource: if resource_variable_ops.is_resource_variable( bound_input): try: handle = bound_input.handle except ValueError: # For mirrored variables we'll copy handle data for components # as they get captured. pass else: custom_gradient.copy_handle_data( handle, internal_capture) else: custom_gradient.copy_handle_data( bound_input, internal_capture) # Setting "captures" first means "capture" won't create a new # placeholder for this input. concrete_function.graph.capture(bound_input)
def _get_func_graph_for_branch(name_attr_list): """Generates and returns a FuncGraph for the given branch.""" inputs = op.inputs[1:] # First input is pred. input_shapes = [t.shape for t in inputs] fdef = op.graph._get_function(name_attr_list.name).definition # `op.graph` may not be the same as `ops.get_default_graph()` e.g. # in the case of nested if ops or when the gradient is being computed # from inside a Defun. We build the `func_graph` with `op.graph` as its # `outer_graph`. This resembles how the `FuncGraph` was built in the # forward pass. We need this so that we can resolve references to tensors # in `func_graph` from its gradient graph in `_resolve_grad_inputs`. with op.graph.as_default(): func_graph = function_def_to_graph.function_def_to_graph( fdef, input_shapes) for external_t, internal_t in zip(inputs, func_graph.inputs): custom_gradient.copy_handle_data(external_t, internal_t) func_graph.reset_captures(zip(inputs, func_graph.inputs)) # Link the op so that the gradient code can use it. func_graph._forward_cond = op return func_graph
def _copy_handle_data(src_tensors, tgt_tensors): for src_t, tgt_t in zip(src_tensors, tgt_tensors): custom_gradient.copy_handle_data(src_t, tgt_t)
def _copy_handle_data(src_tensors, tgt_tensors): for src_t, tgt_t in zip(src_tensors, tgt_tensors): custom_gradient.copy_handle_data(src_t, tgt_t)