def _capture_helper(self, tensor, name): if (tensor.graph is not self._forward_graph or any(tensor is t for t in self._forward_graph.inputs) or any(tensor is t for t in self._forward_graph.outputs)): return super(_CondGradFuncGraph, self)._capture_helper(tensor, name) if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): # XLA does not yet support optionals, so capture intermediates directly. # TODO(skyewm,jpienaar): can XLA support optionals? if tensor not in self.external_captures: self.xla_intermediates.append(tensor) self.op_needs_rewrite = True return super(_CondGradFuncGraph, self)._capture_helper(tensor, name) tensor_id = ops.tensor_id(tensor) captured_tensor = self._indirect_captures.get(tensor_id) if captured_tensor is not None: return captured_tensor # 'tensor' is an uncaptured intermediate in the forward graph. # If it is not a resource, we wrap it in an optional in the forward graph # and capture the optional normally. We then unwrap the captured optional # value in the gradient graph to get the raw intermediate value. # If it is a resource, we trace the resource upto the input in the forward # graph and capture that. if tensor.dtype == dtypes.resource: # Index of the forward graph input corresponding to the resource tensor. index = util.resource_input_index( tensor.name, [t.name for t in self._forward_graph.inputs], {op.name: op.node_def for op in self._forward_graph.get_operations()}, self._forward_graph._functions) # This gets mapped to the corresponding If op input in # `_resolve_grad_inputs`. captured_tensor = super(_CondGradFuncGraph, self)._capture_helper( self._forward_graph.inputs[index], name) else: if tensor_id not in self._wrapped_intermediates: # If the gradient has already been computed for this If op, 'tensor' may # already be wrapped. for consumer in tensor.consumers(): if (consumer.type == "OptionalFromValue" and consumer.outputs[0] in self._forward_graph.outputs): optional = consumer.outputs[0] break else: # 'tensor' hasn't been wrapped, do it now. with self._forward_graph.as_default(): optional = gen_dataset_ops.optional_from_value([tensor]) self.op_needs_rewrite = True self._wrapped_intermediates[tensor_id] = optional optional = self._wrapped_intermediates[tensor_id] captured_optional = super(_CondGradFuncGraph, self)._capture_helper(optional, name) captured_tensor = gen_dataset_ops.optional_get_value( captured_optional, [tensor.dtype], [tensor.shape])[0] self._indirect_captures[tensor_id] = captured_tensor return captured_tensor
def _capture_helper(self, tensor, name): if (tensor.graph is not self._forward_graph or tensor in self._forward_graph.inputs or tensor in self._forward_graph.outputs): return super(_CondGradFuncGraph, self)._capture_helper(tensor, name) if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): # XLA does not yet support optionals, so capture intermediates directly. # TODO(skyewm,jpienaar): can XLA support optionals? if tensor not in self.captures: self.xla_intermediates.append(tensor) self.if_op_needs_rewrite = True return super(_CondGradFuncGraph, self)._capture_helper(tensor, name) captured_tensor = self._indirect_captures.get(tensor) if captured_tensor is not None: return captured_tensor # 'tensor' is an uncaptured intermediate in the forward graph. # If it is not a resource, we wrap it in an optional in the forward graph # and capture the optional normally. We then unwrap the captured optional # value in the gradient graph to get the raw intermediate value. # If it is a resource, we trace the resource upto the input in the forward # graph and capture that. if tensor.dtype == dtypes.resource: # Index of the forward graph input corresponding to the resource tensor. index = util.resource_input_index( tensor.name, [t.name for t in self._forward_graph.inputs], {op.name: op.node_def for op in self._forward_graph.get_operations()}, self._forward_graph._functions) # This gets mapped to the corresponding If op input in # `_resolve_grad_inputs`. captured_tensor = super(_CondGradFuncGraph, self)._capture_helper( self._forward_graph.inputs[index], name) else: if tensor not in self._wrapped_intermediates: # If the gradient has already been computed for this If op, 'tensor' may # already be wrapped. for consumer in tensor.consumers(): if (consumer.type == "OptionalFromValue" and consumer.outputs[0] in self._forward_graph.outputs): optional = consumer.outputs[0] break else: # 'tensor' hasn't been wrapped, do it now. with self._forward_graph.as_default(): optional = gen_dataset_ops.optional_from_value([tensor]) self.if_op_needs_rewrite = True self._wrapped_intermediates[tensor] = optional optional = self._wrapped_intermediates[tensor] captured_optional = super(_CondGradFuncGraph, self)._capture_helper(optional, name) captured_tensor = gen_dataset_ops.optional_get_value( captured_optional, [tensor.dtype], [tensor.shape])[0] self._indirect_captures[tensor] = captured_tensor return captured_tensor
def get_value(self, name=None): # TODO(b/110122868): Consolidate the restructuring logic with similar logic # in `Iterator.get_next()` and `StructuredFunctionWrapper`. with ops.name_scope(name, "OptionalGetValue", [self._variant_tensor]) as scope: # pylint: disable=protected-access return self._value_structure._from_tensor_list( gen_dataset_ops.optional_get_value( self._variant_tensor, name=scope, output_types=self._value_structure._flat_types, output_shapes=self._value_structure._flat_shapes))
def get_value(self, name=None): # TODO(b/110122868): Consolidate the restructuring logic with similar logic # in `Iterator.get_next()` and `StructuredFunctionWrapper`. with ops.name_scope(name, "OptionalGetValue", [self._variant_tensor]) as scope: # pylint: disable=protected-access return self._value_structure._from_tensor_list( gen_dataset_ops.optional_get_value( self._variant_tensor, name=scope, output_types=self._value_structure._flat_types, output_shapes=self._value_structure._flat_shapes))
def get_value(self, name=None): # TODO(b/110122868): Consolidate the restructuring logic with similar logic # in `Iterator.get_next()` and `StructuredFunctionWrapper`. with ops.name_scope(name, "OptionalGetValue", [self._variant_tensor]) as scope: return structure.from_tensor_list( self._value_structure, gen_dataset_ops.optional_get_value( self._variant_tensor, name=scope, output_types=structure.get_flat_tensor_types( self._value_structure), output_shapes=structure.get_flat_tensor_shapes( self._value_structure)))
def get_value(self, name=None): # TODO(b/110122868): Consolidate the restructuring logic with similar logic # in `Iterator.get_next()` and `StructuredFunctionWrapper`. with ops.name_scope(name, "OptionalGetValue", [self._variant_tensor]) as scope: with ops.colocate_with(self._variant_tensor): result = gen_dataset_ops.optional_get_value( self._variant_tensor, name=scope, output_types=structure.get_flat_tensor_types(self._element_spec), output_shapes=structure.get_flat_tensor_shapes(self._element_spec)) # NOTE: We do not colocate the deserialization of composite tensors # because not all ops are guaranteed to have non-GPU kernels. return structure.from_tensor_list(self._element_spec, result)
def _capture_helper(self, tensor, name): if (tensor.graph is not self._forward_graph or tensor in self._forward_graph.inputs or tensor in self._forward_graph.outputs): return super(_CondGradFuncGraph, self)._capture_helper(tensor, name) if control_flow_util.InXlaContext(ops.get_default_graph()): # XLA does not yet support optionals, so capture intermediates directly. # TODO(skyewm,jpienaar): can XLA support optionals? if tensor not in self.captures: self.xla_intermediates.append(tensor) self.if_op_needs_rewrite = True return super(_CondGradFuncGraph, self)._capture_helper(tensor, name) captured_tensor = self._indirect_captures.get(tensor) if captured_tensor is not None: return captured_tensor # 'tensor' is an uncaptured intermediate in the forward graph. We wrap it in # an optional in the forward graph and capture the optional normally. We # then unwrap the captured optional value in the gradient graph to get the # raw intermediate value. if tensor not in self._wrapped_intermediates: # If the gradient has already been computed for this If op, 'tensor' may # already be wrapped. for consumer in tensor.consumers(): if (consumer.type == "OptionalFromValue" and consumer.outputs[0] in self._forward_graph.outputs): optional = consumer.outputs[0] break else: # 'tensor' hasn't been wrapped, do it now. with self._forward_graph.as_default(): optional = gen_dataset_ops.optional_from_value([tensor]) self.if_op_needs_rewrite = True self._wrapped_intermediates[tensor] = optional optional = self._wrapped_intermediates[tensor] captured_optional = super(_CondGradFuncGraph, self)._capture_helper(optional, name) captured_tensor = gen_dataset_ops.optional_get_value( captured_optional, [tensor.dtype], [tensor.shape])[0] self._indirect_captures[tensor] = captured_tensor return captured_tensor
def _capture_helper(self, tensor, name): if (tensor.graph is not self._forward_graph or tensor in self._forward_graph.inputs or tensor in self._forward_graph.outputs): return super(_CondGradFuncGraph, self)._capture_helper(tensor, name) if control_flow_util.InXlaContext(ops.get_default_graph()): # XLA does not yet support optionals, so capture intermediates directly. # TODO(skyewm,jpienaar): can XLA support optionals? if tensor not in self.captures: self.xla_intermediates.append(tensor) self.if_op_needs_rewrite = True return super(_CondGradFuncGraph, self)._capture_helper(tensor, name) captured_tensor = self._indirect_captures.get(tensor) if captured_tensor is not None: return captured_tensor # 'tensor' is an uncaptured intermediate in the forward graph. We wrap it in # an optional in the forward graph and capture the optional normally. We # then unwrap the captured optional value in the gradient graph to get the # raw intermediate value. if tensor not in self._wrapped_intermediates: # If the gradient has already been computed for this If op, 'tensor' may # already be wrapped. for consumer in tensor.consumers(): if (consumer.type == "OptionalFromValue" and consumer.outputs[0] in self._forward_graph.outputs): optional = consumer.outputs[0] break else: # 'tensor' hasn't been wrapped, do it now. with self._forward_graph.as_default(): optional = gen_dataset_ops.optional_from_value([tensor]) self.if_op_needs_rewrite = True self._wrapped_intermediates[tensor] = optional optional = self._wrapped_intermediates[tensor] captured_optional = super(_CondGradFuncGraph, self)._capture_helper( optional, name) captured_tensor = gen_dataset_ops.optional_get_value( captured_optional, [tensor.dtype], [tensor.shape])[0] self._indirect_captures[tensor] = captured_tensor return captured_tensor
def get_value(self, name=None): # TODO(b/110122868): Consolidate the restructuring logic with similar logic # in `Iterator.get_next()` and `StructuredFunctionWrapper`. with ops.name_scope(name, "OptionalGetValue", [self._variant_tensor]) as scope: return sparse.deserialize_sparse_tensors( nest.pack_sequence_as( self._output_types, gen_dataset_ops.optional_get_value( self._variant_tensor, name=scope, output_types=nest.flatten( sparse.as_dense_types(self._output_types, self._output_classes)), output_shapes=nest.flatten( sparse.as_dense_shapes(self._output_shapes, self._output_classes)))), self._output_types, self._output_shapes, self._output_classes)
def get_value(self, name=None): # TODO(b/110122868): Consolidate the restructuring logic with similar logic # in `Iterator.get_next()` and `StructuredFunctionWrapper`. with ops.name_scope(name, "OptionalGetValue", [self._variant_tensor]) as scope: return sparse.deserialize_sparse_tensors( nest.pack_sequence_as( self._output_types, gen_dataset_ops.optional_get_value( self._variant_tensor, name=scope, output_types=nest.flatten( sparse.as_dense_types(self._output_types, self._output_classes)), output_shapes=nest.flatten( sparse.as_dense_shapes(self._output_shapes, self._output_classes)))), self._output_types, self._output_shapes, self._output_classes)
def _capture_helper(self, tensor, name): if (tensor.graph is not self._forward_graph or tensor in self._forward_graph.inputs or tensor in self._forward_graph.outputs): return super(_CondGradFuncGraph, self)._capture_helper(tensor, name) # 'tensor' is an intermediate in the forward graph. We find the corresonding # optional tensor, which is output from the If op, and capture it as # normal. We then unwrap the captured optional value to get the raw # intermediate value. for consumer in tensor.consumers(): if (consumer.type == "OptionalFromValue" and consumer.outputs[0] in self._forward_graph.outputs): optional = consumer.outputs[0] captured_optional = super(_CondGradFuncGraph, self)._capture_helper( optional, name) return gen_dataset_ops.optional_get_value( captured_optional, [tensor.dtype], [tensor.shape])[0] raise ValueError( "Couldn't find OptionalFromValue consumer for tensor '%s'.\n" "This is an internal bug, please report at " "https://github.com/tensorflow/tensorflow/issues." % tensor.name)
def _capture_helper(self, tensor, name): if (tensor.graph is not self._forward_graph or tensor in self._forward_graph.inputs or tensor in self._forward_graph.outputs): return super(_CondGradFuncGraph, self)._capture_helper(tensor, name) # 'tensor' is an intermediate in the forward graph. We find the corresonding # optional tensor, which is output from the If op, and capture it as # normal. We then unwrap the captured optional value to get the raw # intermediate value. for consumer in tensor.consumers(): if (consumer.type == "OptionalFromValue" and consumer.outputs[0] in self._forward_graph.outputs): optional = consumer.outputs[0] captured_optional = super(_CondGradFuncGraph, self)._capture_helper( optional, name) return gen_dataset_ops.optional_get_value( captured_optional, [tensor.dtype], [tensor.shape])[0] raise ValueError( "Couldn't find OptionalFromValue consumer for tensor '%s'.\n" "This is an internal bug, please report at " "https://github.com/tensorflow/tensorflow/issues." % tensor.name)
def _OptionalFromValueGrad(op, grad): return gen_dataset_ops.optional_get_value(grad, [t.dtype for t in op.inputs], [t.shape for t in op.inputs])
def _OptionalFromValueGrad(op, grad): return gen_dataset_ops.optional_get_value( grad, [t.dtype for t in op.inputs], [t.shape for t in op.inputs])