def load_model(saved_model_path):
  """Load a keras.Model from SavedModel.

  load_model reinstantiates model state by:
  1) loading model topology from json (this will eventually come
     from metagraph).
  2) loading model weights from checkpoint.

  Args:
    saved_model_path: a string specifying the path to an existing SavedModel.

  Returns:
    a keras.Model instance.
  """
  # restore model topology from json string
  model_json_filepath = os.path.join(
      compat.as_bytes(saved_model_path),
      compat.as_bytes(constants.ASSETS_DIRECTORY),
      compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON))
  model_json = file_io.read_file_to_string(model_json_filepath)
  model = model_from_json(model_json)

  # restore model weights
  checkpoint_prefix = os.path.join(
      compat.as_text(saved_model_path),
      compat.as_text(constants.VARIABLES_DIRECTORY),
      compat.as_text(constants.VARIABLES_FILENAME))
  model.load_weights(checkpoint_prefix)
  return model
  def testFormatOneTensorOneDimVarySummarize(self):
    with self.test_session():
      tensor = math_ops.range(6)
      format_output = string_ops.string_format("{}", tensor, summarize=-1)
      out = self.evaluate(format_output)
      expected = "[0 1 2 3 4 5]"
      self.assertEqual(compat.as_text(out), expected)

    with self.test_session():
      tensor = math_ops.range(6)
      format_output = string_ops.string_format("{}", tensor, summarize=1)
      out = self.evaluate(format_output)
      expected = "[0 ... 5]"
      self.assertEqual(compat.as_text(out), expected)

    with self.test_session():
      tensor = math_ops.range(6)
      format_output = string_ops.string_format("{}", tensor, summarize=2)
      out = self.evaluate(format_output)
      expected = "[0 1 ... 4 5]"
      self.assertEqual(compat.as_text(out), expected)

    with self.test_session():
      tensor = math_ops.range(6)
      format_output = string_ops.string_format("{}", tensor, summarize=10)
      out = self.evaluate(format_output)
      expected = "[0 1 2 3 4 5]"
      self.assertEqual(compat.as_text(out), expected)
Example #3
0
  def add_meta_graph_and_variables(self,
                                   sess,
                                   tags,
                                   signature_def_map=None,
                                   assets_collection=None,
                                   legacy_init_op=None):
    """Adds the current meta graph to the SavedModel and saves variables.

    Creates a Saver to save the variables from the provided session. Exports the
    corresponding meta graph def. This function assumes that the variables to be
    saved have been initialized. For a given `SavedModelBuilder`, this API must
    be called exactly once and for the first meta graph to save. For subsequent
    meta graph defs to be added, the `add_meta_graph()` API must be used.

    Args:
      sess: The TensorFlow session from which to save the meta graph and
        variables.
      tags: The set of tags with which to save the meta graph.
      signature_def_map: The map of signature def map to add to the meta graph
        def.
      assets_collection: Assets collection to be saved with SavedModel.
      legacy_init_op: Op or group of ops to execute after the restore op upon a
        load.
    """
    if self._has_saved_variables:
      raise AssertionError("Variables and assets have already been saved. "
                           "Please invoke `add_meta_graph()` instead.")

    # Save asset files and write them to disk, if any.
    self._save_and_write_assets(assets_collection)

    # Create the variables sub-directory, if it does not exist.
    variables_dir = os.path.join(
        compat.as_text(self._export_dir),
        compat.as_text(constants.VARIABLES_DIRECTORY))
    if not file_io.file_exists(variables_dir):
      file_io.recursive_create_dir(variables_dir)

    variables_path = os.path.join(
        compat.as_text(variables_dir),
        compat.as_text(constants.VARIABLES_FILENAME))

    # Add legacy init op to the SavedModel.
    self._maybe_add_legacy_init_op(legacy_init_op)

    # Save the variables and export meta graph def.
    saver = tf_saver.Saver(
        variables.all_variables(),
        sharded=True,
        write_version=saver_pb2.SaverDef.V2)
    saver.save(sess, variables_path, write_meta_graph=False)
    meta_graph_def = saver.export_meta_graph()

    # Tag the meta graph def and add it to the SavedModel.
    self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)

    # Mark this instance of SavedModel as having saved variables, such that
    # subsequent attempts to save variables will fail.
    self._has_saved_variables = True
Example #4
0
  def _do_run(self, target_list, fetch_list, feed_dict):
    """Runs a step based on the given fetches and feeds.

    Args:
      target_list: A list of byte arrays corresponding to names of tensors
        or operations to be run to, but not fetched.
      fetch_list: A list of byte arrays corresponding to names of tensors to
        be fetched and operations to be run.
      feed_dict: A dictionary that maps tensor names (as byte arrays) to
        numpy ndarrays.

    Returns:
      A list of numpy ndarrays, corresponding to the elements of
      `fetch_list`.  If the ith element of `fetch_list` contains the
      name of an operation, the first Tensor output of that operation
      will be returned for that element.
    """
    try:
      # Ensure any changes to the graph are reflected in the runtime.
      with self._extend_lock:
        if self._graph.version > self._current_version:
          graph_def = self._graph.as_graph_def(
              from_version=self._current_version)

          try:
            status = tf_session.TF_NewStatus()
            tf_session.TF_ExtendGraph(
                self._session, graph_def.SerializeToString(), status)
            if tf_session.TF_GetCode(status) != 0:
              raise RuntimeError(compat.as_text(tf_session.TF_Message(status)))
            self._opened = True
          finally:
            tf_session.TF_DeleteStatus(status)

          self._current_version = self._graph.version

      return tf_session.TF_Run(self._session, feed_dict, fetch_list,
                               target_list)

    except tf_session.StatusNotOK as e:
      e_type, e_value, e_traceback = sys.exc_info()
      error_message = compat.as_text(e.error_message)
      m = BaseSession._NODEDEF_NAME_RE.search(error_message)
      if m is not None:
        node_name = m.group(1)
        node_def = None
        try:
          op = self._graph.get_operation_by_name(node_name)
          node_def = op.node_def
        except KeyError:
          op = None
        # pylint: disable=protected-access
        raise errors._make_specific_exception(node_def, op, error_message,
                                              e.code)
        # pylint: enable=protected-access
      six.reraise(e_type, e_value, e_traceback)
Example #5
0
def load_from_saved_model(saved_model_path, custom_objects=None):
  """Loads a keras Model from a SavedModel created by `export_saved_model()`.

  This function reinstantiates model state by:
  1) loading model topology from json (this will eventually come
     from metagraph).
  2) loading model weights from checkpoint.

  Example:

  ```python
  import tensorflow as tf

  # Create a tf.keras model.
  model = tf.keras.Sequential()
  model.add(tf.keras.layers.Dense(1, input_shape=[10]))
  model.summary()

  # Save the tf.keras model in the SavedModel format.
  path = '/tmp/simple_keras_model'
  tf.keras.experimental.export_saved_model(model, path)

  # Load the saved keras model back.
  new_model = tf.keras.experimental.load_from_saved_model(path)
  new_model.summary()
  ```

  Args:
    saved_model_path: a string specifying the path to an existing SavedModel.
    custom_objects: Optional dictionary mapping names
        (strings) to custom classes or functions to be
        considered during deserialization.

  Returns:
    a keras.Model instance.
  """
  # restore model topology from json string
  model_json_filepath = os.path.join(
      compat.as_bytes(saved_model_path),
      compat.as_bytes(constants.ASSETS_DIRECTORY),
      compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON))
  model_json = file_io.read_file_to_string(model_json_filepath)
  model = model_from_json(model_json, custom_objects=custom_objects)

  # restore model weights
  checkpoint_prefix = os.path.join(
      compat.as_text(saved_model_path),
      compat.as_text(constants.VARIABLES_DIRECTORY),
      compat.as_text(constants.VARIABLES_FILENAME))
  model.load_weights(checkpoint_prefix)
  return model
 def _start_local_server(self):
   address = self._requestComputeMetadata('instance/network-interfaces/0/ip')
   self._server = server_lib.Server(
       {
           'local': ['0.0.0.0:0']
       }, protocol='grpc', config=None, start=True)
   # self._server.target is of the form: grpc://ipaddress:port
   target = compat.as_bytes(self._server.target)
   splits = target.split(compat.as_bytes(':'))
   assert len(splits) == 3, self._server.target
   assert splits[0] == compat.as_bytes('grpc'), self._server.target
   self._coordinator_port = compat.as_text(splits[2])
   self._coordinator_address = '%s:%s' % (
       address, compat.as_text(self._coordinator_port))
  def testFormatOneTensorOneDim(self):
    with self.test_session():
      tensor = math_ops.range(10)
      format_output = string_ops.string_format("{}", tensor)
      out = self.evaluate(format_output)
      expected = "[0 1 2 ... 7 8 9]"
      self.assertEqual(compat.as_text(out), expected)

    with self.test_session():
      tensor = math_ops.range(10)
      format_output = string_ops.string_format("{}", [tensor])
      out = self.evaluate(format_output)
      expected = "[0 1 2 ... 7 8 9]"
      self.assertEqual(compat.as_text(out), expected)
def save_model(model, saved_model_path):
  """Save a `tf.keras.Model` into Tensorflow SavedModel format.

  `save_model` generates such files/folders under the `saved_model_path` folder:
  1) an asset folder containing the json string of the model's
  configuration(topology).
  2) a checkpoint containing the model weights.

  Note that subclassed models can not be saved via this function, unless you
  provide an implementation for get_config() and from_config().
  Also note that `tf.keras.optimizers.Optimizer` instances can not currently be
  saved to checkpoints. Use optimizers from `tf.train`.

  Args:
    model: A `tf.keras.Model` to be saved.
    saved_model_path: a string specifying the path to the SavedModel directory.

  Raises:
    NotImplementedError: If the passed in model is a subclassed model.
  """
  if not model._is_graph_network:
    raise NotImplementedError

  # save model configuration as a json string under assets folder.
  model_json = model.to_json()
  assets_destination_dir = os.path.join(
      compat.as_bytes(saved_model_path),
      compat.as_bytes(constants.ASSETS_DIRECTORY))

  if not file_io.file_exists(assets_destination_dir):
    file_io.recursive_create_dir(assets_destination_dir)

  model_json_filepath = os.path.join(
      compat.as_bytes(assets_destination_dir),
      compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON))
  file_io.write_string_to_file(model_json_filepath, model_json)

  # save model weights in checkpoint format.
  checkpoint_destination_dir = os.path.join(
      compat.as_bytes(saved_model_path),
      compat.as_bytes(constants.VARIABLES_DIRECTORY))

  if not file_io.file_exists(checkpoint_destination_dir):
    file_io.recursive_create_dir(checkpoint_destination_dir)

  checkpoint_prefix = os.path.join(
      compat.as_text(checkpoint_destination_dir),
      compat.as_text(constants.VARIABLES_FILENAME))
  model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True)
