Example #1
0
def tf_record_iterator(path, options=None):
  """An iterator that read the records from a TFRecords file.

  Args:
    path: The path to the TFRecords file.
    options: (optional) A TFRecordOptions object.

  Yields:
    Strings.

  Raises:
    IOError: If `path` cannot be opened for reading.
  """
  compression_type = TFRecordOptions.get_compression_type_string(options)
  with errors.raise_exception_on_not_ok_status() as status:
    reader = pywrap_tensorflow.PyRecordReader_New(
        compat.as_bytes(path), 0, compat.as_bytes(compression_type), status)

  if reader is None:
    raise IOError("Could not open %s." % path)
  while True:
    try:
      with errors.raise_exception_on_not_ok_status() as status:
        reader.GetNext(status)
    except errors.OutOfRangeError:
      break
    yield reader.record()
  reader.Close()
Example #2
0
  def _initialize_handle_and_devices(self):
    """Initialize handle and devices."""
    with self._initialize_lock:
      if self._context_handle is not None:
        return
      assert self._context_devices is None
      opts = pywrap_tensorflow.TF_NewSessionOptions(
          target=compat.as_bytes(""), config=self._config)
      with errors.raise_exception_on_not_ok_status() as status:
        self._context_handle = pywrap_tensorflow.TFE_NewContext(opts, status)
        pywrap_tensorflow.TF_DeleteSessionOptions(opts)
      # Store list of devices
      self._context_devices = []
      with errors.raise_exception_on_not_ok_status() as status:
        device_list = pywrap_tensorflow.TFE_ContextListDevices(
            self._context_handle, status)
      try:
        self._num_gpus = 0
        for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)):
          with errors.raise_exception_on_not_ok_status() as status:
            dev_name = pywrap_tensorflow.TF_DeviceListName(
                device_list, i, status)
          self._context_devices.append(pydev.canonical_name(dev_name))
          with errors.raise_exception_on_not_ok_status() as status:
            dev_type = pywrap_tensorflow.TF_DeviceListType(
                device_list, i, status)
          if dev_type == "GPU":
            self._num_gpus += 1

      finally:
        pywrap_tensorflow.TF_DeleteDeviceList(device_list)
Example #3
0
def call_cpp_shape_fn(op):
  """A shape function that delegates to the registered C++ shape function.

  Args:
    op: the node in the graph for which to compute output shapes.

  Returns:
    A TensorShape list of the output shapes of the op, as computed using the
    C++ shape inference function registered for the op.

  Raises:
    ValueError: If the C++ shape function returned an error (e.g. because the
    shapes of the inputs are of the wrong rank or otherwise incompatible
    according to the shape function).
  """
  node_def_str = op.node_def.SerializeToString()
  input_shapes = [i.get_shape().as_proto().SerializeToString() for i in
                  op.inputs]

  try:
    with errors.raise_exception_on_not_ok_status() as status:
      output_shapes = pywrap_tensorflow.RunCppShapeInference(
          node_def_str, input_shapes, status)
  except errors.InvalidArgumentError as err:
    raise ValueError(err.message)

  # Convert TensorShapeProto values in output_shapes.
  return [
      tensor_shape.TensorShape(tensor_shape_pb2.TensorShapeProto.FromString(s))
      for s in output_shapes
  ]
Example #4
0
def list_directory(dirname):
  """Returns a list of entries contained within a directory.

  The list is in arbitrary order. It does not contain the special entries "."
  and "..".

  Args:
    dirname: string, path to a directory

  Returns:
    [filename1, filename2, ... filenameN] as strings

  Raises:
    errors.NotFoundError if directory doesn't exist
  """
  if not is_directory(dirname):
    raise errors.NotFoundError(None, None, "Could not find directory")
  with errors.raise_exception_on_not_ok_status() as status:
    # Convert each element to string, since the return values of the
    # vector of string should be interpreted as strings, not bytes.
    return [
        compat.as_str_any(filename)
        for filename in pywrap_tensorflow.GetChildren(
            compat.as_bytes(dirname), status)
    ]
Example #5
0
def smart_constant_value(pred):
  """Return the bool value for `pred`, or None if `pred` had a dynamic value.

  Arguments:
    pred: A scalar, either a Python bool or tensor.

  Returns:
    True or False if `pred` has a constant boolean value, None otherwise.

  Raises:
    TypeError: If `pred` is not a Tensor or bool.
  """
  if pred in {0, 1}:  # Accept 1/0 as valid boolean values
    pred_value = bool(pred)
  elif isinstance(pred, bool):
    pred_value = pred
  elif isinstance(pred, ops.Tensor):
    pred_value = tensor_util.constant_value(pred)
    # TODO(skyewm): consider folding this into tensor_util.constant_value when
    # _USE_C_API is removed (there may be performance and correctness bugs, so I
    # wanted to limit the change hidden behind _USE_C_API).
    # pylint: disable=protected-access
    if pred_value is None and ops._USE_C_API:
      with errors.raise_exception_on_not_ok_status() as status:
        pred_value = c_api.TF_TryEvaluateConstant_wrapper(
            pred.graph._c_graph, pred._as_tf_output(), status)
    # pylint: enable=protected-access

  else:
    raise TypeError("`pred` must be a Tensor, or a Python bool, or 1 or 0. "
                    "Found instead: %s" % pred)
  return pred_value
Example #6
0
  def __init__(self,
               allow_soft_placement=True,
               disable_detailed_stats=True,
               disable_timeline=True,
               devices=None):
    """Creates a Cluster.

    Args:
      allow_soft_placement: If True, TF will automatically fix illegal
        placements instead of erroring out if the placement isn't legal.
      disable_detailed_stats: If True, detailed statistics will not be
        available.
      disable_timeline: If True, the timeline information will not be reported.
      devices: A list of devices of type device_properties_pb2.NamedDevice.
        If None, a device list will be created based on the spec of
        the local machine.
    """
    self._tf_cluster = None
    self._generate_timeline = not disable_timeline
    with errors.raise_exception_on_not_ok_status() as status:
      if devices is None:
        self._tf_cluster = tf_cluster.TF_NewCluster(
            allow_soft_placement, disable_detailed_stats, status)
      else:
        devices_serialized = [device.SerializeToString() for device in devices]
        self._tf_cluster = tf_cluster.TF_NewVirtualCluster(
            devices_serialized, status)
Example #7
0
def capture_value(tensor_map, value, dtype, name):
  """Capture a value from outside the function, to pass in as an extra arg."""
  captured_value = tensor_map.get(ops.tensor_id(value), None)
  if captured_value is None:
    captured_value = graph_placeholder(
        dtype=dtype or value.dtype, shape=value.shape, name=name)
    if captured_value.dtype == dtypes_module.resource:
      handle_data = value._handle_data  # pylint: disable=protected-access
      captured_value._handle_data = handle_data  # pylint: disable=protected-access
      if handle_data is not None and handle_data.is_set:
        # Ensure that shapes and dtypes are propagated.
        shapes, types = zip(*[(pair.shape, pair.dtype)
                              for pair in handle_data.shape_and_type])
        ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
        shapes = [[d.size for d in s.dim]
                  if not s.unknown_rank else None for s in shapes]
        with errors.raise_exception_on_not_ok_status() as status:
          pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
              captured_value._op._graph._c_graph,  # pylint: disable=protected-access
              captured_value._as_tf_output(),  # pylint: disable=protected-access
              shapes,
              ranks,
              types,
              status)

    tensor_map[ops.tensor_id(value)] = (value, captured_value)
  else:
    captured_value = captured_value[1]
  tape.record_operation("captured_value", [captured_value], [value],
                        lambda x: [x])
  return captured_value
