示例#1
0
  def testDeviceBeforeCond(self):
    with ops.Graph().as_default() as g:
      with self.session(graph=g):

        def fn():
          self.assertEqual("", constant_op.constant(3.0).op.device)
          return test_ops.device_placement_op()

        with ops.device("/device:CPU:0"):
          self.assertIn(
              compat.as_bytes("CPU:0"),
              self.evaluate(cond_v2.cond_v2(constant_op.constant(True),
                                            fn, fn)))

        def fn2():
          self.assertEqual("", constant_op.constant(3.0).op.device)
          return test_ops.device_placement_op()

        if test_util.is_gpu_available():
          with ops.device("/device:GPU:0"):
            self.assertIn(
                compat.as_bytes("GPU:0"),
                self.evaluate(cond_v2.cond_v2(constant_op.constant(True),
                                              fn2, fn2)))
        else:
          self.skipTest("Test requires a GPU to check GPU device placement.")
  def testGkeEnvironmentForPod(self):
    os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = ('grpc://10.120.27.5:8470,'
                                                     'grpc://10.120.27.6:8470,'
                                                     'grpc://10.120.27.7:8470,'
                                                     'grpc://10.120.27.8:8470')

    self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ)
    self.assertTrue(cluster_resolver.TPUClusterResolver._inGke())
    self.assertEqual(
        compat.as_bytes('grpc://10.120.27.5:8470,'
                        'grpc://10.120.27.6:8470,'
                        'grpc://10.120.27.7:8470,'
                        'grpc://10.120.27.8:8470'),
        compat.as_bytes(cluster_resolver.TPUClusterResolver._gkeEndpoints()))

    resolver = cluster_resolver.TPUClusterResolver()
    self.assertEqual(
        compat.as_bytes('grpc://10.120.27.5:8470'),
        compat.as_bytes(resolver.master()))
    actual_cluster_spec = resolver.cluster_spec()
    expected_proto = """
    job {
      name: 'worker'
      tasks { key: 0 value: '10.120.27.5:8470' }
      tasks { key: 1 value: '10.120.27.6:8470' }
      tasks { key: 2 value: '10.120.27.7:8470' }
      tasks { key: 3 value: '10.120.27.8:8470' }
    }
    """
    self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)

    del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
示例#3
0
  def _save_and_write_assets(self, assets_collection_to_add=None):
    """Saves asset to the meta graph and writes asset files to disk.

    Args:
      assets_collection_to_add: The collection where the asset paths are setup.
    """
    asset_source_filepath_list = _maybe_save_assets(assets_collection_to_add)

    # Return if there are no assets to write.
    if len(asset_source_filepath_list) is 0:
      tf_logging.info("No assets to write.")
      return

    assets_destination_dir = os.path.join(
        compat.as_bytes(self._export_dir),
        compat.as_bytes(constants.ASSETS_DIRECTORY))

    if not file_io.file_exists(assets_destination_dir):
      file_io.recursive_create_dir(assets_destination_dir)

    # Copy each asset from source path to destination path.
    for asset_source_filepath in asset_source_filepath_list:
      asset_source_filename = os.path.basename(asset_source_filepath)

      asset_destination_filepath = os.path.join(
          compat.as_bytes(assets_destination_dir),
          compat.as_bytes(asset_source_filename))

      # Only copy the asset file to the destination if it does not already
      # exist. This is to ensure that an asset with the same name defined as
      # part of multiple graphs is only copied the first time.
      if not file_io.file_exists(asset_destination_filepath):
        file_io.copy(asset_source_filepath, asset_destination_filepath)

    tf_logging.info("Assets written to: %s", assets_destination_dir)
示例#4
0
  def _save_and_write_assets(self, assets_collection_to_add=None):
    """Saves asset to the meta graph and writes asset files to disk.

    Args:
      assets_collection_to_add: The collection where the asset paths are setup.
    """
    asset_source_filepath_list = self._save_assets(assets_collection_to_add)

    # Return if there are no assets to write.
    if len(asset_source_filepath_list) is 0:
      tf_logging.info("No assets to write.")
      return

    assets_destination_dir = os.path.join(
        compat.as_bytes(self._export_dir),
        compat.as_bytes(constants.ASSETS_DIRECTORY))

    if not file_io.file_exists(assets_destination_dir):
      file_io.recursive_create_dir(assets_destination_dir)

    # Copy each asset from source path to destination path.
    for asset_source_filepath in asset_source_filepath_list:
      asset_source_filename = os.path.basename(asset_source_filepath)

      asset_destination_filepath = os.path.join(
          compat.as_bytes(assets_destination_dir),
          compat.as_bytes(asset_source_filename))
      file_io.copy(
          asset_source_filepath, asset_destination_filepath, overwrite=True)

    tf_logging.info("Assets written to: %s", assets_destination_dir)
示例#5
0
def get_timestamped_dir(dir_base):
  """Builds a path to a new subdirectory within the base directory.

  The subdirectory will be named using the current time.
  This guarantees monotonically increasing directory numbers even across
  multiple runs of the pipeline.
  The timestamp used is the number of seconds since epoch UTC.

  Args:
    dir_base: A string containing a directory to create the subdirectory under.

  Returns:
    The full path of the new subdirectory (which is not actually created yet).

  Raises:
    RuntimeError: if repeated attempts fail to obtain a unique timestamped
      directory name.
  """
  attempts = 0
  while attempts < MAX_DIRECTORY_CREATION_ATTEMPTS:
    timestamp = int(time.time())

    result_dir = os.path.join(
        compat.as_bytes(dir_base), compat.as_bytes(str(timestamp)))
    if not gfile.Exists(result_dir):
      # Collisions are still possible (though extremely unlikely): this
      # directory is not actually created yet, but it will be almost
      # instantly on return from this function.
      return result_dir
    time.sleep(1)
    attempts += 1
    logging.warn('Directory {} already exists; retrying (attempt {}/{})'.format(
        result_dir, attempts, MAX_DIRECTORY_CREATION_ATTEMPTS))
  raise RuntimeError('Failed to obtain a unique export directory name after '
                     '{} attempts.'.format(MAX_DIRECTORY_CREATION_ATTEMPTS))
示例#6
0
  def reset(target, containers=None, config=None):
    """Resets resource containers on `target`, and close all connected sessions.

    A resource container is distributed across all workers in the
    same cluster as `target`.  When a resource container on `target`
    is reset, resources associated with that container will be cleared.
    In particular, all Variables in the container will become undefined:
    they lose their values and shapes.

    NOTE:
    (i) reset() is currently only implemented for distributed sessions.
    (ii) Any sessions on the master named by `target` will be closed.

    If no resource containers are provided, all containers are reset.

    Args:
      target: The execution engine to connect to.
      containers: A list of resource container name strings, or `None` if all of
        all the containers are to be reset.
      config: (Optional.) Protocol buffer with configuration options.

    Raises:
      tf.errors.OpError: Or one of its subclasses if an error occurs while
        resetting containers.
    """
    if target is not None:
      target = compat.as_bytes(target)
    if containers is not None:
      containers = [compat.as_bytes(c) for c in containers]
    else:
      containers = []
    tf_session.TF_Reset(target, containers, config)
