def _capture_tensor_as_extra_input(self, tensor, name=None): # Substitute with a placeholder. self.extra_inputs.append(tensor) # Hoist the new input placeholder out of any control flow context # we're currently in. with ops.control_dependencies(None): ph = array_ops.placeholder(tensor.dtype, shape=tensor.get_shape(), name=name) # pylint: disable=protected-access if isinstance(tensor, ops.EagerTensor): handle_data = tensor._handle_data if handle_data: handle_data = handle_data.SerializeToString() else: handle_data = c_api.GetHandleShapeAndType(tensor.graph._c_graph, tensor._as_tf_output()) if handle_data: c_api.SetHandleShapeAndType(ph.graph._c_graph, ph._as_tf_output(), compat.as_bytes(handle_data)) # pylint: enable=protected-access self.inputs.append(ph) self._captured[tensor] = ph self.extra_args.append(ph) if _is_guaranteed_const(tensor): with ops.control_dependencies(None): return array_ops.guarantee_const(ph) else: return ph
def copy_handle_data(source_t, target_t): """Copies HandleData for variant and resource type tensors if available. The CppShapeInferenceResult::HandleData proto contains information about the shapes and types of the element tensors of resource/variant type tensors. We need to copy this across function boundaries, i.e., when capturing a placeholder or when returning a function tensor as output. If we don't do this the element tensors will have unknown shapes, e.g., if a TensorList variant tensor is captured as a placeholder, elements popped from that list would have unknown shape. Args: source_t: The tensor to copy HandleData from. target_t: The tensor to copy HandleData to. """ if (target_t.dtype == dtypes.resource or target_t.dtype == dtypes.variant): if isinstance(source_t, ops.EagerTensor): handle_data = source_t._handle_data # pylint: disable=protected-access else: handle_data = resource_variable_ops.get_resource_handle_data( source_t) if (handle_data is not None and handle_data.is_set and handle_data.shape_and_type): # pylint: disable=protected-access pywrap_tensorflow.SetHandleShapeAndType( target_t.graph._c_graph, target_t._as_tf_output(), handle_data.SerializeToString()) # pylint: enable=protected-access # Ensure that shapes and dtypes are propagated. shapes, types = zip(*[(pair.shape, pair.dtype) for pair in handle_data.shape_and_type]) ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes] shapes = [ [d.size for d in s.dim] # pylint: disable=g-complex-comprehension if not s.unknown_rank else None for s in shapes ] pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper( target_t._op._graph._c_graph, # pylint: disable=protected-access target_t._as_tf_output(), # pylint: disable=protected-access shapes, ranks, types)