Example #8
0
    def __init__(self, server_or_cluster_def, job_name=None, task_index=None, protocol=None, config=None, start=True):
        """Creates a new server with the given definition.

    The `job_name`, `task_index`, and `protocol` arguments are optional, and
    override any information provided in `server_or_cluster_def`.

    Args:
      server_or_cluster_def: A `tf.train.ServerDef` or
        `tf.train.ClusterDef` protocol buffer, or a
        `tf.train.ClusterSpec` object, describing the server to be
        created and/or the cluster of which it is a member.
      job_name: (Optional.) Specifies the name of the job of which the server
        is a member. Defaults to the value in `server_or_cluster_def`, if
        specified.
      task_index: (Optional.) Specifies the task index of the server in its
        job. Defaults to the value in `server_or_cluster_def`, if specified.
        Otherwise defaults to 0 if the server's job has only one task.
      protocol: (Optional.) Specifies the protocol to be used by the server.
        Acceptable values include `"grpc"`. Defaults to the value in
        `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`.
      config: (Options.) A `tf.ConfigProto` that specifies default
        configuration options for all sessions that run on this server.
      start: (Optional.) Boolean, indicating whether to start the server
        after creating it. Defaults to `True`.

    Raises:
      tf.errors.OpError: Or one of its subclasses if an error occurs while
        creating the TensorFlow server.
    """
        self._server_def = _make_server_def(server_or_cluster_def, job_name, task_index, protocol, config)
        with errors.raise_exception_on_not_ok_status() as status:
            self._server = pywrap_tensorflow.PyServer_New(self._server_def.SerializeToString(), status)
        if start:
            self.start()
Example #9
0
def TransformGraph(input_graph_def, inputs, outputs, transforms):
  """Python wrapper for the Graph Transform Tool.

  Gives access to all graph transforms available through the command line tool.
  See documentation at https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/graph_transforms/README.md
  for full details of the options available.

  Args:
    input_graph_def: GraphDef object containing a model to be transformed.
    inputs: List of node names for the model inputs.
    outputs: List of node names for the model outputs.
    transforms: List of strings containing transform names and parameters.

  Returns:
    New GraphDef with transforms applied.
  """

  input_graph_def_string = input_graph_def.SerializeToString()
  inputs_string = compat.as_bytes(",".join(inputs))
  outputs_string = compat.as_bytes(",".join(outputs))
  transforms_string = compat.as_bytes(" ".join(transforms))
  with errors.raise_exception_on_not_ok_status() as status:
    output_graph_def_string = TransformGraphWithStringInputs(
        input_graph_def_string, inputs_string, outputs_string,
        transforms_string, status)
  output_graph_def = graph_pb2.GraphDef()
  output_graph_def.ParseFromString(output_graph_def_string)
  return output_graph_def
Example #10
0
def recursive_create_dir(dirname):
  with errors.raise_exception_on_not_ok_status() as status:
    dirs = dirname.split('/')
    for i in range(len(dirs)):
      partial_dir = '/'.join(dirs[0:i+1])
      if partial_dir and not file_exists(partial_dir):
        pywrap_tensorflow.CreateDir(compat.as_bytes(partial_dir), status)
Example #11
0
def get_matching_files_v2(pattern):
  """Returns a list of files that match the given pattern(s).

  Args:
    pattern: string or iterable of strings. The glob pattern(s).

  Returns:
    A list of strings containing filenames that match the given pattern(s).

  Raises:
    errors.OpError: If there are filesystem / directory listing errors.
  """
  with errors.raise_exception_on_not_ok_status() as status:
    if isinstance(pattern, six.string_types):
      return [
          # Convert the filenames to string from bytes.
          compat.as_str_any(matching_filename)
          for matching_filename in pywrap_tensorflow.GetMatchingFiles(
              compat.as_bytes(pattern), status)
      ]
    else:
      return [
          # Convert the filenames to string from bytes.
          compat.as_str_any(matching_filename)
          for single_filename in pattern
          for matching_filename in pywrap_tensorflow.GetMatchingFiles(
              compat.as_bytes(single_filename), status)
      ]
Example #12
0
def GenerateCostReport(metagraph,
                       per_node_report=False,
                       verbose=False,
                       cluster=None):
  """Analyze the cost of each TensorFlow op and node in the provided metagraph.

  Args:
    metagraph: A TensorFlow MetaGraphDef.
    per_node_report: by default the report contains stats aggregated on a per op
      type basis, setting per_node_report to True adds results for each
      individual node to the report.
    verbose: Prints out the entire operation proto instead of a summary table.
    cluster: Analyze the costs using the specified cluster, or the local machine
      if no cluster was specified.

  Returns:
    A string of cost report.
  """
  if cluster is None:
    cluster = gcluster.Cluster(disable_detailed_stats=False)

  with errors.raise_exception_on_not_ok_status():
    ret_from_swig = tf_wrap.GenerateCostReport(metagraph.SerializeToString(),
                                               per_node_report, verbose,
                                               cluster.tf_cluster)
  return ret_from_swig
Example #13
0
def register_function_def(fdef):
  fdef_string = fdef.SerializeToString()
  with errors.raise_exception_on_not_ok_status() as status:
    pywrap_tensorflow.TFE_ContextAddFunctionDef(
        context.get_default_context()._handle,  # pylint: disable=protected-access
        fdef_string,
        len(fdef_string),
        status)
Example #14
0
 def _prereadline_check(self):
   if not self._read_buf:
     if not self._read_check_passed:
       raise errors.PermissionDeniedError(None, None,
                                          "File isn't open for reading")
     with errors.raise_exception_on_not_ok_status() as status:
       self._read_buf = pywrap_tensorflow.CreateBufferedInputStream(
           compat.as_bytes(self.__name), 1024 * 512, status)
 def testInvalidDeviceNumber(self):
   opts = tf_session.TF_NewSessionOptions()
   with errors.raise_exception_on_not_ok_status() as status:
     c_session = tf_session.TF_NewSession(
         ops.get_default_graph()._c_graph, opts, status)
     raw_device_list = tf_session.TF_SessionListDevices(
         c_session, status)
   size = tf_session.TF_DeviceListCount(raw_device_list)
   # Test that invalid device numbers return -1 rather than a Swig-wrapped
   # pointer.
   status_no_exception = c_api_util.ScopedTFStatus()
   memory = tf_session.TF_DeviceListMemoryBytes(
       raw_device_list, size, status_no_exception)
   self.assertEqual(memory, -1)
   tf_session.TF_DeleteDeviceList(raw_device_list)
   with errors.raise_exception_on_not_ok_status() as status:
     tf_session.TF_CloseSession(c_session, status)
Example #16
0
 def read(self):
   """Returns the contents of a file as a string."""
   if not self._read_check_passed:
     raise errors.PermissionDeniedError(None, None,
                                        "File isn't open for reading")
   with errors.raise_exception_on_not_ok_status() as status:
     return pywrap_tensorflow.ReadFileToString(
         compat.as_bytes(self.__name), status)
Example #17
0
 def _run_fn(session, feed_dict, fetch_list, target_list, options,
             run_metadata):
   # Ensure any changes to the graph are reflected in the runtime.
   self._extend_graph()
   with errors.raise_exception_on_not_ok_status() as status:
     return tf_session.TF_Run(session, options,
                              feed_dict, fetch_list, target_list,
                              status, run_metadata)