示例#7
0
  def save(self, as_text=False):
    """Writes a `SavedModel` protocol buffer to disk.

    The function writes the SavedModel protocol buffer to the export directory
    in serialized format.

    Args:
      as_text: Writes the SavedModel protocol buffer in text format to disk.

    Returns:
      The path to which the SavedModel protocol buffer was written.
    """
    if not file_io.file_exists(self._export_dir):
      file_io.recursive_create_dir(self._export_dir)

    if as_text:
      path = os.path.join(
          compat.as_bytes(self._export_dir),
          compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
      file_io.write_string_to_file(path, str(self._saved_model))
    else:
      path = os.path.join(
          compat.as_bytes(self._export_dir),
          compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
      file_io.write_string_to_file(path, self._saved_model.SerializeToString())
    tf_logging.info("SavedModel written to: %s", path)

    return path
示例#8
0
def _get_timestamped_export_dir(export_dir_base):
  # When we create a timestamped directory, there is a small chance that the
  # directory already exists because another worker is also writing exports.
  # In this case we just wait one second to get a new timestamp and try again.
  # If this fails several times in a row, then something is seriously wrong.
  max_directory_creation_attempts = 10

  attempts = 0
  while attempts < max_directory_creation_attempts:
    export_timestamp = int(time.time())

    export_dir = os.path.join(
        compat.as_bytes(export_dir_base), compat.as_bytes(
            str(export_timestamp)))
    if not gfile.Exists(export_dir):
      # Collisions are still possible (though extremely unlikely): this
      # directory is not actually created yet, but it will be almost
      # instantly on return from this function.
      return export_dir
    time.sleep(1)
    attempts += 1
    logging.warn(
        "Export directory {} already exists; retrying (attempt {}/{})".format(
            export_dir, attempts, max_directory_creation_attempts))
  raise RuntimeError("Failed to obtain a unique export directory name after "
                     "{} attempts.".format(max_directory_creation_attempts))
示例#9
0
def _get_asset_tensors(export_dir, meta_graph_def_to_load):
  """Gets the asset tensors, if defined in the meta graph def to load.

  Args:
    export_dir: Directory where the SavedModel is located.
    meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.

  Returns:
    A dictionary of asset tensors, keyed by the name of the asset tensor. The
    value in the map corresponds to the absolute path of the asset file.
  """
  # Collection-def that may contain the assets key.
  collection_def = meta_graph_def_to_load.collection_def

  asset_tensor_dict = {}
  if constants.ASSETS_KEY in collection_def:
    # Location of the assets for SavedModel.
    assets_directory = os.path.join(
        compat.as_bytes(export_dir),
        compat.as_bytes(constants.ASSETS_DIRECTORY))
    assets_any_proto = collection_def[constants.ASSETS_KEY].any_list.value
    # Process each asset and add it to the asset tensor dictionary.
    for asset_any_proto in assets_any_proto:
      asset_proto = meta_graph_pb2.AssetFileDef()
      asset_any_proto.Unpack(asset_proto)
      asset_tensor_dict[asset_proto.tensor_info.name] = os.path.join(
          compat.as_bytes(assets_directory),
          compat.as_bytes(asset_proto.filename))
  return asset_tensor_dict
示例#10
0
def _convert_fn(dtec, sorted_feature_names, num_dense, num_sparse_float,
                num_sparse_int, export_dir, unused_eval_result):
  universal_format = custom_export_strategy.convert_to_universal_format(
      dtec, sorted_feature_names, num_dense, num_sparse_float, num_sparse_int)
  with tf.gfile.GFile(os.path.join(
      compat.as_bytes(export_dir), compat.as_bytes("tree_proto")), "w") as f:
    f.write(str(universal_format))
示例#11
0
  def to_proto(self, export_scope=None):  # pylint: disable=unused-argument
    """Converts a `HParams` object to a `HParamDef` protocol buffer.

    Args:
      export_scope: Optional `string`. Name scope to remove.

    Returns:
      A `HParamDef` protocol buffer.
    """
    hparam_proto = hparam_pb2.HParamDef()
    for name in self._hparam_types:
      # Parse the values.
      param_type, is_list = self._hparam_types.get(name, (None, None))
      kind = HParams._get_kind_name(param_type, is_list)

      if is_list:
        if kind.startswith('bytes'):
          v_list = [compat.as_bytes(v) for v in getattr(self, name)]
        else:
          v_list = [v for v in getattr(self, name)]
        getattr(hparam_proto.hparam[name], kind).value.extend(v_list)
      else:
        v = getattr(self, name)
        if kind.startswith('bytes'):
          v = compat.as_bytes(getattr(self, name))
        setattr(hparam_proto.hparam[name], kind, v)

    return hparam_proto
示例#12
0
def tf_record_iterator(path, options=None):
  """An iterator that read the records from a TFRecords file.

  Args:
    path: The path to the TFRecords file.
    options: (optional) A TFRecordOptions object.

  Yields:
    Strings.

  Raises:
    IOError: If `path` cannot be opened for reading.
  """
  compression_type = TFRecordOptions.get_compression_type_string(options)
  with errors.raise_exception_on_not_ok_status() as status:
    reader = pywrap_tensorflow.PyRecordReader_New(
        compat.as_bytes(path), 0, compat.as_bytes(compression_type), status)

  if reader is None:
    raise IOError("Could not open %s." % path)
  while True:
    try:
      with errors.raise_exception_on_not_ok_status() as status:
        reader.GetNext(status)
    except errors.OutOfRangeError:
      break
    yield reader.record()
  reader.Close()
  def testShardsRunOnRequestedDevices(self):
    config = config_pb2.ConfigProto(device_count={"CPU": 4})

    @function.Defun()
    def Body():
      # Serialize DT_RESOURCE handles as DT_STRINGs, which encode the device on
      # which the resource was created, so that we can verify that ops were
      # actually run on the requested devices.
      #
      # TODO(akshayka): Provide a cleaner, more idiomatic API for obtaining the
      # name of the device on which a resource lives / for determining the
      # device on which an op ran.
      with ops.device("/cpu:0"):
        s1 = iterator_ops.Iterator.from_structure(
            (dtypes.float32,)).string_handle()
      with ops.device("/cpu:1"):
        s2 = iterator_ops.Iterator.from_structure(
            (dtypes.float32,)).string_handle()
      with ops.device("/cpu:2"):
        s3 = iterator_ops.Iterator.from_structure(
            (dtypes.float32,)).string_handle()
      return s1, s2, s3

    with self.test_session(config=config, use_gpu=True) as sess:
      outputs = sess.run(functional_ops.partitioned_call(args=[], f=Body))
    self.assertIn(compat.as_bytes("CPU:0"), outputs[0])
    self.assertIn(compat.as_bytes("CPU:1"), outputs[1])
    self.assertIn(compat.as_bytes("CPU:2"), outputs[2])
示例#14
0
 def _shouldResolve(self):
   if (self._tpu == compat.as_bytes('') or
       self._tpu == compat.as_bytes('local') or
       self._tpu.startswith(compat.as_bytes('/bns')) or
       self._tpu.startswith(compat.as_bytes('grpc://'))):
     return False
   return True
示例#15
0
def gfile_copy_callback(files_to_copy, export_dir_path):
  """Callback to copy files using `gfile.Copy` to an export directory.

  This method is used as the default `assets_callback` in `Exporter.init` to
  copy assets from the `assets_collection`. It can also be invoked directly to
  copy additional supplementary files into the export directory (in which case
  it is not a callback).

  Args:
    files_to_copy: A dictionary that maps original file paths to desired
      basename in the export directory.
    export_dir_path: Directory to copy the files to.
  """
  logging.info("Write assest into: %s using gfile_copy.", export_dir_path)
  gfile.MakeDirs(export_dir_path)
  for source_filepath, basename in files_to_copy.items():
    new_path = os.path.join(
        compat.as_bytes(export_dir_path), compat.as_bytes(basename))
    logging.info("Copying asset %s to path %s.", source_filepath, new_path)

    if gfile.Exists(new_path):
      # Guard against being restarted while copying assets, and the file
      # existing and being in an unknown state.
      # TODO(b/28676216): Do some file checks before deleting.
      logging.info("Removing file %s.", new_path)
      gfile.Remove(new_path)
    gfile.Copy(source_filepath, new_path)
示例#16
0
def _write_object_graph(saveable_view, export_dir, asset_file_def_index):
  """Save a SavedObjectGraph proto for `root`."""
  # SavedObjectGraph is similar to the CheckpointableObjectGraph proto in the
  # checkpoint. It will eventually go into the SavedModel.
  proto = saved_object_graph_pb2.SavedObjectGraph()
  saveable_view.fill_object_graph_proto(proto)

  node_ids = util.ObjectIdentityDictionary()
  for i, obj in enumerate(saveable_view.nodes):
    node_ids[obj] = i
    if resource_variable_ops.is_resource_variable(obj):
      node_ids[obj.handle] = i
    elif isinstance(obj, tracking.TrackableAsset):
      node_ids[obj.asset_path.handle] = i

  for obj, obj_proto in zip(saveable_view.nodes, proto.nodes):
    _write_object_proto(obj, obj_proto, asset_file_def_index, node_ids)

  extra_asset_dir = os.path.join(
      compat.as_bytes(export_dir),
      compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY))
  file_io.recursive_create_dir(extra_asset_dir)
  object_graph_filename = os.path.join(
      extra_asset_dir, compat.as_bytes("object_graph.pb"))
  file_io.write_string_to_file(object_graph_filename, proto.SerializeToString())
示例#17
0
  def _save_and_write_assets(self, assets_collection_to_add=None):
    """Saves asset to the meta graph and writes asset files to disk.

    Args:
      assets_collection_to_add: The collection where the asset paths are setup.
    """
    asset_filename_map = _maybe_save_assets(assets_collection_to_add)

    # Return if there are no assets to write.
    if not asset_filename_map:
      tf_logging.info("No assets to write.")
      return

    assets_destination_dir = saved_model_utils.get_or_create_assets_dir(
        self._export_dir)

    # Copy each asset from source path to destination path.
    for asset_basename, asset_source_filepath in asset_filename_map.items():
      asset_destination_filepath = os.path.join(
          compat.as_bytes(assets_destination_dir),
          compat.as_bytes(asset_basename))

      # Only copy the asset file to the destination if it does not already
      # exist. This is to ensure that an asset with the same name defined as
      # part of multiple graphs is only copied the first time.
      if not file_io.file_exists(asset_destination_filepath):
        file_io.copy(asset_source_filepath, asset_destination_filepath)

    tf_logging.info("Assets written to: %s",
                    compat.as_text(assets_destination_dir))
示例#18
0
  def testDebugString(self):
    save_path = os.path.join(self.get_temp_dir(), "ckpt_for_debug_string")
    with self.test_session() as sess:
      # Builds a graph.
      v0 = tf.Variable([[1, 2, 3], [4, 5, 6]], dtype=tf.float32, name="v0")
      v1 = tf.Variable([[[1], [2]], [[3], [4]], [[5], [6]]], dtype=tf.float32,
                       name="v1")
      save = tf.train.Saver({"v0": v0, "v1": v1})
      tf.initialize_all_variables().run()
      # Saves a checkpoint.
      save.save(sess, save_path)

      # Creates a reader.
      reader = tf.train.NewCheckpointReader(save_path)
      # Verifies that the tensors exist.
      self.assertTrue(reader.has_tensor("v0"))
      self.assertTrue(reader.has_tensor("v1"))
      debug_string = reader.debug_string()
      # Verifies that debug string contains the right strings.
      self.assertTrue(compat.as_bytes("v0 (DT_FLOAT) [2,3]") in debug_string)
      self.assertTrue(compat.as_bytes("v1 (DT_FLOAT) [3,2,1]") in debug_string)
      # Verifies get_variable_to_shape_map() returns the correct information.
      var_map = reader.get_variable_to_shape_map()
      self.assertEquals([2, 3], var_map["v0"])
      self.assertEquals([3, 2, 1], var_map["v1"])
      # Verifies get_tensor() returns the tensor value.
      v0_tensor = reader.get_tensor("v0")
      v1_tensor = reader.get_tensor("v1")
      self.assertAllEqual(v0.eval(), v0_tensor)
      self.assertAllEqual(v1.eval(), v1_tensor)
      # Verifies get_tensor() fails for non-existent tensors.
      with self.assertRaisesRegexp(pywrap_tensorflow.StatusNotOK,
                                   "Not found"):
        reader.get_tensor("v3")
示例#19
0
  def testBasic(self):
    base_path = tf.test.test_src_dir_path(
        "contrib/session_bundle/example/half_plus_two/00000123")
    tf.reset_default_graph()
    sess, meta_graph_def = session_bundle.LoadSessionBundleFromPath(
        base_path, target="", config=tf.ConfigProto(device_count={"CPU": 2}))

    self.assertTrue(sess)
    asset_path = os.path.join(base_path, constants.ASSETS_DIRECTORY)
    with sess.as_default():
      path1, path2 = sess.run(["filename1:0", "filename2:0"])
      self.assertEqual(
          compat.as_bytes(os.path.join(asset_path, "hello1.txt")), path1)
      self.assertEqual(
          compat.as_bytes(os.path.join(asset_path, "hello2.txt")), path2)

      collection_def = meta_graph_def.collection_def

      signatures_any = collection_def[constants.SIGNATURES_KEY].any_list.value
      self.assertEquals(len(signatures_any), 1)

      signatures = manifest_pb2.Signatures()
      signatures_any[0].Unpack(signatures)
      default_signature = signatures.default_signature
      input_name = default_signature.regression_signature.input.tensor_name
      output_name = default_signature.regression_signature.output.tensor_name
      y = sess.run([output_name], {input_name: np.array([[0], [1], [2], [3]])})
      # The operation is y = 0.5 * x + 2
      self.assertEqual(y[0][0], 2)
      self.assertEqual(y[0][1], 2.5)
      self.assertEqual(y[0][2], 3)
      self.assertEqual(y[0][3], 3.5)
 def export_fn(estimator, export_dir, checkpoint_path=None, eval_result=None):
   """A wrapper to export to SavedModel, and convert it to other formats."""
   result_dir = base_strategy.export(estimator, export_dir,
                                     checkpoint_path,
                                     eval_result)
   with ops.Graph().as_default() as graph:
     with tf_session.Session(graph=graph) as sess:
       saved_model_loader.load(
           sess, [tag_constants.SERVING], result_dir)
       # Note: This is GTFlow internal API and might change.
       ensemble_model = graph.get_operation_by_name(
           "ensemble_model/TreeEnsembleSerialize")
       _, dfec_str = sess.run(ensemble_model.outputs)
       dtec = tree_config_pb2.DecisionTreeEnsembleConfig()
       dtec.ParseFromString(dfec_str)
       # Export the result in the same folder as the saved model.
       if convert_fn:
         convert_fn(dtec, sorted_feature_names,
                    len(dense_floats),
                    len(sparse_float_indices),
                    len(sparse_int_indices), result_dir, eval_result)
       feature_importances = _get_feature_importances(
           dtec, sorted_feature_names,
           len(dense_floats),
           len(sparse_float_indices), len(sparse_int_indices))
       sorted_by_importance = sorted(
           feature_importances.items(), key=lambda x: -x[1])
       assets_dir = os.path.join(
           compat.as_bytes(result_dir), compat.as_bytes("assets.extra"))
       gfile.MakeDirs(assets_dir)
       with gfile.GFile(os.path.join(
           compat.as_bytes(assets_dir),
           compat.as_bytes("feature_importances")), "w") as f:
         f.write("\n".join("%s, %f" % (k, v) for k, v in sorted_by_importance))
   return result_dir
示例#21
0
 def testWalkPostOrder(self):
   dir_path = os.path.join(self._base_dir, "test_dir")
   self._setupWalkDirectories(dir_path)
   # Now test the walk (in_order = False)
   all_dirs = []
   all_subdirs = []
   all_files = []
   for (w_dir, w_subdirs, w_files) in file_io.walk(dir_path, in_order=False):
     all_dirs.append(w_dir)
     all_subdirs.append(w_subdirs)
     all_files.append(w_files)
   self.assertItemsEqual(all_dirs, [
       compat.as_bytes(os.path.join(dir_path, item))
       for item in ["subdir1_1", "subdir1_2/subdir2", "subdir1_2", "subdir1_3"]
   ] + [compat.as_bytes(dir_path)])
   self.assertEqual(compat.as_bytes(dir_path), all_dirs[4])
   self.assertLess(
       all_dirs.index(
           compat.as_bytes(os.path.join(dir_path, "subdir1_2/subdir2"))),
       all_dirs.index(compat.as_bytes(os.path.join(dir_path, "subdir1_2"))))
   self.assertItemsEqual(all_subdirs[0:4], [[], [], [b"subdir2"], []])
   self.assertItemsEqual(all_subdirs[4],
                         [b"subdir1_1", b"subdir1_2", b"subdir1_3"])
   self.assertItemsEqual(all_files, [[b"file2.txt"], [], [], [],
                                     [b"file1.txt"]])
   self.assertLess(
       all_files.index([b"file2.txt"]), all_files.index([b"file1.txt"]))
示例#22
0
  def testReadFileIgnoreError(self):
    def write_string_to_file(value, filename):
      with open(filename, "w") as f:
        f.write(value)
    filenames = [os.path.join(self.get_temp_dir(), "file_%d.txt" % i)
                 for i in range(5)]
    for filename in filenames:
      write_string_to_file(filename, filename)

    dataset = (dataset_ops.Dataset.from_tensor_slices(filenames)
               .map(io_ops.read_file, num_threads=2, output_buffer_size=2)
               .ignore_errors())
    iterator = dataset.make_initializable_iterator()
    init_op = iterator.initializer
    get_next = iterator.get_next()

    with self.test_session() as sess:
      # All of the files are present.
      sess.run(init_op)
      for filename in filenames:
        self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)

      # Delete one of the files.
      os.remove(filenames[0])

      # Attempting to read filenames[0] will fail, but ignore_errors()
      # will catch the error.
      sess.run(init_op)
      for filename in filenames[1:]:
        self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)
  def testComplexCodeView(self):
    ops.reset_default_graph()
    outfile = os.path.join(test.get_temp_dir(), 'dump')
    opts = (builder(builder.trainable_variables_parameter())
            .with_file_output(outfile)
            .with_accounted_types(['.*'])
            .with_node_names(show_name_regexes=
                             ['.*model_analyzer_testlib.py.*'])
            .account_displayed_op_only(False)
            .select(['params', 'float_ops']).build())

    with profile_context.ProfileContext(test.get_temp_dir(),
                                        trace_steps=[],
                                        dump_steps=[]) as pctx:
      with session.Session() as sess:
        x = lib.BuildFullModel()

        sess.run(variables.global_variables_initializer())
        pctx.trace_next_step()
        _ = sess.run(x)
        tfprof_node = pctx.profiler.profile_python(options=opts)

        # pylint: disable=line-too-long
        with gfile.Open(outfile, 'r') as f:
          lines = f.read().split('\n')
          result = '\n'.join([l[:min(len(l), 80)] for l in lines])
          self.assertEqual(
              compat.as_bytes(
                  'node name | # parameters | # float_ops\n_TFProfRoot (--/2.84k params, --/168.86k flops)\n  model_analyzer_testlib.py:63:BuildFullModel (0/1.80k params, 0/45.37k flops)\n    model_analyzer_testlib.py:40:BuildSmallModel (0/0 params, 0/0 flops)\n    model_analyzer_testlib.py:44:BuildSmallModel (0/4 params, 0/8 flops)\n    model_analyzer_testlib.py:48:BuildSmallModel (0/648 params, 0/1.30k flops)\n    model_analyzer_testlib.py:49:BuildSmallModel (0/0 params, 0/23.33k flops)\n    model_analyzer_testlib.py:53:BuildSmallModel (0/1.15k params, 0/2.30k flops)\n    model_analyzer_testlib.py:54:BuildSmallModel (0/0 params, 0/18.43k flops)\n  model_analyzer_testlib.py:63:BuildFullModel (gradient) (0/0 params, 0/67.39k f\n    model_analyzer_testlib.py:49:BuildSmallModel (gradient) (0/0 params, 0/46.66\n    model_analyzer_testlib.py:54:BuildSmallModel (gradient) (0/0 params, 0/20.74\n  model_analyzer_testlib.py:67:BuildFullModel (0/1.04k params, 0/18.58k flops)\n  model_analyzer_testlib.py:67:BuildFullModel (gradient) (0/0 params, 0/37.00k f\n  model_analyzer_testlib.py:69:BuildFullModel (0/0 params, 0/0 flops)\n  model_analyzer_testlib.py:70:BuildFullModel (0/0 params, 0/258 flops)\n  model_analyzer_testlib.py:70:BuildFullModel (gradient) (0/0 params, 0/129 flop\n  model_analyzer_testlib.py:72:BuildFullModel (0/0 params, 0/141 flops)\n'
              ), compat.as_bytes(lib.CheckAndRemoveDoc(result)))

        self.assertLess(0, tfprof_node.total_exec_micros)
        self.assertEqual(2844, tfprof_node.total_parameters)
        self.assertEqual(168863, tfprof_node.total_float_ops)
        self.assertEqual(8, len(tfprof_node.children))
        self.assertEqual('_TFProfRoot', tfprof_node.name)
        self.assertEqual(
            'model_analyzer_testlib.py:63:BuildFullModel',
            tfprof_node.children[0].name)
        self.assertEqual(
            'model_analyzer_testlib.py:63:BuildFullModel (gradient)',
            tfprof_node.children[1].name)
        self.assertEqual(
            'model_analyzer_testlib.py:67:BuildFullModel',
            tfprof_node.children[2].name)
        self.assertEqual(
            'model_analyzer_testlib.py:67:BuildFullModel (gradient)',
            tfprof_node.children[3].name)
        self.assertEqual(
            'model_analyzer_testlib.py:69:BuildFullModel',
            tfprof_node.children[4].name)
        self.assertEqual(
            'model_analyzer_testlib.py:70:BuildFullModel',
            tfprof_node.children[5].name)
        self.assertEqual(
            'model_analyzer_testlib.py:70:BuildFullModel (gradient)',
            tfprof_node.children[6].name)
        self.assertEqual(
            'model_analyzer_testlib.py:72:BuildFullModel',
            tfprof_node.children[7].name)
  def testNoShuffle(self):
    filenames = ['a', 'b', 'c']
    self._touchTempFiles(filenames)

    # Repeat the list twice and ensure that the order is the same each time.
    # NOTE(mrry): This depends on an implementation detail of `list_files()`,
    # which is that the list of files is captured when the iterator is
    # initialized. Otherwise, or if e.g. the iterator were initialized more than
    # once, it's possible that the non-determinism of `tf.matching_files()`
    # would cause this test to fail. However, it serves as a useful confirmation
    # that the `shuffle=False` argument is working as intended.
    # TODO(b/73959787): Provide some ordering guarantees so that this test is
    # more meaningful.
    dataset = dataset_ops.Dataset.list_files(
        path.join(self.tmp_dir, '*'), shuffle=False).repeat(2)
    with self.cached_session() as sess:
      itr = dataset.make_one_shot_iterator()
      next_element = itr.get_next()

      full_filenames = []
      produced_filenames = []
      for filename in filenames * 2:
        full_filenames.append(
            compat.as_bytes(path.join(self.tmp_dir, filename)))
        produced_filenames.append(compat.as_bytes(sess.run(next_element)))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(itr.get_next())
      self.assertItemsEqual(full_filenames, produced_filenames)
      self.assertEqual(produced_filenames[:len(filenames)],
                       produced_filenames[len(filenames):])
