def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode): """Creates a variable handle with information to do shape inference.""" container = ops.get_default_graph()._container # pylint: disable=protected-access if container is None: container = "" handle = resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, shared_name=shared_name, name=name, container=container) if graph_mode: return handle with context.graph_mode(), ops.Graph().as_default() as graph: h = resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, shared_name=shared_name, name=name, container=container) # Tensor._handle_data contains information for the shape-inference code to # know the shape and dtype of the variable pointed to by a handle. Since # shape inference doesn't run in eager mode we copy this data here for when # the handle is captured by an eager mode function. # pylint: disable=protected-access handle._handle_data = resource_variable_ops.get_resource_handle_data(h) # pylint: enable=protected-access # Clean up op->graph->op reference cycles. ops.dismantle_graph(graph) return handle
def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode): """Creates a variable handle with information to do shape inference.""" container = ops.get_default_graph()._container # pylint: disable=protected-access if container is None: container = "" handle = resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, shared_name=shared_name, name=name, container=container) if graph_mode: return handle with context.graph_mode(), ops.Graph().as_default() as graph: h = resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, shared_name=shared_name, name=name, container=container) # Tensor._handle_data contains information for the shape-inference code to # know the shape and dtype of the variable pointed to by a handle. Since # shape inference doesn't run in eager mode we copy this data here for when # the handle is captured by an eager mode function. # pylint: disable=protected-access handle._handle_data = resource_variable_ops.get_resource_handle_data(h) # pylint: enable=protected-access # Clean up op->graph->op reference cycles. ops.dismantle_graph(graph) return handle
def copy_handle_data(source_t, target_t): """Copies HandleData for variant and resource type tensors if available. The CppShapeInferenceResult::HandleData proto contains information about the shapes and types of the element tensors of resource/variant type tensors. We need to copy this across function boundaries, i.e., when capturing a placeholder or when returning a function tensor as output. If we don't do this the element tensors will have unknown shapes, e.g., if a TensorList variant tensor is captured as a placeholder, elements popped from that list would have unknown shape. Args: source_t: The tensor to copy HandleData from. target_t: The tensor to copy HandleData to. """ if (target_t.dtype == dtypes.resource or target_t.dtype == dtypes.variant): if isinstance(source_t, ops.EagerTensor): handle_data = source_t._handle_data # pylint: disable=protected-access else: handle_data = resource_variable_ops.get_resource_handle_data( source_t) if (handle_data is not None and handle_data.is_set and handle_data.shape_and_type): # pylint: disable=protected-access if isinstance(target_t, ops.EagerTensor): target_t._handle_data = handle_data return pywrap_tf_session.SetHandleShapeAndType( target_t.graph._c_graph, target_t._as_tf_output(), handle_data.SerializeToString())
def _copy_source(s, graph, op_map, handle_captures, inverse_captures, base_graph): """Create a source in a graph based on a Tensor from a different graph. This function creates a placeholder analog of `s` in a graph with the following behavior: 1) If s is a captured Tensor or Variable and handle_captures is set to True, simply capture it in the new graph as well. 2) If s is a PlaceholderWithDefault whose default is a constant, preserve said default in the new graph. 3) When applicable, copy resource variable metadata from `s` to the newly created placeholder. Args: s: The source of interest. graph: The destination graph. op_map: A dict mapping ops and tensors in the old graph to the new one. handle_captures: A boolean indicating whether to re-capture s in the new graph or simply create a vanilla placeholder. inverse_captures: A dict mapping s back to the Tensor or Variable that it captures. base_graph: The graph being copied from. """ if handle_captures and s in inverse_captures: copied_placeholder = graph.capture(inverse_captures[s], name=s.op.name) elif s.op.type == "PlaceholderWithDefault" and _constant_inputs(s): # Copy the default value to the graph. default_value = s.op.inputs[0] unavailable_inputs, unavailable_control_inputs = _copy_non_source( op=default_value.op, graph=graph, op_map=op_map, base_graph=base_graph) if unavailable_inputs or unavailable_control_inputs: raise AssertionError( "Could not copy source node {} because it has inputs." .format(default_value)) with ops.device(s.op.device): copied_placeholder = array_ops.placeholder_with_default( input=op_map[default_value], shape=s.shape, name=s.op.name) else: with ops.device(s.op.device): copied_placeholder = array_ops.placeholder( dtype=s.dtype, shape=s.shape, name=s.op.name) base_handle = resource_variable_ops.get_resource_handle_data(s) if base_handle.shape_and_type: resource_variable_ops._set_handle_shapes_and_types( # pylint: disable=protected-access copied_placeholder, base_handle, graph_mode=True) op_map[s] = copied_placeholder # Add an entry for the op of the source tensor so that if there are any nodes # depending on that op via control dependencies it can work correctly. op_map[s.op] = copied_placeholder.op
def _copy_source(s, graph, op_map, handle_captures, inverse_captures): """Create a source in a graph based on a Tensor from a different graph. This function creates a placeholder analog of `s` in a graph with the following behavior: 1) If s is a captured Tensor or Variable and handle_captures is set to True, simply capture it in the new graph as well. 2) If s is a PlaceholderWithDefault whose default is a constant, preserve said default in the new graph. 3) When applicable, copy resource variable metadata from `s` to the newly created placeholder. Args: s: The source of interest. graph: The destination graph. op_map: A dict mapping ops and tensors in the old graph to the new one. handle_captures: A boolean indicating whether to re-capture s in the new graph or simply create a vanilla placeholder. inverse_captures: A dict mapping s back to the Tensor or Variable that it captures. """ if handle_captures and s in inverse_captures: copied_placeholder = graph.capture(inverse_captures[s], name=s.op.name) elif s.op.type == "PlaceholderWithDefault" and _constant_inputs(s): # Copy the default value to the graph. default_value = s.op.inputs[0] unavailable_inputs, unavailable_control_inputs = _copy_non_source( op=default_value.op, graph=graph, op_map=op_map) if unavailable_inputs or unavailable_control_inputs: raise AssertionError( "Could not copy source node {} because it has inputs." .format(default_value)) with ops.device(s.op.device): copied_placeholder = array_ops.placeholder_with_default( input=op_map[default_value], shape=s.shape, name=s.op.name) else: with ops.device(s.op.device): copied_placeholder = array_ops.placeholder( dtype=s.dtype, shape=s.shape, name=s.op.name) base_handle = resource_variable_ops.get_resource_handle_data(s) if base_handle.shape_and_type: resource_variable_ops._set_handle_shapes_and_types( # pylint: disable=protected-access copied_placeholder, base_handle, graph_mode=True) op_map[s] = copied_placeholder # Add an entry for the op of the source tensor so that if there are any nodes # depending on that op via control dependencies it can work correctly. op_map[s.op] = copied_placeholder.op
def capture_value(tensor_map, value, dtype, name): """Capture a value from outside the function, to pass in as an extra arg.""" captured_value = tensor_map.get(ops.tensor_id(value), None) if captured_value is None: captured_value = graph_placeholder(dtype=dtype or value.dtype, shape=value.shape, name=name) if captured_value.dtype == dtypes_module.resource: if ops._USE_C_SHAPES: # pylint: disable=protected-access if isinstance(value, ops.EagerTensor): handle_data = value._handle_data # pylint: disable=protected-access else: handle_data = resource_variable_ops.get_resource_handle_data( value) else: handle_data = value._handle_data # pylint: disable=protected-access if handle_data is not None and handle_data.is_set: # pylint: disable=protected-access if ops._USE_C_SHAPES: pywrap_tensorflow.TFE_SetResourceHandleShapeAndType( captured_value.graph._c_graph, captured_value._as_tf_output(), handle_data.SerializeToString()) else: captured_value._handle_data = handle_data # pylint: enable=protected-access # Ensure that shapes and dtypes are propagated. shapes, types = zip(*[(pair.shape, pair.dtype) for pair in handle_data.shape_and_type]) ranks = [ len(s.dim) if not s.unknown_rank else -1 for s in shapes ] shapes = [[d.size for d in s.dim] if not s.unknown_rank else None for s in shapes] pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper( captured_value._op._graph._c_graph, # pylint: disable=protected-access captured_value._as_tf_output(), # pylint: disable=protected-access shapes, ranks, types) tensor_map[ops.tensor_id(value)] = (value, captured_value) else: captured_value = captured_value[1] tape.record_operation("captured_value", [captured_value], [value], lambda x: [x]) return captured_value
def copy_handle_data(source_t, target_t): """Copies HandleData for variant and resource type tensors if available. The CppShapeInferenceResult::HandleData proto contains information about the shapes and types of the element tensors of resource/variant type tensors. We need to copy this across function boundaries, i.e., when capturing a placeholder or when returning a function tensor as output. If we don't do this the element tensors will have unknown shapes, e.g., if a TensorList variant tensor is captured as a placeholder, elements popped from that list would have unknown shape. Args: source_t: The tensor to copy HandleData from. target_t: The tensor to copy HandleData to. """ if (target_t.dtype == dtypes.resource or target_t.dtype == dtypes.variant): if isinstance(source_t, ops.EagerTensor): handle_data = source_t._handle_data # pylint: disable=protected-access else: handle_data = resource_variable_ops.get_resource_handle_data( source_t) if (handle_data is not None and handle_data.is_set and handle_data.shape_and_type): # pylint: disable=protected-access pywrap_tf_session.SetHandleShapeAndType( target_t.graph._c_graph, target_t._as_tf_output(), handle_data.SerializeToString()) # pylint: enable=protected-access # Ensure that shapes and dtypes are propagated. shapes, types = zip(*[(pair.shape, pair.dtype) for pair in handle_data.shape_and_type]) ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes] shapes = [ [d.size for d in s.dim] # pylint: disable=g-complex-comprehension if not s.unknown_rank else None for s in shapes ] pywrap_tf_session.TF_GraphSetOutputHandleShapesAndTypes_wrapper( target_t._op._graph._c_graph, # pylint: disable=protected-access target_t._as_tf_output(), # pylint: disable=protected-access shapes, ranks, types)
def capture_value(tensor_map, value, dtype, name): """Capture a value from outside the function, to pass in as an extra arg.""" captured_value = tensor_map.get(ops.tensor_id(value), None) if captured_value is None: captured_value = graph_placeholder( dtype=dtype or value.dtype, shape=value.shape, name=name) if captured_value.dtype == dtypes_module.resource: if ops._USE_C_SHAPES: # pylint: disable=protected-access if isinstance(value, ops.EagerTensor): handle_data = value._handle_data # pylint: disable=protected-access else: handle_data = resource_variable_ops.get_resource_handle_data(value) else: handle_data = value._handle_data # pylint: disable=protected-access if handle_data is not None and handle_data.is_set: # pylint: disable=protected-access if ops._USE_C_SHAPES: pywrap_tensorflow.SetResourceHandleShapeAndType( captured_value.graph._c_graph, captured_value._as_tf_output(), handle_data.SerializeToString()) else: captured_value._handle_data = handle_data # pylint: enable=protected-access # Ensure that shapes and dtypes are propagated. shapes, types = zip(*[(pair.shape, pair.dtype) for pair in handle_data.shape_and_type]) ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes] shapes = [[d.size for d in s.dim] if not s.unknown_rank else None for s in shapes] pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper( captured_value._op._graph._c_graph, # pylint: disable=protected-access captured_value._as_tf_output(), # pylint: disable=protected-access shapes, ranks, types) tensor_map[ops.tensor_id(value)] = (value, captured_value) else: captured_value = captured_value[1] tape.record_operation("captured_value", [captured_value], [value], lambda x: [x]) return captured_value
def copy_handle_data(source_t, target_t): """Copies HandleData for variant and resource type tensors if available. The CppShapeInferenceResult::HandleData proto contains information about the shapes and types of the element tensors of resource/variant type tensors. We need to copy this across function boundaries, i.e., when capturing a placeholder or when returning a function tensor as output. If we don't do this the element tensors will have unknown shapes, e.g., if a TensorList variant tensor is captured as a placeholder, elements popped from that list would have unknown shape. Args: source_t: The tensor to copy HandleData from. target_t: The tensor to copy HandleData to. """ if (target_t.dtype == dtypes.resource or target_t.dtype == dtypes.variant): if isinstance(source_t, ops.EagerTensor): handle_data = source_t._handle_data # pylint: disable=protected-access else: handle_data = resource_variable_ops.get_resource_handle_data(source_t) if (handle_data is not None and handle_data.is_set and handle_data.shape_and_type): # pylint: disable=protected-access pywrap_tensorflow.SetHandleShapeAndType(target_t.graph._c_graph, target_t._as_tf_output(), handle_data.SerializeToString()) # pylint: enable=protected-access # Ensure that shapes and dtypes are propagated. shapes, types = zip(*[(pair.shape, pair.dtype) for pair in handle_data.shape_and_type]) ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes] shapes = [[d.size for d in s.dim] if not s.unknown_rank else None for s in shapes] pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper( target_t._op._graph._c_graph, # pylint: disable=protected-access target_t._as_tf_output(), # pylint: disable=protected-access shapes, ranks, types)