def TF_Reset(target, containers=None, config=None):
  from tensorflow.python.framework import errors
  opts = TF_NewSessionOptions(target=target, config=config)
  try:
    with errors.raise_exception_on_not_ok_status() as status:
      TF_Reset_wrapper(opts, containers, status)
  finally:
    TF_DeleteSessionOptions(opts)
Example #19
0
  def write(self, record):
    """Write a string record to the file.

    Args:
      record: str
    """
    with errors.raise_exception_on_not_ok_status() as status:
      self._writer.WriteRecord(record, status)
Example #20
0
 def _prewrite_check(self):
   if not self._writable_file:
     if not self._write_check_passed:
       raise errors.PermissionDeniedError(None, None,
                                          "File isn't open for writing")
     with errors.raise_exception_on_not_ok_status() as status:
       self._writable_file = pywrap_tensorflow.CreateWritableFile(
           compat.as_bytes(self.__name), status)
Example #21
0
 def close(self):
   """Closes FileIO. Should be called for the WritableFile to be flushed."""
   self._read_buf = None
   if self._writable_file:
     with errors.raise_exception_on_not_ok_status() as status:
       ret_status = self._writable_file.Close()
       pywrap_tensorflow.Set_TF_Status_from_Status(status, ret_status)
   self._writable_file = None
Example #22
0
 def write(self, file_content):
   """Writes file_content to the file."""
   if not self._write_check_passed:
     raise errors.PermissionDeniedError(None, None,
                                        "File isn't open for writing")
   with errors.raise_exception_on_not_ok_status() as status:
     pywrap_tensorflow.WriteStringToFile(
         compat.as_bytes(self.__name), compat.as_bytes(file_content), status)
def recursive_create_dir(dirname):
  from tensorflow.python.framework import errors
  with errors.raise_exception_on_not_ok_status() as status:
    from tensorflow.python.util import compat
    dirs = dirname.split('/')
    for i in range(len(dirs)):
      partial_dir = '/'.join(dirs[0:i+1])
      if partial_dir and not file_exists(partial_dir):
        CreateDir(compat.as_bytes(partial_dir), status)
Example #24
0
 def __del__(self):
   try:
     if self._context_handle is not None:
       with errors.raise_exception_on_not_ok_status() as status:
         pywrap_tensorflow.TFE_DeleteContext(self._context_handle, status)
   except (AttributeError, TypeError):
     # Sometimes deletion during program shutdown throws exception as other
     # modules are no longer available.
     pass
Example #25
0
 def testStatusDoesNotLeak(self):
   try:
     with errors.raise_exception_on_not_ok_status() as status:
       pywrap_tensorflow.DeleteFile(
           compat.as_bytes("/DOES_NOT_EXIST/"), status)
   except:
     pass
   gc.collect()
   self.assertEqual(0, self._CountReferences(c_api_util.ScopedTFStatus))
Example #26
0
  def start(self):
    """Starts this server.

    Raises:
      tf.errors.OpError: Or one of its subclasses if an error occurs while
        starting the TensorFlow server.
    """
    with errors.raise_exception_on_not_ok_status() as status:
      pywrap_tensorflow.PyServer_Start(self._server, status)
Example #27
0
  def read(self):
    """Returns the contents of a file as a string.

    Starts reading from current position in file.
    """
    self._preread_check()
    with errors.raise_exception_on_not_ok_status() as status:
      length = self.size() - self.tell()
      return pywrap_tensorflow.ReadFromStream(self._read_buf, length, status)
def do_quantize_training_on_graphdef(input_graph, num_bits):
  from tensorflow.core.framework.graph_pb2 import GraphDef
  from tensorflow.python.framework import errors
  with errors.raise_exception_on_not_ok_status() as status:
    graph = GraphDef()
    result_graph_string = DoQuantizeTrainingOnGraphDefHelper(
        input_graph.SerializeToString(), num_bits, status)

  graph.ParseFromString(result_graph_string)
  return graph
Example #29
0
  def flush(self):
    """Flushes the Writable file.

    This only ensures that the data has made its way out of the process without
    any guarantees on whether it's written to disk. This means that the
    data would survive an application crash but not necessarily an OS crash.
    """
    if self._writable_file:
      with errors.raise_exception_on_not_ok_status() as status:
        pywrap_tensorflow.FlushWritableFile(self._writable_file, status)
Example #30
0
def op_attr_type(op_type, attr_name):
  try:
    return _op_attr_type_cache[(op_type, attr_name)]
  except KeyError:
    with errors.raise_exception_on_not_ok_status() as status:
      h = context.context()._handle  # pylint: disable=protected-access
      attr_type = pywrap_tensorflow.TFE_OpNameGetAttrType(
          h, op_type, attr_name, status)
    _op_attr_type_cache[(op_type, attr_name)] = attr_type
    return attr_type
Example #31
0
  def __init__(self, target='', graph=None, config=None):
    """Constructs a new TensorFlow session.

    Args:
      target: (Optional) The TensorFlow execution engine to connect to.
      graph: (Optional) The graph to be used. If this argument is None,
        the default graph will be used.
      config: (Optional) ConfigProto proto used to configure the session.

    Raises:
      tf.errors.OpError: Or one of its subclasses if an error occurs while
        creating the TensorFlow session.
    """
    if graph is None:
      self._graph = ops.get_default_graph()
    else:
      self._graph = graph

    self._opened = False
    self._closed = False

    self._current_version = 0
    self._extend_lock = threading.Lock()
    self._target = target

    self._delete_lock = threading.Lock()
    self._dead_handles = []

    self._session = None
    self._config = config
    self._add_shapes = config.graph_options.infer_shapes if (
        config and config.graph_options) else False

    try:
      opts = tf_session.TF_NewSessionOptions(target=target, config=config)
      with errors.raise_exception_on_not_ok_status() as status:
        self._session = tf_session.TF_NewSession(opts, status)
    finally:
      tf_session.TF_DeleteSessionOptions(opts)
Example #32
0
    def __init__(self,
                 server_or_cluster_def,
                 job_name=None,
                 task_index=None,
                 protocol=None,
                 start=True):
        """Creates a new server with the given definition.

    The `job_name`, `task_index`, and `protocol` arguments are optional, and
    override any information provided in `server_or_cluster_def`.

    Args:
      server_or_cluster_def: A `tf.train.ServerDef` or
        `tf.train.ClusterDef` protocol buffer, or a
        `tf.train.ClusterSpec` object, describing the server to be
        created and/or the cluster of which it is a member.
      job_name: (Optional.) Specifies the name of the job of which the server
        is a member. Defaults to the value in `server_or_cluster_def`, if
        specified.
      task_index: (Optional.) Specifies the task index of the server in its
        job. Defaults to the value in `server_or_cluster_def`, if specified.
        Otherwise defaults to 0 if the server's job has only one task.
      protocol: (Optional.) Specifies the protocol to be used by the server.
        Acceptable values include `"grpc"`. Defaults to the value in
        `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`.
      start: (Optional.) Boolean, indicating whether to start the server
        after creating it. Defaults to `True`.

    Raises:
      tf.errors.OpError: Or one of its subclasses if an error occurs while
        creating the TensorFlow server.
    """
        server_def = _make_server_def(server_or_cluster_def, job_name,
                                      task_index, protocol)
        with errors.raise_exception_on_not_ok_status() as status:
            self._server = pywrap_tensorflow.PyServer_New(
                server_def.SerializeToString(), status)
        if start:
            self.start()