示例#25
0
  def testSaveAsText(self):
    export_dir = os.path.join(
        compat.as_bytes(tf.test.get_temp_dir()), compat.as_bytes("astext"))
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Graph with a single variable. SavedModel invoked to:
    # - add with weights.
    with self.test_session(graph=tf.Graph()) as sess:
      v = tf.Variable(42, name="v")
      sess.run(tf.initialize_all_variables())
      self.assertEqual(42, v.eval())
      builder.add_meta_graph_and_variables(sess, ["foo"])

    # Graph with the same single variable. SavedModel invoked to:
    # - simply add the model (weights are not updated).
    with self.test_session(graph=tf.Graph()) as sess:
      v = tf.Variable(43, name="v")
      sess.run(tf.initialize_all_variables())
      self.assertEqual(43, v.eval())
      builder.add_meta_graph(["bar"])

    # Save the SavedModel to disk in text format.
    builder.save(as_text=True)

    # Restore the graph with tag "foo", whose variables were saved.
    with self.test_session(graph=tf.Graph()) as sess:
      loader.load(sess, ["foo"], export_dir)
      self.assertEqual(42, tf.get_collection(tf.GraphKeys.VARIABLES)[0].eval())

    # Restore the graph with tag "bar", whose variables were not saved.
    with self.test_session(graph=tf.Graph()) as sess:
      loader.load(sess, ["bar"], export_dir)
      self.assertEqual(42, tf.get_collection(tf.GraphKeys.VARIABLES)[0].eval())
示例#26
0
  def _AddOpInternal(self, op):
    # pylint: disable=protected-access
    if op.type in _BLACKLISTED_OPS:
      logging.error("Operation of type %s (%s) is not supported on the TPU. "
                    "Execution will fail if this op is used in the graph. " %
                    (op.type, op.name))

    if op.type in _NOT_IMPLEMENTED_OPS:
      self._unsupported_ops.append(op)

    if any(x.dtype._is_ref_dtype for x in op.inputs):
      raise NotImplementedError(
          "Non-resource Variables are not supported inside TPU computations "
          "(operator name: %s)" % op.name)
    if _TPU_REPLICATE_ATTR in op.node_def.attr:
      raise ValueError("TPU computations cannot be nested")
    op._set_attr(_TPU_REPLICATE_ATTR,
                 attr_value_pb2.AttrValue(s=compat.as_bytes(self._name)))
    if self._outside_compilation_cluster:
      op._set_attr(
          _OUTSIDE_COMPILATION_ATTR,
          attr_value_pb2.AttrValue(
              s=compat.as_bytes(self._outside_compilation_cluster)))
    if self._num_replicas > 1 or not self._outside_compilation_cluster:
      # Prevent feeding or fetching anything that is being compiled,
      # and any replicated outside_compilation Op.
      op.graph.prevent_feeding(op)
      op.graph.prevent_fetching(op)
  def testFileMiddles(self):
    filenames = ['a.txt', 'b.py', 'c.pyc']
    self._touchTempFiles(filenames)

    filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
    dataset = dataset_ops.Dataset.list_files(filename_placeholder)

    with self.cached_session() as sess:
      itr = dataset.make_initializable_iterator()
      next_element = itr.get_next()
      sess.run(
          itr.initializer,
          feed_dict={filename_placeholder: path.join(self.tmp_dir, '*.py*')})

      full_filenames = []
      produced_filenames = []
      for filename in filenames[1:]:
        full_filenames.append(
            compat.as_bytes(path.join(self.tmp_dir, filename)))
        produced_filenames.append(compat.as_bytes(sess.run(next_element)))

      self.assertItemsEqual(full_filenames, produced_filenames)

      with self.assertRaises(errors.OutOfRangeError):
        sess.run(itr.get_next())
示例#28
0
  def testAssets(self):
    export_dir = self._get_export_dir("test_assets")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 42)

      # Build an asset collection.
      ignored_filepath = os.path.join(
          compat.as_bytes(test.get_temp_dir()), compat.as_bytes("ignored.txt"))
      file_io.write_string_to_file(ignored_filepath, "will be ignored")

      asset_collection = self._build_asset_collection("hello42.txt",
                                                      "foo bar baz",
                                                      "asset_file_tensor")

      builder.add_meta_graph_and_variables(
          sess, ["foo"], assets_collection=asset_collection)

    # Save the SavedModel to disk.
    builder.save()

    with self.test_session(graph=ops.Graph()) as sess:
      foo_graph = loader.load(sess, ["foo"], export_dir)
      self._validate_asset_collection(export_dir, foo_graph.collection_def,
                                      "hello42.txt", "foo bar baz",
                                      "asset_file_tensor:0")
      ignored_asset_path = os.path.join(
          compat.as_bytes(export_dir),
          compat.as_bytes(constants.ASSETS_DIRECTORY),
          compat.as_bytes("ignored.txt"))
      self.assertFalse(file_io.file_exists(ignored_asset_path))
示例#29
0
def get_matching_files_v2(pattern):
  """Returns a list of files that match the given pattern(s).

  Args:
    pattern: string or iterable of strings. The glob pattern(s).

  Returns:
    A list of strings containing filenames that match the given pattern(s).

  Raises:
    errors.OpError: If there are filesystem / directory listing errors.
  """
  with errors.raise_exception_on_not_ok_status() as status:
    if isinstance(pattern, six.string_types):
      return [
          # Convert the filenames to string from bytes.
          compat.as_str_any(matching_filename)
          for matching_filename in pywrap_tensorflow.GetMatchingFiles(
              compat.as_bytes(pattern), status)
      ]
    else:
      return [
          # Convert the filenames to string from bytes.
          compat.as_str_any(matching_filename)
          for single_filename in pattern
          for matching_filename in pywrap_tensorflow.GetMatchingFiles(
              compat.as_bytes(single_filename), status)
      ]
  def testBasic(self):
    base_path = tf.test.test_src_dir_path(
        "contrib/session_bundle/example/half_plus_two/00000123")
    tf.reset_default_graph()
    sess, meta_graph_def = session_bundle.load_session_bundle_from_path(
        base_path, target="", config=tf.ConfigProto(device_count={"CPU": 2}))

    self.assertTrue(sess)
    asset_path = os.path.join(base_path, constants.ASSETS_DIRECTORY)
    with sess.as_default():
      path1, path2 = sess.run(["filename1:0", "filename2:0"])
      self.assertEqual(
          compat.as_bytes(os.path.join(asset_path, "hello1.txt")), path1)
      self.assertEqual(
          compat.as_bytes(os.path.join(asset_path, "hello2.txt")), path2)

      collection_def = meta_graph_def.collection_def

      signatures_any = collection_def[constants.SIGNATURES_KEY].any_list.value
      self.assertEquals(len(signatures_any), 1)

      signatures = manifest_pb2.Signatures()
      signatures_any[0].Unpack(signatures)
      self._checkRegressionSignature(signatures, sess)
      self._checkNamedSigantures(signatures, sess)
示例#31
0
    def _create_definition_if_needed(self):
        """Creates the function definition if it's not created yet."""

        if self._definition is not None:
            return

        # Create the func_def object.
        temp_graph = _FuncGraph()
        with temp_graph.as_default():
            # List of placeholders for the function_def.
            inputs = []
            for (argname, argtype) in self._args:
                argholder = array_ops.placeholder(argtype, name=argname)
                inputs.append(argholder)
            # Call func and gather the output tensors.
            with vs.variable_scope("", custom_getter=temp_graph.getvar):
                outputs = self._func(*inputs)
            # If func only returned one value, make it a tuple.
            if not isinstance(outputs, (list, tuple)):
                outputs = (outputs, )
            if any([_ is None for _ in outputs]):
                raise ValueError("Function can not return None.")
            # Ensures each output is a Tensor.
            outputs = [ops.convert_to_tensor(_) for _ in outputs]
        self._extra_inputs = temp_graph.extra_inputs
        inputs.extend(temp_graph.extra_args)

        # Build the FunctionDef
        self._definition = _graph_to_function_def(temp_graph,
                                                  inputs,
                                                  outputs,
                                                  out_names=self._out_names)

        # Extra kwargs are treated as attrs on the function def.
        kwargs_attr = _parse_kwargs_as_attrs(**self._extra_kwargs)
        for k in kwargs_attr:
            self._definition.attr[k].CopyFrom(kwargs_attr[k])

        # Hash the definition and its dependencies.
        hasher = hashlib.sha1()

        def _hash_func_def():
            """Hash the function definition agnostic to node/map ordering."""
            def update_num(n):
                hasher.update(compat.as_bytes("%x" % n))

            def update_str(s):
                update_num(len(s))
                hasher.update(compat.as_bytes(s))

            def update_strs(slist):
                update_num(len(slist))
                for s in slist:
                    update_str(s)

            for adef in self._definition.signature.input_arg:
                update_str(adef.SerializeToString())

            for adef in self._definition.signature.output_arg:
                update_str(adef.SerializeToString())

            for n in sorted(self._definition.node_def, key=lambda n: n.name):
                update_str(n.name)
                update_str(n.op)
                update_strs(n.input)
                update_num(len(n.attr))
                # NOTE: protobuf map serialization does not guarantee ordering.
                for k in sorted(n.attr):
                    update_str(k)
                    update_str(n.attr[k].SerializeToString())

        _hash_func_def()
        # pylint: disable=protected-access
        self._sub_functions = temp_graph._functions
        for subname in sorted(self._sub_functions.keys()):
            hasher.update(
                compat.as_bytes(self._sub_functions[subname]._hash_str))
        # pylint: enable=protected-access

        # Uses the first 8 bytes sha1 hash digest as the __hash__.
        self._hash_str = hasher.hexdigest()[:8]
        self._hash = int(self._hash_str, 16)

        # Finally, we decide the function name to use.  If not specified,
        # make up something which is almost certainly unique.
        if not self._func_name:
            self._func_name = "_".join(
                [_get_func_name(self._func), self._hash_str])
        self._definition.signature.name = self._func_name
        if self._func.__doc__:
            self._definition.signature.description = self._func.__doc__
示例#32
0
    def testTags(self):
        export_dir = os.path.join(compat.as_bytes(tf.test.get_temp_dir()),
                                  compat.as_bytes("tags"))
        builder = saved_model_builder.SavedModelBuilder(export_dir)

        # Graph with a single variable. SavedModel invoked to:
        # - add with weights.
        # - a single tag (from predefined constants).
        with self.test_session(graph=tf.Graph()) as sess:
            v = tf.Variable(42, name="v")
            sess.run(tf.initialize_all_variables())
            self.assertEqual(42, v.eval())
            builder.add_meta_graph_and_variables(sess,
                                                 [constants.TAG_TRAINING])

        # Graph that updates the single variable. SavedModel invoked to:
        # - simply add the model (weights are not updated).
        # - a single tag (from predefined constants).
        with self.test_session(graph=tf.Graph()) as sess:
            v = tf.Variable(43, name="v")
            sess.run(tf.initialize_all_variables())
            self.assertEqual(43, v.eval())
            builder.add_meta_graph([constants.TAG_SERVING])

        # Graph that updates the single variable. SavedModel is invoked:
        # - to add the model (weights are not updated).
        # - multiple custom tags.
        with self.test_session(graph=tf.Graph()) as sess:
            v = tf.Variable(44, name="v")
            sess.run(tf.initialize_all_variables())
            self.assertEqual(44, v.eval())
            builder.add_meta_graph(["foo", "bar"])

        # Save the SavedModel to disk.
        builder.save()

        # Restore the graph with a single predefined tag whose variables were saved.
        with self.test_session(graph=tf.Graph()) as sess:
            loader.load(sess, [constants.TAG_TRAINING], export_dir)
            self.assertEqual(
                42,
                tf.get_collection(tf.GraphKeys.VARIABLES)[0].eval())

        # Restore the graph with a single predefined tag whose variables were not
        # saved.
        with self.test_session(graph=tf.Graph()) as sess:
            loader.load(sess, [constants.TAG_SERVING], export_dir)
            self.assertEqual(
                42,
                tf.get_collection(tf.GraphKeys.VARIABLES)[0].eval())

        # Restore the graph with multiple tags. Provide duplicate tags to test set
        # semantics.
        with self.test_session(graph=tf.Graph()) as sess:
            loader.load(sess, ["foo", "bar", "foo"], export_dir)
            self.assertEqual(
                42,
                tf.get_collection(tf.GraphKeys.VARIABLES)[0].eval())

        # Try restoring a graph with a non-existent tag. This should yield a runtime
        # error.
        with self.test_session(graph=tf.Graph()) as sess:
            self.assertRaises(RuntimeError, loader.load, sess, ["INVALID"],
                              export_dir)

        # Try restoring a graph where a subset of the tags match. Since tag matching
        # for meta graph defs follows "all" semantics, this should yield a runtime
        # error.
        with self.test_session(graph=tf.Graph()) as sess:
            self.assertRaises(RuntimeError, loader.load, sess, ["foo", "baz"],
                              export_dir)
