def _resource_capture_helper(self, tensor): """Returns the captured resource tensor. Resource-type tensors are not accumulated. If a resource tensor exists in the loop body it must either be a loop input or an output of a nested While op inside the loop body which had captured the external resource. Args: tensor: the external resource Tensor to be captured. Returns: Tensor in this graph. """ assert tensor.dtype == dtypes.resource index = util.resource_input_index( tensor.name, [t.name for t in self._forward_graph.inputs], {op.name: op.node_def for op in self._forward_graph.get_operations()}, self._forward_graph._functions) input_placeholder = self._forward_graph.inputs[index] tensor_in_outer_graph = self._forward_graph._while.inputs[index] assert input_placeholder.dtype == dtypes.resource assert tensor_in_outer_graph.dtype == dtypes.resource # This must be a loop invariant. assert input_placeholder is self._forward_graph.outputs[index], ( "Resource tensors must be loop invariants %s." % tensor_in_outer_graph) self._indirect_captures[ops.tensor_id(tensor)] = self.capture( tensor_in_outer_graph, whitelisted=True) return self._indirect_captures[ops.tensor_id(tensor)]
def _resource_capture_helper(self, tensor): """Returns the captured resource tensor. Resource-type tensors are not accumulated. If a resource tensor exists in the loop body it must either be a loop input or an output of a nested While op inside the loop body which had captured the external resource. Args: tensor: the external resource Tensor to be captured. Returns: Tensor in this graph. """ assert tensor.dtype == dtypes.resource index = util.resource_input_index( tensor.name, [t.name for t in self._forward_graph.inputs], {op.name: op.node_def for op in self._forward_graph.get_operations()}, self._forward_graph._functions) input_placeholder = self._forward_graph.inputs[index] tensor_in_outer_graph = self._forward_graph._while.inputs[index] assert input_placeholder.dtype == dtypes.resource assert tensor_in_outer_graph.dtype == dtypes.resource # This must be a loop invariant. assert input_placeholder == self._forward_graph.outputs[index], ( "Resource tensors must be loop invariants %s." % tensor_in_outer_graph) self._indirect_captures[tensor] = self.capture( tensor_in_outer_graph, whitelisted=True) return self._indirect_captures[tensor]
def _capture_helper(self, tensor, name): if (tensor.graph is not self._forward_graph or any(tensor is t for t in self._forward_graph.inputs) or any(tensor is t for t in self._forward_graph.outputs)): return super(_CondGradFuncGraph, self)._capture_helper(tensor, name) if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): # XLA does not yet support optionals, so capture intermediates directly. # TODO(skyewm,jpienaar): can XLA support optionals? if tensor not in self.external_captures: self.xla_intermediates.append(tensor) self.op_needs_rewrite = True return super(_CondGradFuncGraph, self)._capture_helper(tensor, name) tensor_id = ops.tensor_id(tensor) captured_tensor = self._indirect_captures.get(tensor_id) if captured_tensor is not None: return captured_tensor # 'tensor' is an uncaptured intermediate in the forward graph. # If it is not a resource, we wrap it in an optional in the forward graph # and capture the optional normally. We then unwrap the captured optional # value in the gradient graph to get the raw intermediate value. # If it is a resource, we trace the resource upto the input in the forward # graph and capture that. if tensor.dtype == dtypes.resource: # Index of the forward graph input corresponding to the resource tensor. index = util.resource_input_index( tensor.name, [t.name for t in self._forward_graph.inputs], {op.name: op.node_def for op in self._forward_graph.get_operations()}, self._forward_graph._functions) # This gets mapped to the corresponding If op input in # `_resolve_grad_inputs`. captured_tensor = super(_CondGradFuncGraph, self)._capture_helper( self._forward_graph.inputs[index], name) else: if tensor_id not in self._wrapped_intermediates: # If the gradient has already been computed for this If op, 'tensor' may # already be wrapped. for consumer in tensor.consumers(): if (consumer.type == "OptionalFromValue" and consumer.outputs[0] in self._forward_graph.outputs): optional = consumer.outputs[0] break else: # 'tensor' hasn't been wrapped, do it now. with self._forward_graph.as_default(): optional = gen_dataset_ops.optional_from_value([tensor]) self.op_needs_rewrite = True self._wrapped_intermediates[tensor_id] = optional optional = self._wrapped_intermediates[tensor_id] captured_optional = super(_CondGradFuncGraph, self)._capture_helper(optional, name) captured_tensor = gen_dataset_ops.optional_get_value( captured_optional, [tensor.dtype], [tensor.shape])[0] self._indirect_captures[tensor_id] = captured_tensor return captured_tensor
def _capture_helper(self, tensor, name): if (tensor.graph is not self._forward_graph or tensor in self._forward_graph.inputs or tensor in self._forward_graph.outputs): return super(_CondGradFuncGraph, self)._capture_helper(tensor, name) if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): # XLA does not yet support optionals, so capture intermediates directly. # TODO(skyewm,jpienaar): can XLA support optionals? if tensor not in self.captures: self.xla_intermediates.append(tensor) self.if_op_needs_rewrite = True return super(_CondGradFuncGraph, self)._capture_helper(tensor, name) captured_tensor = self._indirect_captures.get(tensor) if captured_tensor is not None: return captured_tensor # 'tensor' is an uncaptured intermediate in the forward graph. # If it is not a resource, we wrap it in an optional in the forward graph # and capture the optional normally. We then unwrap the captured optional # value in the gradient graph to get the raw intermediate value. # If it is a resource, we trace the resource upto the input in the forward # graph and capture that. if tensor.dtype == dtypes.resource: # Index of the forward graph input corresponding to the resource tensor. index = util.resource_input_index( tensor.name, [t.name for t in self._forward_graph.inputs], {op.name: op.node_def for op in self._forward_graph.get_operations()}, self._forward_graph._functions) # This gets mapped to the corresponding If op input in # `_resolve_grad_inputs`. captured_tensor = super(_CondGradFuncGraph, self)._capture_helper( self._forward_graph.inputs[index], name) else: if tensor not in self._wrapped_intermediates: # If the gradient has already been computed for this If op, 'tensor' may # already be wrapped. for consumer in tensor.consumers(): if (consumer.type == "OptionalFromValue" and consumer.outputs[0] in self._forward_graph.outputs): optional = consumer.outputs[0] break else: # 'tensor' hasn't been wrapped, do it now. with self._forward_graph.as_default(): optional = gen_dataset_ops.optional_from_value([tensor]) self.if_op_needs_rewrite = True self._wrapped_intermediates[tensor] = optional optional = self._wrapped_intermediates[tensor] captured_optional = super(_CondGradFuncGraph, self)._capture_helper(optional, name) captured_tensor = gen_dataset_ops.optional_get_value( captured_optional, [tensor.dtype], [tensor.shape])[0] self._indirect_captures[tensor] = captured_tensor return captured_tensor