def tf_record_iterator(path, options=None): """An iterator that read the records from a TFRecords file. Args: path: The path to the TFRecords file. options: (optional) A TFRecordOptions object. Yields: Strings. Raises: IOError: If `path` cannot be opened for reading. """ compression_type = TFRecordOptions.get_compression_type_string(options) with errors.raise_exception_on_not_ok_status() as status: reader = pywrap_tensorflow.PyRecordReader_New( compat.as_bytes(path), 0, compat.as_bytes(compression_type), status) if reader is None: raise IOError("Could not open %s." % path) while True: try: with errors.raise_exception_on_not_ok_status() as status: reader.GetNext(status) except errors.OutOfRangeError: break yield reader.record() reader.Close()
def _initialize_handle_and_devices(self): """Initialize handle and devices.""" with self._initialize_lock: if self._context_handle is not None: return assert self._context_devices is None opts = pywrap_tensorflow.TF_NewSessionOptions( target=compat.as_bytes(""), config=self._config) with errors.raise_exception_on_not_ok_status() as status: self._context_handle = pywrap_tensorflow.TFE_NewContext(opts, status) pywrap_tensorflow.TF_DeleteSessionOptions(opts) # Store list of devices self._context_devices = [] with errors.raise_exception_on_not_ok_status() as status: device_list = pywrap_tensorflow.TFE_ContextListDevices( self._context_handle, status) try: self._num_gpus = 0 for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)): with errors.raise_exception_on_not_ok_status() as status: dev_name = pywrap_tensorflow.TF_DeviceListName( device_list, i, status) self._context_devices.append(pydev.canonical_name(dev_name)) with errors.raise_exception_on_not_ok_status() as status: dev_type = pywrap_tensorflow.TF_DeviceListType( device_list, i, status) if dev_type == "GPU": self._num_gpus += 1 finally: pywrap_tensorflow.TF_DeleteDeviceList(device_list)
def call_cpp_shape_fn(op): """A shape function that delegates to the registered C++ shape function. Args: op: the node in the graph for which to compute output shapes. Returns: A TensorShape list of the output shapes of the op, as computed using the C++ shape inference function registered for the op. Raises: ValueError: If the C++ shape function returned an error (e.g. because the shapes of the inputs are of the wrong rank or otherwise incompatible according to the shape function). """ node_def_str = op.node_def.SerializeToString() input_shapes = [i.get_shape().as_proto().SerializeToString() for i in op.inputs] try: with errors.raise_exception_on_not_ok_status() as status: output_shapes = pywrap_tensorflow.RunCppShapeInference( node_def_str, input_shapes, status) except errors.InvalidArgumentError as err: raise ValueError(err.message) # Convert TensorShapeProto values in output_shapes. return [ tensor_shape.TensorShape(tensor_shape_pb2.TensorShapeProto.FromString(s)) for s in output_shapes ]
def list_directory(dirname): """Returns a list of entries contained within a directory. The list is in arbitrary order. It does not contain the special entries "." and "..". Args: dirname: string, path to a directory Returns: [filename1, filename2, ... filenameN] as strings Raises: errors.NotFoundError if directory doesn't exist """ if not is_directory(dirname): raise errors.NotFoundError(None, None, "Could not find directory") with errors.raise_exception_on_not_ok_status() as status: # Convert each element to string, since the return values of the # vector of string should be interpreted as strings, not bytes. return [ compat.as_str_any(filename) for filename in pywrap_tensorflow.GetChildren( compat.as_bytes(dirname), status) ]
def smart_constant_value(pred): """Return the bool value for `pred`, or None if `pred` had a dynamic value. Arguments: pred: A scalar, either a Python bool or tensor. Returns: True or False if `pred` has a constant boolean value, None otherwise. Raises: TypeError: If `pred` is not a Tensor or bool. """ if pred in {0, 1}: # Accept 1/0 as valid boolean values pred_value = bool(pred) elif isinstance(pred, bool): pred_value = pred elif isinstance(pred, ops.Tensor): pred_value = tensor_util.constant_value(pred) # TODO(skyewm): consider folding this into tensor_util.constant_value when # _USE_C_API is removed (there may be performance and correctness bugs, so I # wanted to limit the change hidden behind _USE_C_API). # pylint: disable=protected-access if pred_value is None and ops._USE_C_API: with errors.raise_exception_on_not_ok_status() as status: pred_value = c_api.TF_TryEvaluateConstant_wrapper( pred.graph._c_graph, pred._as_tf_output(), status) # pylint: enable=protected-access else: raise TypeError("`pred` must be a Tensor, or a Python bool, or 1 or 0. " "Found instead: %s" % pred) return pred_value
def __init__(self, allow_soft_placement=True, disable_detailed_stats=True, disable_timeline=True, devices=None): """Creates a Cluster. Args: allow_soft_placement: If True, TF will automatically fix illegal placements instead of erroring out if the placement isn't legal. disable_detailed_stats: If True, detailed statistics will not be available. disable_timeline: If True, the timeline information will not be reported. devices: A list of devices of type device_properties_pb2.NamedDevice. If None, a device list will be created based on the spec of the local machine. """ self._tf_cluster = None self._generate_timeline = not disable_timeline with errors.raise_exception_on_not_ok_status() as status: if devices is None: self._tf_cluster = tf_cluster.TF_NewCluster( allow_soft_placement, disable_detailed_stats, status) else: devices_serialized = [device.SerializeToString() for device in devices] self._tf_cluster = tf_cluster.TF_NewVirtualCluster( devices_serialized, status)
def capture_value(tensor_map, value, dtype, name): """Capture a value from outside the function, to pass in as an extra arg.""" captured_value = tensor_map.get(ops.tensor_id(value), None) if captured_value is None: captured_value = graph_placeholder( dtype=dtype or value.dtype, shape=value.shape, name=name) if captured_value.dtype == dtypes_module.resource: handle_data = value._handle_data # pylint: disable=protected-access captured_value._handle_data = handle_data # pylint: disable=protected-access if handle_data is not None and handle_data.is_set: # Ensure that shapes and dtypes are propagated. shapes, types = zip(*[(pair.shape, pair.dtype) for pair in handle_data.shape_and_type]) ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes] shapes = [[d.size for d in s.dim] if not s.unknown_rank else None for s in shapes] with errors.raise_exception_on_not_ok_status() as status: pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper( captured_value._op._graph._c_graph, # pylint: disable=protected-access captured_value._as_tf_output(), # pylint: disable=protected-access shapes, ranks, types, status) tensor_map[ops.tensor_id(value)] = (value, captured_value) else: captured_value = captured_value[1] tape.record_operation("captured_value", [captured_value], [value], lambda x: [x]) return captured_value
def __init__(self, server_or_cluster_def, job_name=None, task_index=None, protocol=None, config=None, start=True): """Creates a new server with the given definition. The `job_name`, `task_index`, and `protocol` arguments are optional, and override any information provided in `server_or_cluster_def`. Args: server_or_cluster_def: A `tf.train.ServerDef` or `tf.train.ClusterDef` protocol buffer, or a `tf.train.ClusterSpec` object, describing the server to be created and/or the cluster of which it is a member. job_name: (Optional.) Specifies the name of the job of which the server is a member. Defaults to the value in `server_or_cluster_def`, if specified. task_index: (Optional.) Specifies the task index of the server in its job. Defaults to the value in `server_or_cluster_def`, if specified. Otherwise defaults to 0 if the server's job has only one task. protocol: (Optional.) Specifies the protocol to be used by the server. Acceptable values include `"grpc"`. Defaults to the value in `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`. config: (Options.) A `tf.ConfigProto` that specifies default configuration options for all sessions that run on this server. start: (Optional.) Boolean, indicating whether to start the server after creating it. Defaults to `True`. Raises: tf.errors.OpError: Or one of its subclasses if an error occurs while creating the TensorFlow server. """ self._server_def = _make_server_def(server_or_cluster_def, job_name, task_index, protocol, config) with errors.raise_exception_on_not_ok_status() as status: self._server = pywrap_tensorflow.PyServer_New(self._server_def.SerializeToString(), status) if start: self.start()
def TransformGraph(input_graph_def, inputs, outputs, transforms): """Python wrapper for the Graph Transform Tool. Gives access to all graph transforms available through the command line tool. See documentation at https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/graph_transforms/README.md for full details of the options available. Args: input_graph_def: GraphDef object containing a model to be transformed. inputs: List of node names for the model inputs. outputs: List of node names for the model outputs. transforms: List of strings containing transform names and parameters. Returns: New GraphDef with transforms applied. """ input_graph_def_string = input_graph_def.SerializeToString() inputs_string = compat.as_bytes(",".join(inputs)) outputs_string = compat.as_bytes(",".join(outputs)) transforms_string = compat.as_bytes(" ".join(transforms)) with errors.raise_exception_on_not_ok_status() as status: output_graph_def_string = TransformGraphWithStringInputs( input_graph_def_string, inputs_string, outputs_string, transforms_string, status) output_graph_def = graph_pb2.GraphDef() output_graph_def.ParseFromString(output_graph_def_string) return output_graph_def
def recursive_create_dir(dirname): with errors.raise_exception_on_not_ok_status() as status: dirs = dirname.split('/') for i in range(len(dirs)): partial_dir = '/'.join(dirs[0:i+1]) if partial_dir and not file_exists(partial_dir): pywrap_tensorflow.CreateDir(compat.as_bytes(partial_dir), status)
def get_matching_files_v2(pattern): """Returns a list of files that match the given pattern(s). Args: pattern: string or iterable of strings. The glob pattern(s). Returns: A list of strings containing filenames that match the given pattern(s). Raises: errors.OpError: If there are filesystem / directory listing errors. """ with errors.raise_exception_on_not_ok_status() as status: if isinstance(pattern, six.string_types): return [ # Convert the filenames to string from bytes. compat.as_str_any(matching_filename) for matching_filename in pywrap_tensorflow.GetMatchingFiles( compat.as_bytes(pattern), status) ] else: return [ # Convert the filenames to string from bytes. compat.as_str_any(matching_filename) for single_filename in pattern for matching_filename in pywrap_tensorflow.GetMatchingFiles( compat.as_bytes(single_filename), status) ]
def GenerateCostReport(metagraph, per_node_report=False, verbose=False, cluster=None): """Analyze the cost of each TensorFlow op and node in the provided metagraph. Args: metagraph: A TensorFlow MetaGraphDef. per_node_report: by default the report contains stats aggregated on a per op type basis, setting per_node_report to True adds results for each individual node to the report. verbose: Prints out the entire operation proto instead of a summary table. cluster: Analyze the costs using the specified cluster, or the local machine if no cluster was specified. Returns: A string of cost report. """ if cluster is None: cluster = gcluster.Cluster(disable_detailed_stats=False) with errors.raise_exception_on_not_ok_status(): ret_from_swig = tf_wrap.GenerateCostReport(metagraph.SerializeToString(), per_node_report, verbose, cluster.tf_cluster) return ret_from_swig
def register_function_def(fdef): fdef_string = fdef.SerializeToString() with errors.raise_exception_on_not_ok_status() as status: pywrap_tensorflow.TFE_ContextAddFunctionDef( context.get_default_context()._handle, # pylint: disable=protected-access fdef_string, len(fdef_string), status)
def _prereadline_check(self): if not self._read_buf: if not self._read_check_passed: raise errors.PermissionDeniedError(None, None, "File isn't open for reading") with errors.raise_exception_on_not_ok_status() as status: self._read_buf = pywrap_tensorflow.CreateBufferedInputStream( compat.as_bytes(self.__name), 1024 * 512, status)
def testInvalidDeviceNumber(self): opts = tf_session.TF_NewSessionOptions() with errors.raise_exception_on_not_ok_status() as status: c_session = tf_session.TF_NewSession( ops.get_default_graph()._c_graph, opts, status) raw_device_list = tf_session.TF_SessionListDevices( c_session, status) size = tf_session.TF_DeviceListCount(raw_device_list) # Test that invalid device numbers return -1 rather than a Swig-wrapped # pointer. status_no_exception = c_api_util.ScopedTFStatus() memory = tf_session.TF_DeviceListMemoryBytes( raw_device_list, size, status_no_exception) self.assertEqual(memory, -1) tf_session.TF_DeleteDeviceList(raw_device_list) with errors.raise_exception_on_not_ok_status() as status: tf_session.TF_CloseSession(c_session, status)
def read(self): """Returns the contents of a file as a string.""" if not self._read_check_passed: raise errors.PermissionDeniedError(None, None, "File isn't open for reading") with errors.raise_exception_on_not_ok_status() as status: return pywrap_tensorflow.ReadFileToString( compat.as_bytes(self.__name), status)
def _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata): # Ensure any changes to the graph are reflected in the runtime. self._extend_graph() with errors.raise_exception_on_not_ok_status() as status: return tf_session.TF_Run(session, options, feed_dict, fetch_list, target_list, status, run_metadata)
def TF_Reset(target, containers=None, config=None): from tensorflow.python.framework import errors opts = TF_NewSessionOptions(target=target, config=config) try: with errors.raise_exception_on_not_ok_status() as status: TF_Reset_wrapper(opts, containers, status) finally: TF_DeleteSessionOptions(opts)
def write(self, record): """Write a string record to the file. Args: record: str """ with errors.raise_exception_on_not_ok_status() as status: self._writer.WriteRecord(record, status)
def _prewrite_check(self): if not self._writable_file: if not self._write_check_passed: raise errors.PermissionDeniedError(None, None, "File isn't open for writing") with errors.raise_exception_on_not_ok_status() as status: self._writable_file = pywrap_tensorflow.CreateWritableFile( compat.as_bytes(self.__name), status)
def close(self): """Closes FileIO. Should be called for the WritableFile to be flushed.""" self._read_buf = None if self._writable_file: with errors.raise_exception_on_not_ok_status() as status: ret_status = self._writable_file.Close() pywrap_tensorflow.Set_TF_Status_from_Status(status, ret_status) self._writable_file = None
def write(self, file_content): """Writes file_content to the file.""" if not self._write_check_passed: raise errors.PermissionDeniedError(None, None, "File isn't open for writing") with errors.raise_exception_on_not_ok_status() as status: pywrap_tensorflow.WriteStringToFile( compat.as_bytes(self.__name), compat.as_bytes(file_content), status)
def recursive_create_dir(dirname): from tensorflow.python.framework import errors with errors.raise_exception_on_not_ok_status() as status: from tensorflow.python.util import compat dirs = dirname.split('/') for i in range(len(dirs)): partial_dir = '/'.join(dirs[0:i+1]) if partial_dir and not file_exists(partial_dir): CreateDir(compat.as_bytes(partial_dir), status)
def __del__(self): try: if self._context_handle is not None: with errors.raise_exception_on_not_ok_status() as status: pywrap_tensorflow.TFE_DeleteContext(self._context_handle, status) except (AttributeError, TypeError): # Sometimes deletion during program shutdown throws exception as other # modules are no longer available. pass
def testStatusDoesNotLeak(self): try: with errors.raise_exception_on_not_ok_status() as status: pywrap_tensorflow.DeleteFile( compat.as_bytes("/DOES_NOT_EXIST/"), status) except: pass gc.collect() self.assertEqual(0, self._CountReferences(c_api_util.ScopedTFStatus))
def start(self): """Starts this server. Raises: tf.errors.OpError: Or one of its subclasses if an error occurs while starting the TensorFlow server. """ with errors.raise_exception_on_not_ok_status() as status: pywrap_tensorflow.PyServer_Start(self._server, status)
def read(self): """Returns the contents of a file as a string. Starts reading from current position in file. """ self._preread_check() with errors.raise_exception_on_not_ok_status() as status: length = self.size() - self.tell() return pywrap_tensorflow.ReadFromStream(self._read_buf, length, status)
def do_quantize_training_on_graphdef(input_graph, num_bits): from tensorflow.core.framework.graph_pb2 import GraphDef from tensorflow.python.framework import errors with errors.raise_exception_on_not_ok_status() as status: graph = GraphDef() result_graph_string = DoQuantizeTrainingOnGraphDefHelper( input_graph.SerializeToString(), num_bits, status) graph.ParseFromString(result_graph_string) return graph
def flush(self): """Flushes the Writable file. This only ensures that the data has made its way out of the process without any guarantees on whether it's written to disk. This means that the data would survive an application crash but not necessarily an OS crash. """ if self._writable_file: with errors.raise_exception_on_not_ok_status() as status: pywrap_tensorflow.FlushWritableFile(self._writable_file, status)
def op_attr_type(op_type, attr_name): try: return _op_attr_type_cache[(op_type, attr_name)] except KeyError: with errors.raise_exception_on_not_ok_status() as status: h = context.context()._handle # pylint: disable=protected-access attr_type = pywrap_tensorflow.TFE_OpNameGetAttrType( h, op_type, attr_name, status) _op_attr_type_cache[(op_type, attr_name)] = attr_type return attr_type
def __init__(self, target='', graph=None, config=None): """Constructs a new TensorFlow session. Args: target: (Optional) The TensorFlow execution engine to connect to. graph: (Optional) The graph to be used. If this argument is None, the default graph will be used. config: (Optional) ConfigProto proto used to configure the session. Raises: tf.errors.OpError: Or one of its subclasses if an error occurs while creating the TensorFlow session. """ if graph is None: self._graph = ops.get_default_graph() else: self._graph = graph self._opened = False self._closed = False self._current_version = 0 self._extend_lock = threading.Lock() self._target = target self._delete_lock = threading.Lock() self._dead_handles = [] self._session = None self._config = config self._add_shapes = config.graph_options.infer_shapes if ( config and config.graph_options) else False try: opts = tf_session.TF_NewSessionOptions(target=target, config=config) with errors.raise_exception_on_not_ok_status() as status: self._session = tf_session.TF_NewSession(opts, status) finally: tf_session.TF_DeleteSessionOptions(opts)
def __init__(self, server_or_cluster_def, job_name=None, task_index=None, protocol=None, start=True): """Creates a new server with the given definition. The `job_name`, `task_index`, and `protocol` arguments are optional, and override any information provided in `server_or_cluster_def`. Args: server_or_cluster_def: A `tf.train.ServerDef` or `tf.train.ClusterDef` protocol buffer, or a `tf.train.ClusterSpec` object, describing the server to be created and/or the cluster of which it is a member. job_name: (Optional.) Specifies the name of the job of which the server is a member. Defaults to the value in `server_or_cluster_def`, if specified. task_index: (Optional.) Specifies the task index of the server in its job. Defaults to the value in `server_or_cluster_def`, if specified. Otherwise defaults to 0 if the server's job has only one task. protocol: (Optional.) Specifies the protocol to be used by the server. Acceptable values include `"grpc"`. Defaults to the value in `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`. start: (Optional.) Boolean, indicating whether to start the server after creating it. Defaults to `True`. Raises: tf.errors.OpError: Or one of its subclasses if an error occurs while creating the TensorFlow server. """ server_def = _make_server_def(server_or_cluster_def, job_name, task_index, protocol) with errors.raise_exception_on_not_ok_status() as status: self._server = pywrap_tensorflow.PyServer_New( server_def.SerializeToString(), status) if start: self.start()
def Load(self): """Loads all new values from disk. Calling Load multiple times in a row will not 'drop' events as long as the return value is not iterated over. Yields: All values that were written to disk that have not been yielded yet. """ while True: try: with errors.raise_exception_on_not_ok_status() as status: self._reader.GetNext(status) except (errors.DataLossError, errors.OutOfRangeError): # We ignore partial read exceptions, because a record may be truncated. # PyRecordReader holds the offset prior to the failed read, so retrying # will succeed. break event = event_pb2.Event() event.ParseFromString(self._reader.record()) yield event logging.debug('No more events in %s', self._file_path)
def _swig_call(self, method, request, response): """Calls method, serializing and deserializing inputs and outputs. Note that this does not check the types of request and response. This can throw a variety of Python errors, based upon the underlying tensorflow error returned in MetadataStore. See _CODE_TO_EXCEPTION_CLASS in tensorflow/python/framework/errors_impl.py for the mapping. Args: method: the method to call in SWIG. request: a protobuf message, serialized and sent to the method. response: a protobuf message, filled from the return value of the method. Raises: Error: whatever tensorflow error is returned by the method. """ with errors.raise_exception_on_not_ok_status() as status: response_str = method(self._metadata_store, request.SerializeToString(), status) response.ParseFromString(response_str)
def _generic_iterator(self, file_path): """A helper method that makes an iterator given a debug-events file path. Repeated calls to this method create iterators that remember the last successful reading position (offset) for each given `file_path`. So the iterators are meant for incremental reading of the file. Args: file_path: Path to the file to create the iterator for. Yields: A tuple of (offset, debug_event_proto) on each `next()` call. """ # The following code uses the double-checked locking pattern to optimize # the common case (where the reader is already initialized). if file_path not in self._readers: # 1st check, without lock. with self._readers_lock: if file_path not in self._readers: # 2nd check, with lock. with errors.raise_exception_on_not_ok_status() as status: # TODO(b/136474806): Use tf_record.tf_record_iterator() once it # supports offset. self._readers[ file_path] = pywrap_tensorflow.PyRecordReader_New( compat.as_bytes(file_path), 0, b"", status) reader = self._readers[file_path] while True: offset = reader.offset() try: reader.GetNext() except (errors.DataLossError, errors.OutOfRangeError): # We ignore partial read exceptions, because a record may be truncated. # PyRecordReader holds the offset prior to the failed read, so retrying # will succeed. break yield DebugEventWithOffset( debug_event=debug_event_pb2.DebugEvent.FromString( reader.record()), offset=offset)
def MeasureCosts(self, item): """Returns the cost of running the specified item. Args: item: The item for which to measure the costs. Returns: The triplet op_perfs, runtime, step_stats. """ with errors.raise_exception_on_not_ok_status() as status: ret_from_swig = tf_cluster.TF_MeasureCosts(item.tf_item, self._tf_cluster, self._generate_timeline, status) if ret_from_swig is None: return None op_perf_bytes_list, run_time, step_stats_bytes = ret_from_swig op_perfs = [ op_performance_data_pb2.OpPerformance.FromString(op_perf_bytes) for op_perf_bytes in op_perf_bytes_list ] return (op_perfs, run_time, step_stats_pb2.StepStats.FromString(step_stats_bytes))
def tf_record_iterator(path, options=None): """An iterator that read the records from a TFRecords file. Args: path: The path to the TFRecords file. options: (optional) A TFRecordOptions object. Yields: Strings. Raises: IOError: If `path` cannot be opened for reading. """ compression_type = TFRecordOptions.get_compression_type_string(options) with errors.raise_exception_on_not_ok_status() as status: reader = pywrap_tensorflow.PyRecordReader_New( compat.as_bytes(path), 0, compat.as_bytes(compression_type), status) if reader is None: raise IOError("Could not open %s." % path) while reader.GetNext(): yield reader.record() reader.Close()
def _call_cpp_shape_fn_impl(op, input_tensors_needed, input_tensors_as_shapes_needed, require_shape_fn): """Core implementaton of call_cpp_shape_fn.""" graph_def_version = op.graph.graph_def_versions.producer node_def_str = op.node_def.SerializeToString() def tensor_to_inference_result(t): r = cpp_shape_inference_pb2.CppShapeInferenceResult() r.shape.CopyFrom(t.get_shape().as_proto()) # pylint: disable=protected-access if t._handle_data is not None: r.handle_data.CopyFrom(t._handle_data) # pylint: enable=protected-access return r.SerializeToString() input_shapes = [tensor_to_inference_result(i) for i in op.inputs] input_tensors = [None for i in input_shapes] for idx in input_tensors_needed: v = tensor_util.constant_value(op.inputs[idx]) if v is not None: input_tensors[idx] = np.asarray(v) serialized_unknown_shape = ( tensor_shape.TensorShape(None).as_proto().SerializeToString()) arr = [serialized_unknown_shape for i in input_shapes] for idx in input_tensors_as_shapes_needed: s = tensor_util.constant_value_as_shape(op.inputs[idx]) if s is not None: arr[idx] = s.as_proto().SerializeToString() input_tensors_as_shapes = arr missing_shape_fn = False try: with errors.raise_exception_on_not_ok_status() as status: output = pywrap_tensorflow.RunCppShapeInference( graph_def_version, node_def_str, input_shapes, input_tensors, input_tensors_as_shapes, status) except errors.InvalidArgumentError as err: if err.message.startswith("No shape inference function exists for op"): missing_shape_fn = True else: raise ValueError(err.message) if missing_shape_fn: if require_shape_fn: raise RuntimeError( "No C++ shape function registered for standard op: %s" % op.type) return unknown_shape(op) output_shapes = output[:-1] # Convert TensorShapeProto values in output_shapes. result_protos = [ cpp_shape_inference_pb2.CppShapeInferenceResult().FromString(s) for s in output_shapes ] result = [r.shape for r in result_protos] result_handle_data = [ r.handle_data if r.handle_data.is_set else None for r in result_protos ] return { "shapes": result, "handle_data": result_handle_data, "inputs_needed": output[-1] }
def create_dir(dirname): with errors.raise_exception_on_not_ok_status() as status: pywrap_tensorflow.CreateDir(compat.as_bytes(dirname), status)
def rename(oldname, newname, overwrite=False): with errors.raise_exception_on_not_ok_status() as status: return pywrap_tensorflow.RenameFile(compat.as_bytes(oldname), compat.as_bytes(newname), overwrite, status)
def copy(oldpath, newpath, overwrite=False): with errors.raise_exception_on_not_ok_status() as status: pywrap_tensorflow.CopyFile(compat.as_bytes(oldpath), compat.as_bytes(newpath), overwrite, status)
def _create_definition_if_needed_impl(self): """This is not what you want, see _create_definition_if_needed.""" if self._definition is not None or self._c_func is not None: return # Create the func_def object. temp_graph = _FuncGraph(capture_by_value=self._capture_by_value) with temp_graph.as_default(): # List of placeholders for the function_def. inputs = [] for (argname, argtype) in self._args: argholder = array_ops.placeholder(argtype, name=argname) inputs.append(argholder) # Call func and gather the output tensors. with vs.variable_scope("", custom_getter=temp_graph.getvar): outputs = self._func(*inputs) # There is no way of distinguishing between a function not returning # anything and a function returning None in Python. # We need to allow the former and ideally want to forbid the latter as # it is most likely user error. # TODO(iga): Consider adding a @NoOutput decorator on top of @Defun to # allow users to explicitly mark the function as not returning anything. # For now, we allow a single None return and interpret it as a function # with no output. if outputs is None: outputs = [] else: # If func only returned one value, make it a tuple. if not isinstance(outputs, (list, tuple)): outputs = (outputs,) if any([_ is None for _ in outputs]): raise ValueError("Function can not return None.") # Ensures each output is a Tensor. outputs = [ops.convert_to_tensor(_) for _ in outputs] self._extra_inputs = temp_graph.extra_inputs inputs.extend(temp_graph.extra_args) # pylint: disable=protected-access self._sub_functions = temp_graph._functions # pylint: enable=protected-access # Extra kwargs are treated as attrs on the function def. base_func_name = self._func_name or _get_func_name(self._func) kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs) if not temp_graph._c_graph: # pylint: disable=protected-access # Build the FunctionDef self._definition = graph_to_function_def.graph_to_function_def( temp_graph, temp_graph.get_operations(), inputs, outputs, out_names=self._out_names) for k in kwargs_attr: self._definition.attr[k].CopyFrom(kwargs_attr[k]) # Hash the definition and its dependencies. self._hash_str = self._create_hash_str( self._definition.signature.input_arg, self._definition.signature.output_arg, self._definition.node_def) # Finally, we decide the function name to use. If not specified, # make up something which is almost certainly unique (but deterministic). if not self._func_name: self._func_name = "_".join([base_func_name, self._hash_str]) self._definition.signature.name = self._func_name if self._func.__doc__: self._definition.signature.description = self._func.__doc__ self._op_def = self._definition.signature else: # C API is enabled output_names = ([compat.as_bytes(x) for x in self._out_names] if self._out_names else []) description = self._func.__doc__ or None # pylint: disable=protected-access with errors.raise_exception_on_not_ok_status() as status: self._c_func = c_api.TF_GraphToFunction_wrapper( temp_graph._c_graph, base_func_name, self._func_name is None, # append_hash_to_fn_name None, # opers [t._as_tf_output() for t in inputs], [t._as_tf_output() for t in outputs], output_names, None, # opts description, status) # pylint: enable=protected-access self._set_c_attrs(kwargs_attr) # Set cached fields: _op_def and _func_name (if not already set) self._op_def = self.definition.signature if self._func_name: assert self._func_name == self._op_def.name else: self._func_name = compat.as_str(self._op_def.name)
def __del__(self): self.close() if self._session is not None: with errors.raise_exception_on_not_ok_status() as status: tf_session.TF_DeleteSession(self._session, status) self._session = None
def _create_definition_if_needed_impl(self): """This is not what you want, see _create_definition_if_needed.""" if self._definition is not None: return # Create the func_def object. temp_graph = _FuncGraph(capture_by_value=self._capture_by_value) with temp_graph.as_default(): # List of placeholders for the function_def. inputs = [] for (argname, argtype) in self._args: argholder = array_ops.placeholder(argtype, name=argname) inputs.append(argholder) # Call func and gather the output tensors. with vs.variable_scope("", custom_getter=temp_graph.getvar): outputs = self._func(*inputs) # If func only returned one value, make it a tuple. if not isinstance(outputs, (list, tuple)): outputs = (outputs,) if any([_ is None for _ in outputs]): raise ValueError("Function can not return None.") # Ensures each output is a Tensor. outputs = [ops.convert_to_tensor(_) for _ in outputs] self._extra_inputs = temp_graph.extra_inputs inputs.extend(temp_graph.extra_args) # pylint: disable=protected-access self._sub_functions = temp_graph._functions # pylint: enable=protected-access # Build the FunctionDef self._definition = graph_to_function_def.graph_to_function_def( temp_graph, temp_graph.get_operations(), inputs, outputs, out_names=self._out_names) # Extra kwargs are treated as attrs on the function def. sig_pre_func_name = self._func_name or _get_func_name(self._func) kwargs_attr = _parse_kwargs_as_attrs(sig_pre_func_name, **self._extra_kwargs) for k in kwargs_attr: self._definition.attr[k].CopyFrom(kwargs_attr[k]) # Hash the definition and its dependencies. self._hash_str = self._create_hash_str( self._definition.signature.input_arg, self._definition.signature.output_arg, self._definition.node_def) # Finally, we decide the function name to use. If not specified, # make up something which is almost certainly unique (but deterministic). if not self._func_name: self._func_name = "_".join([_get_func_name(self._func), self._hash_str]) self._definition.signature.name = self._func_name if self._func.__doc__: self._definition.signature.description = self._func.__doc__ # pylint: disable=protected-access if temp_graph._c_graph: output_names = ([compat.as_bytes(x) for x in self._out_names] if self._out_names else []) description = self._func.__doc__ or None with errors.raise_exception_on_not_ok_status() as status: self._c_func = c_api.TF_GraphToFunction_wrapper( temp_graph._c_graph, self._func_name, False, # append_hash_to_fn_name None, # opers [t._as_tf_output() for t in inputs], [t._as_tf_output() for t in outputs], output_names, None, # opts description, status) self._set_c_attrs(kwargs_attr)
def read_file_to_string(filename): with errors.raise_exception_on_not_ok_status() as status: return pywrap_tensorflow.ReadFileToString(compat.as_bytes(filename), status)
def __del__(self): if self._handle is not None: with errors.raise_exception_on_not_ok_status() as status: pywrap_tensorflow.TFE_DeleteContext(self._handle, status)
def _BuildTFItem(self): with errors.raise_exception_on_not_ok_status() as status: self._tf_item = tf_item.TF_NewItem( self._metagraph.SerializeToString(), self._ignore_colocation, self._ignore_user_placement, status)
def write(self, file_content): """Writes file_content to the file. Appends to the end of the file.""" self._prewrite_check() with errors.raise_exception_on_not_ok_status() as status: pywrap_tensorflow.AppendToFile(compat.as_bytes(file_content), self._writable_file, status)
def _create_offset_reader(self, file_path, offset): with errors.raise_exception_on_not_ok_status() as status: # TODO(b/136474806): Use tf_record.tf_record_iterator() once it # supports ofset. return pywrap_tensorflow.PyRecordReader_New( file_path, offset, b"", status)
def seek(self, position): """Seeks to the position in the file.""" self._preread_check() with errors.raise_exception_on_not_ok_status() as status: ret_status = self._read_buf.Seek(position) pywrap_tensorflow.Set_TF_Status_from_Status(status, ret_status)
def get_matching_files(filename): with errors.raise_exception_on_not_ok_status() as status: return pywrap_tensorflow.GetMatchingFiles(compat.as_bytes(filename), status)
# setup low level args for TF_Run call session = sess._session options=None feed_dict = {} # uncomment lines below if you want to fetch things fetch_list = [b'MatMul_2:0'] target_list = [] if len(sys.argv)>1 and 'nofetch' in sys.argv[1]: fetch_list=[] target_list=[b'MatMul_2'] run_metadata = None status_orig = errors.raise_exception_on_not_ok_status() status = pywrap_tensorflow.TF_NewStatus() def fast_tf(): return tf_session.TF_Run(session, options, feed_dict, fetch_list, target_list, status, run_metadata) num_iters = 5000 warmup_iters = 2 iter_times = np.zeros((num_iters+warmup_iters,)) y = create_graph() for i in range(num_iters+warmup_iters): iter_start = time.time() if i == warmup_iters: start_time = time.time()
def write_string_to_file(filename, file_content): with errors.raise_exception_on_not_ok_status() as status: pywrap_tensorflow.WriteStringToFile(compat.as_bytes(filename), compat.as_bytes(file_content), status)
def flush(self): """Flush the file.""" with errors.raise_exception_on_not_ok_status() as status: self._writer.Flush(status)
def op_attr_type(op_type, attr_name): with errors.raise_exception_on_not_ok_status() as status: h = context.context()._handle # pylint: disable=protected-access op = pywrap_tensorflow.TFE_NewOp(h, op_type, status) attr_type = pywrap_tensorflow.TFE_OpGetAttrType(op, attr_name, status) return attr_type
def import_graph_def(graph_def, input_map=None, return_elements=None, name=None, op_dict=None, producer_op_list=None): """Imports the graph from `graph_def` into the current default `Graph`. This function provides a way to import a serialized TensorFlow [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) protocol buffer, and extract individual objects in the `GraphDef` as @{tf.Tensor} and @{tf.Operation} objects. Once extracted, these objects are placed into the current default `Graph`. See @{tf.Graph.as_graph_def} for a way to create a `GraphDef` proto. Args: graph_def: A `GraphDef` proto containing operations to be imported into the default graph. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. return_elements: A list of strings containing operation names in `graph_def` that will be returned as `Operation` objects; and/or tensor names in `graph_def` that will be returned as `Tensor` objects. name: (Optional.) A prefix that will be prepended to the names in `graph_def`. Note that this does not apply to imported function names. Defaults to `"import"`. op_dict: (Optional.) Deprecated, do not use. producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped) list of `OpDef`s used by the producer of the graph. If provided, unrecognized attrs for ops in `graph_def` that have their default value according to `producer_op_list` will be removed. This will allow some more `GraphDef`s produced by later binaries to be accepted by earlier binaries. Returns: A list of `Operation` and/or `Tensor` objects from the imported graph, corresponding to the names in `return_elements`. Raises: TypeError: If `graph_def` is not a `GraphDef` proto, `input_map` is not a dictionary mapping strings to `Tensor` objects, or `return_elements` is not a list of strings. ValueError: If `input_map`, or `return_elements` contains names that do not appear in `graph_def`, or `graph_def` is not well-formed (e.g. it refers to an unknown tensor). """ op_dict = op_def_registry.get_registered_ops() graph_def = _ProcessGraphDefParam(graph_def, op_dict) input_map = _ProcessInputMapParam(input_map) return_elements = _ProcessReturnElementsParam(return_elements) if producer_op_list is not None: # TODO(skyewm): make a copy of graph_def so we're not mutating the argument? _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def) graph = ops.get_default_graph() if graph._c_graph: # pylint: disable=protected-access with ops.name_scope(name, 'import', input_map.values()) as scope: # Save unique prefix generated by name_scope if scope: assert scope.endswith('/') prefix = scope[:-1] else: prefix = '' # Generate any input map tensors inside name scope input_map = _ConvertInputMapValues(name, input_map) scoped_options = c_api_util.ScopedTFImportGraphDefOptions() options = scoped_options.options _PopulateTFImportGraphDefOptions(options, prefix, input_map, return_elements) # _ProcessNewOps mutates the new operations. _lock ensures a Session.run # call cannot occur between creating the TF_Operations in the # TF_GraphImportGraphDefWithResults call and mutating the them in # _ProcessNewOps. with graph._lock: # pylint: disable=protected-access with c_api_util.tf_buffer( graph_def.SerializeToString()) as serialized: try: with errors.raise_exception_on_not_ok_status() as status: results = c_api.TF_GraphImportGraphDefWithResults( graph._c_graph, serialized, options, status) # pylint: disable=protected-access except errors.InvalidArgumentError as e: # Convert to ValueError for backwards compatibility. raise ValueError(str(e)) _ProcessNewOps(graph) # Create _DefinedFunctions for any imported functions. # # We do this by creating _DefinedFunctions directly from `graph_def`, and # adding them to `graph`. Adding an existing function to a TF_Graph is a # no-op, so this only has the effect of updating the Python state (usually # _DefinedFunction.add_to_graph also adds the function to the TF_Graph). # # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph if graph_def.library and graph_def.library.function: # pylint: disable=protected-access functions = function._from_library(graph_def.library) for f in functions: f.add_to_graph(graph) # pylint: enable=protected-access # Treat input mappings that don't appear in the graph as an error, because # they are likely to be due to a typo. missing_unused_input_keys = ( c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper( results)) if missing_unused_input_keys: missing_unused_input_keys = [ compat.as_str(s) for s in missing_unused_input_keys ] raise ValueError( 'Attempted to map inputs that were not found in graph_def: [%s]' % ', '.join(missing_unused_input_keys)) if return_elements is None: return None else: return _GatherReturnElements(return_elements, graph, results) else: g = graph # Use a canonical representation for all tensor names. input_map = {_CanonicalInputName(k): v for k, v in input_map.items()} used_input_keys = set() name_to_op = {} # Add any functions defined in `graph_def` to `g` if graph_def.library and graph_def.library.function: # Copy op_dict so we don't clobber the original op_dict = copy.copy(op_dict) # pylint: disable=protected-access # Note that we do not prepend `name` to the function name. The reasoning # is that function names are similar to op definition names, which # currently do not have a scoped name or namespace scheme. functions = function._from_library(graph_def.library) for f in functions: f.add_to_graph(g) op_dict[f.name] = f.definition.signature # pylint: enable=protected-access # LINT.IfChange with ops.name_scope(name, 'import', input_map.values()) as scope: # TODO(ashankar): Should this just copy over or should it do some # more nuanced merging? For example, the graph may already have some # marked "bad versions" and we don't want to lose those because of # what's in graph_def.versions? The C++ ImporGraphDef does something # more nuanced. g.graph_def_versions.CopyFrom(graph_def.versions) input_map = _ConvertInputMapValues(name, input_map) # NOTE(mrry): We do this in two passes, because there may be a cycle in # `graph_def`. # 1. Add operations without their inputs. for node in graph_def.node: # Check to see if this op's name matches a previously seen op if node.name in name_to_op: raise ValueError('Duplicate name \'%s\' in GraphDef.' % node.name) if node.op not in op_dict: raise ValueError('No op named %s in defined operations.' % node.op) op_def = op_dict[node.op] output_types = _OutputTypes(node, op_dict) name_to_op[node.name] = g.create_op(node.op, [], output_types, name=node.name, attrs=node.attr, compute_shapes=False, compute_device=False, op_def=op_def) # Maps from a node to the ops it is colocated with, if colocation # is specified in the attributes. colocation_pairs = collections.defaultdict(list) # 2. Add inputs to the operations. for node in graph_def.node: op = name_to_op[node.name] input_types = _InputTypes(node, op_dict) apply_device_function = True # Rewrite the colocation attributes in the graph, since the # names of new ops may have changed. for key, value in op.node_def.attr.items(): if key == '_class': class_values = value.list new_class_values = [] for class_value in class_values.s: if class_value.startswith(b'loc:@'): op_to_bind_to = class_value[5:].decode() # Find the op by its original name. if op_to_bind_to not in name_to_op: raise ValueError( 'Specified colocation to an op that ' 'does not exist during import: %s in %s' % (op_to_bind_to, node.name)) original_op = name_to_op[op_to_bind_to] new_class_values.append( compat.as_bytes('loc:@' + original_op.name)) if op_to_bind_to != node.name: # Keep track of this mapping for a later phase. colocation_pairs[op].append(original_op) # Don't apply this op's device function, # the colocation constraint will ensure # the proper device gets assigned at runtime. apply_device_function = False else: new_class_values.append(class_value) value.list.CopyFrom( attr_value_pb2.AttrValue.ListValue( s=new_class_values)) # NOTE(mrry): We cannot use zip here because control inputs do not # appear in the list of input_types. for i, input_name in enumerate( [_CanonicalInputName(x) for x in node.input]): if _IsControlInput(input_name): # (a) Input is a control input that should be taken from an op # in "graph_def". try: source_op = name_to_op[input_name[1:]] except KeyError: raise ValueError( _InvalidNodeMessage( node, 'Control input %r not found in graph_def.' % (input_name, ))) # pylint: disable=protected-access op._add_control_input(source_op) # pylint: enable=protected-access else: try: input_type = input_types[i] except IndexError: raise ValueError( _InvalidNodeMessage( node, 'More inputs specified (%r) than the op expects.' % (input_name, ))) if input_name in input_map: # (b) Input should be replaced by a tensor from the caller. source_tensor = input_map[input_name] used_input_keys.add(input_name) else: # (c) Input should be taken from an op in `graph_def`. operation_name, output_index = _ParseTensorName( input_name) try: source_op = name_to_op[operation_name] source_tensor = list( source_op.values())[output_index] except (KeyError, IndexError): raise ValueError( _InvalidNodeMessage( node, 'Input tensor %r not found in graph_def.' % (input_name, ))) try: # pylint: disable=protected-access op._add_input(source_tensor, dtype=input_type) # pylint: enable=protected-access except TypeError as te: raise ValueError( _InvalidNodeMessage( node, 'Input tensor %r %s' % (input_name, te))) # pylint: disable=protected-access if op._input_types != input_types: raise ValueError( _InvalidNodeMessage( node, 'Input types mismatch (expected %r but got %r)' % (', '.join( dtypes.as_dtype(x).name for x in input_types), ', '.join( x.name for x in op._input_types)))) # pylint: enable=protected-access if not g._is_function(op.type): # pylint: disable=protected-access # Execute shape inference for this op. # NOTE(mrry): If the graph contains a cycle, the full shape # information may not be available for this op's inputs. ops.set_shapes_for_outputs(op) # For nodes with _output_shapes set, set the output shapes. if '_output_shapes' in op.node_def.attr: for i, output in enumerate(op.outputs): dims = op.node_def.attr['_output_shapes'].list.shape[i] output_shape = tensor_shape.TensorShape( None if dims.unknown_rank else [ dim.size if dim.size >= 0 else None for dim in dims.dim ]) try: output.set_shape(output_shape) except ValueError as e: # If the output shape is incompatible with what is inferred # by the graph for a very specific whitelist of ops, then we # ignore this output shape. This can happen if there is a # bug in the shape function for some operation, and the # serialized graph def has the incorrect shape set when # running on a newer binary with the fixed shape function. # This is an escape hatch that allows us to correct shape # functions that are not critical to correct execution but # would cause graphs to fail if imported after correcting. # # This can be removed after 2017/03/08. if op.type in [ 'RandomShuffleQueue', 'PaddingFIFOQueue', 'FIFOQueue', 'PriorityQueue', 'QueueSize', 'Stack', 'Barrier', 'BarrierReadySize', 'BarrierIncompleteSize', 'HashTable', 'MutableHashTable', 'MutableHashTableOfTensors', 'Mutex', 'CuckooTable', 'IndexTable', 'WholeFileReader', 'TextLineReader', 'FixedLengthRecordReader', 'TFRecordReader', 'IdentityReader', 'LMDBReader', 'RefSwitch', 'RefEnter', 'RefNextIteration', 'RefMerge', 'RefIdentity' ]: pass elif op.type in [ 'ConditionalAccumulator', 'SparseConditionalAccumulator', 'Table' ]: # This can be removed after 2017/04/24. pass else: raise e del op.node_def.attr['_output_shapes'] # NOTE(mrry): We do this after configuring the inputs, because # the result of the device functions may depend on the inputs. if apply_device_function: with _MaybeDevice(node.device): g._apply_device_functions(op) # pylint: disable=protected-access # The following loop populates the device field of ops that are # colocated with another op. This is implied by the colocation # attribute, but we propagate the device field for completeness. for op, coloc_op_list in colocation_pairs.items(): coloc_device = None # Find any device in the list of colocated ops that have a # device, if it exists. We assume that if multiple ops # have devices, they refer to the same device. Otherwise, a # runtime error will occur since the colocation property # cannot be guaranteed. # # One possible improvement is to try to check for compatibility # of all devices in this list at import time here, which would # require implementing a compatibility function for device specs # in python. for coloc_op in coloc_op_list: if coloc_op.device: coloc_device = pydev.DeviceSpec.from_string( coloc_op.device) break if coloc_device: op._set_device(coloc_device) # pylint: disable=protected-access # Treat input mappings that don't appear in the graph as an error, # because they are likely to be due to a typo. def _IsImportedNodeOutput(tensor_name): operation_name, output_index = _ParseTensorName(tensor_name) try: return output_index < len( name_to_op[operation_name].outputs) except KeyError: return False absent_input_keys = [ k for k in frozenset(input_map.keys()).difference( used_input_keys) if not _IsImportedNodeOutput(k) ] if absent_input_keys: raise ValueError( 'Attempted to map inputs that were not found in graph_def: [%s]' % ', '.join(absent_input_keys)) if return_elements is None: return None else: ret = [] for name in return_elements: name = compat.as_str(name) if ':' in name: try: operation_name, output_index = _ParseTensorName( name) ret.append(name_to_op[operation_name]. outputs[output_index]) except (ValueError, KeyError, IndexError): raise ValueError( 'Requested return_element %r not found in graph_def.' % name) else: try: ret.append(name_to_op[name]) except KeyError: raise ValueError( 'Requested return_element %r not found in graph_def.' % name) return ret
def _prun_fn(session, handle, feed_dict, fetch_list): if target_list: raise RuntimeError('partial_run() requires empty target_list.') with errors.raise_exception_on_not_ok_status() as status: return tf_session.TF_PRun(session, handle, feed_dict, fetch_list, status)
def _setup_fn(session, feed_list, fetch_list, target_list): self._extend_graph() with errors.raise_exception_on_not_ok_status() as status: return tf_session.TF_PRunSetup(session, feed_list, fetch_list, target_list, status)
def close(self): """Close the file.""" with errors.raise_exception_on_not_ok_status() as status: self._writer.Close(status)
def call_cpp_shape_fn(op, input_tensors_needed=None, input_tensors_as_shapes_needed=None, debug_python_shape_fn=None, require_shape_fn=True): """A shape function that delegates to the registered C++ shape function. Args: op: the node in the graph for which to compute output shapes. input_tensors_needed: a list of input tensor indices for which to compute the input tensor's value and pass to the C++ shape function. input_tensors_as_shapes_needed: a list of input tensor indices for which to compute the constant_value_as_shape and pass to the C++ shape function. debug_python_shape_fn: For testing only during migration to using call_cpp_shape_fn. Do not submit calls that set this, as the comparison is slow. If non-None, the python shape function; this function will be called and its output compared to that of the C++ shape function. require_shape_fn: If true, and the C++ shape function is not registered in the current binary then an exception is raised; otherwise, if the C++ shape function is not registered then unknown_shape is used. Returns: A dictionary with the following keys: shapes: A TensorShape list of the output shapes of the op, as computed using the C++ shape inference function registered for the op. handle_shapes: A TensorShape list of the shapes for handle outputs, if any. handle_dtypes: A list of DataType enums for the handle outputs, if any. Raises: ValueError: If the C++ shape function returned an error (e.g. because the shapes of the inputs are of the wrong rank or otherwise incompatible according to the shape function). RuntimeError: If the C++ shape function is not registered and <require_shape_fn> is True. """ if op.type == "Const": # To avoid serializing large constants, we special-case constant # here, even though it has a C++ shape function. When Python # calls the C / C-API directly, we should be able to remove this. return { "shapes": [tensor_shape.TensorShape(op.get_attr("value").tensor_shape)], "handle_shapes": [tensor_shape.TensorShape(None).as_proto()], "handle_dtypes": [types_pb2.DT_INVALID] } node_def_str = op.node_def.SerializeToString() def tensor_to_inference_result(t): r = cpp_shape_inference_pb2.CppShapeInferenceResult() r.shape.CopyFrom(t.get_shape().as_proto()) # pylint: disable=protected-access r.handle_shape.CopyFrom(t._handle_shape) r.handle_dtype = t._handle_dtype # pylint: enable=protected-access return r.SerializeToString() input_shapes = [tensor_to_inference_result(i) for i in op.inputs] input_tensors = [None for i in input_shapes] if input_tensors_needed: for idx in input_tensors_needed: v = tensor_util.constant_value(op.inputs[idx]) if v is not None: input_tensors[idx] = np.asarray(v) serialized_unknown_shape = ( tensor_shape.TensorShape(None).as_proto().SerializeToString()) arr = [serialized_unknown_shape for i in input_shapes] if input_tensors_as_shapes_needed: for idx in input_tensors_as_shapes_needed: s = tensor_util.constant_value_as_shape(op.inputs[idx]) if s is not None: arr[idx] = s.as_proto().SerializeToString() input_tensors_as_shapes = arr missing_shape_fn = False try: with errors.raise_exception_on_not_ok_status() as status: output_shapes = pywrap_tensorflow.RunCppShapeInference( node_def_str, input_shapes, input_tensors, input_tensors_as_shapes, status) except errors.InvalidArgumentError as err: if err.message.startswith("No shape inference function exists for op"): missing_shape_fn = True else: raise ValueError(err.message) if missing_shape_fn: if require_shape_fn: raise RuntimeError( "No C++ shape function registered for standard op: %s" % op.type) return unknown_shape(op) # Convert TensorShapeProto values in output_shapes. result_protos = [ cpp_shape_inference_pb2.CppShapeInferenceResult().FromString(s) for s in output_shapes ] result = [r.shape for r in result_protos] result_handle_shapes = [r.handle_shape for r in result_protos] result_handle_dtypes = [r.handle_dtype for r in result_protos] if debug_python_shape_fn: try: python_result = [ tensor_shape.as_shape(s) for s in debug_python_shape_fn(op) ] except Exception as err: raise AssertionError("Python shape function return error but " "C++ shape functon did not: %s" % str(err)) result_as_shapes = [tensor_shape.as_shape(s) for s in result] if str(result_as_shapes) != str(python_result): raise ValueError( ("Python vs CPP shape mismatch. " "CPP: %s vs python: %s on node %s " "with input shapes %s") % (str(result_as_shapes), str(python_result), str(op.node_def), ",".join([str(i.get_shape()) for i in op.inputs]))) return { "shapes": result, "handle_shapes": result_handle_shapes, "handle_dtypes": result_handle_dtypes }