示例#33
0
    def testAssets(self):
        export_dir = os.path.join(compat.as_bytes(tf.test.get_temp_dir()),
                                  compat.as_bytes("with-assets"))
        builder = saved_model_builder.SavedModelBuilder(export_dir)

        with self.test_session(graph=tf.Graph()) as sess:
            v = tf.Variable(42, name="v")
            sess.run(tf.initialize_all_variables())
            self.assertEqual(42, v.eval())

            # Build an asset collection.
            asset_filepath = os.path.join(
                compat.as_bytes(tf.test.get_temp_dir()),
                compat.as_bytes("hello42.txt"))
            file_io.write_string_to_file(asset_filepath, "foo bar baz")
            asset_file_tensor = tf.constant(asset_filepath,
                                            name="asset_file_tensor")
            tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS,
                                 asset_file_tensor)

            ignored_filepath = os.path.join(
                compat.as_bytes(tf.test.get_temp_dir()),
                compat.as_bytes("ignored.txt"))
            file_io.write_string_to_file(ignored_filepath, "will be ignored")

            asset_collection = tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS)

            builder.add_meta_graph_and_variables(
                sess, ["foo"], assets_collection=asset_collection)

        # Save the SavedModel to disk.
        builder.save()

        with self.test_session(graph=tf.Graph()) as sess:
            foo_graph = loader.load(sess, ["foo"], export_dir)

            # Validate the assets.
            collection_def = foo_graph.collection_def
            assets_any = collection_def[constants.ASSETS_KEY].any_list.value
            self.assertEqual(len(assets_any), 1)
            asset = manifest_pb2.AssetFile()
            assets_any[0].Unpack(asset)
            assets_path = os.path.join(
                compat.as_bytes(export_dir),
                compat.as_bytes(constants.ASSETS_DIRECTORY),
                compat.as_bytes("hello42.txt"))
            asset_contents = file_io.read_file_to_string(assets_path)
            self.assertEqual("foo bar baz", compat.as_text(asset_contents))
            self.assertEqual("hello42.txt", asset.filename)
            self.assertEqual("asset_file_tensor:0",
                             asset.tensor_binding.tensor_name)
            ignored_asset_path = os.path.join(
                compat.as_bytes(export_dir),
                compat.as_bytes(constants.ASSETS_DIRECTORY),
                compat.as_bytes("ignored.txt"))
            self.assertFalse(file_io.file_exists(ignored_asset_path))
示例#34
0
    def testSignatureDefs(self):
        export_dir = os.path.join(compat.as_bytes(tf.test.get_temp_dir()),
                                  compat.as_bytes("signature_defs"))
        builder = saved_model_builder.SavedModelBuilder(export_dir)

        # Graph with a single variable and a single entry in the signature def map.
        # SavedModel is invoked to add with weights.
        with self.test_session(graph=tf.Graph()) as sess:
            v = tf.Variable(42, name="v")
            sess.run(tf.initialize_all_variables())
            self.assertEqual(42, v.eval())
            # Build and populate an empty SignatureDef for testing.
            foo_signature = utils.build_signature_def(dict(), dict(), "foo")
            builder.add_meta_graph_and_variables(
                sess, ["foo"], signature_def_map={"foo_key": foo_signature})

        # Graph with the same single variable and multiple entries in the signature
        # def map. No weights are saved by SavedModel.
        with self.test_session(graph=tf.Graph()) as sess:
            v = tf.Variable(43, name="v")
            sess.run(tf.initialize_all_variables())
            self.assertEqual(43, v.eval())

            # Build and populate a different SignatureDef for testing.
            bar_signature = utils.build_signature_def(dict(), dict(), "bar")
            # Also, build a different SignatureDef corresponding to "foo_key" defined
            # in the previous graph.
            foo_new_signature = utils.build_signature_def(
                dict(), dict(), "foo_new")
            builder.add_meta_graph(["bar"],
                                   signature_def_map={
                                       "bar_key": bar_signature,
                                       "foo_key": foo_new_signature
                                   })

        # Save the SavedModel to disk.
        builder.save()

        # Restore the graph with tag "foo". The single entry in the SignatureDef map
        # corresponding to "foo_key" should exist.
        with self.test_session(graph=tf.Graph()) as sess:
            foo_graph = loader.load(sess, ["foo"], export_dir)
            self.assertEqual(
                42,
                tf.get_collection(tf.GraphKeys.VARIABLES)[0].eval())

            foo_signature = foo_graph.signature_def
            self.assertEqual(len(foo_signature), 1)
            self.assertEqual("foo", foo_signature["foo_key"].method_name)

        # Restore the graph with tag "bar". The SignatureDef map should have two
        # entries. One corresponding to "bar_key" and another corresponding to the
        # new value of "foo_key".
        with self.test_session(graph=tf.Graph()) as sess:
            bar_graph = loader.load(sess, ["bar"], export_dir)
            self.assertEqual(
                42,
                tf.get_collection(tf.GraphKeys.VARIABLES)[0].eval())

            bar_signature = bar_graph.signature_def
            self.assertEqual(len(bar_signature), 2)
            self.assertEqual("bar", bar_signature["bar_key"].method_name)
            self.assertEqual("foo_new", bar_signature["foo_key"].method_name)
示例#35
0
    def _init_from_args(self,
                        initial_value=None,
                        trainable=True,
                        collections=None,
                        validate_shape=True,
                        caching_device=None,
                        name=None,
                        dtype=None,
                        constraint=None):
        """Creates a variable.

    Args:
      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
        which is the initial value for the Variable. The initial value must have
        a shape specified unless `validate_shape` is set to False. Can also be a
        callable with no argument that returns the initial value when called.
        (Note that initializer functions from init_ops.py must first be bound
         to a shape before being used here.)
      trainable: If `True`, the default, also adds the variable to the graph
        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
        the default list of variables to use by the `Optimizer` classes.
      collections: List of graph collections keys. The new variable is added to
        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
      validate_shape: Ignored. Provided for compatibility with tf.Variable.
      caching_device: Optional device string or function describing where the
        Variable should be cached for reading.  Defaults to the Variable's
        device.  If not `None`, caches on another device.  Typical use is to
        cache on the device where the Ops using the Variable reside, to
        deduplicate copying through `Switch` and other conditional statements.
      name: Optional name for the variable. Defaults to `'Variable'` and gets
        uniquified automatically.
      dtype: If set, initial_value will be converted to the given type.
        If None, either the datatype will be kept (if initial_value is
       a Tensor) or float32 will be used (if it is a Python object convertible
       to a Tensor).
      constraint: An optional projection function to be applied to the variable
        after being updated by an `Optimizer` (e.g. used to implement norm
        constraints or value constraints for layer weights). The function must
        take as input the unprojected Tensor representing the value of the
        variable and return the Tensor for the projected value
        (which must have the same shape). Constraints are not safe to
        use when doing asynchronous distributed training.

    Raises:
      ValueError: If the initial value is not specified, or does not have a
        shape and `validate_shape` is `True`.

    @compatibility(eager)
    When Eager Execution is enabled, variables are never added to collections.
    It is not implicitly added to the `GLOBAL_VARIABLES` or
    `TRAINABLE_VARIABLES` collections, and the `collections` argument is
    ignored.
    @end_compatibility
    """
        if initial_value is None:
            raise ValueError("initial_value must be specified.")
        init_from_fn = callable(initial_value)

        if collections is None:
            collections = [ops.GraphKeys.GLOBAL_VARIABLES]
        if not isinstance(collections, (list, tuple, set)):
            raise ValueError(
                "collections argument to Variable constructor must be a list, tuple, "
                "or set. Got %s of type %s" % (collections, type(collections)))
        if constraint is not None and not callable(constraint):
            raise ValueError("The `constraint` argument must be a callable.")

        self._trainable = trainable
        if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
            collections = list(collections) + [
                ops.GraphKeys.TRAINABLE_VARIABLES
            ]
        self._save_slice_info = None
        self._in_graph_mode = context.in_graph_mode()
        # Save the graph's container prefix for error checking. Reading the value of
        # the ResourceVariable from another Graph in Eager mode is an error.
        self._container_prefix = ops.get_default_graph()._container_prefix  # pylint: disable=protected-access
        if not self._in_graph_mode and not name:
            # TODO(ashankar,josh11b): make this unnecessary using the same
            # logic as in layer
            raise ValueError(
                "Variables need to have explicit names when eager "
                "execution is enabled")

        with ops.control_dependencies(None):
            with ops.name_scope(
                    name, "Variable",
                [] if init_from_fn else [initial_value]) as name:
                # pylint: disable=protected-access
                handle_name = ops._name_from_scope_name(name)
                if init_from_fn:
                    # Use attr_scope and device(None) to simulate the behavior of
                    # colocate_with when the variable we want to colocate with doesn't
                    # yet exist.
                    if self._in_graph_mode:
                        attr = attr_value_pb2.AttrValue(
                            list=attr_value_pb2.AttrValue.ListValue(
                                s=[compat.as_bytes("loc:@%s" % handle_name)]))
                        with ops.get_default_graph()._attr_scope(
                            {"_class": attr}):
                            with ops.name_scope("Initializer"), ops.device(
                                    None):
                                initial_value = ops.convert_to_tensor(
                                    initial_value(),
                                    name="initial_value",
                                    dtype=dtype)
                            self._handle = _eager_safe_variable_handle(
                                shape=initial_value.get_shape(),
                                dtype=initial_value.dtype.base_dtype,
                                shared_name=handle_name,
                                name=name,
                                graph_mode=self._in_graph_mode)
                            self._handle_device = (
                                self._handle.device if self._in_graph_mode else
                                context.get_default_context().device_name)
                            self._shape = initial_value.get_shape()
                    else:
                        initial_value = initial_value()
                        with ops.name_scope("Initializer"):
                            initial_value = ops.convert_to_tensor(
                                initial_value,
                                name="initial_value",
                                dtype=dtype)
                        self._handle = _eager_safe_variable_handle(
                            shape=initial_value.get_shape(),
                            dtype=initial_value.dtype.base_dtype,
                            shared_name=handle_name,
                            name=name,
                            graph_mode=False)
                        self._handle_device = (
                            self._handle.device if self._in_graph_mode else
                            context.get_default_context().device_name)
                        self._shape = initial_value.get_shape()
                # pylint: enable=protected-access

                # Or get the initial value from a Tensor or Python object.
                else:
                    with ops.name_scope("Initializer"):
                        initial_value = ops.convert_to_tensor(
                            initial_value, name="initial_value", dtype=dtype)
                    # pylint: disable=protected-access
                    if (self._in_graph_mode and initial_value is not None
                            and initial_value.op._get_control_flow_context()
                            is not None):
                        raise ValueError(
                            "Initializer for variable %s is from inside a control-flow "
                            "construct, such as a loop or conditional. When creating a "
                            "variable inside a loop or conditional, use a lambda as the "
                            "initializer." % name)
                    # pylint: enable=protected-access
                    self._handle = _eager_safe_variable_handle(
                        shape=initial_value.get_shape(),
                        dtype=initial_value.dtype.base_dtype,
                        shared_name=handle_name,
                        name=name,
                        graph_mode=self._in_graph_mode)
                    self._handle_device = (
                        self._handle.device if self._in_graph_mode else
                        context.get_default_context().device_name)
                    self._shape = initial_value.get_shape()

                self._initial_value = initial_value if self._in_graph_mode else None
                self._handle_name = handle_name + ":0"
                self._dtype = initial_value.dtype.base_dtype
                self._constraint = constraint

                if self._in_graph_mode:
                    with ops.name_scope("IsInitialized"):
                        self._is_initialized_op = (
                            gen_resource_variable_ops.var_is_initialized_op(
                                self._handle))
                    if initial_value is not None:
                        with ops.name_scope("Assign") as n, ops.colocate_with(
                                self._handle):
                            self._initializer_op = (
                                gen_resource_variable_ops.assign_variable_op(
                                    self._handle,
                                    self._build_initializer_expr(
                                        initial_value),
                                    name=n))
                    with ops.name_scope("Read"), ops.colocate_with(
                            self._handle):
                        # Manually assign reads to the handle's device to avoid log
                        # messages.
                        with ops.device(self._handle_device):
                            value = self._read_variable_op()
                        self._graph_element = value
                        if caching_device is not None:
                            # Variables may be created in a tf.device() or ops.colocate_with()
                            # context. At the same time, users would expect caching device to
                            # be independent of this context, and/or would not expect the
                            # current device context to be merged with the caching device
                            # spec.  Therefore we reset the colocation stack before creating
                            # the cached value. Note that resetting the colocation stack will
                            # also reset the device stack.
                            with ops.colocate_with(None, ignore_existing=True):
                                with ops.device(caching_device):
                                    self._cached_value = array_ops.identity(
                                        value)
                        else:
                            self._cached_value = None
                else:
                    gen_resource_variable_ops.assign_variable_op(
                        self._handle, initial_value)
                    self._is_initialized_op = None
                    self._initializer_op = None
                    self._graph_element = None
                    if caching_device:
                        with ops.device(caching_device):
                            self._cached_value = self._read_variable_op()
                    else:
                        self._cached_value = None
                if context.in_graph_mode():
                    ops.add_to_collections(collections, self)
                elif ops.GraphKeys.GLOBAL_STEP in collections:
                    ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self)
示例#36
0
 def update_str(s):
     update_num(len(s))
     hasher.update(compat.as_bytes(s))
示例#37
0
 def update_num(n):
     hasher.update(compat.as_bytes("%x" % n))
示例#38
0
 def testString(self):
     self._testCpu(
         np.array([compat.as_bytes(str(x))
                   for x in np.arange(-15, 15)]).reshape([2, 3, 5]))
     self._testCpu(np.empty((2, 0, 5)).astype(np.str_))
示例#39
0
 def testNoCallComputeMetadata(self):
     tpu_cluster_resolver = TPUClusterResolver(tpu='/bns/foo/bar')
     self.assertEqual(compat.as_bytes('/bns/foo/bar'),
                      tpu_cluster_resolver.master())
     self.assertEqual(None, tpu_cluster_resolver.cluster_spec())
示例#40
0
 def eagerly_executed_grad(*dy):
     tape, eager_inputs, eager_outputs = tape_cache.pop(
         compat.as_bytes(token))
     return tape.gradient(eager_outputs, eager_inputs, output_gradients=dy)