Example #9
0
def load_keras_model(saved_model_path):
  """Load a keras.Model from SavedModel.

  load_model reinstantiates model state by:
  1) loading model topology from json (this will eventually come
     from metagraph).
  2) loading model weights from checkpoint.

  Example:

  ```python
  import tensorflow as tf

  # Create a tf.keras model.
  model = tf.keras.Sequential()
  model.add(tf.keras.layers.Dense(1, input_shape=[10]))
  model.summary()

  # Save the tf.keras model in the SavedModel format.
  saved_to_path = tf.contrib.saved_model.save_keras_model(
        model, '/tmp/my_simple_tf_keras_saved_model')

  # Load the saved keras model back.
  model_prime = tf.contrib.saved_model.load_keras_model(saved_to_path)
  model_prime.summary()
  ```

  Args:
    saved_model_path: a string specifying the path to an existing SavedModel.

  Returns:
    a keras.Model instance.
  """
  # restore model topology from json string
  model_json_filepath = os.path.join(
      compat.as_bytes(saved_model_path),
      compat.as_bytes(constants.ASSETS_DIRECTORY),
      compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON))
  model_json = file_io.read_file_to_string(model_json_filepath)
  model = model_from_json(model_json)

  # restore model weights
  checkpoint_prefix = os.path.join(
      compat.as_text(saved_model_path),
      compat.as_text(constants.VARIABLES_DIRECTORY),
      compat.as_text(constants.VARIABLES_FILENAME))
  model.load_weights(checkpoint_prefix)
  return model
Example #10
0
  def testWriteEvents(self):
    file_prefix = os.path.join(self.get_temp_dir(), "events")
    writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(file_prefix))
    filename = compat.as_text(writer.FileName())
    event_written = event_pb2.Event(
        wall_time=123.45, step=67,
        summary=summary_pb2.Summary(
            value=[summary_pb2.Summary.Value(tag="foo", simple_value=89.0)]))
    writer.WriteEvent(event_written)
    writer.Flush()
    writer.Close()

    with self.assertRaises(IOError):
      for r in tf_record.tf_record_iterator(filename + "DOES_NOT_EXIST"):
        self.assertTrue(False)

    reader = tf_record.tf_record_iterator(filename)
    event_read = event_pb2.Event()

    event_read.ParseFromString(next(reader))
    self.assertTrue(event_read.HasField("file_version"))

    event_read.ParseFromString(next(reader))
    # Second event
    self.assertProtoEquals("""
    wall_time: 123.45 step: 67
    summary { value { tag: 'foo' simple_value: 89.0 } }
    """, event_read)

    with self.assertRaises(StopIteration):
      next(reader)
Example #11
0
def _export_model_json(model, saved_model_path):
  """Saves model configuration as a json string under assets folder."""
  model_json = model.to_json()
  model_json_filepath = os.path.join(
      saved_model_utils.get_or_create_assets_dir(saved_model_path),
      compat.as_text(constants.SAVED_MODEL_FILENAME_JSON))
  file_io.write_string_to_file(model_json_filepath, model_json)
Example #12
0
  def testZLibFlushRecord(self):
    fn = self._WriteRecordsToFile([b"small record"], "small_record")
    with open(fn, "rb") as h:
      buff = h.read()

    # creating more blocks and trailing blocks shouldn't break reads
    compressor = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS)

    output = b""
    for c in buff:
      if isinstance(c, int):
        c = six.int2byte(c)
      output += compressor.compress(c)
      output += compressor.flush(zlib.Z_FULL_FLUSH)

    output += compressor.flush(zlib.Z_FULL_FLUSH)
    output += compressor.flush(zlib.Z_FULL_FLUSH)
    output += compressor.flush(zlib.Z_FINISH)

    # overwrite the original file with the compressed data
    with open(fn, "wb") as h:
      h.write(output)

    with self.test_session() as sess:
      options = tf_record.TFRecordOptions(
          compression_type=TFRecordCompressionType.ZLIB)
      reader = io_ops.TFRecordReader(name="test_reader", options=options)
      queue = data_flow_ops.FIFOQueue(1, [dtypes.string], shapes=())
      key, value = reader.read(queue)
      queue.enqueue(fn).run()
      queue.close().run()
      k, v = sess.run([key, value])
      self.assertTrue(compat.as_text(k).startswith("%s:" % fn))
      self.assertAllEqual(b"small record", v)
Example #13
0
  def testReadGzipFiles(self):
    files = self._CreateFiles()
    gzip_files = []
    for i, fn in enumerate(files):
      with open(fn, "rb") as f:
        cdata = f.read()

        zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i)
        with gzip.GzipFile(zfn, "wb") as f:
          f.write(cdata)
        gzip_files.append(zfn)

    with self.test_session() as sess:
      options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP)
      reader = io_ops.TFRecordReader(name="test_reader", options=options)
      queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
      key, value = reader.read(queue)

      queue.enqueue_many([gzip_files]).run()
      queue.close().run()
      for i in range(self._num_files):
        for j in range(self._num_records):
          k, v = sess.run([key, value])
          self.assertTrue(compat.as_text(k).startswith("%s:" % gzip_files[i]))
          self.assertAllEqual(self._Record(i, j), v)
Example #14
0
  def _TestOneEpochWithHopBytes(self,
                                files,
                                num_overlapped_records,
                                encoding=None):
    with self.test_session() as sess:
      reader = io_ops.FixedLengthRecordReader(
          header_bytes=self._header_bytes,
          record_bytes=self._record_bytes,
          footer_bytes=self._footer_bytes,
          hop_bytes=self._hop_bytes,
          encoding=encoding,
          name="test_reader")
      queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
      key, value = reader.read(queue)

      queue.enqueue_many([files]).run()
      queue.close().run()
      for i in range(self._num_files):
        for j in range(num_overlapped_records):
          k, v = sess.run([key, value])
          self.assertAllEqual("%s:%d" % (files[i], j), compat.as_text(k))
          self.assertAllEqual(self._OverlappedRecord(i, j), v)

      with self.assertRaisesOpError("is closed and has insufficient elements "
                                    "\\(requested 1, current size 0\\)"):
        k, v = sess.run([key, value])
Example #15
0
 def testFormatOneTensorOneDimFloat(self):
   with self.test_session():
     tensor = constant_op.constant([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])
     format_output = string_ops.string_format("{}", tensor)
     out = self.evaluate(format_output)
     expected = "[0 0.1 0.2 ... 0.5 0.6 0.7]"
     self.assertEqual(compat.as_text(out), expected)
Example #16
0
 def testFormatOneTensorOneDimAlmostSummarize(self):
   with self.test_session():
     tensor = math_ops.range(5)
     format_output = string_ops.string_format("{}", tensor, summarize=3)
     out = self.evaluate(format_output)
     expected = "[0 1 2 3 4]"
     self.assertEqual(compat.as_text(out), expected)
  def testComplexCodeView(self):
    ops.reset_default_graph()
    outfile = os.path.join(test.get_temp_dir(), 'dump')
    opts = (builder(builder.trainable_variables_parameter())
            .with_file_output(outfile)
            .with_accounted_types(['.*'])
            .with_node_names(show_name_regexes=
                             ['.*model_analyzer_testlib.py.*'])
            .account_displayed_op_only(False)
            .select(['params', 'float_ops']).build())

    with profile_context.ProfileContext(test.get_temp_dir(),
                                        trace_steps=[],
                                        dump_steps=[]) as pctx:
      with session.Session() as sess:
        x = lib.BuildFullModel()

        sess.run(variables.global_variables_initializer())
        pctx.trace_next_step()
        _ = sess.run(x)
        tfprof_node = pctx.profiler.profile_python(options=opts)

        # pylint: disable=line-too-long
        with gfile.Open(outfile, 'r') as f:
          lines = f.read().split('\n')
          self.assertGreater(len(lines), 5)
          result = '\n'.join([l[:min(len(l), 80)] for l in lines])
          self.assertTrue(
              compat.as_text(lib.CheckAndRemoveDoc(result))
              .startswith('node name | # parameters | # float_ops'))

        self.assertLess(0, tfprof_node.total_exec_micros)
        self.assertEqual(2844, tfprof_node.total_parameters)
        self.assertLess(145660, tfprof_node.total_float_ops)
        self.assertEqual(8, len(tfprof_node.children))
        self.assertEqual('_TFProfRoot', tfprof_node.name)
        self.assertEqual(
            'model_analyzer_testlib.py:63:BuildFullModel',
            tfprof_node.children[0].name)
        self.assertEqual(
            'model_analyzer_testlib.py:63:BuildFullModel (gradient)',
            tfprof_node.children[1].name)
        self.assertEqual(
            'model_analyzer_testlib.py:67:BuildFullModel',
            tfprof_node.children[2].name)
        self.assertEqual(
            'model_analyzer_testlib.py:67:BuildFullModel (gradient)',
            tfprof_node.children[3].name)
        self.assertEqual(
            'model_analyzer_testlib.py:69:BuildFullModel',
            tfprof_node.children[4].name)
        self.assertEqual(
            'model_analyzer_testlib.py:70:BuildFullModel',
            tfprof_node.children[5].name)
        self.assertEqual(
            'model_analyzer_testlib.py:70:BuildFullModel (gradient)',
            tfprof_node.children[6].name)
        self.assertEqual(
            'model_analyzer_testlib.py:72:BuildFullModel',
            tfprof_node.children[7].name)
Example #18
0
def run_benchmark(sess, init_op, add_op):
  """Returns MB/s rate of addition."""


  logdir=FLAGS.logdir_prefix+'/'+FLAGS.name
  os.system('mkdir -p '+logdir)
  
  # TODO: make events follow same format as eager writer
  writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(logdir+'/events'))
  filename = compat.as_text(writer.FileName())
  training_util.get_or_create_global_step()

  sess.run(init_op)

  for step in range(FLAGS.iters):
    start_time = time.time()
    for i in range(FLAGS.iters_per_step):
      sess.run(add_op.op)

    elapsed_time = time.time() - start_time
    rate = float(FLAGS.iters)*FLAGS.data_mb/elapsed_time
    event = make_event('rate', rate, step)
    writer.WriteEvent(event)
    writer.Flush()
  writer.Close()