Example #33
0
  def Load(self):
    """Loads all new values from disk.

    Calling Load multiple times in a row will not 'drop' events as long as the
    return value is not iterated over.

    Yields:
      All values that were written to disk that have not been yielded yet.
    """
    while True:
      try:
        with errors.raise_exception_on_not_ok_status() as status:
          self._reader.GetNext(status)
      except (errors.DataLossError, errors.OutOfRangeError):
        # We ignore partial read exceptions, because a record may be truncated.
        # PyRecordReader holds the offset prior to the failed read, so retrying
        # will succeed.
        break
      event = event_pb2.Event()
      event.ParseFromString(self._reader.record())
      yield event
    logging.debug('No more events in %s', self._file_path)
Example #34
0
    def _swig_call(self, method, request, response):
        """Calls method, serializing and deserializing inputs and outputs.

    Note that this does not check the types of request and response.

    This can throw a variety of Python errors, based upon the underlying
    tensorflow error returned in MetadataStore.
    See _CODE_TO_EXCEPTION_CLASS in tensorflow/python/framework/errors_impl.py
    for the mapping.

    Args:
      method: the method to call in SWIG.
      request: a protobuf message, serialized and sent to the method.
      response: a protobuf message, filled from the return value of the method.

    Raises:
      Error: whatever tensorflow error is returned by the method.
    """
        with errors.raise_exception_on_not_ok_status() as status:
            response_str = method(self._metadata_store,
                                  request.SerializeToString(), status)
        response.ParseFromString(response_str)
Example #35
0
    def _generic_iterator(self, file_path):
        """A helper method that makes an iterator given a debug-events file path.

    Repeated calls to this method create iterators that remember the last
    successful reading position (offset) for each given `file_path`. So the
    iterators are meant for incremental reading of the file.

    Args:
      file_path: Path to the file to create the iterator for.

    Yields:
      A tuple of (offset, debug_event_proto) on each `next()` call.
    """
        # The following code uses the double-checked locking pattern to optimize
        # the common case (where the reader is already initialized).
        if file_path not in self._readers:  # 1st check, without lock.
            with self._readers_lock:
                if file_path not in self._readers:  # 2nd check, with lock.
                    with errors.raise_exception_on_not_ok_status() as status:
                        # TODO(b/136474806): Use tf_record.tf_record_iterator() once it
                        # supports offset.
                        self._readers[
                            file_path] = pywrap_tensorflow.PyRecordReader_New(
                                compat.as_bytes(file_path), 0, b"", status)
        reader = self._readers[file_path]
        while True:
            offset = reader.offset()
            try:
                reader.GetNext()
            except (errors.DataLossError, errors.OutOfRangeError):
                # We ignore partial read exceptions, because a record may be truncated.
                # PyRecordReader holds the offset prior to the failed read, so retrying
                # will succeed.
                break
            yield DebugEventWithOffset(
                debug_event=debug_event_pb2.DebugEvent.FromString(
                    reader.record()),
                offset=offset)
Example #36
0
    def MeasureCosts(self, item):
        """Returns the cost of running the specified item.

    Args:
      item: The item for which to measure the costs.
    Returns: The triplet op_perfs, runtime, step_stats.
    """
        with errors.raise_exception_on_not_ok_status() as status:
            ret_from_swig = tf_cluster.TF_MeasureCosts(item.tf_item,
                                                       self._tf_cluster,
                                                       self._generate_timeline,
                                                       status)

        if ret_from_swig is None:
            return None

        op_perf_bytes_list, run_time, step_stats_bytes = ret_from_swig
        op_perfs = [
            op_performance_data_pb2.OpPerformance.FromString(op_perf_bytes)
            for op_perf_bytes in op_perf_bytes_list
        ]
        return (op_perfs, run_time,
                step_stats_pb2.StepStats.FromString(step_stats_bytes))
Example #37
0
def tf_record_iterator(path, options=None):
    """An iterator that read the records from a TFRecords file.

  Args:
    path: The path to the TFRecords file.
    options: (optional) A TFRecordOptions object.

  Yields:
    Strings.

  Raises:
    IOError: If `path` cannot be opened for reading.
  """
    compression_type = TFRecordOptions.get_compression_type_string(options)
    with errors.raise_exception_on_not_ok_status() as status:
        reader = pywrap_tensorflow.PyRecordReader_New(
            compat.as_bytes(path), 0, compat.as_bytes(compression_type),
            status)

    if reader is None:
        raise IOError("Could not open %s." % path)
    while reader.GetNext():
        yield reader.record()
    reader.Close()
Example #38
0
def _call_cpp_shape_fn_impl(op, input_tensors_needed,
                            input_tensors_as_shapes_needed, require_shape_fn):
    """Core implementaton of call_cpp_shape_fn."""
    graph_def_version = op.graph.graph_def_versions.producer
    node_def_str = op.node_def.SerializeToString()

    def tensor_to_inference_result(t):
        r = cpp_shape_inference_pb2.CppShapeInferenceResult()
        r.shape.CopyFrom(t.get_shape().as_proto())
        # pylint: disable=protected-access
        if t._handle_data is not None:
            r.handle_data.CopyFrom(t._handle_data)
        # pylint: enable=protected-access
        return r.SerializeToString()

    input_shapes = [tensor_to_inference_result(i) for i in op.inputs]

    input_tensors = [None for i in input_shapes]
    for idx in input_tensors_needed:
        v = tensor_util.constant_value(op.inputs[idx])
        if v is not None:
            input_tensors[idx] = np.asarray(v)

    serialized_unknown_shape = (
        tensor_shape.TensorShape(None).as_proto().SerializeToString())
    arr = [serialized_unknown_shape for i in input_shapes]
    for idx in input_tensors_as_shapes_needed:
        s = tensor_util.constant_value_as_shape(op.inputs[idx])
        if s is not None:
            arr[idx] = s.as_proto().SerializeToString()
    input_tensors_as_shapes = arr

    missing_shape_fn = False
    try:
        with errors.raise_exception_on_not_ok_status() as status:
            output = pywrap_tensorflow.RunCppShapeInference(
                graph_def_version, node_def_str, input_shapes, input_tensors,
                input_tensors_as_shapes, status)
    except errors.InvalidArgumentError as err:
        if err.message.startswith("No shape inference function exists for op"):
            missing_shape_fn = True
        else:
            raise ValueError(err.message)

    if missing_shape_fn:
        if require_shape_fn:
            raise RuntimeError(
                "No C++ shape function registered for standard op: %s" %
                op.type)
        return unknown_shape(op)

    output_shapes = output[:-1]

    # Convert TensorShapeProto values in output_shapes.
    result_protos = [
        cpp_shape_inference_pb2.CppShapeInferenceResult().FromString(s)
        for s in output_shapes
    ]
    result = [r.shape for r in result_protos]
    result_handle_data = [
        r.handle_data if r.handle_data.is_set else None for r in result_protos
    ]

    return {
        "shapes": result,
        "handle_data": result_handle_data,
        "inputs_needed": output[-1]
    }
Example #39
0
def create_dir(dirname):
  with errors.raise_exception_on_not_ok_status() as status:
    pywrap_tensorflow.CreateDir(compat.as_bytes(dirname), status)
Example #40
0
def rename(oldname, newname, overwrite=False):
  with errors.raise_exception_on_not_ok_status() as status:
    return pywrap_tensorflow.RenameFile(compat.as_bytes(oldname),
                                        compat.as_bytes(newname), overwrite,
                                        status)
Example #41
0
def copy(oldpath, newpath, overwrite=False):
  with errors.raise_exception_on_not_ok_status() as status:
    pywrap_tensorflow.CopyFile(compat.as_bytes(oldpath),
                               compat.as_bytes(newpath), overwrite, status)