示例#41
0
def create_image_lists(image_dir, testing_percentage, validation_percentage):
    """Builds a list of training images from the file system.
    Analyzes the sub folders in the image directory, splits them into stable
    training, testing, and validation sets, and returns a data structure
    describing the lists of images for each label and their paths.
    Args:
      image_dir: String path to a folder containing subfolders of images.
      testing_percentage: Integer percentage of the images to reserve for tests.
      validation_percentage: Integer percentage of images reserved for validation.
    Returns:
      A dictionary containing an entry for each label subfolder, with images split
      into training, testing, and validation sets within each label.
    """
    if not gfile.Exists(image_dir):
        print("Image directory '" + image_dir + "' not found.")
        return None
    result = {}
    sub_dirs = [x[0] for x in gfile.Walk(image_dir)]
    # The root directory comes first, so skip it.
    is_root_dir = True
    for sub_dir in sub_dirs:
        if is_root_dir:
            is_root_dir = False
            continue
        extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
        file_list = []
        dir_name = os.path.basename(sub_dir)
        if dir_name == image_dir:
            continue
        print("Looking for images in '" + dir_name + "'")
        for extension in extensions:
            file_glob = os.path.join(image_dir, dir_name, '*.' + extension)
            file_list.extend(gfile.Glob(file_glob))
        if not file_list:
            print('No files found')
            continue
        if len(file_list) < 20:
            print('WARNING: Folder has less than 20 images, which may cause issues.')
        elif len(file_list) > MAX_NUM_IMAGES_PER_CLASS:
            print('WARNING: Folder {} has more than {} images. Some images will '
                  'never be selected.'.format(dir_name, MAX_NUM_IMAGES_PER_CLASS))
        label_name = re.sub(r'[^a-z0-9]+', ' ', dir_name.lower())
        training_images = []
        testing_images = []
        validation_images = []
        for file_name in file_list:
            base_name = os.path.basename(file_name)
            # We want to ignore anything after '_nohash_' in the file name when
            # deciding which set to put an image in, the data set creator has a way of
            # grouping photos that are close variations of each other. For example
            # this is used in the plant disease data set to group multiple pictures of
            # the same leaf.
            hash_name = re.sub(r'_nohash_.*$', '', file_name)
            # This looks a bit magical, but we need to decide whether this file should
            # go into the training, testing, or validation sets, and we want to keep
            # existing files in the same set even if more files are subsequently
            # added.
            # To do that, we need a stable way of deciding based on just the file name
            # itself, so we do a hash of that and then use that to generate a
            # probability value that we use to assign it.
            hash_name_hashed = hashlib.sha1(compat.as_bytes(hash_name)).hexdigest()
            percentage_hash = ((int(hash_name_hashed, 16) %
                                (MAX_NUM_IMAGES_PER_CLASS + 1)) *
                               (100.0 / MAX_NUM_IMAGES_PER_CLASS))
            if percentage_hash < validation_percentage:
                validation_images.append(base_name)
            elif percentage_hash < (testing_percentage + validation_percentage):
                testing_images.append(base_name)
            else:
                training_images.append(base_name)
        result[label_name] = {
            'dir': dir_name,
            'training': training_images,
            'testing': testing_images,
            'validation': validation_images,
        }
    return result
示例#42
0
def _MakeStr(v, arg_name):
    if not isinstance(v, compat.bytes_or_text_types):
        raise TypeError("Expected string for argument '%s' not %s." %
                        (arg_name, repr(v)))
    return compat.as_bytes(v)  # Convert unicode strings to bytes.
示例#43
0
    def __init__(self,
                 tpu=None,
                 zone=None,
                 project=None,
                 job_name='worker',
                 coordinator_name=None,
                 coordinator_address=None,
                 credentials='default',
                 service=None,
                 discovery_url=None):
        """Creates a new TPUClusterResolver object.

    The ClusterResolver will then use the parameters to query the Cloud TPU APIs
    for the IP addresses and ports of each Cloud TPU listed.

    Args:
      tpu: Either a string, or a list of strings corresponding to the TPUs to
        use. If the single string is the empty string, the string 'local', or a
        string that begins with 'grpc://' or '/bns', then it is assumed to not
        correspond with a Cloud TPU and will instead be passed as the session
        master and no ClusterSpec propagation will be done.
      zone: Zone where the TPUs are located. If omitted or empty, we will assume
        that the zone of the TPU is the same as the zone of the GCE VM, which we
        will try to discover from the GCE metadata service.
      project: Name of the GCP project containing Cloud TPUs. If omitted or
        empty, we will try to discover the project name of the GCE VM from the
        GCE metadata service.
      job_name: Name of the TensorFlow job the TPUs belong to.
      coordinator_name: The name to use for the coordinator. Set to None if the
        coordinator should not be included in the computed ClusterSpec.
      coordinator_address: The address of the coordinator (typically an ip:port
        pair). If set to None, a TF server will be started. If coordinator_name
        is None, a TF server will not be started even if coordinator_address is
        None.
      credentials: GCE Credentials. If None, then we use default credentials
        from the oauth2client
      service: The GCE API object returned by the googleapiclient.discovery
        function. If you specify a custom service object, then the credentials
        parameter will be ignored.
      discovery_url: A URL template that points to the location of
        the discovery service. It should have two parameters {api} and
        {apiVersion} that when filled in produce an absolute URL to the
        discovery document for that service. The environment variable
        'TPU_API_DISCOVERY_URL' will override this.

    Raises:
      ImportError: If the googleapiclient is not installed.
      ValueError: If no TPUs are specified.
    """
        if isinstance(tpu, list):
            if not tpu:
                raise ValueError('At least one TPU must be specified.')
            if len(tpu) != 1:
                raise NotImplementedError(
                    'Using multiple TPUs in a single session is not yet implemented'
                )
            tpu = tpu[0]

        in_gke = self._inGke()
        # When using GKE with Cloud TPUs, the env variable will be set.
        if tpu is None:
            if in_gke:
                tpu = self._gkeMaster()
            else:
                tpu = self._envVarFallback()

        self._tpu = compat.as_bytes(tpu)  # self._tpu is always bytes
        self._job_name = job_name
        self._credentials = credentials

        should_resolve = self._shouldResolve()

        if not project and should_resolve:
            project = compat.as_str(
                self._requestComputeMetadata('project/project-id'))

        if not zone and should_resolve:
            zone_path = compat.as_str(
                self._requestComputeMetadata('instance/zone'))
            zone = zone_path.split('/')[-1]

        self._project = project
        self._zone = zone

        if credentials == 'default' and should_resolve:
            if _GOOGLE_API_CLIENT_INSTALLED:
                self._credentials = GoogleCredentials.get_application_default()

        if service is None and should_resolve:
            if not _GOOGLE_API_CLIENT_INSTALLED:
                raise ImportError(
                    'googleapiclient and oauth2client must be installed '
                    'before using the TPU cluster resolver. Execute: '
                    '`pip install --upgrade google-api-python-client` '
                    'and `pip install --upgrade oauth2client` to '
                    'install with pip.')

            final_discovery_url = self._discoveryUrl() or discovery_url
            if final_discovery_url:
                self._service = discovery.build(
                    'tpu',
                    'v1alpha1',
                    credentials=self._credentials,
                    discoveryServiceUrl=final_discovery_url)
            else:
                self._service = discovery.build('tpu',
                                                'v1alpha1',
                                                credentials=self._credentials)
        else:
            self._service = service

        self._coordinator_name = coordinator_name
        if coordinator_name and not coordinator_address and (should_resolve
                                                             or in_gke):
            self._start_local_server()
        else:
            self._coordinator_address = coordinator_address
示例#44
0
def create_image_lists(image_dir, testing_percentage, validation_percentage):
  """file system으로부터 training 이미지들의 list를 만듭니다.
  이미지 디렉토리로부터 sub folder들을 분석하고, 그들을 training, testing, validation sets으로 나눕니다.
  그리고 각각의 label을 위한 이미지 list와 그들의 경로(path)를 나타내는 자료구조(data structure)를 반환합니다.
  인자들(Args):
    image_dir: 이미지들의 subfolder들을 포함한 folder의 String path.
    testing_percentage: 전체 이미지중 테스트를 위해 사용되는 비율을 나타내는 Integer.
    validation_percentage: 전체 이미지중 validation을 위해 사용되는 비율을 나타내는 Integer.
  반환값들(Returns):
    각각의 label subfolder를 위한 enrtry를 포함한 dictionary
    (각각의 label에서 이미지들은 training, testing, validation sets으로 나뉘어져 있습니다.)
  """
  if not gfile.Exists(image_dir):
    print("Image directory '" + image_dir + "' not found.")
    return None
  result = {}
  sub_dirs = [x[0] for x in gfile.Walk(image_dir)]
  # root directory는 처음에 옵니다. 따라서 이를 skip합니다.
  is_root_dir = True
  for sub_dir in sub_dirs:
    if is_root_dir:
      is_root_dir = False
      continue
    extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
    file_list = []
    dir_name = os.path.basename(sub_dir)
    if dir_name == image_dir:
      continue
    print("Looking for images in '" + dir_name + "'")
    for extension in extensions:
      file_glob = os.path.join(image_dir, dir_name, '*.' + extension)
      file_list.extend(gfile.Glob(file_glob))
    if not file_list:
      print('No files found')
      continue
    if len(file_list) < 20:
      print('WARNING: Folder has less than 20 images, which may cause issues.')
    elif len(file_list) > MAX_NUM_IMAGES_PER_CLASS:
      print('WARNING: Folder {} has more than {} images. Some images will '
            'never be selected.'.format(dir_name, MAX_NUM_IMAGES_PER_CLASS))
    label_name = re.sub(r'[^a-z0-9]+', ' ', dir_name.lower())
    training_images = []
    testing_images = []
    validation_images = []
    for file_name in file_list:
      base_name = os.path.basename(file_name)
      # 어떤 이미지로 리스트를 만들지 결정할때 파일 이름에 "_nohash_"가 포함되어 있으면 이를 무시할 수 있습니다.
      # 이를 이용해서, 데이터셋을 만드는 사람은 서로 비슷한 사진들을 grouping 할수 있습니다.
      # 예를 들어, plant disease 데이터셋을 만들기 위해서, 여러 장의 같은 잎사귀(leaf)를 grouping할 수 있습니다.
      hash_name = re.sub(r'_nohash_.*$', '', file_name)
      # 이는 일종의 마법처럼 보일 수 있습니다. 하지만, 우리는 이 파일이 training sets로 갈지, testing sets로 갈지, validation sets로 갈지를 결정해야만 합니다.
      # 그리고 우리는 더많은 파일들이 추가되더라도, 같은 set에 이미 존재하는 파일들이 유지되길 원합니다.
      # 그렇게 하기 위해서, 우리는 파일 이름 그자체로부터 결정하는 안정적인 방법이 있어야만 합니다.
      # 이를 위해 우리는 파일 이름을 hash하고, 이를 이를 할당하는데 사용하는 확률을 결정하는데 사용합니다.
      hash_name_hashed = hashlib.sha1(compat.as_bytes(hash_name)).hexdigest()
      percentage_hash = ((int(hash_name_hashed, 16) %
                          (MAX_NUM_IMAGES_PER_CLASS + 1)) *
                         (100.0 / MAX_NUM_IMAGES_PER_CLASS))
      if percentage_hash < validation_percentage:
        validation_images.append(base_name)
      elif percentage_hash < (testing_percentage + validation_percentage):
        testing_images.append(base_name)
      else:
        training_images.append(base_name)
    result[label_name] = {
        'dir': dir_name,
        'training': training_images,
        'testing': testing_images,
        'validation': validation_images,
    }
  return result
示例#45
0
    def cluster_spec(self):
        """Returns a ClusterSpec object based on the latest TPU information.

    We retrieve the information from the GCE APIs every time this method is
    called.

    Returns:
      A ClusterSpec containing host information returned from Cloud TPUs.

    Raises:
      RuntimeError: If the provided TPU is not healthy.
    """
        ############################################################################
        # There are 5 potential cases this code must handle:
        #  1. [Normal case.] We should resolve the TPU name to a set of tasks, and
        #      a. Create a ClusterSpec that includes the coordinator job
        #      b. Create a ClusterSpec without the coordinator job.
        #  2. [GKE / No API Access.] We should not resolve the TPU name to a set of
        #     tasks and
        #      a. Create a ClusterSpec with the coordinator
        #      b. Create a ClusterSpec without the coordinator
        #  3. [Other (legacy non-gRPC).] We should return an empty ClusterSpec.
        ############################################################################

        if self._shouldResolve():
            # Case 1.
            full_name = 'projects/%s/locations/%s/nodes/%s' % (
                self._project, self._zone, compat.as_text(self._tpu))
            request = self._service.projects().locations().nodes().get(
                name=full_name)
            response = request.execute()

            if 'state' in response and response['state'] != 'READY':
                raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' %
                                   (self._tpu, response['state']))

            if 'health' in response and response['health'] != 'HEALTHY':
                raise RuntimeError('TPU "%s" is unhealthy: "%s"' %
                                   (self._tpu, response['health']))

            if 'networkEndpoints' in response:
                worker_list = [
                    '%s:%s' % (endpoint['ipAddress'], endpoint['port'])
                    for endpoint in response['networkEndpoints']
                ]
            else:
                # Fall back to the deprecated response format
                instance_url = '%s:%s' % (response['ipAddress'],
                                          response['port'])
                worker_list = [instance_url]

            cluster_spec = {self._job_name: worker_list}
        else:
            if not self._tpu.startswith(compat.as_bytes('grpc://')):
                # Case 3.
                return None
            # Case 2.
            cluster_spec = {
                self._job_name: [self._tpu[len(compat.as_bytes('grpc://')):]]
            }

        if self._coordinator_address:
            # {1, 2}.a
            cluster_spec[self._coordinator_name] = [self._coordinator_address]

        return server_lib.ClusterSpec(cluster_spec)
示例#46
0
def load(sess, tags, export_dir, **saver_kwargs):
  """Loads the model from a SavedModel as specified by tags.

  Args:
    sess: The TensorFlow session to restore the variables.
    tags: Set of string tags to identify the required MetaGraphDef. These should
        correspond to the tags used when saving the variables using the
        SavedModel `save()` API.
    export_dir: Directory in which the SavedModel protocol buffer and variables
        to be loaded are located.
    **saver_kwargs: Optional keyword arguments passed through to Saver.

  Returns:
    The `MetaGraphDef` protocol buffer loaded in the provided session. This
    can be used to further extract signature-defs, collection-defs, etc.

  Raises:
    RuntimeError: MetaGraphDef associated with the tags cannot be found.
  """
  # Build the SavedModel protocol buffer and find the requested meta graph def.
  saved_model = _parse_saved_model(export_dir)
  found_match = False
  for meta_graph_def in saved_model.meta_graphs:
    if set(meta_graph_def.meta_info_def.tags) == set(tags):
      meta_graph_def_to_load = meta_graph_def
      found_match = True
      break

  if not found_match:
    raise RuntimeError("MetaGraphDef associated with tags " + str(tags).strip(
        "[]") + " could not be found in SavedModel")

  # Build a saver by importing the meta graph def to load.
  saver = tf_saver.import_meta_graph(meta_graph_def_to_load, **saver_kwargs)

  if saver:
    # Build the checkpoint path where the variables are located.
    variables_path = os.path.join(
        compat.as_bytes(export_dir),
        compat.as_bytes(constants.VARIABLES_DIRECTORY),
        compat.as_bytes(constants.VARIABLES_FILENAME))

    # Restore the variables using the built saver in the provided session.
    saver.restore(sess, variables_path)
  else:
    tf_logging.info("The specified SavedModel has no variables; no "
                    "checkpoints were restored.")

  # Get asset tensors, if any.
  asset_tensors_dictionary = _get_asset_tensors(export_dir,
                                                meta_graph_def_to_load)

  main_op_tensor = _get_main_op_tensor(meta_graph_def_to_load)
  if main_op_tensor is not None:
    sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary)
  else:
    legacy_init_op_tensor = _get_legacy_init_op_tensor(meta_graph_def_to_load)
    if legacy_init_op_tensor is not None:
      sess.run(fetches=[legacy_init_op_tensor],
               feed_dict=asset_tensors_dictionary)

  return meta_graph_def_to_load
