def capture_value(tensor_map, value, dtype, name): """Capture a value from outside the function, to pass in as an extra arg.""" captured_value = tensor_map.get(ops.tensor_id(value), None) if captured_value is None: captured_value = graph_placeholder(dtype=dtype or value.dtype, shape=value.shape, name=name) if captured_value.dtype == dtypes_module.resource: handle_data = value._handle_data # pylint: disable=protected-access captured_value._handle_data = handle_data # pylint: disable=protected-access if handle_data is not None and handle_data.is_set: # Ensure that shapes and dtypes are propagated. shapes, types = zip(*[(pair.shape, pair.dtype) for pair in handle_data.shape_and_type]) ranks = [ len(s.dim) if not s.unknown_rank else -1 for s in shapes ] shapes = [[d.size for d in s.dim] if not s.unknown_rank else None for s in shapes] with errors.raise_exception_on_not_ok_status() as status: pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper( captured_value._op._graph._c_graph, # pylint: disable=protected-access captured_value._as_tf_output(), # pylint: disable=protected-access shapes, ranks, types, status) tensor_map[ops.tensor_id(value)] = (value, captured_value) else: captured_value = captured_value[1] tape.record_operation("captured_value", [captured_value], [value], lambda x: [x]) return captured_value
def capture_value(tensor_map, value, dtype, name): """Capture a value from outside the function, to pass in as an extra arg.""" captured_value = tensor_map.get(ops.tensor_id(value), None) if captured_value is None: captured_value = graph_placeholder(dtype=dtype or value.dtype, shape=value.shape, name=name) if captured_value.dtype == dtypes_module.resource: if ops._USE_C_SHAPES: # pylint: disable=protected-access if isinstance(value, ops.EagerTensor): handle_data = value._handle_data # pylint: disable=protected-access else: handle_data = resource_variable_ops.get_resource_handle_data( value) else: handle_data = value._handle_data # pylint: disable=protected-access if handle_data is not None and handle_data.is_set: # pylint: disable=protected-access if ops._USE_C_SHAPES: pywrap_tensorflow.TFE_SetResourceHandleShapeAndType( captured_value.graph._c_graph, captured_value._as_tf_output(), handle_data.SerializeToString()) else: captured_value._handle_data = handle_data # pylint: enable=protected-access # Ensure that shapes and dtypes are propagated. shapes, types = zip(*[(pair.shape, pair.dtype) for pair in handle_data.shape_and_type]) ranks = [ len(s.dim) if not s.unknown_rank else -1 for s in shapes ] shapes = [[d.size for d in s.dim] if not s.unknown_rank else None for s in shapes] pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper( captured_value._op._graph._c_graph, # pylint: disable=protected-access captured_value._as_tf_output(), # pylint: disable=protected-access shapes, ranks, types) tensor_map[ops.tensor_id(value)] = (value, captured_value) else: captured_value = captured_value[1] tape.record_operation("captured_value", [captured_value], [value], lambda x: [x]) return captured_value
def copy_handle_data(source_t, target_t): """Copies HandleData for variant and resource type tensors if available. The CppShapeInferenceResult::HandleData proto contains information about the shapes and types of the element tensors of resource/variant type tensors. We need to copy this across function boundaries, i.e., when capturing a placeholder or when returning a function tensor as output. If we don't do this the element tensors will have unknown shapes, e.g., if a TensorList variant tensor is captured as a placeholder, elements popped from that list would have unknown shape. Args: source_t: The tensor to copy HandleData from. target_t: The tensor to copy HandleData to. """ if (target_t.dtype == dtypes.resource or target_t.dtype == dtypes.variant): if isinstance(source_t, ops.EagerTensor): handle_data = source_t._handle_data # pylint: disable=protected-access else: handle_data = resource_variable_ops.get_resource_handle_data( source_t) if (handle_data is not None and handle_data.is_set and handle_data.shape_and_type): # pylint: disable=protected-access pywrap_tensorflow.SetHandleShapeAndType( target_t.graph._c_graph, target_t._as_tf_output(), handle_data.SerializeToString()) # pylint: enable=protected-access # Ensure that shapes and dtypes are propagated. shapes, types = zip(*[(pair.shape, pair.dtype) for pair in handle_data.shape_and_type]) ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes] shapes = [ [d.size for d in s.dim] # pylint: disable=g-complex-comprehension if not s.unknown_rank else None for s in shapes ] pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper( target_t._op._graph._c_graph, # pylint: disable=protected-access target_t._as_tf_output(), # pylint: disable=protected-access shapes, ranks, types)