Example #42
0
  def _create_definition_if_needed_impl(self):
    """This is not what you want, see _create_definition_if_needed."""
    if self._definition is not None or self._c_func is not None:
      return

    # Create the func_def object.
    temp_graph = _FuncGraph(capture_by_value=self._capture_by_value)
    with temp_graph.as_default():
      # List of placeholders for the function_def.
      inputs = []
      for (argname, argtype) in self._args:
        argholder = array_ops.placeholder(argtype, name=argname)
        inputs.append(argholder)
      # Call func and gather the output tensors.
      with vs.variable_scope("", custom_getter=temp_graph.getvar):
        outputs = self._func(*inputs)

      # There is no way of distinguishing between a function not returning
      # anything and a function returning None in Python.
      # We need to allow the former and ideally want to forbid the latter as
      # it is most likely user error.
      # TODO(iga): Consider adding a @NoOutput decorator on top of @Defun to
      # allow users to explicitly mark the function as not returning anything.
      # For now, we allow a single None return and interpret it as a function
      # with no output.
      if outputs is None:
        outputs = []
      else:
        # If func only returned one value, make it a tuple.
        if not isinstance(outputs, (list, tuple)):
          outputs = (outputs,)
        if any([_ is None for _ in outputs]):
          raise ValueError("Function can not return None.")
      # Ensures each output is a Tensor.
      outputs = [ops.convert_to_tensor(_) for _ in outputs]
    self._extra_inputs = temp_graph.extra_inputs
    inputs.extend(temp_graph.extra_args)
    # pylint: disable=protected-access
    self._sub_functions = temp_graph._functions
    # pylint: enable=protected-access

    # Extra kwargs are treated as attrs on the function def.
    base_func_name = self._func_name or _get_func_name(self._func)
    kwargs_attr = _parse_kwargs_as_attrs(base_func_name,
                                         **self._extra_kwargs)

    if not temp_graph._c_graph:  # pylint: disable=protected-access
      # Build the FunctionDef
      self._definition = graph_to_function_def.graph_to_function_def(
          temp_graph,
          temp_graph.get_operations(),
          inputs,
          outputs,
          out_names=self._out_names)

      for k in kwargs_attr:
        self._definition.attr[k].CopyFrom(kwargs_attr[k])

      # Hash the definition and its dependencies.
      self._hash_str = self._create_hash_str(
          self._definition.signature.input_arg,
          self._definition.signature.output_arg, self._definition.node_def)

      # Finally, we decide the function name to use.  If not specified,
      # make up something which is almost certainly unique (but deterministic).
      if not self._func_name:
        self._func_name = "_".join([base_func_name, self._hash_str])
      self._definition.signature.name = self._func_name
      if self._func.__doc__:
        self._definition.signature.description = self._func.__doc__

      self._op_def = self._definition.signature
    else:  # C API is enabled
      output_names = ([compat.as_bytes(x) for x in self._out_names]
                      if self._out_names else [])
      description = self._func.__doc__ or None
      # pylint: disable=protected-access
      with errors.raise_exception_on_not_ok_status() as status:
        self._c_func = c_api.TF_GraphToFunction_wrapper(
            temp_graph._c_graph,
            base_func_name,
            self._func_name is None,  # append_hash_to_fn_name
            None,  # opers
            [t._as_tf_output() for t in inputs],
            [t._as_tf_output() for t in outputs],
            output_names,
            None,  # opts
            description,
            status)
      # pylint: enable=protected-access
      self._set_c_attrs(kwargs_attr)

      # Set cached fields: _op_def and _func_name (if not already set)
      self._op_def = self.definition.signature
      if self._func_name:
        assert self._func_name == self._op_def.name
      else:
        self._func_name = compat.as_str(self._op_def.name)
Example #43
0
 def __del__(self):
     self.close()
     if self._session is not None:
         with errors.raise_exception_on_not_ok_status() as status:
             tf_session.TF_DeleteSession(self._session, status)
         self._session = None
Example #44
0
  def _create_definition_if_needed_impl(self):
    """This is not what you want, see _create_definition_if_needed."""
    if self._definition is not None:
      return

    # Create the func_def object.
    temp_graph = _FuncGraph(capture_by_value=self._capture_by_value)
    with temp_graph.as_default():
      # List of placeholders for the function_def.
      inputs = []
      for (argname, argtype) in self._args:
        argholder = array_ops.placeholder(argtype, name=argname)
        inputs.append(argholder)
      # Call func and gather the output tensors.
      with vs.variable_scope("", custom_getter=temp_graph.getvar):
        outputs = self._func(*inputs)
      # If func only returned one value, make it a tuple.
      if not isinstance(outputs, (list, tuple)):
        outputs = (outputs,)
      if any([_ is None for _ in outputs]):
        raise ValueError("Function can not return None.")
      # Ensures each output is a Tensor.
      outputs = [ops.convert_to_tensor(_) for _ in outputs]
    self._extra_inputs = temp_graph.extra_inputs
    inputs.extend(temp_graph.extra_args)
    # pylint: disable=protected-access
    self._sub_functions = temp_graph._functions
    # pylint: enable=protected-access

    # Build the FunctionDef
    self._definition = graph_to_function_def.graph_to_function_def(
        temp_graph,
        temp_graph.get_operations(),
        inputs,
        outputs,
        out_names=self._out_names)

    # Extra kwargs are treated as attrs on the function def.
    sig_pre_func_name = self._func_name or _get_func_name(self._func)
    kwargs_attr = _parse_kwargs_as_attrs(sig_pre_func_name,
                                         **self._extra_kwargs)
    for k in kwargs_attr:
      self._definition.attr[k].CopyFrom(kwargs_attr[k])

    # Hash the definition and its dependencies.
    self._hash_str = self._create_hash_str(
        self._definition.signature.input_arg,
        self._definition.signature.output_arg, self._definition.node_def)

    # Finally, we decide the function name to use.  If not specified,
    # make up something which is almost certainly unique (but deterministic).
    if not self._func_name:
      self._func_name = "_".join([_get_func_name(self._func), self._hash_str])
    self._definition.signature.name = self._func_name
    if self._func.__doc__:
      self._definition.signature.description = self._func.__doc__

    # pylint: disable=protected-access
    if temp_graph._c_graph:
      output_names = ([compat.as_bytes(x) for x in self._out_names]
                      if self._out_names else [])
      description = self._func.__doc__ or None
      with errors.raise_exception_on_not_ok_status() as status:
        self._c_func = c_api.TF_GraphToFunction_wrapper(
            temp_graph._c_graph,
            self._func_name,
            False,  # append_hash_to_fn_name
            None,  # opers
            [t._as_tf_output() for t in inputs],
            [t._as_tf_output() for t in outputs],
            output_names,
            None,  # opts
            description,
            status)
      self._set_c_attrs(kwargs_attr)
Example #45
0
def read_file_to_string(filename):
  with errors.raise_exception_on_not_ok_status() as status:
    return pywrap_tensorflow.ReadFileToString(compat.as_bytes(filename), status)
Example #46
0
 def __del__(self):
     if self._handle is not None:
         with errors.raise_exception_on_not_ok_status() as status:
             pywrap_tensorflow.TFE_DeleteContext(self._handle, status)
Example #47
0
 def _BuildTFItem(self):
     with errors.raise_exception_on_not_ok_status() as status:
         self._tf_item = tf_item.TF_NewItem(
             self._metagraph.SerializeToString(), self._ignore_colocation,
             self._ignore_user_placement, status)