示例#47
0
    def export(self,
               export_dir_base,
               global_step_tensor,
               sess=None,
               exports_to_keep=None):
        """Exports the model.

    Args:
      export_dir_base: A string path to the base export dir.
      global_step_tensor: An Tensor or tensor name providing the
        global step counter to append to the export directory path and set
        in the manifest version.
      sess: A Session to use to save the parameters.
      exports_to_keep: a gc.Path filter function used to determine the set of
        exports to keep. If set to None, all versions will be kept.

    Returns:
      The string path to the exported directory.

    Raises:
      RuntimeError: if init is not called.
      RuntimeError: if the export would overwrite an existing directory.
    """
        if not self._has_init:
            raise RuntimeError("init must be called first")

        # Export dir must not end with / or it will break exports to keep. Strip /.
        if export_dir_base.endswith("/"):
            export_dir_base = export_dir_base[:-1]

        global_step = training_util.global_step(sess, global_step_tensor)
        export_dir = os.path.join(
            compat.as_bytes(export_dir_base),
            compat.as_bytes(constants.VERSION_FORMAT_SPECIFIER % global_step))

        # Prevent overwriting on existing exports which could lead to bad/corrupt
        # storage and loading of models. This is an important check that must be
        # done before any output files or directories are created.
        if gfile.Exists(export_dir):
            raise RuntimeError(
                "Overwriting exports can cause corruption and are "
                "not allowed. Duplicate export dir: %s" % export_dir)

        # Output to a temporary directory which is atomically renamed to the final
        # directory when complete.
        tmp_export_dir = compat.as_text(export_dir) + "-tmp"
        gfile.MakeDirs(tmp_export_dir)

        self._saver.save(sess,
                         os.path.join(
                             compat.as_text(tmp_export_dir),
                             compat.as_text(constants.EXPORT_BASE_NAME)),
                         meta_graph_suffix=constants.EXPORT_SUFFIX_NAME)

        # Run the asset callback.
        if self._assets_callback and self._assets_to_copy:
            assets_dir = os.path.join(
                compat.as_bytes(tmp_export_dir),
                compat.as_bytes(constants.ASSETS_DIRECTORY))
            gfile.MakeDirs(assets_dir)
            self._assets_callback(self._assets_to_copy, assets_dir)

        # TODO (b/27794910): Delete *checkpoint* file before rename. id:2165 gh:2166
        gfile.Rename(tmp_export_dir, export_dir)

        if exports_to_keep:
            # create a simple parser that pulls the export_version from the directory.
            def parser(path):
                if os.name == 'nt':
                    match = re.match(
                        "^" + export_dir_base.replace('\\', '/') +
                        "/(\\d{8})$", path.path.replace('\\', '/'))
                else:
                    match = re.match("^" + export_dir_base + "/(\\d{8})$",
                                     path.path)
                if not match:
                    return None
                return path._replace(export_version=int(match.group(1)))

            paths_to_delete = gc.negation(exports_to_keep)
            for p in paths_to_delete(
                    gc.get_paths(export_dir_base, parser=parser)):
                gfile.DeleteRecursively(p.path)

        return export_dir
示例#48
0
 def _requestComputeMetadata(self, path):
     req = Request('http://metadata/computeMetadata/v1/%s' % path,
                   headers={'Metadata-Flavor': 'Google'})
     resp = urlopen(req)
     return compat.as_bytes(resp.read())
示例#49
0
    def test_export_savedmodel(self):
        tmpdir = tempfile.mkdtemp()
        est, export_input_fn = _build_estimator_for_export_tests(tmpdir)

        extra_file_name = os.path.join(compat.as_bytes(tmpdir),
                                       compat.as_bytes('my_extra_file'))
        extra_file = gfile.GFile(extra_file_name, mode='w')
        extra_file.write(EXTRA_FILE_CONTENT)
        extra_file.close()
        assets_extra = {'some/sub/directory/my_extra_file': extra_file_name}

        export_dir_base = os.path.join(compat.as_bytes(tmpdir),
                                       compat.as_bytes('export'))
        export_dir = est.export_savedmodel(export_dir_base,
                                           export_input_fn,
                                           assets_extra=assets_extra)

        self.assertTrue(gfile.Exists(export_dir_base))
        self.assertTrue(gfile.Exists(export_dir))
        self.assertTrue(
            gfile.Exists(
                os.path.join(compat.as_bytes(export_dir),
                             compat.as_bytes('saved_model.pb'))))
        self.assertTrue(
            gfile.Exists(
                os.path.join(compat.as_bytes(export_dir),
                             compat.as_bytes('variables'))))
        self.assertTrue(
            gfile.Exists(
                os.path.join(compat.as_bytes(export_dir),
                             compat.as_bytes('variables/variables.index'))))
        self.assertTrue(
            gfile.Exists(
                os.path.join(
                    compat.as_bytes(export_dir),
                    compat.as_bytes(
                        'variables/variables.data-00000-of-00001'))))

        self.assertTrue(
            gfile.Exists(
                os.path.join(compat.as_bytes(export_dir),
                             compat.as_bytes('assets'))))
        self.assertTrue(
            gfile.Exists(
                os.path.join(compat.as_bytes(export_dir),
                             compat.as_bytes('assets/my_vocab_file'))))
        self.assertEqual(
            compat.as_bytes(VOCAB_FILE_CONTENT),
            compat.as_bytes(
                gfile.GFile(
                    os.path.join(
                        compat.as_bytes(export_dir),
                        compat.as_bytes('assets/my_vocab_file'))).read()))

        expected_extra_path = os.path.join(
            compat.as_bytes(export_dir),
            compat.as_bytes('assets.extra/some/sub/directory/my_extra_file'))
        self.assertTrue(
            gfile.Exists(
                os.path.join(compat.as_bytes(export_dir),
                             compat.as_bytes('assets.extra'))))
        self.assertTrue(gfile.Exists(expected_extra_path))
        self.assertEqual(
            compat.as_bytes(EXTRA_FILE_CONTENT),
            compat.as_bytes(gfile.GFile(expected_extra_path).read()))

        expected_vocab_file = os.path.join(compat.as_bytes(tmpdir),
                                           compat.as_bytes('my_vocab_file'))
        # Restore, to validate that the export was well-formed.
        with ops.Graph().as_default() as graph:
            with session_lib.Session(graph=graph) as sess:
                loader.load(sess, [tag_constants.SERVING], export_dir)
                assets = [
                    x.eval() for x in graph.get_collection(
                        ops.GraphKeys.ASSET_FILEPATHS)
                ]
                self.assertItemsEqual([expected_vocab_file], assets)
                graph_ops = [x.name for x in graph.get_operations()]
                self.assertTrue('input_example_tensor' in graph_ops)
                self.assertTrue('ParseExample/ParseExample' in graph_ops)
                self.assertTrue('linear/linear/feature/matmul' in graph_ops)

        # cleanup
        gfile.DeleteRecursively(tmpdir)
示例#50
0
    def _init_from_args(self,
                        initial_value=None,
                        trainable=True,
                        collections=None,
                        validate_shape=True,
                        caching_device=None,
                        name=None,
                        dtype=None):
        """Creates a variable.

    Args:
      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
        which is the initial value for the Variable. The initial value must have
        a shape specified unless `validate_shape` is set to False. Can also be a
        callable with no argument that returns the initial value when called.
        (Note that initializer functions from init_ops.py must first be bound
         to a shape before being used here.)
      trainable: If `True`, the default, also adds the variable to the graph
        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
        the default list of variables to use by the `Optimizer` classes.
      collections: List of graph collections keys. The new variable is added to
        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
      validate_shape: Ignored. Provided for compatibility with tf.Variable.
      caching_device: Optional device string or function describing where the
        Variable should be cached for reading.  Defaults to the Variable's
        device.  If not `None`, caches on another device.  Typical use is to
        cache on the device where the Ops using the Variable reside, to
        deduplicate copying through `Switch` and other conditional statements.
      name: Optional name for the variable. Defaults to `'Variable'` and gets
        uniquified automatically.
      dtype: If set, initial_value will be converted to the given type.
        If None, either the datatype will be kept (if initial_value is
       a Tensor) or float32 will be used (if it is a Python object convertible
       to a Tensor).

    Raises:
      ValueError: If the initial value is not specified, or does not have a
        shape and `validate_shape` is `True`.
    """
        if initial_value is None:
            raise ValueError("initial_value must be specified.")
        init_from_fn = callable(initial_value)

        if collections is None:
            collections = [ops.GraphKeys.GLOBAL_VARIABLES]
        if not isinstance(collections, (list, tuple, set)):
            raise ValueError(
                "collections argument to Variable constructor must be a list, tuple, "
                "or set. Got %s of type %s" % (collections, type(collections)))
        if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
            collections = list(collections) + [
                ops.GraphKeys.TRAINABLE_VARIABLES
            ]
        self._save_slice_info = None
        with ops.control_dependencies(None):
            with ops.name_scope(
                    name, "Variable",
                [] if init_from_fn else [initial_value]) as name:
                # pylint: disable=protected-access
                true_name = ops._name_from_scope_name(name)
                if init_from_fn:
                    # Use attr_scope and device(None) to simulate the behavior of
                    # colocate_with when the variable we want to colocate with doesn't
                    # yet exist.
                    attr = attr_value_pb2.AttrValue(
                        list=attr_value_pb2.AttrValue.ListValue(
                            s=[compat.as_bytes("loc:@%s" % true_name)]))
                    with ops.get_default_graph()._attr_scope({"_class": attr}):
                        with ops.name_scope("Initializer"), ops.device(None):
                            self._initial_value = ops.convert_to_tensor(
                                initial_value(),
                                name="initial_value",
                                dtype=dtype)
                        self._handle = gen_resource_variable_ops.var_handle_op(
                            shape=self._initial_value.get_shape(),
                            dtype=self._initial_value.dtype.base_dtype,
                            shared_name=true_name,
                            name=name)
                # pylint: enable=protected-access

                # Or get the initial value from a Tensor or Python object.
                else:
                    self._initial_value = ops.convert_to_tensor(
                        initial_value, name="initial_value", dtype=dtype)
                    self._handle = gen_resource_variable_ops.var_handle_op(
                        shape=self._initial_value.get_shape(),
                        dtype=self._initial_value.dtype.base_dtype,
                        shared_name=true_name,
                        name=name)

                self._dtype = self._initial_value.dtype.base_dtype

                with ops.name_scope("IsInitialized"):
                    self._is_initialized_op = (
                        gen_resource_variable_ops.var_is_initialized_op(
                            self._handle))
                if initial_value is not None:
                    with ops.name_scope("Assign") as n, ops.colocate_with(
                            self._handle):
                        self._initialize_op = gen_resource_variable_ops.assign_variable_op(
                            self._handle, self._initial_value, name=n)
                with ops.name_scope("Read"), ops.colocate_with(self._handle):
                    value = gen_resource_variable_ops.read_variable_op(
                        self._handle, dtype=self._dtype)
                    self._graph_element = value
                    if caching_device is not None:
                        # Variables may be created in a tf.device() or ops.colocate_with()
                        # context. At the same time, users would expect caching device to be
                        # independent of this context, and/or would not expect the current
                        # device context to be merged with the caching device spec.
                        # Therefore we reset the colocation stack before creating the cached
                        # value. Note that resetting the colocation stack will also reset
                        # the device stack.
                        with ops.colocate_with(None, ignore_existing=True):
                            with ops.device(caching_device):
                                self._cached_value = array_ops.identity(value)
                    else:
                        self._cached_value = None
                    ops.add_to_collections(collections, self)
def load_function_def_library(library,
                              load_shared_name_suffix=None,
                              wrapper_function=None):
  """Load a set of functions as concrete functions without captured inputs.

  Functions names are manipulated during load such that they do not overlap
  with previously created ones.

  Gradients are re-registered under new names. Ops that reference the gradients
  are updated to reflect the new registered names.

  Args:
    library: FunctionDefLibrary proto message.
    load_shared_name_suffix: If specified, used to uniquify shared
      names. Otherwise, a unique name is generated.
    wrapper_function: An object that will be wrapped on newly created functions.

  Returns:
    Map of original function names in the library to instances of
    `ConcreteFunction` without captured inputs.

  Raises:
    ValueError: if functions dependencies have a cycle.
  """
  library_function_names = set(fdef.signature.name for fdef in library.function)
  functions = {}
  renamed_functions = {}

  # Our graph building code currently requires functions to be registered with
  # some tf.Graph in order to import functions using the
  # op-name-is-function-name calling convention. To avoid leaking memory into
  # the global default graph when executing eagerly, we create a temporary
  # Graph.
  #
  # TODO(allenl): Make this Graph creation unnecessary when executing eagerly by
  # fixing function_def_to_graph_def.
  if ops.executing_eagerly_outside_functions():
    graph = ops.Graph()
  else:
    graph = ops.get_default_graph()

  if load_shared_name_suffix is None:
    load_shared_name_suffix = "_load_{}".format(ops.uid())

  # Custom gradient functions must be re-registered under new UIDs.
  library_gradient_names = {}  # Maps old op type to old function name
  new_gradient_op_types = {}  # Maps old gradient op type to new op type.
  gradients_to_register = {}  # Maps old function name to new op type
  for gdef in library.registered_gradients:
    if gdef.registered_op_type:
      new_op_type = custom_gradient.generate_name()
      old_op_type = compat.as_bytes(gdef.registered_op_type)

      library_gradient_names[old_op_type] = gdef.gradient_func
      new_gradient_op_types[old_op_type] = new_op_type
      gradients_to_register[gdef.gradient_func] = new_op_type

  function_deps = {}
  for fdef in library.function:
    function_deps[fdef.signature.name] = _list_function_deps(
        fdef, library_function_names, library_gradient_names)

  loaded_gradients = {}
  for fdef in _sort_function_defs(library, function_deps):
    copy = _fix_fdef(fdef, functions, load_shared_name_suffix,
                     new_gradient_op_types)

    # There is no need to copy all functions into the function def graph. It
    # leads to a O(n^2) increase of memory when importing functions and the
    # extra function definitions are a no-op since they already imported as a
    # function before and passed in explicitly (due to the topologic sort
    # import).
    with graph.as_default():
      func_graph = function_def_lib.function_def_to_graph(copy)
    # Restores gradients for function-call ops (not the same as ops that use
    # custom gradients)
    _restore_gradient_functions(func_graph, renamed_functions, loaded_gradients)

    for dep in function_deps[fdef.signature.name]:
      functions[dep].add_to_graph(func_graph)

    # We do not initialize the new ConcreteFunction's function_spec and/or
    # arg_keywords here (which are used to parse the structured and flat
    # signatures, respectively). ConcreteFunction that are part of a saved
    # function is set up later by recreate_function(); and bare ConcreteFunction
    # is set up by by setup_bare_concrete_function().
    # However, we copy the FunctionDef attributes to the new ConcreteFunction,
    # excluding the "_input_shapes", which may cause an error during input shape
    # initialization at a later stage.
    if "_input_shapes" in copy.attr:
      del copy.attr["_input_shapes"]
    func = function_lib.ConcreteFunction(func_graph, attrs=copy.attr)
    if wrapper_function:
      func = wrapper_function(func)
    func.add_to_graph(graph)

    functions[fdef.signature.name] = func
    renamed_functions[func.name] = func
    if any(op.type == "TRTEngineOp" for op in func_graph.get_operations()):
      # TODO(b/150708051): Remove this hack once TensorRT SavedModel integration
      # is fixed. Currently it's leaking memory to maintain bug compatibility
      # with previous behavior.
      func.add_to_graph(ops.get_default_graph())

    if fdef.signature.name in gradients_to_register:
      gradient_op_type = gradients_to_register[fdef.signature.name]
      loaded_gradients[compat.as_bytes(gradient_op_type)] = func
      ops.RegisterGradient(gradient_op_type)(_gen_gradient_func(func))

  return functions