Example #19
0
  def save(self, as_text=False):
    """Writes a `SavedModel` protocol buffer to disk.

    The function writes the SavedModel protocol buffer to the export directory
    in serialized format.

    Args:
      as_text: Writes the SavedModel protocol buffer in text format to disk.

    Returns:
      The path to which the SavedModel protocol buffer was written.
    """
    if not file_io.file_exists(self._export_dir):
      file_io.recursive_create_dir(self._export_dir)

    if as_text:
      path = os.path.join(
          compat.as_bytes(self._export_dir),
          compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
      file_io.write_string_to_file(path, str(self._saved_model))
    else:
      path = os.path.join(
          compat.as_bytes(self._export_dir),
          compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
      file_io.write_string_to_file(path, self._saved_model.SerializeToString())
    tf_logging.info("SavedModel written to: %s", compat.as_text(path))

    return path
Example #20
0
  def _save_and_write_assets(self, assets_collection_to_add=None):
    """Saves asset to the meta graph and writes asset files to disk.

    Args:
      assets_collection_to_add: The collection where the asset paths are setup.
    """
    asset_filename_map = _maybe_save_assets(assets_collection_to_add)

    # Return if there are no assets to write.
    if not asset_filename_map:
      tf_logging.info("No assets to write.")
      return

    assets_destination_dir = saved_model_utils.get_or_create_assets_dir(
        self._export_dir)

    # Copy each asset from source path to destination path.
    for asset_basename, asset_source_filepath in asset_filename_map.items():
      asset_destination_filepath = os.path.join(
          compat.as_bytes(assets_destination_dir),
          compat.as_bytes(asset_basename))

      # Only copy the asset file to the destination if it does not already
      # exist. This is to ensure that an asset with the same name defined as
      # part of multiple graphs is only copied the first time.
      if not file_io.file_exists(asset_destination_filepath):
        file_io.copy(asset_source_filepath, asset_destination_filepath)

    tf_logging.info("Assets written to: %s",
                    compat.as_text(assets_destination_dir))
Example #21
0
def load_file_system_library(library_filename):
  """Loads a TensorFlow plugin, containing file system implementation.

  Pass `library_filename` to a platform-specific mechanism for dynamically
  loading a library. The rules for determining the exact location of the
  library are platform-specific and are not documented here.

  Args:
    library_filename: Path to the plugin.
      Relative or absolute filesystem path to a dynamic library file.

  Returns:
    None.

  Raises:
    RuntimeError: when unable to load the library.
  """
  status = py_tf.TF_NewStatus()
  lib_handle = py_tf.TF_LoadLibrary(library_filename, status)
  try:
    error_code = py_tf.TF_GetCode(status)
    if error_code != 0:
      error_msg = compat.as_text(py_tf.TF_Message(status))
      # pylint: disable=protected-access
      raise errors_impl._make_specific_exception(
          None, None, error_msg, error_code)
      # pylint: enable=protected-access
  finally:
    py_tf.TF_DeleteStatus(status)
Example #22
0
def load_op_library(library_filename):
  """Loads a TensorFlow plugin, containing custom ops and kernels.

  Pass "library_filename" to a platform-specific mechanism for dynamically
  loading a library. The rules for determining the exact location of the
  library are platform-specific and are not documented here. When the
  library is loaded, ops and kernels registered in the library via the
  REGISTER_* macros are made available in the TensorFlow process. Note
  that ops with the same name as an existing op are rejected and not
  registered with the process.

  Args:
    library_filename: Path to the plugin.
      Relative or absolute filesystem path to a dynamic library file.

  Returns:
    A python module containing the Python wrappers for Ops defined in
    the plugin.

  Raises:
    RuntimeError: when unable to load the library or get the python wrappers.
  """
  status = py_tf.TF_NewStatus()

  lib_handle = py_tf.TF_LoadLibrary(library_filename, status)
  try:
    error_code = py_tf.TF_GetCode(status)
    if error_code != 0:
      error_msg = compat.as_text(py_tf.TF_Message(status))
      with _OP_LIBRARY_MAP_LOCK:
        if (error_code == error_codes_pb2.ALREADY_EXISTS and
            'has already been loaded' in error_msg and
            library_filename in _OP_LIBRARY_MAP):
          return _OP_LIBRARY_MAP[library_filename]
      # pylint: disable=protected-access
      raise errors._make_specific_exception(None, None, error_msg, error_code)
      # pylint: enable=protected-access
  finally:
    py_tf.TF_DeleteStatus(status)

  op_list_str = py_tf.TF_GetOpList(lib_handle)
  op_list = op_def_pb2.OpList()
  op_list.ParseFromString(compat.as_bytes(op_list_str))
  wrappers = py_tf.GetPythonWrappers(op_list_str)

  # Get a unique name for the module.
  module_name = hashlib.md5(wrappers).hexdigest()
  module = imp.new_module(module_name)
  # pylint: disable=exec-used
  exec(wrappers, module.__dict__)
  # Stash away the library handle for making calls into the dynamic library.
  module.LIB_HANDLE = lib_handle
  # OpDefs of the list of ops defined in the library.
  module.OP_LIST = op_list
  sys.modules[module_name] = module
  # Memoize the filename to module mapping.
  with _OP_LIBRARY_MAP_LOCK:
    _OP_LIBRARY_MAP[library_filename] = module
  return module
  def cluster_spec(self):
    """Returns a ClusterSpec object based on the latest TPU information.

    We retrieve the information from the GCE APIs every time this method is
    called.

    Returns:
      A ClusterSpec containing host information returned from Cloud TPUs.

    Raises:
      RuntimeError: If the provided TPU is not healthy.
    """
    ############################################################################
    # There are 5 potential cases this code must handle:
    #  1. [Normal case.] We should resolve the TPU name to a set of tasks, and
    #      a. Create a ClusterSpec that includes the coordinator job
    #      b. Create a ClusterSpec without the coordinator job.
    #  2. [GKE / No API Access.] We should not resolve the TPU name to a set of
    #     tasks and
    #      a. Create a ClusterSpec with the coordinator
    #      b. Create a ClusterSpec without the coordinator
    #  3. [Other (legacy non-gRPC).] We should return an empty ClusterSpec.
    ############################################################################

    if self._shouldResolve():
      # Case 1.
      full_name = 'projects/%s/locations/%s/nodes/%s' % (
          self._project, self._zone, compat.as_text(self._tpu))
      request = self._service.projects().locations().nodes().get(name=full_name)
      response = request.execute()

      if 'health' in response and response['health'] != 'HEALTHY':
        raise RuntimeError('TPU "%s" is unhealthy: "%s"' % (self._tpu,
                                                            response['health']))

      if 'networkEndpoints' in response:
        worker_list = [
            '%s:%s' % (endpoint['ipAddress'], endpoint['port'])
            for endpoint in response['networkEndpoints']
        ]
      else:
        # Fall back to the deprecated response format
        instance_url = '%s:%s' % (response['ipAddress'], response['port'])
        worker_list = [instance_url]

      cluster_spec = {self._job_name: worker_list}
    else:
      if not self._tpu.startswith(compat.as_bytes('grpc://')):
        # Case 3.
        return None
      # Case 2.
      cluster_spec = {self._job_name: [self._tpu[len(
          compat.as_bytes('grpc://')):]]}

    if self._coordinator_address:
      # {1, 2}.a
      cluster_spec[self._coordinator_name] = [self._coordinator_address]

    return server_lib.ClusterSpec(cluster_spec)
Example #24
0
def raise_exception_on_not_ok_status():
  status = c_api_util.ScopedTFStatus()
  yield status.status
  if c_api.TF_GetCode(status) != 0:
    raise _make_specific_exception(
        None, None,
        compat.as_text(c_api.TF_Message(status)),
        c_api.TF_GetCode(status))
Example #25
0
 def testFormatOneTensorTwoDimLessThanSummarize(self):
   with self.test_session():
     tensor = array_ops.reshape(math_ops.range(4), [2, 2])
     format_output = string_ops.string_format("{}", tensor, summarize=3)
     out = self.evaluate(format_output)
     expected = ("[[0 1]\n"
                 " [2 3]]")
     self.assertEqual(compat.as_text(out), expected)
Example #26
0
 def testFormatOneVariableScalar(self):
   with self.test_session():
     var = variables.Variable(3.34)
     format_output = string_ops.string_format("{}", [var])
     if not context.executing_eagerly():
       variables.global_variables_initializer().run()
     out = self.evaluate(format_output)
     expected = "3.34"
     self.assertEqual(compat.as_text(out), expected)
Example #27
0
 def testFormatOneVariableOneDim(self):
   with self.test_session():
     var = variables.Variable(math_ops.range(10))
     format_output = string_ops.string_format("{}", [var])
     if not context.executing_eagerly():
       variables.global_variables_initializer().run()
     out = self.evaluate(format_output)
     expected = "[0 1 2 ... 7 8 9]"
     self.assertEqual(compat.as_text(out), expected)
Example #28
0
 def testFormatSummarizeOne(self):
   with self.test_session():
     tensor = array_ops.reshape(math_ops.range(100), [10, 10])
     format_output = string_ops.string_format("tensor summary: {}", tensor,
                                              summarize=1)
     out = self.evaluate(format_output)
     expected = ("tensor summary: [[0 ... 9]\n"
                 " ...\n"
                 " [90 ... 99]]")
     self.assertEqual(compat.as_text(out), expected)
Example #29
0
  def add_meta_graph_and_variables(self,
                                   sess,
                                   tags,
                                   signature_def_map=None,
                                   assets_collection=None):
    """Adds the current meta graph to the SavedModel and saves variables.

    Creates a Saver to save the variables from the provided session. Exports the
    corresponding meta graph def. This function assumes that the variables to be
    saved have been initialized. For a given `SavedModelBuilder`, this API must
    be called exactly once and for the first meta graph to save. For subsequent
    meta graph defs to be added, the `add_meta_graph()` API must be used.

    Args:
      sess: The TensorFlow session from which to save the meta graph and
        variables.
      tags: The set of tags with which to save the meta graph.
      signature_def_map: The map of signature def map to add to the meta graph
        def.
      assets_collection: Assets collection to be saved with SavedModel.
    """
    if self._has_saved_variables:
      raise AssertionError("Variables and assets have already been saved. "
                           "Please invoke `add_meta_graph()` instead.")

    # Save asset files and write them to disk, if any.
    self._save_and_write_assets(assets_collection)

    export_path = os.path.join(
        compat.as_text(self._export_dir),
        compat.as_text(constants.VARIABLES_FILENAME))

    # Save the variables and export meta graph def.
    saver = tf_saver.Saver(variables.all_variables())
    saver.save(sess, export_path, write_meta_graph=False)
    meta_graph_def = saver.export_meta_graph()

    # Tag the meta graph def and add it to the SavedModel.
    self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)

    # Mark this instance of SavedModel as having saved variables, such that
    # subsequent attempts to save variables will fail.
    self._has_saved_variables = True
Example #30
0
def _find_all_hints_in_graph_def(graphdef):
  """Look at the current default graph and return a list of LiteFuncCall objs.

  Args:
    graphdef: A TensorFlow graph_def to look for LiteFuncCalls.
  Returns:
    a list of `LifeFuncCall` objects in the form

  """
  func_calls = _collections.defaultdict(_LiteFuncCall)

  for node in graphdef.node:
    attr = node.attr
    # This is an op hint if it has a FUNCTION_UUID_ATTR, otherwise skip
    uuid = attr[OpHint.FUNCTION_UUID_ATTR].s
    if (OpHint.FUNCTION_UUID_ATTR not in attr
        or not attr[OpHint.FUNCTION_UUID_ATTR].s):
      continue

    # Start building function
    call_def = func_calls[uuid]
    call_def.uuid = uuid
    call_def.function_name = attr[OpHint.FUNCTION_NAME_ATTR].s
    # Get sorting and aggregation information

    sort = (attr[OpHint.FUNCTION_SORT_INDEX_ATTR].i
            if OpHint.FUNCTION_SORT_INDEX_ATTR in attr else None)
    if sort == -1: sort = None
    aggregation = None
    if OpHint.FUNCTION_AGGREGATE_ATTR in attr:
      aggregation = _compat.as_text(attr[OpHint.FUNCTION_AGGREGATE_ATTR].s)

    # Add the input or output
    def put_operand(stuff, index, sort, operand, aggregation):
      """Add a given index into the function structure."""
      if sort is None:
        stuff[index] = _LiteSingleOperand(operand)
      else:
        if index not in stuff:
          stuff[index] = _LiteAggregateOperand(aggregation)
        stuff[index].add(sort, operand)

    if OpHint.FUNCTION_INPUT_INDEX_ATTR in attr:
      put_operand(call_def.inputs, attr[OpHint.FUNCTION_INPUT_INDEX_ATTR].i,
                  sort, node, aggregation)
    if OpHint.FUNCTION_OUTPUT_INDEX_ATTR in attr:
      put_operand(call_def.outputs, attr[OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i,
                  sort, node, aggregation)

    # Remember attributes
    for a in attr:
      if a.startswith("_tflite_attr_"):
        call_def.params[a.replace("_tflite_attr_,", "")] = attr[a].tensor

  return func_calls
Example #31
0
def copy_assets_to_destination_dir(asset_filename_map, destination_dir):
    """Copy all assets from source path to destination path."""
    assets_destination_dir = saved_model_utils.get_or_create_assets_dir(
        destination_dir)

    # Copy each asset from source path to destination path.
    for asset_basename, asset_source_filepath in asset_filename_map.items():
        asset_destination_filepath = os.path.join(
            compat.as_bytes(assets_destination_dir),
            compat.as_bytes(asset_basename))

        # Only copy the asset file to the destination if it does not already
        # exist. This is to ensure that an asset with the same name defined as
        # part of multiple graphs is only copied the first time.
        if not file_io.file_exists(asset_destination_filepath):
            file_io.copy(asset_source_filepath, asset_destination_filepath)

    tf_logging.info("Assets written to: %s",
                    compat.as_text(assets_destination_dir))
Example #32
0
    def deserialize(self, encoded_accumulator):
        """Deserialize an accumulator received from 'serialize()'."""
        accumulator_dict = json.loads(compat.as_text(encoded_accumulator))

        accumulator = self._create_accumulator()
        count_dict = dict(
            zip(accumulator_dict["vocab"], accumulator_dict["vocab_counts"]))
        accumulator.count_dict.update(count_dict)

        if self._compute_idf:
            accumulator.data = accumulator_dict["data"]
            create_dict = lambda x: {"count": x, "last_doc_id": -1}
            idf_count_dicts = [
                create_dict(count) for count in accumulator_dict["idf_counts"]
            ]
            idf_dict = dict(zip(accumulator_dict["idf_vocab"],
                                idf_count_dicts))
            accumulator.per_doc_count_dict.update(idf_dict)
        return accumulator
    def _testOneEpoch(self, files):
        with self.test_session() as sess:
            reader = io_ops.TextLineReader(name="test_reader")
            queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
            key, value = reader.read(queue)

            queue.enqueue_many([files]).run()
            queue.close().run()
            for i in range(self._num_files):
                for j in range(self._num_lines):
                    k, v = sess.run([key, value])
                    self.assertAllEqual("%s:%d" % (files[i], j + 1),
                                        compat.as_text(k))
                    self.assertAllEqual(self._LineText(i, j), v)

            with self.assertRaisesOpError(
                    "is closed and has insufficient elements "
                    "\\(requested 1, current size 0\\)"):
                k, v = sess.run([key, value])
Example #34
0
    def close(self):
        """Closes this session.

    Calling this method frees all resources associated with the session.

    Raises:
      RuntimeError: If an error occurs while closing the session.
    """
        with self._extend_lock:
            if self._opened and not self._closed:
                self._closed = True
                try:
                    status = tf_session.TF_NewStatus()
                    tf_session.TF_CloseSession(self._session, status)
                    if tf_session.TF_GetCode(status) != 0:
                        raise RuntimeError(
                            compat.as_text(tf_session.TF_Message(status)))
                finally:
                    tf_session.TF_DeleteStatus(status)
Example #35
0
    def __init__(self, target='', graph=None, config=None):
        """Constructs a new TensorFlow session.

    Args:
      target: (Optional) The TensorFlow execution engine to connect to.
      graph: (Optional) The graph to be used. If this argument is None,
        the default graph will be used.
      config: (Optional) ConfigProto proto used to configure the session.

    Raises:
      RuntimeError: If an error occurs while creating the TensorFlow
        session.
    """
        if graph is None:
            self._graph = ops.get_default_graph()
        else:
            self._graph = graph

        self._opened = False
        self._closed = False

        self._current_version = 0
        self._extend_lock = threading.Lock()
        self._target = target

        self._delete_lock = threading.Lock()
        self._dead_handles = []

        self._session = None

        opts = tf_session.TF_NewSessionOptions(target=target, config=config)
        try:
            status = tf_session.TF_NewStatus()
            try:
                self._session = tf_session.TF_NewSession(opts, status)
                if tf_session.TF_GetCode(status) != 0:
                    raise RuntimeError(
                        compat.as_text(tf_session.TF_Message(status)))
            finally:
                tf_session.TF_DeleteStatus(status)
        finally:
            tf_session.TF_DeleteSessionOptions(opts)
Example #36
0
    def testOneEpoch(self):
        files = self._CreateFiles()
        with self.cached_session() as sess:
            reader = io_ops.TFRecordReader(name="test_reader")
            queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
            key, value = reader.read(queue)

            queue.enqueue_many([files]).run()
            queue.close().run()
            for i in range(self._num_files):
                for j in range(self._num_records):
                    k, v = sess.run([key, value])
                    self.assertTrue(
                        compat.as_text(k).startswith("%s:" % files[i]))
                    self.assertAllEqual(self._Record(i, j), v)

            with self.assertRaisesOpError(
                    "is closed and has insufficient elements "
                    "\\(requested 1, current size 0\\)"):
                k, v = sess.run([key, value])
Example #37
0
    def _extend_graph(self):
        # Ensure any changes to the graph are reflected in the runtime.
        with self._extend_lock:
            if self._graph.version > self._current_version:
                graph_def = self._graph.as_graph_def(
                    from_version=self._current_version)

                try:
                    status = tf_session.TF_NewStatus()
                    tf_session.TF_ExtendGraph(self._session,
                                              graph_def.SerializeToString(),
                                              status)
                    if tf_session.TF_GetCode(status) != 0:
                        raise RuntimeError(
                            compat.as_text(tf_session.TF_Message(status)))
                    self._opened = True
                finally:
                    tf_session.TF_DeleteStatus(status)

                self._current_version = self._graph.version
Example #38
0
    def _write_value_event(self, event):
        value = event.summary.value[0]

        # Obtain the device name from the metadata.
        summary_metadata = event.summary.value[0].metadata
        if not summary_metadata.plugin_data:
            raise ValueError("The value lacks plugin data.")
        try:
            content = json.loads(
                compat.as_text(summary_metadata.plugin_data.content))
        except ValueError as err:
            raise ValueError("Could not parse content into JSON: %r, %r" %
                             (content, err))
        device_name = content["device"]

        dump_full_path = _get_dump_file_path(self._dump_dir, device_name,
                                             value.node_name)
        self._try_makedirs(os.path.dirname(dump_full_path))
        with open(dump_full_path, "wb") as f:
            f.write(event.SerializeToString())
Example #39
0
def _serialize_object_graph(saveable_view, asset_file_def_index):
    """Save a SavedObjectGraph proto for `root`."""
    # SavedObjectGraph is similar to the TrackableObjectGraph proto in the
    # checkpoint. It will eventually go into the SavedModel.
    proto = saved_object_graph_pb2.SavedObjectGraph()
    saveable_view.fill_object_graph_proto(proto)

    coder = nested_structure_coder.StructureCoder()
    for concrete_function in saveable_view.concrete_functions:
        name = compat.as_text(concrete_function.name)
        name = saveable_view.function_name_map.get(name, name)
        serialized = function_serialization.serialize_concrete_function(
            concrete_function, saveable_view.captured_tensor_node_ids, coder)
        if serialized is not None:
            proto.concrete_functions[name].CopyFrom(serialized)

    for obj, obj_proto in zip(saveable_view.nodes, proto.nodes):
        _write_object_proto(obj, obj_proto, asset_file_def_index,
                            saveable_view.function_name_map)
    return proto
Example #40
0
 def _do_call(self, fn, *args):
     try:
         return fn(*args)
     except tf_session.StatusNotOK as e:
         e_type, e_value, e_traceback = sys.exc_info()
         error_message = compat.as_text(e.error_message)
         m = BaseSession._NODEDEF_NAME_RE.search(error_message)
         if m is not None:
             node_name = m.group(1)
             node_def = None
             try:
                 op = self._graph.get_operation_by_name(node_name)
                 node_def = op.node_def
             except KeyError:
                 op = None
             # pylint: disable=protected-access
             raise errors._make_specific_exception(node_def, op,
                                                   error_message, e.code)
             # pylint: enable=protected-access
         six.reraise(e_type, e_value, e_traceback)
    def cluster_spec(self):
        """Returns a ClusterSpec object based on the latest TPU information.

    We retrieve the information from the GCE APIs every time this method is
    called.

    Returns:
      A ClusterSpec containing host information returned from Cloud TPUs.

    Raises:
      RuntimeError: If the provided TPU is not healthy.
    """
        if not self._shouldResolve():
            return server_lib.ClusterSpec({})

        full_name = 'projects/%s/locations/%s/nodes/%s' % (
            self._project, self._zone, compat.as_text(self._tpu))
        request = self._service.projects().locations().nodes().get(
            name=full_name)
        response = request.execute()

        if 'health' in response and response['health'] != 'HEALTHY':
            raise RuntimeError('TPU "%s" is unhealthy: "%s"' %
                               (self._tpu, response['health']))

        if 'networkEndpoints' in response:
            worker_list = [
                '%s:%s' % (endpoint['ipAddress'], endpoint['port'])
                for endpoint in response['networkEndpoints']
            ]
        else:
            # Fall back to the deprecated response format
            instance_url = '%s:%s' % (response['ipAddress'], response['port'])
            worker_list = [instance_url]

        cluster_spec = {self._job_name: worker_list}

        if self._coordinator_address:
            cluster_spec[self._coordinator_name] = [self._coordinator_address]

        return server_lib.ClusterSpec(cluster_spec)
Example #42
0
def Cleanse(obj, encoding='utf-8'):
    """Makes Python object appropriate for JSON serialization.

  - Replaces instances of Infinity/-Infinity/NaN with strings.
  - Turns byte strings into unicode strings.
  - Turns sets into sorted lists.
  - Turns tuples into lists.

  Args:
    obj: Python data structure.
    encoding: Charset used to decode byte strings.

  Returns:
    Unicode JSON data structure.
  """
    if isinstance(obj, int):
        return obj
    elif isinstance(obj, float):
        if obj == _INFINITY:
            return 'Infinity'
        elif obj == _NEGATIVE_INFINITY:
            return '-Infinity'
        elif math.isnan(obj):
            return 'NaN'
        else:
            return obj
    elif isinstance(obj, bytes):
        return compat.as_text(obj, encoding)
    elif isinstance(obj, list) or isinstance(obj, tuple):
        return [Cleanse(i, encoding) for i in obj]
    elif isinstance(obj, set):
        return [Cleanse(i, encoding) for i in sorted(obj)]
    elif isinstance(obj, dict):
        return {
            Cleanse(k, encoding): Cleanse(v, encoding)
            for k, v in obj.items()
        }
    else:
        return obj
Example #43
0
 def testFormatMultiTensor(self):
   with self.test_session():
     tensor_one = array_ops.reshape(math_ops.range(100), [10, 10])
     tensor_two = tensor_one * 10
     format_output = string_ops.string_format("One: {},\nTwo: {}",
                                              (tensor_one, tensor_two))
     out = self.evaluate(format_output)
     expected = ("One: [[0 1 2 ... 7 8 9]\n"
                 " [10 11 12 ... 17 18 19]\n"
                 " [20 21 22 ... 27 28 29]\n"
                 " ...\n"
                 " [70 71 72 ... 77 78 79]\n"
                 " [80 81 82 ... 87 88 89]\n"
                 " [90 91 92 ... 97 98 99]],\n"
                 "Two: [[0 10 20 ... 70 80 90]\n"
                 " [100 110 120 ... 170 180 190]\n"
                 " [200 210 220 ... 270 280 290]\n"
                 " ...\n"
                 " [700 710 720 ... 770 780 790]\n"
                 " [800 810 820 ... 870 880 890]\n"
                 " [900 910 920 ... 970 980 990]]")
     self.assertEqual(compat.as_text(out), expected)
Example #44
0
  def _TestOneEpoch(self, files, num_records, gap_bytes, encoding=None):
    hop_bytes = 0 if gap_bytes == 0 else self._record_bytes + gap_bytes
    reader = io_ops.FixedLengthRecordReader(
        header_bytes=self._header_bytes,
        record_bytes=self._record_bytes,
        footer_bytes=self._footer_bytes,
        hop_bytes=hop_bytes,
        encoding=encoding,
        name="test_reader")
    queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
    key, value = reader.read(queue)

    self.evaluate(queue.enqueue_many([files]))
    self.evaluate(queue.close())
    for i in range(self._num_files):
      for j in range(num_records):
        k, v = self.evaluate([key, value])
        self.assertAllEqual("%s:%d" % (files[i], j), compat.as_text(k))
        self.assertAllEqual(self._Record(i, j), v)

    with self.assertRaisesOpError("is closed and has insufficient elements "
                                  "\\(requested 1, current size 0\\)"):
      k, v = self.evaluate([key, value])
Example #45
0
  def testOneEpoch(self):
    files = self._CreateFiles()
    with self.test_session() as sess:
      reader = io_ops.FixedLengthRecordReader(
          header_bytes=self._header_bytes,
          record_bytes=self._record_bytes,
          footer_bytes=self._footer_bytes,
          hop_bytes=0,
          name="test_reader")
      queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
      key, value = reader.read(queue)

      queue.enqueue_many([files]).run()
      queue.close().run()
      for i in range(self._num_files):
        for j in range(self._num_records):
          k, v = sess.run([key, value])
          self.assertAllEqual("%s:%d" % (files[i], j), compat.as_text(k))
          self.assertAllEqual(self._Record(i, j), v)

      with self.assertRaisesOpError("is closed and has insufficient elements "
                                    "\\(requested 1, current size 0\\)"):
        k, v = sess.run([key, value])
Example #46
0
    def _save_and_write_assets(self, assets_collection_to_add=None):
        """Saves asset to the meta graph and writes asset files to disk.

    Args:
      assets_collection_to_add: The collection where the asset paths are setup.
    """
        asset_source_filepath_list = _maybe_save_assets(
            assets_collection_to_add)

        # Return if there are no assets to write.
        if len(asset_source_filepath_list) is 0:
            tf_logging.info("No assets to write.")
            return

        assets_destination_dir = os.path.join(
            compat.as_bytes(self._export_dir),
            compat.as_bytes(constants.ASSETS_DIRECTORY))

        if not file_io.file_exists(assets_destination_dir):
            file_io.recursive_create_dir(assets_destination_dir)

        # Copy each asset from source path to destination path.
        for asset_source_filepath in asset_source_filepath_list:
            asset_source_filename = os.path.basename(asset_source_filepath)

            asset_destination_filepath = os.path.join(
                compat.as_bytes(assets_destination_dir),
                compat.as_bytes(asset_source_filename))

            # Only copy the asset file to the destination if it does not already
            # exist. This is to ensure that an asset with the same name defined as
            # part of multiple graphs is only copied the first time.
            if not file_io.file_exists(asset_destination_filepath):
                file_io.copy(asset_source_filepath, asset_destination_filepath)

        tf_logging.info("Assets written to: %s",
                        compat.as_text(assets_destination_dir))
Example #47
0
def load_file_system_library(library_filename):
    """Loads a TensorFlow plugin, containing file system implementation.

  Pass `library_filename` to a platform-specific mechanism for dynamically
  loading a library. The rules for determining the exact location of the
  library are platform-specific and are not documented here.

  Args:
    library_filename: Path to the plugin.
      Relative or absolute filesystem path to a dynamic library file.

  Returns:
    None.

  Raises:
    RuntimeError: when unable to load the library.
  """
    status = py_tf.TF_NewStatus()
    lib_handle = py_tf.TF_LoadLibrary(library_filename, status)
    try:
        error_code = py_tf.TF_GetCode(status)
        if error_code != 0:
            error_msg = compat.as_text(py_tf.TF_Message(status))
            with _FILE_SYSTEM_LIBRARY_MAP_LOCK:
                if (error_code == error_codes_pb2.ALREADY_EXISTS
                        and 'has already been loaded' in error_msg
                        and library_filename in _FILE_SYSTEM_LIBRARY_MAP):
                    return
            # pylint: disable=protected-access
            raise errors._make_specific_exception(None, None, error_msg,
                                                  error_code)
            # pylint: enable=protected-access
    finally:
        py_tf.TF_DeleteStatus(status)

    with _FILE_SYSTEM_LIBRARY_MAP_LOCK:
        _FILE_SYSTEM_LIBRARY_MAP[library_filename] = lib_handle
Example #48
0
 def deserialize(self, encoded_accumulator):
   """Deserialize an accumulator received from 'serialize()'."""
   value_dict = json.loads(compat.as_text(encoded_accumulator))
   return self._create_accumulator(
       np.array(value_dict[_COUNT_NAME]), np.array(value_dict[_MEAN_NAME]),
       np.array(value_dict[_VARIANCE_NAME]))
Example #49
0
def initialize_tpu_system(cluster_resolver=None):
  """Initialize the TPU devices.

  Args:
    cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
        which provides information about the TPU cluster.
  Returns:
    The tf.tpu.Topology object for the topology of the TPU cluster.

  Raises:
    RuntimeError: If no TPU devices found for eager execution.
  """
  if cluster_resolver is None:
    cluster_resolver = TPUClusterResolver("")
  assert isinstance(cluster_resolver, TPUClusterResolver)

  tpu_name = compat.as_text(cluster_resolver._tpu)  # pylint: disable=protected-access
  if tpu_name in _INITIALIZED_TPU_SYSTEMS:
    logging.warning("TPU system %s has already been initialized. "
                    "Reinitializing the TPU can cause previously created "
                    "variables on TPU to be lost.")

  logging.info("Initializing the TPU system.")

  if context.executing_eagerly():
    # This function looks as it is for the following non-intuitive reasons.
    # tpu.initialize_system creates a dummy op whose sole purpose is to trigger
    # DistributedTPURewritePass. This pass actually adds real ops that
    # initialize the TPU system. Thus, we can't simply run tpu.initialize_system
    # eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
    @function.defun
    def _tpu_init_fn():
      return tpu.initialize_system()

    tpu_devices = sorted(
        [x for x in context.list_devices() if "device:TPU:" in x])

    if not tpu_devices:
      raise RuntimeError("Could not find any TPU devices")

    # Replace the remote TPU device with the remote TPU_SYSTEM system device. As
    # in the remote TPU device case, we will try to compile it instead of
    # running through optimization passes and TF Executor, but TPU_SYSTEM should
    # work.
    tpu_system_device = tpu_devices[0].replace("TPU", "TPU_SYSTEM")

    with ops.device(tpu_system_device):
      output = _tpu_init_fn()
    serialized_topology = output.numpy()
  else:
    master = cluster_resolver.master()
    session_config = config_pb2.ConfigProto(allow_soft_placement=True)
    with ops.Graph().as_default():
      with session_lib.Session(config=session_config, target=master) as sess:
        serialized_topology = sess.run(tpu.initialize_system())

  logging.info("Finished initializing TPU system.")
  tpu_topology = topology.Topology(serialized=serialized_topology)
  _INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology

  return tpu_topology
Example #50
0
  def apply_op(self, op_type_name, name=None, **keywords):
    # pylint: disable=g-doc-args
    """Add a node invoking a registered Op to a graph.

    Example usage:
       # input1 and input2 can be Tensors or anything ops.convert_to_tensor()
       # will convert to a Tensor.
       op_def_library.apply_op("op", input1=input1, input2=input2)
       # Can specify a node name.
       op_def_library.apply_op("op", input1=input1, name="node_name")
       # Must use keyword arguments, with the names specified in the OpDef.
       op_def_library.apply_op("op", input_name=input, attr_name=attr)

    All attrs must either be inferred from an input or specified.
    (If inferred, the attr must not be specified.)  If an attr has a default
    value specified in the Op's OpDef, then you may pass None as the value
    of that attr to get the default.

    Args:
      op_type_name: string. Must match the name field of a registered Op.
      name: string. Optional name of the created op.
      **keywords: input Tensor and attr arguments specified by name,
        and optional parameters to pass when constructing the Operation.

    Returns:
      The Tensor(s) representing the output of the operation, or the Operation
      itself if there are no outputs.

    Raises:
      RuntimeError: On some errors.
      TypeError: On some errors.
      ValueError: On some errors.
    """
    op_info = self._ops.get(op_type_name, None)
    if op_info is None:
      raise RuntimeError("Unrecognized Op name " + op_type_name)
    op_def = op_info.op_def

    # Determine the graph context.
    try:
      # Need to flatten all the arguments into a list.
      # pylint: disable=protected-access
      g = ops._get_graph_from_inputs(_Flatten(keywords.values()))
      # pyline: enable=protected-access
    except AssertionError as e:
      raise RuntimeError(
          "Cannot determine graph for Op '%s' due to: %s"
          % (op_type_name, e.message))

    # Default name if not specified.
    if name is None:
      name = op_type_name

    # Check for deprecation
    deprecation_version = op_def.deprecation.version
    if deprecation_version:
      producer = g.graph_def_versions.producer
      if producer >= deprecation_version:
        raise NotImplementedError(
            ("Op %s is not available in GraphDef version %d. "
             "It has been removed in version %d. %s.") %
            (op_type_name, producer, deprecation_version,
             op_def.deprecation.explanation))

    # Fill in the list of default types for all "type" attrs.  This
    # will be used to choose a preferred dtype to convert to in the
    # absence of input type information.
    #
    # TODO(b/31302892): Currently the defaults don't work in the right
    # way if you have two inputs, one of whose type resolution depends
    # on the other.  Handling this will require restructuring this code
    # significantly.
    default_type_attr_map = {}
    for attr_def in op_def.attr:
      if attr_def.type != "type":
        continue
      key = attr_def.name
      if attr_def.HasField("default_value"):
        default_type_attr_map[key] = dtypes.as_dtype(
            attr_def.default_value.type)

    # Requires that op_def has passed validation (using the C++
    # ValidateOpDef() from ../framework/op_def_util.h).
    attrs = {}
    inputs = []
    input_types = []
    with g.as_default(), ops.name_scope(name) as scope:

      # Perform input type inference
      inferred_from = {}
      for input_arg in op_def.input_arg:
        input_name = input_arg.name
        if input_name in keywords:
          values = keywords.pop(input_name)
        elif input_name + "_" in keywords:
          # Handle the case where the name is a keyword or built-in
          # for Python so we use the name + _ instead.
          input_name += "_"
          values = keywords.pop(input_name)
        else:
          raise TypeError("No argument for input " + input_name)

        # Goals:
        # * Convert values to Tensors if it contains constants.
        # * Verify that values is a list if that matches the input_arg's
        #   type.
        # * If the input_arg's type is determined by attrs, either set
        #   those attrs and validate those attr values are legal (if
        #   they have not yet been set) or validate the input matches
        #   the type indicated by the attrs (if they have already been
        #   inferred via an earlier input).
        # * If the input_arg has an explicit type, make sure the input
        #   conforms.

        if _IsListParameter(input_arg):
          if not _IsListValue(values):
            raise TypeError(
                "Expected list for '%s' argument to '%s' Op, not %s." %
                (input_name, op_type_name, values))
          # In cases where we expect all elements of the list to have the
          # same dtype, try to cast non-Tensor elements to that type.
          dtype = None
          default_dtype = None
          if input_arg.type != types_pb2.DT_INVALID:
            dtype = input_arg.type
          elif input_arg.number_attr:
            if input_arg.type_attr in attrs:
              dtype = attrs[input_arg.type_attr]
            else:
              for t in values:
                if isinstance(t, ops.Tensor):
                  dtype = t.dtype
                  break

            # dtype still not found, prefer using the default dtype
            # from the attr.
            if dtype is None and input_arg.type_attr in default_type_attr_map:
              default_dtype = default_type_attr_map[input_arg.type_attr]

          try:
            if not input_arg.is_ref and dtype:
              dtype = dtypes.as_dtype(dtype).base_dtype
            values = ops.convert_n_to_tensor(
                values,
                name=input_arg.name,
                dtype=dtype if dtype else None,
                preferred_dtype=default_dtype,
                as_ref=input_arg.is_ref)
            if input_arg.number_attr and len(
                set(v.dtype.base_dtype for v in values)) > 1:
              raise TypeError()  # All types should match.
          except (TypeError, ValueError):
            # What types does the conversion function think values have?
            observed_types = []
            for value in values:
              try:
                converted_value = ops.convert_to_tensor(
                    value, as_ref=input_arg.is_ref)
                observed_types.append(converted_value.dtype.base_dtype.name)
              except (TypeError, ValueError):
                observed_types.append("<NOT CONVERTIBLE TO TENSOR>")
            observed = ", ".join(observed_types)

            prefix = (
                "Tensors in list passed to '%s' of '%s' Op have types [%s]" %
                (input_name, op_type_name, observed))
            if input_arg.number_attr:
              if input_arg.type != types_pb2.DT_INVALID:
                raise TypeError("%s that do not match expected type %s." %
                                (prefix, dtype.name))
              elif input_arg.type_attr in attrs:
                raise TypeError("%s that do not match type %s inferred from "
                                "earlier arguments." %
                                (prefix, dtype.name))
              else:
                raise TypeError("%s that don't all match." % prefix)
            else:
              raise TypeError("%s that are invalid." % prefix)

          types = [x.dtype for x in values]
          inputs.extend(values)
        else:
          # In cases where we have an expected type, try to convert non-Tensor
          # arguments to that type.
          dtype = None
          default_dtype = None
          if input_arg.type != types_pb2.DT_INVALID:
            dtype = input_arg.type
          elif input_arg.type_attr in attrs:
            dtype = attrs[input_arg.type_attr]
          elif input_arg.type_attr in default_type_attr_map:
            # The dtype could not be inferred solely from the inputs,
            # so we prefer the attr's default, so code that adds a new attr
            # with a default is backwards compatible.
            default_dtype = default_type_attr_map[input_arg.type_attr]

          try:
            values = ops.convert_to_tensor(
                values,
                name=input_arg.name,
                dtype=dtype,
                as_ref=input_arg.is_ref,
                preferred_dtype=default_dtype)
          except ValueError:
            # What type does convert_to_tensor think it has?
            observed = ops.convert_to_tensor(values,
                                             as_ref=input_arg.is_ref).dtype.name
            prefix = ("Input '%s' of '%s' Op has type %s that does not match" %
                      (input_name, op_type_name, observed))
            if input_arg.type != types_pb2.DT_INVALID:
              raise TypeError("%s expected type of %s." %
                              (prefix, dtypes.as_dtype(input_arg.type).name))
            else:
              # Update the maps with the default, if needed.
              k = input_arg.type_attr
              if k in default_type_attr_map:
                if k not in attrs:
                  attrs[k] = default_type_attr_map[k]
                  if k not in inferred_from:
                    inferred_from[k] = "Default in OpDef"

              raise TypeError(
                  "%s type %s of argument '%s'." %
                  (prefix, dtypes.as_dtype(attrs[input_arg.type_attr]).name,
                   inferred_from[input_arg.type_attr]))

          types = [values.dtype]
          inputs.append(values)
        base_types = [x.base_dtype for x in types]

        if input_arg.number_attr:
          # <number-attr> * <type> or <number-attr> * <type-attr>
          if input_arg.number_attr in attrs:
            if len(values) != attrs[input_arg.number_attr]:
              raise ValueError(
                  "List argument '%s' to '%s' Op with length %d must match "
                  "length %d of argument '%s'." %
                  (input_name, op_type_name, len(values),
                   attrs[input_arg.number_attr],
                   inferred_from[input_arg.number_attr]))
          else:
            attrs[input_arg.number_attr] = len(values)
            inferred_from[input_arg.number_attr] = input_name
            num_attr = _Attr(op_def, input_arg.number_attr)
            if num_attr.has_minimum and len(values) < num_attr.minimum:
              raise ValueError(
                  "List argument '%s' to '%s' Op with length %d shorter "
                  "than minimum length %d." %
                  (input_name, op_type_name, len(values), num_attr.minimum))
          # All tensors must have the same base type.
          if any([bt != base_types[0] for bt in base_types]):
            raise TypeError(
                "All tensors passed to '%s' of '%s' Op "
                "must have the same type." %
                (input_name, op_type_name))
          if input_arg.type != types_pb2.DT_INVALID:
            # <number-attr> * <type> case
            if base_types and base_types[0] != input_arg.type:
              assert False, "Unreachable"
          elif input_arg.type_attr in attrs:
            # <number-attr> * <type-attr> case, where <type-attr> already
            # has an inferred value.
            if base_types and base_types[0] != attrs[input_arg.type_attr]:
              assert False, "Unreachable"
          else:
            # <number-attr> * <type-attr> case, where we are now setting
            # the <type-attr> based on this input
            if not base_types:
              raise TypeError(
                  "Don't know how to infer type variable from empty input "
                  "list passed to input '%s' of '%s' Op." %
                  (input_name, op_type_name))
            attrs[input_arg.type_attr] = base_types[0]
            inferred_from[input_arg.type_attr] = input_name
            type_attr = _Attr(op_def, input_arg.type_attr)
            _SatisfiesTypeConstraint(base_types[0], type_attr)
        elif input_arg.type_attr:
          # <type-attr>
          attr_value = base_types[0]
          if input_arg.type_attr in attrs:
            if attrs[input_arg.type_attr] != attr_value:
              assert False, "Unreachable"
          else:
            for base_type in base_types:
              _SatisfiesTypeConstraint(base_type,
                                       _Attr(op_def, input_arg.type_attr))
            attrs[input_arg.type_attr] = attr_value
            inferred_from[input_arg.type_attr] = input_name
        elif input_arg.type_list_attr:
          # <type-list-attr>
          attr_value = base_types
          if input_arg.type_list_attr in attrs:
            if attrs[input_arg.type_list_attr] != attr_value:
              raise TypeError(
                  "Input '%s' of '%s' Op has type list of %s that does not "
                  "match type list %s of argument '%s'." %
                  (input_name, op_type_name,
                   ", ".join(dtypes.as_dtype(x).name for x in attr_value),
                   ", ".join(dtypes.as_dtype(x).name
                             for x in attrs[input_arg.type_list_attr]),
                   inferred_from[input_arg.type_list_attr]))
          else:
            for base_type in base_types:
              _SatisfiesTypeConstraint(base_type,
                                       _Attr(op_def, input_arg.type_list_attr))
            attrs[input_arg.type_list_attr] = attr_value
            inferred_from[input_arg.type_list_attr] = input_name
        else:
          # single Tensor with specified type
          if base_types[0] != input_arg.type:
            assert False, "Unreachable"

        if input_arg.is_ref:
          if not all(x.is_ref_dtype for x in types):
            raise TypeError(
                "Input '%s' of '%s' Op requires l-value input" %
                (input_name, op_type_name))
          input_types.extend(types)
        else:
          input_types.extend(base_types)

      # Process remaining attrs
      for attr in op_def.attr:
        # Skip attrs that have already had their values inferred
        if attr.name in attrs:
          if attr.name in keywords:
            raise TypeError(
                "Should not specify value for inferred attr '%s'." % attr.name)
          continue
        if attr.name in keywords:
          attrs[attr.name] = keywords.pop(attr.name)
        elif attr.name + "_" in keywords:
          # Attrs whose names match Python keywords have an extra '_'
          # appended, so we must check for that as well.
          attrs[attr.name] = keywords.pop(attr.name + "_")
        else:
          raise TypeError("No argument for attr " + attr.name)

      # Convert attr values to AttrValue protos.
      attr_protos = {}
      for attr_def in op_def.attr:
        key = attr_def.name
        value = attrs[key]
        attr_value = attr_value_pb2.AttrValue()
        if attr_def.HasField("default_value") and value is None:
          attr_value.CopyFrom(attr_def.default_value)
          attr_protos[key] = attr_value
          continue
        if attr_def.type.startswith("list("):
          if not _IsListValue(value):
            raise TypeError("Expected list for attr " + key)
          if attr_def.has_minimum:
            if len(value) < attr_def.minimum:
              raise ValueError("Attr '%s' of '%s' Op passed list of length %d "
                               "less than minimum %d." %
                               (key, op_type_name, len(value),
                                attr_def.minimum))
          attr_value.list.SetInParent()
        if attr_def.type == "string":
          attr_value.s = _MakeStr(value, key)
          if attr_def.HasField("allowed_values"):
            if attr_value.s not in attr_def.allowed_values.list.s:
              raise ValueError(
                  "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
                  (key, op_type_name, compat.as_text(attr_value.s),
                   '", "'.join(map(compat.as_text,
                                   attr_def.allowed_values.list.s))))
        elif attr_def.type == "list(string)":
          attr_value.list.s.extend([_MakeStr(x, key) for x in value])
          if attr_def.HasField("allowed_values"):
            for x in attr_value.list.s:
              if x not in attr_def.allowed_values.list.s:
                raise ValueError(
                    "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
                    (key, op_type_name, compat.as_text(x),
                     '", "'.join(map(compat.as_text,
                                     attr_def.allowed_values.list.s))))
        elif attr_def.type == "int":
          attr_value.i = _MakeInt(value, key)
          if attr_def.has_minimum:
            if attr_value.i < attr_def.minimum:
              raise ValueError(
                  "Attr '%s' of '%s' Op passed %d less than minimum %d." %
                  (key, op_type_name, attr_value.i, attr_def.minimum))
        elif attr_def.type == "list(int)":
          attr_value.list.i.extend([_MakeInt(x, key) for x in value])
        elif attr_def.type == "float":
          attr_value.f = _MakeFloat(value, key)
        elif attr_def.type == "list(float)":
          attr_value.list.f.extend([_MakeFloat(x, key) for x in value])
        elif attr_def.type == "bool":
          attr_value.b = _MakeBool(value, key)
        elif attr_def.type == "list(bool)":
          attr_value.list.b.extend([_MakeBool(x, key) for x in value])
        elif attr_def.type == "type":
          attr_value.type = _MakeType(value, attr_def)
        elif attr_def.type == "list(type)":
          attr_value.list.type.extend(
              [_MakeType(x, attr_def) for x in value])
        elif attr_def.type == "shape":
          attr_value.shape.CopyFrom(_MakeShape(value, key))
        elif attr_def.type == "list(shape)":
          attr_value.list.shape.extend(
              [_MakeShape(x, key) for x in value])
        elif attr_def.type == "tensor":
          attr_value.tensor.CopyFrom(_MakeTensor(value, key))
        elif attr_def.type == "list(tensor)":
          attr_value.list.tensor.extend(
              [_MakeTensor(x, key) for x in value])
        elif attr_def.type == "func":
          if isinstance(value, compat.bytes_or_text_types):
            attr_value.func.name = value
          else:
            value.add_to_graph(ops.get_default_graph())
            attr_value.func.name = value.name
        else:
          raise TypeError("Unrecognized Attr type " + attr_def.type)

        attr_protos[key] = attr_value
      del attrs  # attrs is no longer authoritative, use attr_protos instead

      # Determine output types (possibly using attrs)
      output_types = []
      output_structure = []
      for arg in op_def.output_arg:
        types = []
        if arg.number_attr:
          n = _AttrValue(attr_protos, arg.number_attr).i
          if arg.type_attr:
            types = [_AttrValue(attr_protos, arg.type_attr).type] * n
          else:
            types = [arg.type] * n
          output_structure.append(n)
        elif arg.type_attr:
          t = _AttrValue(attr_protos, arg.type_attr)
          types = [t.type]
          output_structure.append(None)
        elif arg.type_list_attr:
          t = _AttrValue(attr_protos, arg.type_list_attr)
          types = t.list.type
          output_structure.append(len(types))
        else:
          types = [arg.type]
          output_structure.append(None)
        if arg.is_ref:
          types = [dtypes.as_dtype(x).as_ref for x in types]
        output_types.extend(types)

      if keywords:
        raise TypeError("apply_op() got unexpected keyword arguments: " +
                        ", ".join(sorted(keywords.keys())))

      # NOTE(mrry): We add an explicit colocation constraint between
      # the newly created op and any of its reference-typed inputs.
      must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs)
                              if arg.is_ref]
      with _MaybeColocateWith(must_colocate_inputs):
        # Add Op to graph
        op = g.create_op(op_type_name, inputs, output_types, name=scope,
                         input_types=input_types, attrs=attr_protos,
                         op_def=op_def)
        if output_structure:
          outputs = op.outputs
          res = _Restructure(ops.convert_n_to_tensor(outputs), output_structure)
          if isinstance(res, list) and not res and op_def.is_stateful:
            return op
          else:
            return res
        else:
          return op
 def deserialize(self, encoded_accumulator):
     """Deserialize an accumulator received from 'serialize()'."""
     return json.loads(compat.as_text(encoded_accumulator))
