def stop(): """Stop current profiling session and return its result. Returns: A binary string of tensorflow.tpu.Trace. User can write the string to file for offline analysis by tensorboard. Raises: ProfilerNotRunningError: If there is no active profiling session. """ global _profiler global _run_num with _profiler_lock: if _profiler is None: raise ProfilerNotRunningError( 'Cannot stop profiling. No profiler is running.') if context.default_execution_mode == context.EAGER_MODE: context.context().executor.wait() with c_api_util.tf_buffer() as buffer_: pywrap_tensorflow.TFE_ProfilerSerializeToString(_profiler, buffer_) result = pywrap_tensorflow.TF_GetBuffer(buffer_) pywrap_tensorflow.TFE_DeleteProfiler(_profiler) _profiler = None _run_num += 1 return result
def make_function_def(name, graph, operations, inputs, outputs): """Makes FunctionDef proto and defined function. Args: name: the function name graph: the graph from which to build the function operations: the operations in the function body inputs: tensors to be used as function arguments outputs: tensors to be returned from the function Returns: fdef: a FunctionDef protocol buffer for the function fn: a wrapped TF_Function for the function """ with errors.raise_exception_on_not_ok_status() as status: fn = pywrap_tensorflow.TF_GraphToFunction_wrapper( graph._c_graph, # pylint: disable=protected-access compat.as_str(name), False, [o._c_op for o in operations], # pylint: disable=protected-access [t._as_tf_output() for t in inputs], # pylint: disable=protected-access [t._as_tf_output() for t in outputs], # pylint: disable=protected-access [], None, compat.as_str(""), status) # TODO(apassos) avoid creating a FunctionDef (specially to grab the signature, # but also in general it's nice not to depend on it. with c_api_util.tf_buffer() as buffer_: with errors.raise_exception_on_not_ok_status() as status: pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status) proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) fdef = function_pb2.FunctionDef() fdef.ParseFromString(compat.as_bytes(proto_data)) return fdef, fn
def value(self): """Retrieves the current value.""" with c_api_util.tf_buffer() as buffer_: pywrap_tensorflow.TFE_MonitoringStringGaugeCellValue( self._cell, buffer_) value = pywrap_tensorflow.TF_GetBuffer(buffer_).decode('utf-8') return value
def stop(): """Stop current profiling session and return its result. Returns: A binary string of tensorflow.tpu.Trace. User can write the string to file for offline analysis by tensorboard. Raises: ProfilerNotRunningError: If there is no active profiling session. """ global _profiler global _run_num with _profiler_lock: if _profiler is None: raise ProfilerNotRunningError( 'Cannot stop profiling. No profiler is running.') with c_api_util.tf_buffer() as buffer_: pywrap_tensorflow.TFE_ProfilerSerializeToString( context.context()._handle, # pylint: disable=protected-access _profiler, buffer_) result = pywrap_tensorflow.TF_GetBuffer(buffer_) pywrap_tensorflow.TFE_DeleteProfiler(_profiler) _profiler = None _run_num += 1 return result
def stop(): """Stop current profiling session and return its result. Returns: A binary string of tensorflow.tpu.Trace. User can write the string to file for offline analysis by tensorboard. Raises: AssertionError: If there is no active profiling session. """ global _profiler global _run_num if _profiler is None: raise AssertionError('Cannot stop profiling. No profiler is running.') with c_api_util.tf_buffer() as buffer_: pywrap_tensorflow.TFE_ProfilerSerializeToString( context.context()._handle, # pylint: disable=protected-access _profiler, buffer_) result = pywrap_tensorflow.TF_GetBuffer(buffer_) with _profiler_lock: pywrap_tensorflow.TFE_DeleteProfiler(_profiler) _profiler = None _run_num += 1 return result
def function_def_from_tf_function(c_func): """Converts a SWIG-wrapped TF_Function* to a FunctionDef proto.""" with c_api_util.tf_buffer() as buf: c_api.TF_FunctionToFunctionDef(c_func, buf) data = c_api.TF_GetBuffer(buf) fdef = function_pb2.FunctionDef() fdef.ParseFromString(compat.as_bytes(data)) return fdef
def definition(self): """Function definition proto.""" self._create_definition_if_needed() if self._c_func: with c_api_util.tf_buffer() as buf: c_api.TF_FunctionToFunctionDef(self._c_func.func, buf) fdef = function_pb2.FunctionDef() proto_data = c_api.TF_GetBuffer(buf) fdef.ParseFromString(compat.as_bytes(proto_data)) return fdef return self._definition
def get_resource_handle_data(graph_op): assert ops._USE_C_SHAPES # pylint: disable=protected-access assert type(graph_op) == ops.Tensor # pylint: disable=unidiomatic-typecheck with c_api_util.tf_buffer() as buf: pywrap_tensorflow.TFE_GetResourceHandleShapeAndType( graph_op.graph._c_graph, graph_op._as_tf_output(), buf) # pylint: disable=protected-access data = pywrap_tensorflow.TF_GetBuffer(buf) return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString( compat.as_bytes(data))
def definition(self): """Function definition proto.""" self._create_definition_if_needed() if self._c_func: with c_api_util.tf_buffer() as buf: with errors.raise_exception_on_not_ok_status() as status: c_api.TF_FunctionToFunctionDef(self._c_func, buf, status) fdef = function_pb2.FunctionDef() proto_data = c_api.TF_GetBuffer(buf) fdef.ParseFromString(compat.as_bytes(proto_data)) return fdef return self._definition
def value(self): """Retrieves the current distribution of samples. Returns: A HistogramProto describing the distribution of samples. """ with c_api_util.tf_buffer() as buffer_: pywrap_tensorflow.TFE_MonitoringSamplerCellValue(self._cell, buffer_) proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) histogram_proto = summary_pb2.HistogramProto() histogram_proto.ParseFromString(compat.as_bytes(proto_data)) return histogram_proto
def value(self): """Retrieves the current distribution of samples. Returns: A HistogramProto describing the distribution of samples. """ with c_api_util.tf_buffer() as buffer_: pywrap_tfe.TFE_MonitoringSamplerCellValue(self._cell, buffer_) proto_data = pywrap_tf_session.TF_GetBuffer(buffer_) histogram_proto = summary_pb2.HistogramProto() histogram_proto.ParseFromString(compat.as_bytes(proto_data)) return histogram_proto
def definition(self): """Function definition proto.""" self._create_definition_if_needed() if self._c_func: with c_api_util.tf_buffer() as buf: c_api.TF_FunctionToFunctionDef(self._c_func.func, buf) fdef = function_pb2.FunctionDef() proto_data = c_api.TF_GetBuffer(buf) fdef.ParseFromString(compat.as_bytes(proto_data)) with ops.init_scope(): if context.executing_eagerly(): context.add_function(self._c_func.func) self._function_deleter = _DefinedFunctionDeleter( fdef.signature.name) return fdef return self._definition
def __init__(self, name, graph, operations, inputs, outputs, attrs): """Initializes an eager defined function. Args: name: str, the name for the created function. graph: Graph, the graph containing the operations in the function operations: list of Operation; the subset of operations in the graph which will be in the function inputs: the tensors in the graph to be used as inputs to the function outputs: the tensors in the graph which will be outputs to the function attrs: dict mapping names of attributes to their AttrValue values """ fn = pywrap_tensorflow.TF_GraphToFunction_wrapper( graph._c_graph, # pylint: disable=protected-access compat.as_str(name), False, [o._c_op for o in operations], # pylint: disable=protected-access [t._as_tf_output() for t in inputs], # pylint: disable=protected-access [t._as_tf_output() for t in outputs], # pylint: disable=protected-access [], None, compat.as_str("")) for name, attr_value in attrs.items(): serialized = attr_value.SerializeToString() # TODO(iga): this creates and deletes a new TF_Status for every attr. # It might be worth creating a convenient way to re-use status. pywrap_tensorflow.TF_FunctionSetAttrValueProto( fn, compat.as_str(name), serialized) # TODO(apassos) avoid creating a FunctionDef (specially to grab the # signature, but also in general it's nice not to depend on it. with c_api_util.tf_buffer() as buffer_: pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_) proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) function_def = function_pb2.FunctionDef() function_def.ParseFromString(compat.as_bytes(proto_data)) if context.executing_eagerly(): _register(fn) self.definition = function_def self.name = function_def.signature.name self.signature = function_def.signature self.grad_func_name = None self.python_grad_func = None self._c_func = c_api_util.ScopedTFFunction(fn) self._grad_func = None
def export_run_metadata(self): """Returns a RunMetadata proto with accumulated information. The returned protocol buffer contains information since the most recent call to either enable_run_metadata or export_run_metadata. Returns: A RunMetadata protocol buffer. """ with c_api_util.tf_buffer() as buffer_: with errors.raise_exception_on_not_ok_status() as status: pywrap_tensorflow.TFE_ContextExportRunMetadata( self._context_handle, buffer_, status) proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) run_metadata = config_pb2.RunMetadata() run_metadata.ParseFromString(compat.as_bytes(proto_data)) return run_metadata
def export_run_metadata(self): """Returns a RunMetadata proto with accumulated information. The returned protocol buffer contains information since the most recent call to either enable_run_metadata or export_run_metadata. Returns: A RunMetadata protocol buffer. Or None if not enabled. """ if not self._context_handle: return None with c_api_util.tf_buffer() as buffer_: pywrap_tensorflow.TFE_ContextExportRunMetadata( self._context_handle, buffer_) proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) run_metadata = config_pb2.RunMetadata() run_metadata.ParseFromString(compat.as_bytes(proto_data)) return run_metadata
def monitor(service_addr, duration_ms, level=1): """Sends grpc requests to profiler server to perform on-demand monitoring. This method will block caller thread until receives monitoring result. Args: service_addr: Address of profiler service e.g. localhost:6009. duration_ms: Duration of monitoring in ms. level: Choose a monitoring level between 1 and 2 to monitor your job. Level 2 is more verbose than level 1 and shows more metrics. Returns: A string of monitoring output. """ with c_api_util.tf_buffer() as buffer_: pywrap_tfe.TFE_ProfilerClientMonitor(service_addr, duration_ms, level, True, buffer_) return pywrap_tf_session.TF_GetBuffer(buffer_)
def __init__(self, name, graph, operations, inputs, outputs): """Initializes an eager defined function. Args: name: str, the name for the created function. graph: Graph, the graph containing the operations in the function operations: list of Operation; the subset of operations in the graph which will be in the function inputs: the tensors in the graph to be used as inputs to the function outputs: the tensors in the graph which will be outputs to the function """ with errors.raise_exception_on_not_ok_status() as status: fn = pywrap_tensorflow.TF_GraphToFunction_wrapper( graph._c_graph, # pylint: disable=protected-access compat.as_str(name), False, [o._c_op for o in operations], # pylint: disable=protected-access [t._as_tf_output() for t in inputs], # pylint: disable=protected-access [t._as_tf_output() for t in outputs], # pylint: disable=protected-access [], None, compat.as_str(""), status) # TODO(apassos) avoid creating a FunctionDef (specially to grab the # signature, but also in general it's nice not to depend on it. with c_api_util.tf_buffer() as buffer_: with errors.raise_exception_on_not_ok_status() as status: pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status) proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) function_def = function_pb2.FunctionDef() function_def.ParseFromString(compat.as_bytes(proto_data)) if context.in_eager_mode(): _register(fn) self.definition = function_def self.name = function_def.signature.name self.signature = function_def.signature self.grad_func_name = None self.python_grad_func = None self._c_func = fn self._grad_func = None
def __init__(self, name, graph, operations, inputs, outputs): """Initializes an eager defined function. Args: name: str, the name for the created function. graph: Graph, the graph containing the operations in the function operations: list of Operation; the subset of operations in the graph which will be in the function inputs: the tensors in the graph to be used as inputs to the function outputs: the tensors in the graph which will be outputs to the function """ with errors.raise_exception_on_not_ok_status() as status: fn = pywrap_tensorflow.TF_GraphToFunction_wrapper( graph._c_graph, # pylint: disable=protected-access compat.as_str(name), False, [o._c_op for o in operations], # pylint: disable=protected-access [t._as_tf_output() for t in inputs], # pylint: disable=protected-access [t._as_tf_output() for t in outputs], # pylint: disable=protected-access [], None, compat.as_str(""), status) # TODO(apassos) avoid creating a FunctionDef (specially to grab the # signature, but also in general it's nice not to depend on it. with c_api_util.tf_buffer() as buffer_: with errors.raise_exception_on_not_ok_status() as status: pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status) proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) function_def = function_pb2.FunctionDef() function_def.ParseFromString(compat.as_bytes(proto_data)) if context.executing_eagerly(): _register(fn) self.definition = function_def self.name = function_def.signature.name self.signature = function_def.signature self.grad_func_name = None self.python_grad_func = None self._c_func = fn self._grad_func = None
def monitor(service_addr, duration_ms, monitoring_level=1, display_timestamp=False): """Sends grpc requests to profiler server to perform on-demand monitoring. This method will block caller thread until receives monitoring result. Args: service_addr: Address of profiler service e.g. localhost:6009. duration_ms: Duration of tracing or monitoring in ms. monitoring_level: Choose a monitoring level between 1 and 2 to monitor your job. Level 2 is more verbose than level 1 and shows more metrics. display_timestamp: Set to true to display timestamp in monitoring result. Returns: A string of monitoring output. """ with c_api_util.tf_buffer() as buffer_: pywrap_tensorflow.TFE_ProfilerClientMonitor(service_addr, duration_ms, monitoring_level, display_timestamp, buffer_) return pywrap_tensorflow.TF_GetBuffer(buffer_)
def import_graph_def(graph_def, input_map=None, return_elements=None, name=None, op_dict=None, producer_op_list=None): """Imports the graph from `graph_def` into the current default `Graph`. This function provides a way to import a serialized TensorFlow [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) protocol buffer, and extract individual objects in the `GraphDef` as @{tf.Tensor} and @{tf.Operation} objects. Once extracted, these objects are placed into the current default `Graph`. See @{tf.Graph.as_graph_def} for a way to create a `GraphDef` proto. Args: graph_def: A `GraphDef` proto containing operations to be imported into the default graph. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. return_elements: A list of strings containing operation names in `graph_def` that will be returned as `Operation` objects; and/or tensor names in `graph_def` that will be returned as `Tensor` objects. name: (Optional.) A prefix that will be prepended to the names in `graph_def`. Note that this does not apply to imported function names. Defaults to `"import"`. op_dict: (Optional.) Deprecated, do not use. producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped) list of `OpDef`s used by the producer of the graph. If provided, unrecognized attrs for ops in `graph_def` that have their default value according to `producer_op_list` will be removed. This will allow some more `GraphDef`s produced by later binaries to be accepted by earlier binaries. Returns: A list of `Operation` and/or `Tensor` objects from the imported graph, corresponding to the names in `return_elements`. Raises: TypeError: If `graph_def` is not a `GraphDef` proto, `input_map` is not a dictionary mapping strings to `Tensor` objects, or `return_elements` is not a list of strings. ValueError: If `input_map`, or `return_elements` contains names that do not appear in `graph_def`, or `graph_def` is not well-formed (e.g. it refers to an unknown tensor). """ graph_def = _ProcessGraphDefParam(graph_def) input_map = _ProcessInputMapParam(input_map) return_elements = _ProcessReturnElementsParam(return_elements) op_dict = op_def_registry.get_registered_ops() if producer_op_list is not None: # TODO(skyewm): make a copy of graph_def so we're not mutating the argument? _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def) graph = ops.get_default_graph() if graph._c_graph: # pylint: disable=protected-access with ops.name_scope(name, 'import', input_map.values()) as scope: # Save unique prefix generated by name_scope if scope: assert scope.endswith('/') prefix = scope[:-1] else: prefix = '' # Generate any input map tensors inside name scope input_map = _ConvertInputMapValues(name, input_map) scoped_options = c_api_util.ScopedTFImportGraphDefOptions() options = scoped_options.options _PopulateTFImportGraphDefOptions(options, prefix, input_map, return_elements) with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized: try: with errors.raise_exception_on_not_ok_status() as status: results = c_api.TF_GraphImportGraphDefWithResults( graph._c_graph, serialized, options, status) # pylint: disable=protected-access except errors.InvalidArgumentError as e: # Convert to ValueError for backwards compatibility. raise ValueError(str(e)) _ProcessNewOps(graph) # Create _DefinedFunctions for any imported functions. # # We do this by creating _DefinedFunctions directly from `graph_def`, and # adding them to `graph`. Adding an existing function to a TF_Graph is a # no-op, so this only has the effect of updating the Python state (usually # _DefinedFunction.add_to_graph also adds the function to the TF_Graph). # # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph if graph_def.library and graph_def.library.function: # pylint: disable=protected-access functions = function._from_library(graph_def.library) for f in functions: f.add_to_graph(graph) # pylint: enable=protected-access # Treat input mappings that don't appear in the graph as an error, because # they are likely to be due to a typo. missing_unused_input_keys = ( c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper( results)) if missing_unused_input_keys: missing_unused_input_keys = [compat.as_str(s) for s in missing_unused_input_keys] raise ValueError( 'Attempted to map inputs that were not found in graph_def: [%s]' % ', '.join(missing_unused_input_keys)) if return_elements is None: return None else: return _GatherReturnElements(return_elements, graph, results) else: g = graph # Use a canonical representation for all tensor names. input_map = {_CanonicalInputName(k): v for k, v in input_map.items()} used_input_keys = set() name_to_op = {} # Add any functions defined in `graph_def` to `g` if graph_def.library and graph_def.library.function: # Copy op_dict so we don't clobber the original op_dict = copy.copy(op_dict) # pylint: disable=protected-access # Note that we do not prepend `name` to the function name. The reasoning # is that function names are similar to op definition names, which # currently do not have a scoped name or namespace scheme. functions = function._from_library(graph_def.library) for f in functions: f.add_to_graph(g) op_dict[f.name] = f.definition.signature # pylint: enable=protected-access # LINT.IfChange with ops.name_scope(name, 'import', input_map.values()) as scope: # TODO(ashankar): Should this just copy over or should it do some # more nuanced merging? For example, the graph may already have some # marked "bad versions" and we don't want to lose those because of # what's in graph_def.versions? The C++ ImporGraphDef does something # more nuanced. g.graph_def_versions.CopyFrom(graph_def.versions) input_map = _ConvertInputMapValues(name, input_map) # NOTE(mrry): We do this in two passes, because there may be a cycle in # `graph_def`. # 1. Add operations without their inputs. for node in graph_def.node: # Check to see if this op's name matches a previously seen op if node.name in name_to_op: raise ValueError('Duplicate name \'%s\' in GraphDef.' % node.name) # Set any default attr values that aren't present. if node.op not in op_dict: raise ValueError('No op named %s in defined operations.' % node.op) op_def = op_dict[node.op] for attr_def in op_def.attr: key = attr_def.name if attr_def.HasField('default_value'): value = node.attr[key] if value is None or value.WhichOneof('value') is None: node.attr[key].CopyFrom(attr_def.default_value) output_types = _OutputTypes(node, op_dict) name_to_op[node.name] = g.create_op( node.op, [], output_types, name=node.name, attrs=node.attr, compute_shapes=False, compute_device=False, op_def=op_def) # Maps from a node to the ops it is colocated with, if colocation # is specified in the attributes. colocation_pairs = collections.defaultdict(list) # 2. Add inputs to the operations. for node in graph_def.node: op = name_to_op[node.name] input_types = _InputTypes(node, op_dict) apply_device_function = True # Rewrite the colocation attributes in the graph, since the # names of new ops may have changed. for key, value in op.node_def.attr.items(): if key == '_class': class_values = value.list new_class_values = [] for class_value in class_values.s: if class_value.startswith(b'loc:@'): op_to_bind_to = class_value[5:].decode() # Find the op by its original name. if op_to_bind_to not in name_to_op: raise ValueError('Specified colocation to an op that ' 'does not exist during import: %s in %s' % ( op_to_bind_to, node.name)) original_op = name_to_op[op_to_bind_to] new_class_values.append(compat.as_bytes( 'loc:@' + original_op.name)) if op_to_bind_to != node.name: # Keep track of this mapping for a later phase. colocation_pairs[op].append(original_op) # Don't apply this op's device function, # the colocation constraint will ensure # the proper device gets assigned at runtime. apply_device_function = False else: new_class_values.append(class_value) value.list.CopyFrom(attr_value_pb2.AttrValue.ListValue( s=new_class_values)) # NOTE(mrry): We cannot use zip here because control inputs do not # appear in the list of input_types. for i, input_name in enumerate( [_CanonicalInputName(x) for x in node.input]): if _IsControlInput(input_name): # (a) Input is a control input that should be taken from an op # in "graph_def". try: source_op = name_to_op[input_name[1:]] except KeyError: raise ValueError( _InvalidNodeMessage( node, 'Control input %r not found in graph_def.' % (input_name,))) # pylint: disable=protected-access op._add_control_input(source_op) # pylint: enable=protected-access else: try: input_type = input_types[i] except IndexError: raise ValueError(_InvalidNodeMessage( node, 'More inputs specified (%r) than the op expects.' % (input_name,))) if input_name in input_map: # (b) Input should be replaced by a tensor from the caller. source_tensor = input_map[input_name] used_input_keys.add(input_name) else: # (c) Input should be taken from an op in `graph_def`. operation_name, output_index = _ParseTensorName(input_name) try: source_op = name_to_op[operation_name] source_tensor = list(source_op.values())[output_index] except (KeyError, IndexError): raise ValueError( _InvalidNodeMessage( node, 'Input tensor %r not found in graph_def.' % (input_name,))) try: # pylint: disable=protected-access op._add_input(source_tensor, dtype=input_type) # pylint: enable=protected-access except TypeError as te: raise ValueError(_InvalidNodeMessage( node, 'Input tensor %r %s' % (input_name, te))) # pylint: disable=protected-access if op._input_types != input_types: raise ValueError( _InvalidNodeMessage( node, 'Input types mismatch (expected %r but got %r)' % (', '.join(dtypes.as_dtype(x).name for x in input_types), ', '.join(x.name for x in op._input_types)))) # pylint: enable=protected-access if not g._is_function(op.type): # pylint: disable=protected-access # Execute shape inference for this op. # NOTE(mrry): If the graph contains a cycle, the full shape # information may not be available for this op's inputs. ops.set_shapes_for_outputs(op) # For nodes with _output_shapes set, set the output shapes. if '_output_shapes' in op.node_def.attr: for i, output in enumerate(op.outputs): dims = op.node_def.attr['_output_shapes'].list.shape[i] output_shape = tensor_shape.TensorShape( None if dims.unknown_rank else [dim.size if dim.size >= 0 else None for dim in dims.dim]) try: output.set_shape(output_shape) except ValueError as e: # If the output shape is incompatible with what is inferred # by the graph for a very specific whitelist of ops, then we # ignore this output shape. This can happen if there is a # bug in the shape function for some operation, and the # serialized graph def has the incorrect shape set when # running on a newer binary with the fixed shape function. # This is an escape hatch that allows us to correct shape # functions that are not critical to correct execution but # would cause graphs to fail if imported after correcting. # # This can be removed after 2017/03/08. if op.type in ['RandomShuffleQueue', 'PaddingFIFOQueue', 'FIFOQueue', 'PriorityQueue', 'QueueSize', 'Stack', 'Barrier', 'BarrierReadySize', 'BarrierIncompleteSize', 'HashTable', 'MutableHashTable', 'MutableHashTableOfTensors', 'Mutex', 'CuckooTable', 'IndexTable', 'WholeFileReader', 'TextLineReader', 'FixedLengthRecordReader', 'TFRecordReader', 'IdentityReader', 'LMDBReader', 'RefSwitch', 'RefEnter', 'RefNextIteration', 'RefMerge', 'RefIdentity']: pass elif op.type in [ 'ConditionalAccumulator', 'SparseConditionalAccumulator', 'Table' ]: # This can be removed after 2017/04/24. pass else: raise e del op.node_def.attr['_output_shapes'] # NOTE(mrry): We do this after configuring the inputs, because # the result of the device functions may depend on the inputs. if apply_device_function: with _MaybeDevice(node.device): g._apply_device_functions(op) # pylint: disable=protected-access # The following loop populates the device field of ops that are # colocated with another op. This is implied by the colocation # attribute, but we propagate the device field for completeness. for op, coloc_op_list in colocation_pairs.items(): coloc_device = None # Find any device in the list of colocated ops that have a # device, if it exists. We assume that if multiple ops # have devices, they refer to the same device. Otherwise, a # runtime error will occur since the colocation property # cannot be guaranteed. # # One possible improvement is to try to check for compatibility # of all devices in this list at import time here, which would # require implementing a compatibility function for device specs # in python. for coloc_op in coloc_op_list: if coloc_op.device: coloc_device = pydev.DeviceSpec.from_string(coloc_op.device) break if coloc_device: op._set_device(coloc_device) # pylint: disable=protected-access # Treat input mappings that don't appear in the graph as an error, # because they are likely to be due to a typo. def _IsImportedNodeOutput(tensor_name): operation_name, output_index = _ParseTensorName(tensor_name) try: return output_index < len(name_to_op[operation_name].outputs) except KeyError: return False absent_input_keys = [ k for k in frozenset(input_map.keys()).difference(used_input_keys) if not _IsImportedNodeOutput(k)] if absent_input_keys: raise ValueError( 'Attempted to map inputs that were not found in graph_def: [%s]' % ', '.join(absent_input_keys)) if return_elements is None: return None else: ret = [] for name in return_elements: name = compat.as_str(name) if ':' in name: try: operation_name, output_index = _ParseTensorName(name) ret.append(name_to_op[operation_name].outputs[output_index]) except (ValueError, KeyError, IndexError): raise ValueError( 'Requested return_element %r not found in graph_def.' % name) else: try: ret.append(name_to_op[name]) except KeyError: raise ValueError( 'Requested return_element %r not found in graph_def.' % name) return ret
graph = tf.Graph() with graph.as_default(): for i in range(1000): # such that we fill up the memory a bit #x = tf.placeholder(tf.float32) x = tf.constant(42) if i == 0: x_c_op = x.op._c_op # New graph, such that we remove traces to graph. Also to fill some memory. with tf.Graph().as_default(): for i in range(1000): # such that we fill up the memory a bit x = tf.placeholder(tf.float32) # Fill some more memory. a = [bytes([255] * 10000000) for i in range(10)] del graph del x gc.collect() gc.collect() print(c_api.TF_OperationName(x_c_op)) print(c_api.TF_OperationOpType(x_c_op)) print(c_api.TF_OperationDevice(x_c_op)) print(c_api.TF_OperationNumOutputs(x_c_op)) with c_api_util.tf_buffer() as buf: c_api.TF_OperationToNodeDef(x_c_op, buf)
def value(self): """Retrieves the current value.""" with c_api_util.tf_buffer() as buffer_: pywrap_tensorflow.TFE_MonitoringStringGaugeCellValue(self._cell, buffer_) value = pywrap_tensorflow.TF_GetBuffer(buffer_).decode('utf-8') return value
def import_graph_def(graph_def, input_map=None, return_elements=None, name=None, op_dict=None, producer_op_list=None): """Imports the graph from `graph_def` into the current default `Graph`. This function provides a way to import a serialized TensorFlow [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) protocol buffer, and extract individual objects in the `GraphDef` as `tf.Tensor` and `tf.Operation` objects. Once extracted, these objects are placed into the current default `Graph`. See `tf.Graph.as_graph_def` for a way to create a `GraphDef` proto. Args: graph_def: A `GraphDef` proto containing operations to be imported into the default graph. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. return_elements: A list of strings containing operation names in `graph_def` that will be returned as `Operation` objects; and/or tensor names in `graph_def` that will be returned as `Tensor` objects. name: (Optional.) A prefix that will be prepended to the names in `graph_def`. Note that this does not apply to imported function names. Defaults to `"import"`. op_dict: (Optional.) Deprecated, do not use. producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped) list of `OpDef`s used by the producer of the graph. If provided, unrecognized attrs for ops in `graph_def` that have their default value according to `producer_op_list` will be removed. This will allow some more `GraphDef`s produced by later binaries to be accepted by earlier binaries. Returns: A list of `Operation` and/or `Tensor` objects from the imported graph, corresponding to the names in `return_elements`. Raises: TypeError: If `graph_def` is not a `GraphDef` proto, `input_map` is not a dictionary mapping strings to `Tensor` objects, or `return_elements` is not a list of strings. ValueError: If `input_map`, or `return_elements` contains names that do not appear in `graph_def`, or `graph_def` is not well-formed (e.g. it refers to an unknown tensor). """ op_dict = op_def_registry.get_registered_ops() graph_def = _ProcessGraphDefParam(graph_def, op_dict) input_map = _ProcessInputMapParam(input_map) return_elements = _ProcessReturnElementsParam(return_elements) if producer_op_list is not None: # TODO(skyewm): make a copy of graph_def so we're not mutating the argument? _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def) graph = ops.get_default_graph() with ops.name_scope(name, 'import', input_map.values()) as scope: # Save unique prefix generated by name_scope if scope: assert scope.endswith('/') prefix = scope[:-1] else: prefix = '' # Generate any input map tensors inside name scope input_map = _ConvertInputMapValues(name, input_map) scoped_options = c_api_util.ScopedTFImportGraphDefOptions() options = scoped_options.options _PopulateTFImportGraphDefOptions(options, prefix, input_map, return_elements) # _ProcessNewOps mutates the new operations. _mutation_lock ensures a # Session.run call cannot occur between creating the TF_Operations in the # TF_GraphImportGraphDefWithResults call and mutating the them in # _ProcessNewOps. with graph._mutation_lock(): # pylint: disable=protected-access with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized: try: results = c_api.TF_GraphImportGraphDefWithResults( graph._c_graph, serialized, options) # pylint: disable=protected-access results = c_api_util.ScopedTFImportGraphDefResults(results) except errors.InvalidArgumentError as e: # Convert to ValueError for backwards compatibility. raise ValueError(str(e)) # Create _DefinedFunctions for any imported functions. # # We do this by creating _DefinedFunctions directly from `graph_def`, and # adding them to `graph`. Adding an existing function to a TF_Graph is a # no-op, so this only has the effect of updating the Python state (usually # _DefinedFunction.add_to_graph also adds the function to the TF_Graph). # # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph # TODO(b/74620627): move this after _ProcessNewOps outside the lock once # _USE_C_SHAPES is removed. if graph_def.library and graph_def.library.function: # pylint: disable=protected-access functions = function._from_library(graph_def.library) for f in functions: f.add_to_graph(graph) # pylint: enable=protected-access _ProcessNewOps(graph) # Treat input mappings that don't appear in the graph as an error, because # they are likely to be due to a typo. missing_unused_input_keys = ( c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper( results.results)) if missing_unused_input_keys: missing_unused_input_keys = [ compat.as_str(s) for s in missing_unused_input_keys ] raise ValueError( 'Attempted to map inputs that were not found in graph_def: [%s]' % ', '.join(missing_unused_input_keys)) if return_elements is None: return None else: return _GatherReturnElements(return_elements, graph, results.results)
def import_graph_def(graph_def, input_map=None, return_elements=None, name=None, op_dict=None, producer_op_list=None): """Imports the graph from `graph_def` into the current default `Graph`. This function provides a way to import a serialized TensorFlow [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) protocol buffer, and extract individual objects in the `GraphDef` as @{tf.Tensor} and @{tf.Operation} objects. Once extracted, these objects are placed into the current default `Graph`. See @{tf.Graph.as_graph_def} for a way to create a `GraphDef` proto. Args: graph_def: A `GraphDef` proto containing operations to be imported into the default graph. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. return_elements: A list of strings containing operation names in `graph_def` that will be returned as `Operation` objects; and/or tensor names in `graph_def` that will be returned as `Tensor` objects. name: (Optional.) A prefix that will be prepended to the names in `graph_def`. Note that this does not apply to imported function names. Defaults to `"import"`. op_dict: (Optional.) Deprecated, do not use. producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped) list of `OpDef`s used by the producer of the graph. If provided, unrecognized attrs for ops in `graph_def` that have their default value according to `producer_op_list` will be removed. This will allow some more `GraphDef`s produced by later binaries to be accepted by earlier binaries. Returns: A list of `Operation` and/or `Tensor` objects from the imported graph, corresponding to the names in `return_elements`. Raises: TypeError: If `graph_def` is not a `GraphDef` proto, `input_map` is not a dictionary mapping strings to `Tensor` objects, or `return_elements` is not a list of strings. ValueError: If `input_map`, or `return_elements` contains names that do not appear in `graph_def`, or `graph_def` is not well-formed (e.g. it refers to an unknown tensor). """ op_dict = op_def_registry.get_registered_ops() graph_def = _ProcessGraphDefParam(graph_def, op_dict) input_map = _ProcessInputMapParam(input_map) return_elements = _ProcessReturnElementsParam(return_elements) if producer_op_list is not None: # TODO(skyewm): make a copy of graph_def so we're not mutating the argument? _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def) graph = ops.get_default_graph() with ops.name_scope(name, 'import', input_map.values()) as scope: # Save unique prefix generated by name_scope if scope: assert scope.endswith('/') prefix = scope[:-1] else: prefix = '' # Generate any input map tensors inside name scope input_map = _ConvertInputMapValues(name, input_map) scoped_options = c_api_util.ScopedTFImportGraphDefOptions() options = scoped_options.options _PopulateTFImportGraphDefOptions(options, prefix, input_map, return_elements) # _ProcessNewOps mutates the new operations. _mutation_lock ensures a # Session.run call cannot occur between creating the TF_Operations in the # TF_GraphImportGraphDefWithResults call and mutating the them in # _ProcessNewOps. with graph._mutation_lock(): # pylint: disable=protected-access with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized: try: results = c_api.TF_GraphImportGraphDefWithResults( graph._c_graph, serialized, options) # pylint: disable=protected-access results = c_api_util.ScopedTFImportGraphDefResults(results) except errors.InvalidArgumentError as e: # Convert to ValueError for backwards compatibility. raise ValueError(str(e)) # Create _DefinedFunctions for any imported functions. # # We do this by creating _DefinedFunctions directly from `graph_def`, and # adding them to `graph`. Adding an existing function to a TF_Graph is a # no-op, so this only has the effect of updating the Python state (usually # _DefinedFunction.add_to_graph also adds the function to the TF_Graph). # # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph # TODO(b/74620627): move this after _ProcessNewOps outside the lock once # _USE_C_SHAPES is removed. if graph_def.library and graph_def.library.function: # pylint: disable=protected-access functions = function._from_library(graph_def.library) for f in functions: f.add_to_graph(graph) # pylint: enable=protected-access _ProcessNewOps(graph) # Treat input mappings that don't appear in the graph as an error, because # they are likely to be due to a typo. missing_unused_input_keys = ( c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper( results.results)) if missing_unused_input_keys: missing_unused_input_keys = [ compat.as_str(s) for s in missing_unused_input_keys ] raise ValueError( 'Attempted to map inputs that were not found in graph_def: [%s]' % ', '.join(missing_unused_input_keys)) if return_elements is None: return None else: return _GatherReturnElements(return_elements, graph, results.results)
def import_graph_def(graph_def, input_map=None, return_elements=None, name=None, op_dict=None, producer_op_list=None): """Imports the graph from `graph_def` into the current default `Graph`. This function provides a way to import a serialized TensorFlow [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) protocol buffer, and extract individual objects in the `GraphDef` as @{tf.Tensor} and @{tf.Operation} objects. Once extracted, these objects are placed into the current default `Graph`. See @{tf.Graph.as_graph_def} for a way to create a `GraphDef` proto. Args: graph_def: A `GraphDef` proto containing operations to be imported into the default graph. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. return_elements: A list of strings containing operation names in `graph_def` that will be returned as `Operation` objects; and/or tensor names in `graph_def` that will be returned as `Tensor` objects. name: (Optional.) A prefix that will be prepended to the names in `graph_def`. Note that this does not apply to imported function names. Defaults to `"import"`. op_dict: (Optional.) Deprecated, do not use. producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped) list of `OpDef`s used by the producer of the graph. If provided, unrecognized attrs for ops in `graph_def` that have their default value according to `producer_op_list` will be removed. This will allow some more `GraphDef`s produced by later binaries to be accepted by earlier binaries. Returns: A list of `Operation` and/or `Tensor` objects from the imported graph, corresponding to the names in `return_elements`. Raises: TypeError: If `graph_def` is not a `GraphDef` proto, `input_map` is not a dictionary mapping strings to `Tensor` objects, or `return_elements` is not a list of strings. ValueError: If `input_map`, or `return_elements` contains names that do not appear in `graph_def`, or `graph_def` is not well-formed (e.g. it refers to an unknown tensor). """ op_dict = op_def_registry.get_registered_ops() graph_def = _ProcessGraphDefParam(graph_def, op_dict) input_map = _ProcessInputMapParam(input_map) return_elements = _ProcessReturnElementsParam(return_elements) if producer_op_list is not None: # TODO(skyewm): make a copy of graph_def so we're not mutating the argument? _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def) graph = ops.get_default_graph() if graph._c_graph: # pylint: disable=protected-access with ops.name_scope(name, 'import', input_map.values()) as scope: # Save unique prefix generated by name_scope if scope: assert scope.endswith('/') prefix = scope[:-1] else: prefix = '' # Generate any input map tensors inside name scope input_map = _ConvertInputMapValues(name, input_map) scoped_options = c_api_util.ScopedTFImportGraphDefOptions() options = scoped_options.options _PopulateTFImportGraphDefOptions(options, prefix, input_map, return_elements) # _ProcessNewOps mutates the new operations. _lock ensures a Session.run # call cannot occur between creating the TF_Operations in the # TF_GraphImportGraphDefWithResults call and mutating the them in # _ProcessNewOps. with graph._lock: # pylint: disable=protected-access with c_api_util.tf_buffer( graph_def.SerializeToString()) as serialized: try: with errors.raise_exception_on_not_ok_status() as status: results = c_api.TF_GraphImportGraphDefWithResults( graph._c_graph, serialized, options, status) # pylint: disable=protected-access except errors.InvalidArgumentError as e: # Convert to ValueError for backwards compatibility. raise ValueError(str(e)) _ProcessNewOps(graph) # Create _DefinedFunctions for any imported functions. # # We do this by creating _DefinedFunctions directly from `graph_def`, and # adding them to `graph`. Adding an existing function to a TF_Graph is a # no-op, so this only has the effect of updating the Python state (usually # _DefinedFunction.add_to_graph also adds the function to the TF_Graph). # # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph if graph_def.library and graph_def.library.function: # pylint: disable=protected-access functions = function._from_library(graph_def.library) for f in functions: f.add_to_graph(graph) # pylint: enable=protected-access # Treat input mappings that don't appear in the graph as an error, because # they are likely to be due to a typo. missing_unused_input_keys = ( c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper( results)) if missing_unused_input_keys: missing_unused_input_keys = [ compat.as_str(s) for s in missing_unused_input_keys ] raise ValueError( 'Attempted to map inputs that were not found in graph_def: [%s]' % ', '.join(missing_unused_input_keys)) if return_elements is None: return None else: return _GatherReturnElements(return_elements, graph, results) else: g = graph # Use a canonical representation for all tensor names. input_map = {_CanonicalInputName(k): v for k, v in input_map.items()} used_input_keys = set() name_to_op = {} # Add any functions defined in `graph_def` to `g` if graph_def.library and graph_def.library.function: # Copy op_dict so we don't clobber the original op_dict = copy.copy(op_dict) # pylint: disable=protected-access # Note that we do not prepend `name` to the function name. The reasoning # is that function names are similar to op definition names, which # currently do not have a scoped name or namespace scheme. functions = function._from_library(graph_def.library) for f in functions: f.add_to_graph(g) op_dict[f.name] = f.definition.signature # pylint: enable=protected-access # LINT.IfChange with ops.name_scope(name, 'import', input_map.values()) as scope: # TODO(ashankar): Should this just copy over or should it do some # more nuanced merging? For example, the graph may already have some # marked "bad versions" and we don't want to lose those because of # what's in graph_def.versions? The C++ ImporGraphDef does something # more nuanced. g.graph_def_versions.CopyFrom(graph_def.versions) input_map = _ConvertInputMapValues(name, input_map) # NOTE(mrry): We do this in two passes, because there may be a cycle in # `graph_def`. # 1. Add operations without their inputs. for node in graph_def.node: # Check to see if this op's name matches a previously seen op if node.name in name_to_op: raise ValueError('Duplicate name \'%s\' in GraphDef.' % node.name) if node.op not in op_dict: raise ValueError('No op named %s in defined operations.' % node.op) op_def = op_dict[node.op] output_types = _OutputTypes(node, op_dict) name_to_op[node.name] = g.create_op(node.op, [], output_types, name=node.name, attrs=node.attr, compute_shapes=False, compute_device=False, op_def=op_def) # Maps from a node to the ops it is colocated with, if colocation # is specified in the attributes. colocation_pairs = collections.defaultdict(list) # 2. Add inputs to the operations. for node in graph_def.node: op = name_to_op[node.name] input_types = _InputTypes(node, op_dict) apply_device_function = True # Rewrite the colocation attributes in the graph, since the # names of new ops may have changed. for key, value in op.node_def.attr.items(): if key == '_class': class_values = value.list new_class_values = [] for class_value in class_values.s: if class_value.startswith(b'loc:@'): op_to_bind_to = class_value[5:].decode() # Find the op by its original name. if op_to_bind_to not in name_to_op: raise ValueError( 'Specified colocation to an op that ' 'does not exist during import: %s in %s' % (op_to_bind_to, node.name)) original_op = name_to_op[op_to_bind_to] new_class_values.append( compat.as_bytes('loc:@' + original_op.name)) if op_to_bind_to != node.name: # Keep track of this mapping for a later phase. colocation_pairs[op].append(original_op) # Don't apply this op's device function, # the colocation constraint will ensure # the proper device gets assigned at runtime. apply_device_function = False else: new_class_values.append(class_value) value.list.CopyFrom( attr_value_pb2.AttrValue.ListValue( s=new_class_values)) # NOTE(mrry): We cannot use zip here because control inputs do not # appear in the list of input_types. for i, input_name in enumerate( [_CanonicalInputName(x) for x in node.input]): if _IsControlInput(input_name): # (a) Input is a control input that should be taken from an op # in "graph_def". try: source_op = name_to_op[input_name[1:]] except KeyError: raise ValueError( _InvalidNodeMessage( node, 'Control input %r not found in graph_def.' % (input_name, ))) # pylint: disable=protected-access op._add_control_input(source_op) # pylint: enable=protected-access else: try: input_type = input_types[i] except IndexError: raise ValueError( _InvalidNodeMessage( node, 'More inputs specified (%r) than the op expects.' % (input_name, ))) if input_name in input_map: # (b) Input should be replaced by a tensor from the caller. source_tensor = input_map[input_name] used_input_keys.add(input_name) else: # (c) Input should be taken from an op in `graph_def`. operation_name, output_index = _ParseTensorName( input_name) try: source_op = name_to_op[operation_name] source_tensor = list( source_op.values())[output_index] except (KeyError, IndexError): raise ValueError( _InvalidNodeMessage( node, 'Input tensor %r not found in graph_def.' % (input_name, ))) try: # pylint: disable=protected-access op._add_input(source_tensor, dtype=input_type) # pylint: enable=protected-access except TypeError as te: raise ValueError( _InvalidNodeMessage( node, 'Input tensor %r %s' % (input_name, te))) # pylint: disable=protected-access if op._input_types != input_types: raise ValueError( _InvalidNodeMessage( node, 'Input types mismatch (expected %r but got %r)' % (', '.join( dtypes.as_dtype(x).name for x in input_types), ', '.join( x.name for x in op._input_types)))) # pylint: enable=protected-access if not g._is_function(op.type): # pylint: disable=protected-access # Execute shape inference for this op. # NOTE(mrry): If the graph contains a cycle, the full shape # information may not be available for this op's inputs. ops.set_shapes_for_outputs(op) # For nodes with _output_shapes set, set the output shapes. if '_output_shapes' in op.node_def.attr: for i, output in enumerate(op.outputs): dims = op.node_def.attr['_output_shapes'].list.shape[i] output_shape = tensor_shape.TensorShape( None if dims.unknown_rank else [ dim.size if dim.size >= 0 else None for dim in dims.dim ]) try: output.set_shape(output_shape) except ValueError as e: # If the output shape is incompatible with what is inferred # by the graph for a very specific whitelist of ops, then we # ignore this output shape. This can happen if there is a # bug in the shape function for some operation, and the # serialized graph def has the incorrect shape set when # running on a newer binary with the fixed shape function. # This is an escape hatch that allows us to correct shape # functions that are not critical to correct execution but # would cause graphs to fail if imported after correcting. # # This can be removed after 2017/03/08. if op.type in [ 'RandomShuffleQueue', 'PaddingFIFOQueue', 'FIFOQueue', 'PriorityQueue', 'QueueSize', 'Stack', 'Barrier', 'BarrierReadySize', 'BarrierIncompleteSize', 'HashTable', 'MutableHashTable', 'MutableHashTableOfTensors', 'Mutex', 'CuckooTable', 'IndexTable', 'WholeFileReader', 'TextLineReader', 'FixedLengthRecordReader', 'TFRecordReader', 'IdentityReader', 'LMDBReader', 'RefSwitch', 'RefEnter', 'RefNextIteration', 'RefMerge', 'RefIdentity' ]: pass elif op.type in [ 'ConditionalAccumulator', 'SparseConditionalAccumulator', 'Table' ]: # This can be removed after 2017/04/24. pass else: raise e del op.node_def.attr['_output_shapes'] # NOTE(mrry): We do this after configuring the inputs, because # the result of the device functions may depend on the inputs. if apply_device_function: with _MaybeDevice(node.device): g._apply_device_functions(op) # pylint: disable=protected-access # The following loop populates the device field of ops that are # colocated with another op. This is implied by the colocation # attribute, but we propagate the device field for completeness. for op, coloc_op_list in colocation_pairs.items(): coloc_device = None # Find any device in the list of colocated ops that have a # device, if it exists. We assume that if multiple ops # have devices, they refer to the same device. Otherwise, a # runtime error will occur since the colocation property # cannot be guaranteed. # # One possible improvement is to try to check for compatibility # of all devices in this list at import time here, which would # require implementing a compatibility function for device specs # in python. for coloc_op in coloc_op_list: if coloc_op.device: coloc_device = pydev.DeviceSpec.from_string( coloc_op.device) break if coloc_device: op._set_device(coloc_device) # pylint: disable=protected-access # Treat input mappings that don't appear in the graph as an error, # because they are likely to be due to a typo. def _IsImportedNodeOutput(tensor_name): operation_name, output_index = _ParseTensorName(tensor_name) try: return output_index < len( name_to_op[operation_name].outputs) except KeyError: return False absent_input_keys = [ k for k in frozenset(input_map.keys()).difference( used_input_keys) if not _IsImportedNodeOutput(k) ] if absent_input_keys: raise ValueError( 'Attempted to map inputs that were not found in graph_def: [%s]' % ', '.join(absent_input_keys)) if return_elements is None: return None else: ret = [] for name in return_elements: name = compat.as_str(name) if ':' in name: try: operation_name, output_index = _ParseTensorName( name) ret.append(name_to_op[operation_name]. outputs[output_index]) except (ValueError, KeyError, IndexError): raise ValueError( 'Requested return_element %r not found in graph_def.' % name) else: try: ret.append(name_to_op[name]) except KeyError: raise ValueError( 'Requested return_element %r not found in graph_def.' % name) return ret