示例#52
0
def save(obj, export_dir, signatures=None):
    # pylint: disable=line-too-long
    """Exports the Trackable object `obj` to [SavedModel format](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md).

  Example usage:

  ```python
  class Adder(tf.train.Checkpoint):

    @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
    def add(self, x):
      return x + x + 1.

  to_export = Adder()
  tf.saved_model.save(to_export, '/tmp/adder')
  ```

  The resulting SavedModel is then servable with an input named "x", its value
  having any shape and dtype float32.

  The optional `signatures` argument controls which methods in `obj` will be
  available to programs which consume `SavedModel`s, for example serving
  APIs. Python functions may be decorated with
  `@tf.function(input_signature=...)` and passed as signatures directly, or
  lazily with a call to `get_concrete_function` on the method decorated with
  `@tf.function`.

  If the `signatures` argument is omitted, `obj` will be searched for
  `@tf.function`-decorated methods. If exactly one `@tf.function` is found, that
  method will be used as the default signature for the SavedModel. This behavior
  is expected to change in the future, when a corresponding
  `tf.saved_model.load` symbol is added. At that point signatures will be
  completely optional, and any `@tf.function` attached to `obj` or its
  dependencies will be exported for use with `load`.

  When invoking a signature in an exported SavedModel, `Tensor` arguments are
  identified by name. These names will come from the Python function's argument
  names by default. They may be overridden by specifying a `name=...` argument
  in the corresponding `tf.TensorSpec` object. Explicit naming is required if
  multiple `Tensor`s are passed through a single argument to the Python
  function.

  The outputs of functions used as `signatures` must either be flat lists, in
  which case outputs will be numbered, or a dictionary mapping string keys to
  `Tensor`, in which case the keys will be used to name outputs.

  Signatures are available in objects returned by `tf.saved_model.load` as a
  `.signatures` attribute. This is a reserved attribute: `tf.saved_model.save`
  on an object with a custom `.signatures` attribute will raise an exception.

  Since `tf.keras.Model` objects are also Trackable, this function can be
  used to export Keras models. For example, exporting with a signature
  specified:

  ```python
  class Model(tf.keras.Model):

    @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])
    def serve(self, serialized):
      ...

  m = Model()
  tf.saved_model.save(m, '/tmp/saved_model/')
  ```

  Exporting from a function without a fixed signature:

  ```python
  class Model(tf.keras.Model):

    @tf.function
    def call(self, x):
      ...

  m = Model()
  tf.saved_model.save(
      m, '/tmp/saved_model/',
      signatures=m.call.get_concrete_function(
          tf.TensorSpec(shape=[None, 3], dtype=tf.float32, name="inp")))
  ```

  `tf.keras.Model` instances constructed from inputs and outputs already have a
  signature and so do not require a `@tf.function` decorator or a `signatures`
  argument. If neither are specified, the model's forward pass is exported.

  ```python
  x = input_layer.Input((4,), name="x")
  y = core.Dense(5, name="out")(x)
  model = training.Model(x, y)
  tf.saved_model.save(model, '/tmp/saved_model/')
  # The exported SavedModel takes "x" with shape [None, 4] and returns "out"
  # with shape [None, 5]
  ```

  Variables must be tracked by assigning them to an attribute of a tracked
  object or to an attribute of `obj` directly. TensorFlow objects (e.g. layers
  from `tf.keras.layers`, optimizers from `tf.train`) track their variables
  automatically. This is the same tracking scheme that `tf.train.Checkpoint`
  uses, and an exported `Checkpoint` object may be restored as a training
  checkpoint by pointing `tf.train.Checkpoint.restore` to the SavedModel's
  "variables/" subdirectory. Currently variables are the only stateful objects
  supported by `tf.saved_model.save`, but others (e.g. tables) will be supported
  in the future.

  `tf.function` does not hard-code device annotations from outside the function
  body, instead using the calling context's device. This means for example that
  exporting a model which runs on a GPU and serving it on a CPU will generally
  work, with some exceptions. `tf.device` annotations inside the body of the
  function will be hard-coded in the exported model; this type of annotation is
  discouraged. Device-specific operations, e.g. with "cuDNN" in the name or with
  device-specific layouts, may cause issues. Currently a `DistributionStrategy`
  is another exception: active distribution strategies will cause device
  placements to be hard-coded in a function. Exporting a single-device
  computation and importing under a `DistributionStrategy` is not currently
  supported, but may be in the future.

  SavedModels exported with `tf.saved_model.save` [strip default-valued
  attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes)
  automatically, which removes one source of incompatibilities when the consumer
  of a SavedModel is running an older TensorFlow version than the
  producer. There are however other sources of incompatibilities which are not
  handled automatically, such as when the exported model contains operations
  which the consumer does not have definitions for.

  The current implementation of `tf.saved_model.save` targets serving use-cases,
  but omits information which will be necessary for the planned future
  implementation of `tf.saved_model.load`. Exported models using the current
  `save` implementation, and other existing SavedModels, will not be compatible
  with `tf.saved_model.load` when it is implemented. Further, `save` will in the
  future attempt to export `@tf.function`-decorated methods which it does not
  currently inspect, so some objects which are exportable today will raise
  exceptions on export in the future (e.g. due to complex/non-serializable
  default arguments). Such backwards-incompatible API changes are expected only
  prior to the TensorFlow 2.0 release.

  Args:
    obj: A trackable object to export.
    export_dir: A directory in which to write the SavedModel.
    signatures: Optional, either a `tf.function` with an input signature
      specified or the result of `f.get_concrete_function` on a
      `@tf.function`-decorated function `f`, in which case `f` will be used to
      generate a signature for the SavedModel under the default serving
      signature key. `signatures` may also be a dictionary, in which case it
      maps from signature keys to either `tf.function` instances with input
      signatures or concrete functions. The keys of such a dictionary may be
      arbitrary strings, but will typically be from the
      `tf.saved_model.signature_constants` module.

  Raises:
    ValueError: If `obj` is not trackable.

  @compatibility(eager)
  Not well supported when graph building. From TensorFlow 1.x,
  `tf.enable_eager_execution()` should run first. Calling tf.saved_model.save in
  a loop when graph building from TensorFlow 1.x will add new save operations to
  the default graph each iteration.

  May not be called from within a function body.
  @end_compatibility
  """
    if ops.inside_function():
        raise AssertionError(
            "tf.saved_model.save is not supported inside a traced "
            "@tf.function. Move the call to the outer eagerly-executed "
            "context.")
    # pylint: enable=line-too-long
    if not isinstance(obj, base.Trackable):
        raise ValueError(
            "Expected a Trackable object for export, got {}.".format(obj))

    checkpoint_graph_view = _AugmentedGraphView(obj)
    if signatures is None:
        signatures = signature_serialization.find_function_to_export(
            checkpoint_graph_view)

    signatures = signature_serialization.canonicalize_signatures(signatures)
    signature_serialization.validate_saveable_view(checkpoint_graph_view)
    signature_map = signature_serialization.create_signature_map(signatures)
    checkpoint_graph_view.add_object(
        parent_node=checkpoint_graph_view.root,
        name_in_parent=signature_serialization.SIGNATURE_ATTRIBUTE_NAME,
        subgraph_root=signature_map)

    # Use _SaveableView to provide a frozen listing of properties and functions.
    # Note we run this twice since, while constructing the view the first time
    # there can be side effects of creating variables.
    _ = _SaveableView(checkpoint_graph_view)
    saveable_view = _SaveableView(checkpoint_graph_view)

    # TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x
    # compatible (no sessions) and share it with this export API rather than
    # making a SavedModel proto and writing it directly.
    saved_model = saved_model_pb2.SavedModel()
    meta_graph_def = saved_model.meta_graphs.add()
    object_saver = util.TrackableSaver(checkpoint_graph_view)
    asset_info, exported_graph = _fill_meta_graph_def(meta_graph_def,
                                                      saveable_view,
                                                      signatures)
    saved_model.saved_model_schema_version = (
        constants.SAVED_MODEL_SCHEMA_VERSION)
    # So far we've just been generating protocol buffers with no I/O. Now we write
    # the checkpoint, copy assets into the assets directory, and write out the
    # SavedModel proto itself.
    utils_impl.get_or_create_variables_dir(export_dir)
    object_saver.save(utils_impl.get_variables_path(export_dir))
    builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map,
                                                export_dir)
    path = os.path.join(compat.as_bytes(export_dir),
                        compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
    object_graph_proto = _serialize_object_graph(saveable_view,
                                                 asset_info.asset_index)
    meta_graph_def.object_graph_def.CopyFrom(object_graph_proto)
    file_io.write_string_to_file(path, saved_model.SerializeToString())
    # Clean reference cycles so repeated export()s don't make work for the garbage
    # collector. Before this point we need to keep references to captured
    # constants in the saved graph.
    ops.dismantle_graph(exported_graph)
示例#53
0
 def _lineText(self, f, l):
     return compat.as_bytes("%d: %d" % (f, l))
示例#54
0
        def add(self,
                arg,
                tag=None,
                name=None,
                aggregate=None,
                index_override=None):
            """Return a wrapped tensor of an input tensor as an argument.

      Args:
        arg: A TensorFlow tensor that should be considered an argument.
        tag: String tag to identify arguments that should be packed.
        name: Name of argument. This is included in the Identity hint op names.
        aggregate: Strategy to aggregate.
        Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST,
          and OpHint.AGGREGATE_STACK.
          Note, aggregate is only valid if tag is specified.
        index_override: Specify what input/output index should this be in the
          final stub. i.e. add(arg0, index=1); add(arg1, index=0) wil make the
          final stub be as stub_func(inputs[arg1, arg0], outputs=[]) rather than
          the default call order based ordering.

      Returns:
        A tensor representing the wrapped argument.

      Raises:
        ValueError: When indices are not consistent.
      """

            # Find the appropriate index
            if tag is None:
                if aggregate is not None:
                    raise ValueError(
                        "You must specify `tag` if using aggregate.")
                global_index = self._get_new_global_index(index_override)
                sort_index = None
            else:
                if aggregate is None:
                    raise ValueError(
                        "You must specify `aggregate` if using tag.")
                if tag not in self._tag_to_global_index:
                    self._tag_to_global_index[tag] = (
                        self._get_new_global_index(index_override))
                    self._tag_to_next_sort_index[tag] = 0
                elif (index_override
                      and index_override != self._tag_to_global_index[tag]):
                    raise ValueError(
                        "Tag %r was called with two indices %r and %r" %
                        (tag, index_override, self._tag_to_global_index[tag]))
                global_index = self._tag_to_global_index[tag]
                sort_index = self._tag_to_next_sort_index[tag]
                self._tag_to_next_sort_index[tag] += 1

            uuid = self._unique_function_id
            name = "%s-%s-%s-%r-%r-%s" % (self._node_name_prefix,
                                          self._function_name, uuid,
                                          global_index, sort_index, name)
            identity_op = _array_ops.identity(arg, name=name)

            # pylint: disable=protected-access
            identity_op.op._set_attr(
                OpHint.FUNCTION_NAME_ATTR,
                _attr_value_pb2.AttrValue(
                    s=_compat.as_bytes(self._function_name)))
            identity_op.op._set_attr(
                OpHint.FUNCTION_UUID_ATTR,
                _attr_value_pb2.AttrValue(
                    s=_compat.as_bytes(self._unique_function_id)))
            identity_op.op._set_attr(self._attr_name,
                                     _attr_value_pb2.AttrValue(i=global_index))
            if sort_index is not None:
                identity_op.op._set_attr(
                    OpHint.FUNCTION_SORT_INDEX_ATTR,
                    _attr_value_pb2.AttrValue(i=sort_index))
            if aggregate is not None:
                identity_op.op._set_attr(
                    OpHint.FUNCTION_AGGREGATE_ATTR,
                    _attr_value_pb2.AttrValue(s=_compat.as_bytes((aggregate))))
            # pylint: enable=protected-access
            return identity_op
示例#55
0
 def _record(self, f, r):
     return compat.as_bytes("Record %d of file %d" % (r, f))
示例#56
0
def _GradientsHelper(ys,
                     xs,
                     grad_ys=None,
                     name="gradients",
                     colocate_gradients_with_ops=False,
                     gate_gradients=False,
                     aggregation_method=None,
                     stop_gradients=None,
                     unconnected_gradients=UnconnectedGradients.NONE,
                     src_graph=None):
    """Implementation of gradients()."""
    if context.executing_eagerly():
        raise RuntimeError(
            "tf.gradients is not supported when eager execution "
            "is enabled. Use tf.GradientTape instead.")
    if src_graph is None:
        src_graph = ops.get_default_graph()
    try:
        unconnected_gradients = UnconnectedGradients(unconnected_gradients)
    except ValueError:
        raise ValueError(
            f"Unknown value for unconnected_gradients: '{unconnected_gradients}'"
        )

    # If src_graph is a _FuncGraph (i.e. a function body), gather it and all
    # ancestor graphs. This is necessary for correctly handling captured values.
    func_graphs = []
    curr_graph = src_graph
    while _IsFunction(curr_graph):
        func_graphs.append(curr_graph)
        if isinstance(curr_graph, FuncGraph):
            curr_graph = curr_graph.outer_graph
        else:
            assert isinstance(curr_graph, framework_function._FuncGraph)  # pylint: disable=protected-access
            curr_graph = curr_graph._outer_graph  # pylint: disable=protected-access

    ys = _AsList(ys)
    xs = _AsList(xs)
    stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients)
    if grad_ys is None:
        grad_ys = [None] * len(ys)
    else:
        grad_ys = _AsList(grad_ys)

    with ops.name_scope(
            name, "gradients",
            list(ys) + list(xs) + list(stop_gradients) +
            list(grad_ys)) as grad_scope:
        # Get a uid for this call to gradients that can be used to help
        # cluster ops for compilation.
        gradient_uid = ops.get_default_graph().unique_name("uid")
        ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y")
        xs = [
            x.handle if resource_variable_ops.is_resource_variable(x) else x
            for x in xs
        ]
        xs = ops.internal_convert_n_to_tensor_or_indexed_slices(xs,
                                                                name="x",
                                                                as_ref=True)
        xs_set = object_identity.ObjectIdentitySet(xs)
        grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops,
                                 gradient_uid)

        # The approach we take here is as follows: Create a list of all ops in the
        # subgraph between the ys and xs.  Visit these ops in reverse order of ids
        # to ensure that when we visit an op the gradients w.r.t its outputs have
        # been collected.  Then aggregate these gradients if needed, call the op's
        # gradient function, and add the generated gradients to the gradients for
        # its input.

        # Initialize the pending count for ops in the connected subgraph from ys
        # to the xs.
        to_ops = [t.op for t in ys]
        from_ops = [t.op for t in xs]
        stop_gradient_ops = [t.op for t in stop_gradients]
        reachable_to_ops, pending_count, loop_state = _PendingCount(
            to_ops, from_ops, colocate_gradients_with_ops, func_graphs, xs_set)

        # Iterate over the collected ops.
        #
        # grads: op => list of gradients received on each output endpoint of the
        # op.  The gradients for each endpoint are initially collected as a list.
        # When it is time to call the op's gradient function, for each endpoint we
        # aggregate the list of received gradients into a Add() Operation if there
        # is more than one.
        grads = {}

        # Add the initial gradients for the ys.
        for y, grad_y in zip(ys, grad_ys):
            _SetGrad(grads, y, grad_y)

        # Initialize queue with to_ops.
        queue = collections.deque()
        # Add the ops in 'to_ops' into the queue.
        to_ops_set = set()
        for op in to_ops:
            # 'ready' handles the case where one output gradient relies on
            # another output's gradient.
            ready = (pending_count[op] == 0)
            if ready and op not in to_ops_set and op in reachable_to_ops:
                to_ops_set.add(op)
                queue.append(op)

        if loop_state:
            loop_exits = loop_state.ProcessUnusedLoopExits(
                pending_count, to_ops_set)
            for y in loop_exits:
                if backprop_util.IsTrainable(y):
                    _SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
                    queue.append(y.op)

        stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs_set)
        while queue:
            # generate gradient subgraph for op.
            op = queue.popleft()
            with _maybe_colocate_with(op, gradient_uid,
                                      colocate_gradients_with_ops):
                if loop_state:
                    loop_state.EnterGradWhileContext(op, before=True)
                out_grads = _AggregatedGrads(grads, op, gradient_uid,
                                             loop_state, aggregation_method)
                if loop_state:
                    loop_state.ExitGradWhileContext(op, before=True)

                grad_fn = None
                func_call = None
                is_partitioned_call = _IsPartitionedCall(op)
                # pylint: disable=protected-access
                is_func_call = (src_graph._is_function(op.type)
                                or is_partitioned_call)
                # pylint: enable=protected-access
                has_out_grads = any(
                    isinstance(g, ops.Tensor) or g for g in out_grads)
                if has_out_grads and (op not in stop_ops):
                    try:
                        grad_fn = ops.get_gradient_function(op)
                    except LookupError:
                        if is_func_call:
                            if is_partitioned_call:
                                func_name = compat.as_bytes(
                                    op.get_attr("f").name)
                                func_call = src_graph._get_function(  # pylint: disable=protected-access
                                    func_name)
                                # When a graph is imported, the FunctionDefs are not copied over
                                # to each sub-graph so we recursively search the outer graphs
                                # for the FunctionDef.
                                if not func_call and hasattr(
                                        src_graph, "outer_graph"):
                                    graph = src_graph.outer_graph
                                    while graph is not None:
                                        func_call = graph._get_function(
                                            func_name)  # pylint: disable=protected-access
                                        if func_call is not None:
                                            break
                                        if hasattr(graph, "outer_graph"):
                                            graph = graph.outer_graph
                                        else:
                                            break
                            else:
                                func_call = src_graph._get_function(op.type)  # pylint: disable=protected-access
                            # Note that __defun is not set if the graph is
                            # imported. If it's set, we prefer to access the original
                            # defun.
                            func_call = getattr(op, "__defun", func_call)
                            grad_fn = func_call.python_grad_func
                        else:
                            raise LookupError(
                                "No gradient defined for operation"
                                f"'{op.name}' (op type: {op.type}). "
                                "In general every operation must have an associated "
                                "`@tf.RegisterGradient` for correct autodiff, which this "
                                "op is lacking. If you want to pretend this "
                                "operation is a constant in your program, you may insert "
                                "`tf.stop_gradient`. This can be useful to silence the "
                                "error in cases where you know gradients are not needed, "
                                "e.g. the forward pass of tf.custom_gradient. "
                                "Please see more details in "
                                "https://www.tensorflow.org/api_docs/python/tf/custom_gradient.")  # pylint: disable=line-too-long
                if loop_state:
                    loop_state.EnterGradWhileContext(op, before=False)

                # NOTE(skyewm): We don't support computing gradients wrt a loop variable
                # unless it's within the context of a single iteration (i.e. the
                # gradient is wrt to the loop parameter in the body function, not wrt or
                # through the initial value). This means if we're in a while loop
                # context, we should never see a switch node from this context.
                # pylint: disable=protected-access
                if (control_flow_util.IsSwitch(op)
                        and op._control_flow_context is not None
                        and op._control_flow_context.IsWhileContext()
                        and op._control_flow_context ==
                        ops.get_default_graph()._get_control_flow_context()):
                    _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs_set)
                # pylint: enable=protected-access

                if (grad_fn or is_func_call) and has_out_grads:
                    # NOTE: If _AggregatedGrads didn't compute a value for the i'th
                    # output, it means that the cost does not depend on output[i],
                    # therefore dC/doutput[i] is 0.
                    for i, out_grad in enumerate(out_grads):
                        if (not isinstance(out_grad, ops.Tensor)
                                and not out_grad) and (
                                    (not grad_fn and is_func_call) or
                                    backprop_util.IsTrainable(op.outputs[i])):
                            # Only trainable outputs or outputs for a function call that
                            # will use SymbolicGradient get a zero gradient. Gradient
                            # functions should ignore the gradient for other outputs.
                            # TODO(apassos) gradients of resource handles might be an
                            # issue here because of zeros.
                            if loop_state:
                                out_grads[i] = loop_state.ZerosLikeV1WhileLoop(
                                    op, i)
                            elif default_gradient.supports_default_grad(
                                    op.outputs[i]):
                                # TODO(b/143286622): The supports_default_grad check is needed
                                # because While op emits non-differentiable resource tensors
                                # as outputs. Remove this check when that is not the case.
                                out_grads[i] = control_flow_state.ZerosLike(
                                    op, i)
                    with ops.name_scope(op.name + "_grad"):
                        # pylint: disable=protected-access
                        with src_graph._original_op(op):
                            # pylint: enable=protected-access
                            if grad_fn:
                                # If grad_fn was found, do not use SymbolicGradient even for
                                # functions.
                                in_grads = _MaybeCompile(
                                    grad_scope, op, func_call,
                                    lambda: grad_fn(op, *out_grads))
                            else:
                                # For function call ops, we add a 'SymbolicGradient'
                                # node to the graph to compute gradients.
                                in_grads = _MaybeCompile(
                                    grad_scope, op, func_call,
                                    lambda: _SymGrad(op, out_grads))
                            in_grads = _AsList(in_grads)
                            _VerifyGeneratedGradients(in_grads, op)
                            if gate_gradients and len(
                                [x for x in in_grads if x is not None]) > 1:
                                with ops.device(None):
                                    with ops._colocate_with_for_gradient(  # pylint: disable=protected-access
                                            None,
                                            gradient_uid,
                                            ignore_existing=True):
                                        in_grads = control_flow_ops.tuple(
                                            in_grads)
                    _LogOpGradients(op, out_grads, in_grads)
                else:
                    # If no grad_fn is defined or none of out_grads is available,
                    # just propagate a list of None backwards.
                    in_grads = [None] * len(_Inputs(op, xs_set))
                # Note: we don't filter out eager inputs here because the inputs need to
                # line up with in_grads.
                for i, (t_in, in_grad) in enumerate(
                        zip(_Inputs(op, xs_set), in_grads)):
                    if in_grad is not None:
                        if (isinstance(in_grad, ops.Tensor)
                                and t_in.dtype != dtypes.resource):
                            try:
                                in_grad.set_shape(t_in.get_shape())
                            except ValueError:
                                raise ValueError(
                                    "Incompatible shapes between op input and calculated "
                                    f"input gradient. Forward operation: {op.name}. Input "
                                    f"index: {i}. Original input shape: {t_in.shape}. "
                                    f"Calculated input gradient shape: {in_grad.shape}"
                                )
                        if not isinstance(t_in, ops.EagerTensor):
                            _SetGrad(grads, t_in, in_grad)
                if loop_state:
                    loop_state.ExitGradWhileContext(op, before=False)

            # Update pending count for the inputs of op and enqueue ready ops.
            _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count,
                                          loop_state, xs_set)

    if loop_state:
        loop_state.PostProcessing()
    return [_GetGrad(grads, x, unconnected_gradients) for x in xs]