def initialize_tpu_system(cluster_resolver=None):
    """Initialize the TPU devices.

  Args:
    cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
        which provides information about the TPU cluster.
  Returns:
    The tf.tpu.Topology object for the topology of the TPU cluster.

  Raises:
    RuntimeError: If no TPU devices found for eager execution.
  """
    if cluster_resolver is None:
        cluster_resolver = TPUClusterResolver("")
    assert isinstance(cluster_resolver, TPUClusterResolver)

    tpu_name = compat.as_text(cluster_resolver._tpu)  # pylint: disable=protected-access
    if tpu_name in _INITIALIZED_TPU_SYSTEMS:
        logging.warning("TPU system %s has already been initialized. "
                        "Reinitializing the TPU can cause previously created "
                        "variables on TPU to be lost.")

    logging.info("Initializing the TPU system: %s", tpu_name)

    if context.executing_eagerly():
        # This function looks as it is for the following non-intuitive reasons.
        # tpu.initialize_system creates a dummy op whose sole purpose is to trigger
        # DistributedTPURewritePass. This pass actually adds real ops that
        # initialize the TPU system. Thus, we can't simply run tpu.initialize_system
        # eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
        job = None
        if tpu_name not in _LOCAL_MASTERS:
            # Explicitly place the tpu.initialize_system in the first worker to
            # avoid the output node match multiple devices error.
            job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name())

        @function.defun
        def _tpu_init_fn():
            return tpu.initialize_system(job=job)

        # The TPU_SYSTEM device must match the device used in tpu.initialize_system
        # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM
        # devices available.
        with ops.device(tpu._tpu_system_device_name(job)):  # pylint: disable=protected-access
            output = _tpu_init_fn()

        # Clear out the eager context caches since the memory is invalid now.
        logging.info("Clearing out eager caches")
        context.context()._clear_caches()  # pylint: disable=protected-access

        serialized_topology = output.numpy()
    else:
        master = cluster_resolver.master()
        cluster_spec = cluster_resolver.cluster_spec()

        session_config = config_pb2.ConfigProto(allow_soft_placement=True)
        if cluster_spec:
            session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())

        with ops.Graph().as_default():
            with session_lib.Session(config=session_config,
                                     target=master) as sess:
                serialized_topology = sess.run(tpu.initialize_system())

    logging.info("Finished initializing TPU system.")
    tpu_topology = topology.Topology(serialized=serialized_topology)
    _INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology

    return tpu_topology