Example #48
0
 def write(self, file_content):
     """Writes file_content to the file. Appends to the end of the file."""
     self._prewrite_check()
     with errors.raise_exception_on_not_ok_status() as status:
         pywrap_tensorflow.AppendToFile(compat.as_bytes(file_content),
                                        self._writable_file, status)
Example #49
0
 def _create_offset_reader(self, file_path, offset):
     with errors.raise_exception_on_not_ok_status() as status:
         # TODO(b/136474806): Use tf_record.tf_record_iterator() once it
         # supports ofset.
         return pywrap_tensorflow.PyRecordReader_New(
             file_path, offset, b"", status)
Example #50
0
 def seek(self, position):
     """Seeks to the position in the file."""
     self._preread_check()
     with errors.raise_exception_on_not_ok_status() as status:
         ret_status = self._read_buf.Seek(position)
         pywrap_tensorflow.Set_TF_Status_from_Status(status, ret_status)
Example #51
0
def get_matching_files(filename):
  with errors.raise_exception_on_not_ok_status() as status:
    return pywrap_tensorflow.GetMatchingFiles(compat.as_bytes(filename), status)
Example #52
0
# setup low level args for TF_Run call
session = sess._session
options=None
feed_dict = {}

# uncomment lines below if you want to fetch things
fetch_list = [b'MatMul_2:0']
target_list = []

if len(sys.argv)>1 and 'nofetch' in sys.argv[1]:
    fetch_list=[]
    target_list=[b'MatMul_2']
    
run_metadata = None
status_orig = errors.raise_exception_on_not_ok_status()
status = pywrap_tensorflow.TF_NewStatus()

def fast_tf():
    return tf_session.TF_Run(session, options,
                             feed_dict, fetch_list, target_list,
                             status, run_metadata)

num_iters = 5000
warmup_iters = 2
iter_times = np.zeros((num_iters+warmup_iters,))
y = create_graph()
for i in range(num_iters+warmup_iters):
    iter_start = time.time()
    if i == warmup_iters:
        start_time = time.time()
Example #53
0
def write_string_to_file(filename, file_content):
  with errors.raise_exception_on_not_ok_status() as status:
    pywrap_tensorflow.WriteStringToFile(compat.as_bytes(filename),
                                        compat.as_bytes(file_content), status)
Example #54
0
 def flush(self):
     """Flush the file."""
     with errors.raise_exception_on_not_ok_status() as status:
         self._writer.Flush(status)
Example #55
0
def op_attr_type(op_type, attr_name):
    with errors.raise_exception_on_not_ok_status() as status:
        h = context.context()._handle  # pylint: disable=protected-access
        op = pywrap_tensorflow.TFE_NewOp(h, op_type, status)
        attr_type = pywrap_tensorflow.TFE_OpGetAttrType(op, attr_name, status)
    return attr_type