示例#57
0
    def _GetLayerMatch(match_result):
        """Populates a layer match object containing ops/tensors for folding BNs.

    Args:
      match_result: Matched result from graph matcher

    Returns:
      layer_op: Matching conv/fc op prior to batch norm
      BatchNormMatch: _BatchNormMatch containing all required batch norm
      parameters.
    """
        moving_mean_tensor = None
        moving_variance_tensor = None
        bn_decay_mean_tensor = None
        bn_decay_var_tensor = None
        batch_to_space_op = None
        layer_op = match_result.get_op(layer_pattern)
        layer_tensor = match_result.get_tensor(layer_pattern)
        bn_id_op = match_result.get_op(batch_norm_identity_pattern)
        bn_op = match_result.get_op(batch_norm_pattern)
        if bn_id_op is None:
            bn_id_op = bn_op

        batch_epsilon = bn_op.get_attr('epsilon')

        # In the MatMul case, the output of batch norm is reshaped back into a
        # 2D tensor, so the output_tensor is the output of the Reshape op.
        output_tensor = bn_op.outputs[0]
        if layer_op.type == 'MatMul':
            output_reshape_op = match_result.get_op(
                matmul_bn_output_reshape_pattern)
            # If the matcher didn't match matmul_bn_output_reshape, there will be
            # another match for this 'MatMul' later, so we can skip this one.
            if output_reshape_op is None:
                return None, None
            output_tensor = output_reshape_op.outputs[0]

        # Ensure that the output tensor has consumers, otherwise this is a dangling
        # node and not a match.
        if not output_tensor.consumers():
            return None, None

        batch_to_space_op = match_result.get_op(batch_to_space_pattern)
        input_tensor = match_result.get_tensor(input_pattern)
        weight_tensor = match_result.get_tensor(weight_pattern)
        gamma_tensor = match_result.get_tensor(gamma_pattern)
        beta_tensor = match_result.get_tensor(beta_pattern)
        # FusedBatchNorm in training is different from that in inference. It takes
        # empty 'mean' and empty 'variance', and produces the mean and the variance
        # of the batch. Therefore, when is_training is true, mean_tensor and
        # variance_tensor point to 1st and 2nd (0-based) output of bn_op,
        # respectively; when is_training is false, they point to bn_op's inputs.
        is_training = bn_op.get_attr('is_training')
        if is_training:
            # FusedBatchNormGrad doesn't compute gradients of the batch_mean and
            # batch_variance outputs, so we need to substitute our own custom
            # gradient.
            # TODO(suharshs, raghuramank): Find a way to avoid needing this hack.
            # pylint: disable=protected-access
            bn_op._set_attr(
                '_gradient_op_type',
                attr_value_pb2.AttrValue(
                    s=compat.as_bytes('FoldFusedBatchNormGrad')))
            # pylint: enable=protected-access
            mean_tensor = bn_op.outputs[1]
            # The batch variance used during forward and backward prop is biased,
            # i.e it is calculated as: V=sum(x(k)-mu)^2/N. For the moving average
            # calculation, the variance is corrected by the term N/N-1 (Bessel's
            # correction). The variance tensor read from FuseBatchNorm has Bessel's
            # correction applied, so we undo it here.
            scope, sep, _ = bn_op.name.rpartition('/')
            g = ops.get_default_graph()
            with g.as_default(), g.name_scope(scope + sep):
                n = math_ops.cast(
                    array_ops.size(layer_tensor) / array_ops.size(mean_tensor),
                    dtypes.float32)
                variance_tensor = math_ops.multiply(
                    bn_op.outputs[2], (n - 1) / n,
                    name='Undo_Bessel_Correction')
            # TODO(suharshs): Find a way to get rid of this inner match.
            for mul_match_result in moving_avg_mul_matcher.match_graph(graph):
                sub_op = mul_match_result.get_op(moving_average_sub_pattern)
                if sub_op.inputs[1].name == bn_op.outputs[1].name:
                    # During training: Batch Mean is bn_op.outputs[1]
                    moving_mean_tensor = sub_op.inputs[0]
                    bn_decay_mean_tensor = mul_match_result.get_tensor(
                        bn_decay_pattern)
                if sub_op.inputs[1].name == bn_op.outputs[2].name:
                    # During training: Batch Var is bn_op.outputs[2]
                    moving_variance_tensor = sub_op.inputs[0]
                    bn_decay_var_tensor = mul_match_result.get_tensor(
                        bn_decay_pattern)
        else:
            mean_tensor = match_result.get_tensor(mean_pattern)
            variance_tensor = match_result.get_tensor(variance_pattern)

        return layer_op, _BatchNormMatch(
            layer_op=layer_op,
            bn_op=bn_op,
            output_tensor=output_tensor,
            input_tensor=input_tensor,
            weight_tensor=weight_tensor,
            gamma_tensor=gamma_tensor,
            beta_tensor=beta_tensor,
            mean_tensor=mean_tensor,
            variance_tensor=variance_tensor,
            moving_mean_tensor=moving_mean_tensor,
            moving_variance_tensor=moving_variance_tensor,
            bn_decay_mean_tensor=bn_decay_mean_tensor,
            bn_decay_var_tensor=bn_decay_var_tensor,
            batch_epsilon=batch_epsilon,
            batch_to_space_op=batch_to_space_op)
示例#58
0
 def _get_keywords(self, f, r):
     num_keywords = 1 + (f + r) % 2
     keywords = []
     for index in range(num_keywords):
         keywords.append(compat.as_bytes("keyword%d" % index))
     return keywords
 def test_seq_ex_in_sequence_categorical_column_with_vocabulary_list(self):
     self._test_parsed_sequence_example(
         'bytes_list', sfc.sequence_categorical_column_with_vocabulary_list,
         list(string.ascii_lowercase), [3, 4],
         [compat.as_bytes(x) for x in 'acg'])
示例#60
0
 def _record(self, f, r):
     return compat.as_bytes(str(f * 2 + r) * self._record_bytes)