Example #53
0
    def add_meta_graph_and_variables(self,
                                     sess,
                                     tags,
                                     signature_def_map=None,
                                     assets_collection=None,
                                     legacy_init_op=None,
                                     clear_devices=False,
                                     main_op=None):
        """Adds the current meta graph to the SavedModel and saves variables.

    Creates a Saver to save the variables from the provided session. Exports the
    corresponding meta graph def. This function assumes that the variables to be
    saved have been initialized. For a given `SavedModelBuilder`, this API must
    be called exactly once and for the first meta graph to save. For subsequent
    meta graph defs to be added, the `add_meta_graph()` API must be used.

    Args:
      sess: The TensorFlow session from which to save the meta graph and
        variables.
      tags: The set of tags with which to save the meta graph.
      signature_def_map: The map of signature def map to add to the meta graph
        def.
      assets_collection: Assets collection to be saved with SavedModel.
      legacy_init_op: Legacy support for op or group of ops to execute after the
          restore op upon a load.
      clear_devices: Set to true if the device info on the default graph should
          be cleared.
      main_op: Op or group of ops to execute when the graph is loaded.
    """
        if self._has_saved_variables:
            raise AssertionError(
                "Graph state including variables and assets has "
                "already been saved. Please invoke "
                "`add_meta_graph()` instead.")

        # Validate the signature def map to ensure all included TensorInfos are
        # properly populated.
        self._validate_signature_def_map(signature_def_map)

        # Save asset files and write them to disk, if any.
        self._save_and_write_assets(assets_collection)

        # Create the variables sub-directory, if it does not exist.
        variables_dir = os.path.join(
            compat.as_text(self._export_dir),
            compat.as_text(constants.VARIABLES_DIRECTORY))
        if not file_io.file_exists(variables_dir):
            file_io.recursive_create_dir(variables_dir)

        variables_path = os.path.join(
            compat.as_text(variables_dir),
            compat.as_text(constants.VARIABLES_FILENAME))

        if main_op is None:
            # Add legacy init op to the SavedModel.
            self._maybe_add_legacy_init_op(legacy_init_op)
        else:
            self._add_main_op(main_op)

        # Initialize a saver to generate a sharded output for all saveables in the
        # current scope.
        saver = tf_saver.Saver(
            variables._all_saveable_objects(),  # pylint: disable=protected-access
            sharded=True,
            write_version=saver_pb2.SaverDef.V2,
            allow_empty=True)

        # Save the variables. Also, disable writing the checkpoint state proto. The
        # file is not used during SavedModel loading. In addition, since a
        # SavedModel can be copied or moved, this avoids the checkpoint state to
        # become outdated.
        saver.save(sess,
                   variables_path,
                   write_meta_graph=False,
                   write_state=False)

        # Export the meta graph def.

        # The graph almost certainly previously contained at least one Saver, and
        # possibly several (e.g. one for loading a pretrained embedding, and another
        # for the model weights).  However, a *new* Saver was just created that
        # includes all of the variables.  In the context of the SavedModel, this
        # new Saver is the only one that needs to be retained.  The associated
        # checkpoint that was saved just above contains all of the variable values.
        # Thus, any preexisting Savers are redundant and useless at best, but worse
        # may break downstream graph-processing tools, and can be confusing during
        # debugging.  It is therefore safe and wise to set `clear_extraneous_savers`
        # to `True`, since it removes both the extraneous SaverDefs and their
        # associated Save/Restore Ops from the graph.
        meta_graph_def = saver.export_meta_graph(clear_devices=clear_devices,
                                                 clear_extraneous_savers=True)

        # Tag the meta graph def and add it to the SavedModel.
        self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)

        # Mark this instance of SavedModel as having saved variables, such that
        # subsequent attempts to save variables will fail.
        self._has_saved_variables = True