Example #56
0
def import_graph_def(graph_def,
                     input_map=None,
                     return_elements=None,
                     name=None,
                     op_dict=None,
                     producer_op_list=None):
    """Imports the graph from `graph_def` into the current default `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  @{tf.Tensor} and @{tf.Operation} objects. Once extracted,
  these objects are placed into the current default `Graph`. See
  @{tf.Graph.as_graph_def} for a way to create a `GraphDef`
  proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Note that this does not apply to imported function names.
      Defaults to `"import"`.
    op_dict: (Optional.) Deprecated, do not use.
    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
      list of `OpDef`s used by the producer of the graph. If provided,
      unrecognized attrs for ops in `graph_def` that have their default value
      according to `producer_op_list` will be removed. This will allow some more
      `GraphDef`s produced by later binaries to be accepted by earlier binaries.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
    op_dict = op_def_registry.get_registered_ops()

    graph_def = _ProcessGraphDefParam(graph_def, op_dict)
    input_map = _ProcessInputMapParam(input_map)
    return_elements = _ProcessReturnElementsParam(return_elements)

    if producer_op_list is not None:
        # TODO(skyewm): make a copy of graph_def so we're not mutating the argument?
        _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)

    graph = ops.get_default_graph()

    if graph._c_graph:  # pylint: disable=protected-access
        with ops.name_scope(name, 'import', input_map.values()) as scope:
            # Save unique prefix generated by name_scope
            if scope:
                assert scope.endswith('/')
                prefix = scope[:-1]
            else:
                prefix = ''

            # Generate any input map tensors inside name scope
            input_map = _ConvertInputMapValues(name, input_map)

        scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
        options = scoped_options.options
        _PopulateTFImportGraphDefOptions(options, prefix, input_map,
                                         return_elements)

        # _ProcessNewOps mutates the new operations. _lock ensures a Session.run
        # call cannot occur between creating the TF_Operations in the
        # TF_GraphImportGraphDefWithResults call and mutating the them in
        # _ProcessNewOps.
        with graph._lock:  # pylint: disable=protected-access
            with c_api_util.tf_buffer(
                    graph_def.SerializeToString()) as serialized:
                try:
                    with errors.raise_exception_on_not_ok_status() as status:
                        results = c_api.TF_GraphImportGraphDefWithResults(
                            graph._c_graph, serialized, options, status)  # pylint: disable=protected-access
                except errors.InvalidArgumentError as e:
                    # Convert to ValueError for backwards compatibility.
                    raise ValueError(str(e))

            _ProcessNewOps(graph)

        # Create _DefinedFunctions for any imported functions.
        #
        # We do this by creating _DefinedFunctions directly from `graph_def`, and
        # adding them to `graph`. Adding an existing function to a TF_Graph is a
        # no-op, so this only has the effect of updating the Python state (usually
        # _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
        #
        # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
        # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
        if graph_def.library and graph_def.library.function:
            # pylint: disable=protected-access
            functions = function._from_library(graph_def.library)
            for f in functions:
                f.add_to_graph(graph)
            # pylint: enable=protected-access

        # Treat input mappings that don't appear in the graph as an error, because
        # they are likely to be due to a typo.
        missing_unused_input_keys = (
            c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
                results))
        if missing_unused_input_keys:
            missing_unused_input_keys = [
                compat.as_str(s) for s in missing_unused_input_keys
            ]
            raise ValueError(
                'Attempted to map inputs that were not found in graph_def: [%s]'
                % ', '.join(missing_unused_input_keys))

        if return_elements is None:
            return None
        else:
            return _GatherReturnElements(return_elements, graph, results)

    else:
        g = graph

        # Use a canonical representation for all tensor names.
        input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
        used_input_keys = set()
        name_to_op = {}

        # Add any functions defined in `graph_def` to `g`
        if graph_def.library and graph_def.library.function:
            # Copy op_dict so we don't clobber the original
            op_dict = copy.copy(op_dict)
            # pylint: disable=protected-access
            # Note that we do not prepend `name` to the function name. The reasoning
            # is that function names are similar to op definition names, which
            # currently do not have a scoped name or namespace scheme.
            functions = function._from_library(graph_def.library)
            for f in functions:
                f.add_to_graph(g)
                op_dict[f.name] = f.definition.signature
            # pylint: enable=protected-access

        # LINT.IfChange
        with ops.name_scope(name, 'import', input_map.values()) as scope:
            # TODO(ashankar): Should this just copy over or should it do some
            # more nuanced merging? For example, the graph may already have some
            # marked "bad versions" and we don't want to lose those because of
            # what's in graph_def.versions? The C++ ImporGraphDef does something
            # more nuanced.
            g.graph_def_versions.CopyFrom(graph_def.versions)

            input_map = _ConvertInputMapValues(name, input_map)

            # NOTE(mrry): We do this in two passes, because there may be a cycle in
            # `graph_def`.

            # 1. Add operations without their inputs.
            for node in graph_def.node:
                # Check to see if this op's name matches a previously seen op
                if node.name in name_to_op:
                    raise ValueError('Duplicate name \'%s\' in GraphDef.' %
                                     node.name)
                if node.op not in op_dict:
                    raise ValueError('No op named %s in defined operations.' %
                                     node.op)
                op_def = op_dict[node.op]

                output_types = _OutputTypes(node, op_dict)
                name_to_op[node.name] = g.create_op(node.op, [],
                                                    output_types,
                                                    name=node.name,
                                                    attrs=node.attr,
                                                    compute_shapes=False,
                                                    compute_device=False,
                                                    op_def=op_def)

            # Maps from a node to the ops it is colocated with, if colocation
            # is specified in the attributes.
            colocation_pairs = collections.defaultdict(list)

            # 2. Add inputs to the operations.
            for node in graph_def.node:
                op = name_to_op[node.name]
                input_types = _InputTypes(node, op_dict)
                apply_device_function = True

                # Rewrite the colocation attributes in the graph, since the
                # names of new ops may have changed.
                for key, value in op.node_def.attr.items():
                    if key == '_class':
                        class_values = value.list
                        new_class_values = []
                        for class_value in class_values.s:
                            if class_value.startswith(b'loc:@'):
                                op_to_bind_to = class_value[5:].decode()
                                # Find the op by its original name.
                                if op_to_bind_to not in name_to_op:
                                    raise ValueError(
                                        'Specified colocation to an op that '
                                        'does not exist during import: %s in %s'
                                        % (op_to_bind_to, node.name))
                                original_op = name_to_op[op_to_bind_to]
                                new_class_values.append(
                                    compat.as_bytes('loc:@' +
                                                    original_op.name))
                                if op_to_bind_to != node.name:
                                    # Keep track of this mapping for a later phase.
                                    colocation_pairs[op].append(original_op)
                                    # Don't apply this op's device function,
                                    # the colocation constraint will ensure
                                    # the proper device gets assigned at runtime.
                                    apply_device_function = False

                            else:
                                new_class_values.append(class_value)
                        value.list.CopyFrom(
                            attr_value_pb2.AttrValue.ListValue(
                                s=new_class_values))

                # NOTE(mrry): We cannot use zip here because control inputs do not
                # appear in the list of input_types.
                for i, input_name in enumerate(
                    [_CanonicalInputName(x) for x in node.input]):

                    if _IsControlInput(input_name):
                        # (a) Input is a control input that should be taken from an op
                        #     in "graph_def".
                        try:
                            source_op = name_to_op[input_name[1:]]
                        except KeyError:
                            raise ValueError(
                                _InvalidNodeMessage(
                                    node,
                                    'Control input %r not found in graph_def.'
                                    % (input_name, )))
                        # pylint: disable=protected-access
                        op._add_control_input(source_op)
                        # pylint: enable=protected-access

                    else:
                        try:
                            input_type = input_types[i]
                        except IndexError:
                            raise ValueError(
                                _InvalidNodeMessage(
                                    node,
                                    'More inputs specified (%r) than the op expects.'
                                    % (input_name, )))

                        if input_name in input_map:
                            # (b) Input should be replaced by a tensor from the caller.
                            source_tensor = input_map[input_name]
                            used_input_keys.add(input_name)

                        else:
                            # (c) Input should be taken from an op in `graph_def`.
                            operation_name, output_index = _ParseTensorName(
                                input_name)
                            try:
                                source_op = name_to_op[operation_name]
                                source_tensor = list(
                                    source_op.values())[output_index]
                            except (KeyError, IndexError):
                                raise ValueError(
                                    _InvalidNodeMessage(
                                        node,
                                        'Input tensor %r not found in graph_def.'
                                        % (input_name, )))

                        try:
                            # pylint: disable=protected-access
                            op._add_input(source_tensor, dtype=input_type)
                            # pylint: enable=protected-access
                        except TypeError as te:
                            raise ValueError(
                                _InvalidNodeMessage(
                                    node,
                                    'Input tensor %r %s' % (input_name, te)))

                # pylint: disable=protected-access
                if op._input_types != input_types:
                    raise ValueError(
                        _InvalidNodeMessage(
                            node,
                            'Input types mismatch (expected %r but got %r)' %
                            (', '.join(
                                dtypes.as_dtype(x).name
                                for x in input_types), ', '.join(
                                    x.name for x in op._input_types))))
                # pylint: enable=protected-access

                if not g._is_function(op.type):  # pylint: disable=protected-access
                    # Execute shape inference for this op.
                    # NOTE(mrry): If the graph contains a cycle, the full shape
                    # information may not be available for this op's inputs.
                    ops.set_shapes_for_outputs(op)
                # For nodes with _output_shapes set, set the output shapes.
                if '_output_shapes' in op.node_def.attr:
                    for i, output in enumerate(op.outputs):
                        dims = op.node_def.attr['_output_shapes'].list.shape[i]
                        output_shape = tensor_shape.TensorShape(
                            None if dims.unknown_rank else [
                                dim.size if dim.size >= 0 else None
                                for dim in dims.dim
                            ])

                        try:
                            output.set_shape(output_shape)
                        except ValueError as e:
                            # If the output shape is incompatible with what is inferred
                            # by the graph for a very specific whitelist of ops, then we
                            # ignore this output shape.  This can happen if there is a
                            # bug in the shape function for some operation, and the
                            # serialized graph def has the incorrect shape set when
                            # running on a newer binary with the fixed shape function.
                            # This is an escape hatch that allows us to correct shape
                            # functions that are not critical to correct execution but
                            # would cause graphs to fail if imported after correcting.
                            #
                            # This can be removed after 2017/03/08.
                            if op.type in [
                                    'RandomShuffleQueue', 'PaddingFIFOQueue',
                                    'FIFOQueue', 'PriorityQueue', 'QueueSize',
                                    'Stack', 'Barrier', 'BarrierReadySize',
                                    'BarrierIncompleteSize', 'HashTable',
                                    'MutableHashTable',
                                    'MutableHashTableOfTensors', 'Mutex',
                                    'CuckooTable', 'IndexTable',
                                    'WholeFileReader', 'TextLineReader',
                                    'FixedLengthRecordReader',
                                    'TFRecordReader', 'IdentityReader',
                                    'LMDBReader', 'RefSwitch', 'RefEnter',
                                    'RefNextIteration', 'RefMerge',
                                    'RefIdentity'
                            ]:
                                pass
                            elif op.type in [
                                    'ConditionalAccumulator',
                                    'SparseConditionalAccumulator', 'Table'
                            ]:
                                # This can be removed after 2017/04/24.
                                pass
                            else:
                                raise e

                    del op.node_def.attr['_output_shapes']

                # NOTE(mrry): We do this after configuring the inputs, because
                # the result of the device functions may depend on the inputs.
                if apply_device_function:
                    with _MaybeDevice(node.device):
                        g._apply_device_functions(op)  # pylint: disable=protected-access

            # The following loop populates the device field of ops that are
            # colocated with another op.  This is implied by the colocation
            # attribute, but we propagate the device field for completeness.
            for op, coloc_op_list in colocation_pairs.items():
                coloc_device = None
                # Find any device in the list of colocated ops that have a
                # device, if it exists.  We assume that if multiple ops
                # have devices, they refer to the same device.  Otherwise, a
                # runtime error will occur since the colocation property
                # cannot be guaranteed.
                #
                # One possible improvement is to try to check for compatibility
                # of all devices in this list at import time here, which would
                # require implementing a compatibility function for device specs
                # in python.
                for coloc_op in coloc_op_list:
                    if coloc_op.device:
                        coloc_device = pydev.DeviceSpec.from_string(
                            coloc_op.device)
                        break
                if coloc_device:
                    op._set_device(coloc_device)  # pylint: disable=protected-access

            # Treat input mappings that don't appear in the graph as an error,
            # because they are likely to be due to a typo.
            def _IsImportedNodeOutput(tensor_name):
                operation_name, output_index = _ParseTensorName(tensor_name)
                try:
                    return output_index < len(
                        name_to_op[operation_name].outputs)
                except KeyError:
                    return False

            absent_input_keys = [
                k for k in frozenset(input_map.keys()).difference(
                    used_input_keys) if not _IsImportedNodeOutput(k)
            ]
            if absent_input_keys:
                raise ValueError(
                    'Attempted to map inputs that were not found in graph_def: [%s]'
                    % ', '.join(absent_input_keys))

            if return_elements is None:
                return None
            else:
                ret = []
                for name in return_elements:
                    name = compat.as_str(name)
                    if ':' in name:
                        try:
                            operation_name, output_index = _ParseTensorName(
                                name)
                            ret.append(name_to_op[operation_name].
                                       outputs[output_index])
                        except (ValueError, KeyError, IndexError):
                            raise ValueError(
                                'Requested return_element %r not found in graph_def.'
                                % name)
                    else:
                        try:
                            ret.append(name_to_op[name])
                        except KeyError:
                            raise ValueError(
                                'Requested return_element %r not found in graph_def.'
                                % name)
                return ret
Example #57
0
 def _prun_fn(session, handle, feed_dict, fetch_list):
   if target_list:
     raise RuntimeError('partial_run() requires empty target_list.')
   with errors.raise_exception_on_not_ok_status() as status:
     return tf_session.TF_PRun(session, handle, feed_dict, fetch_list,
                               status)
Example #58
0
 def _setup_fn(session, feed_list, fetch_list, target_list):
   self._extend_graph()
   with errors.raise_exception_on_not_ok_status() as status:
     return tf_session.TF_PRunSetup(session, feed_list, fetch_list,
                                    target_list, status)
Example #59
0
 def close(self):
     """Close the file."""
     with errors.raise_exception_on_not_ok_status() as status:
         self._writer.Close(status)
Example #60
0
def call_cpp_shape_fn(op,
                      input_tensors_needed=None,
                      input_tensors_as_shapes_needed=None,
                      debug_python_shape_fn=None,
                      require_shape_fn=True):
    """A shape function that delegates to the registered C++ shape function.

  Args:
    op: the node in the graph for which to compute output shapes.
    input_tensors_needed: a list of input tensor indices for which to compute
      the input tensor's value and pass to the C++ shape function.
    input_tensors_as_shapes_needed: a list of input tensor indices for which to
      compute the constant_value_as_shape and pass to the C++ shape function.
    debug_python_shape_fn: For testing only during migration to using
      call_cpp_shape_fn. Do not submit calls that set this,
      as the comparison is slow. If non-None, the python shape function;
      this function will be called and its output compared to that of
      the C++ shape function.
    require_shape_fn: If true, and the C++ shape function is not registered
      in the current binary then an exception is raised; otherwise, if the
      C++ shape function is not registered then unknown_shape is used.

  Returns:
    A dictionary with the following keys:
      shapes: A TensorShape list of the output shapes of the op, as computed
        using the C++ shape inference function registered for the op.
      handle_shapes: A TensorShape list of the shapes for handle outputs, if
         any.
      handle_dtypes: A list of DataType enums for the handle outputs, if any.

  Raises:
    ValueError: If the C++ shape function returned an error (e.g. because the
      shapes of the inputs are of the wrong rank or otherwise incompatible
      according to the shape function).
    RuntimeError: If the C++ shape function is not registered and
      <require_shape_fn> is True.
  """
    if op.type == "Const":
        # To avoid serializing large constants, we special-case constant
        # here, even though it has a C++ shape function.  When Python
        # calls the C / C-API directly, we should be able to remove this.
        return {
            "shapes":
            [tensor_shape.TensorShape(op.get_attr("value").tensor_shape)],
            "handle_shapes": [tensor_shape.TensorShape(None).as_proto()],
            "handle_dtypes": [types_pb2.DT_INVALID]
        }

    node_def_str = op.node_def.SerializeToString()

    def tensor_to_inference_result(t):
        r = cpp_shape_inference_pb2.CppShapeInferenceResult()
        r.shape.CopyFrom(t.get_shape().as_proto())
        # pylint: disable=protected-access
        r.handle_shape.CopyFrom(t._handle_shape)
        r.handle_dtype = t._handle_dtype
        # pylint: enable=protected-access
        return r.SerializeToString()

    input_shapes = [tensor_to_inference_result(i) for i in op.inputs]

    input_tensors = [None for i in input_shapes]
    if input_tensors_needed:
        for idx in input_tensors_needed:
            v = tensor_util.constant_value(op.inputs[idx])
            if v is not None:
                input_tensors[idx] = np.asarray(v)

    serialized_unknown_shape = (
        tensor_shape.TensorShape(None).as_proto().SerializeToString())
    arr = [serialized_unknown_shape for i in input_shapes]
    if input_tensors_as_shapes_needed:
        for idx in input_tensors_as_shapes_needed:
            s = tensor_util.constant_value_as_shape(op.inputs[idx])
            if s is not None:
                arr[idx] = s.as_proto().SerializeToString()
    input_tensors_as_shapes = arr

    missing_shape_fn = False
    try:
        with errors.raise_exception_on_not_ok_status() as status:
            output_shapes = pywrap_tensorflow.RunCppShapeInference(
                node_def_str, input_shapes, input_tensors,
                input_tensors_as_shapes, status)
    except errors.InvalidArgumentError as err:
        if err.message.startswith("No shape inference function exists for op"):
            missing_shape_fn = True
        else:
            raise ValueError(err.message)

    if missing_shape_fn:
        if require_shape_fn:
            raise RuntimeError(
                "No C++ shape function registered for standard op: %s" %
                op.type)
        return unknown_shape(op)

    # Convert TensorShapeProto values in output_shapes.
    result_protos = [
        cpp_shape_inference_pb2.CppShapeInferenceResult().FromString(s)
        for s in output_shapes
    ]
    result = [r.shape for r in result_protos]
    result_handle_shapes = [r.handle_shape for r in result_protos]
    result_handle_dtypes = [r.handle_dtype for r in result_protos]

    if debug_python_shape_fn:
        try:
            python_result = [
                tensor_shape.as_shape(s) for s in debug_python_shape_fn(op)
            ]
        except Exception as err:
            raise AssertionError("Python shape function return error but "
                                 "C++ shape functon did not: %s" % str(err))
        result_as_shapes = [tensor_shape.as_shape(s) for s in result]
        if str(result_as_shapes) != str(python_result):
            raise ValueError(
                ("Python vs CPP shape mismatch.  "
                 "CPP: %s vs python: %s on node %s "
                 "with input shapes %s") %
                (str(result_as_shapes), str(python_result), str(op.node_def),
                 ",".join([str(i.get_shape()) for i in op.inputs])))

    return {
        "shapes": result,
        "handle_shapes": result_handle_shapes,
        "handle_dtypes": result_handle_dtypes
    }