Example #54
0
    def cluster_spec(self):
        """Returns a ClusterSpec object based on the latest TPU information.

    We retrieve the information from the GCE APIs every time this method is
    called.

    Returns:
      A ClusterSpec containing host information returned from Cloud TPUs.

    Raises:
      RuntimeError: If the provided TPU is not healthy.
    """
        ############################################################################
        # There are 5 potential cases this code must handle:
        #  1. [Normal case.] We should resolve the TPU name to a set of tasks, and
        #      a. Create a ClusterSpec that includes the coordinator job
        #      b. Create a ClusterSpec without the coordinator job.
        #  2. [GKE / No API Access.] We should not resolve the TPU name to a set of
        #     tasks and
        #      a. Create a ClusterSpec with the coordinator
        #      b. Create a ClusterSpec without the coordinator
        #  3. [Other (legacy non-gRPC).] We should return an empty ClusterSpec.
        ############################################################################

        if self._shouldResolve():
            # Case 1.
            full_name = 'projects/%s/locations/%s/nodes/%s' % (
                self._project, self._zone, compat.as_text(self._tpu))
            request = self._service.projects().locations().nodes().get(
                name=full_name)
            response = request.execute()

            if 'state' in response and response['state'] != 'READY':
                raise RuntimeError(
                    'TPU "%s" is not yet ready; state: "%s"' %
                    (compat.as_text(self._tpu), response['state']))

            if 'health' in response and response['health'] != 'HEALTHY':
                raise RuntimeError(
                    'TPU "%s" is unhealthy: "%s"' %
                    (compat.as_text(self._tpu), response['health']))

            if 'networkEndpoints' in response:
                worker_list = [
                    '%s:%s' % (endpoint['ipAddress'], endpoint['port'])
                    for endpoint in response['networkEndpoints']
                ]
            else:
                # Fall back to the deprecated response format
                instance_url = '%s:%s' % (response['ipAddress'],
                                          response['port'])
                worker_list = [instance_url]

            cluster_spec = {self._job_name: worker_list}
        else:
            if not self._tpu.startswith(compat.as_bytes('grpc://')):
                # Case 3.
                return None
            # Case 2.
            cluster_spec = {
                self._job_name: [
                    x[len(compat.as_bytes('grpc://')):]
                    for x in self._tpu.split(
                        compat.as_bytes(_ENDPOINTS_SEPARATOR))
                ]
            }

        if self._coordinator_address:
            # {1, 2}.a
            cluster_spec[self._coordinator_name] = [self._coordinator_address]

        return server_lib.ClusterSpec(cluster_spec)
Example #55
0
    def cluster_spec(self):
        """Returns a ClusterSpec object based on the latest TPU information.

    We retrieve the information from the GCE APIs every time this method is
    called.

    Returns:
      A ClusterSpec containing host information returned from Cloud TPUs.

    Raises:
      RuntimeError: If the provided TPU is not healthy.
    """
        ############################################################################
        # There are 5 potential cases this code must handle:
        #  1. [Normal case.] We should resolve the TPU name to a set of tasks, and
        #      a. Create a ClusterSpec that includes the coordinator job
        #      b. Create a ClusterSpec without the coordinator job.
        #  2. [GKE / No API Access.] We should not resolve the TPU name to a set of
        #     tasks and
        #      a. Create a ClusterSpec with the coordinator
        #      b. Create a ClusterSpec without the coordinator
        #  3. [Other (legacy non-gRPC).] We should return an empty ClusterSpec.
        ############################################################################

        if self._should_resolve():
            # Case 1.
            response = self._fetch_cloud_tpu_metadata()  # pylint: disable=protected-access

            if 'state' in response and response['state'] != 'READY':
                raise RuntimeError(
                    'TPU "%s" is not yet ready; state: "%s"' %
                    (compat.as_text(self._tpu), response['state']))

            if 'networkEndpoints' in response:
                worker_list = [
                    '%s:%s' % (endpoint['ipAddress'], endpoint['port'])
                    for endpoint in response['networkEndpoints']
                ]
            else:
                # Fall back to the deprecated response format
                instance_url = '%s:%s' % (response['ipAddress'],
                                          response['port'])
                worker_list = [instance_url]

            cluster_spec = {self.task_type: worker_list}
        else:
            if self.rpc_layer is None:
                # Case 3.
                return None
            # Case 2.
            tpus = []
            for tpu in compat.as_text(self._tpu).split(_ENDPOINTS_SEPARATOR):
                # We are working around the fact that GKE environment variable that is
                # supplied to us has the protocol string embedded in it, but we want
                # to strip it out for the ClusterSpec.
                if (self.rpc_layer is not None
                        and tpu.startswith(self.rpc_layer + '://')):
                    tpus.append(tpu[len(self.rpc_layer + '://'):])
                else:
                    tpus.append(tpu)
            cluster_spec = {self.task_type: tpus}

        if self._coordinator_address:
            # {1, 2}.a
            cluster_spec[self._coordinator_name] = [self._coordinator_address]

        return server_lib.ClusterSpec(cluster_spec)
Example #56
0
def get_variables_dir(export_dir):
    """Return variables sub-directory in the SavedModel."""
    return os.path.join(compat.as_text(export_dir),
                        compat.as_text(constants.VARIABLES_DIRECTORY))
Example #57
0
def get_variables_path(export_dir):
    """Return the variables path, used as the prefix for checkpoint files."""
    return os.path.join(compat.as_text(get_variables_dir(export_dir)),
                        compat.as_text(constants.VARIABLES_FILENAME))
Example #58
0
def get_assets_dir(export_dir):
    """Return path to asset directory in the SavedModel."""
    return os.path.join(compat.as_text(export_dir),
                        compat.as_text(constants.ASSETS_DIRECTORY))
Example #59
0
def get_debug_dir(export_dir):
    """Returns path to the debug sub-directory in the SavedModel."""
    return os.path.join(compat.as_text(export_dir),
                        compat.as_text(constants.DEBUG_DIRECTORY))
Example #60
0
 def testFormatNoTensor(self):
   with self.test_session():
     format_output = string_ops.string_format("No tensor.", ())
     out = self.evaluate(format_output)
     expected = "No tensor."
     self.assertEqual(compat.as_text(out), expected)