Esempio n. 1
0
  def testCopy(self):
    gfile.MkDir(self.tmp + "dir1")
    gfile.MkDir(self.tmp + "dir2")
    with gfile.GFile(self.tmp + "dir1/file1", "w"):
      pass  # Create file
    with gfile.GFile(self.tmp + "dir2/file2", "w"):
      pass  # Create file

    # Dest file already exists, overwrite=False (default).
    self.assertRaises(
        OSError, lambda: gfile.Copy(self.tmp + "dir1/file1",
                                    self.tmp + "dir2/file2"))
    # Overwrite succeeds
    gfile.Copy(self.tmp + "dir1/file1", self.tmp + "dir2/file2",
               overwrite=True)
    self.assertTrue(gfile.Exists(self.tmp + "dir2/file2"))

    # Normal copy.
    gfile.Rename(self.tmp + "dir1/file1", self.tmp + "dir2/file1")
    self.assertTrue(gfile.Exists(self.tmp + "dir2/file1"))

    # Normal copy to non-existent dir
    self.assertRaises(OSError,
                      lambda: gfile.Rename(self.tmp + "dir1/file1",
                                           self.tmp + "newdir/file1"))
    def export_fn(estimator, export_dir_base, checkpoint_path=None):
        """Exports the given Estimator as a SavedModel and invokes post_export_fn.

    Args:
      estimator: the Estimator to export.
      export_dir_base: A string containing a directory to write the exported
        graphs and checkpoint.
      checkpoint_path: The checkpoint path to export. If None (the default),
        the most recent checkpoint found within the model directory is chosen.

    Returns:
      The string path to the SavedModel indicated by post_export_fn.

    Raises:
      ValueError: If `estimator` is a ${tf.estimator.Estimator} instance
        and `default_output_alternative_key` was specified or if post_export_fn
        does not return a valid directory.
    """
        tmp_base_export_dir = tempfile.mkdtemp()
        tmp_base_export = base_export_strategy.export(estimator,
                                                      tmp_base_export_dir,
                                                      checkpoint_path)
        tmp_post_export_dir = tempfile.mkdtemp()
        tmp_post_export = post_export_fn(tmp_base_export, tmp_post_export_dir)

        if not tmp_post_export.startswith(tmp_post_export_dir):
            raise ValueError(
                'post_export_fn must return a sub-directory of {}'.format(
                    tmp_post_export_dir))
        export_relpath = os.path.relpath(tmp_post_export, tmp_post_export_dir)

        gfile.Rename(os.path.join(tmp_post_export_dir, export_relpath),
                     os.path.join(export_dir_base, export_relpath))
        return os.path.join(export_dir_base, export_relpath)
def shuffle_records(fname):
    """Shuffle records in a single file."""
    print("Shuffling records in file %s" % fname)

    # Rename file prior to shuffling
    tmp_fname = fname + ".unshuffled"
    gfile.Rename(fname, tmp_fname)

    reader = python_io.tf_record_iterator(tmp_fname)
    records = []
    for record in reader:
        records.append(record)
        if len(records) % 100000 == 0:
            print("\tRead: %d", len(records))

    random.shuffle(records)

    # Write shuffled records to original file name
    with python_io.TFRecordWriter(fname) as w:
        for count, record in enumerate(records):
            w.write(record)
            if count > 0 and count % 100000 == 0:
                print("\tWriting record: %d" % count)

    gfile.Remove(tmp_fname)
Esempio n. 4
0
    def test_file_operations(self):
        """Test file operations"""

        f = get_oss_path("test_file_operations")
        self.assertFalse(gfile.Exists(f))

        fh = gfile.Open(f, mode="w")
        content = "file content"
        fh.write(content)
        fh.close()
        self.assertTrue(gfile.Exists(f))

        fh = gfile.Open(f)
        self.assertEqual(fh.read(), content)

        self.assertEqual(gfile.Stat(f).length, len(content))

        f2 = get_oss_path("test_file_2")
        gfile.Rename(f, f2)
        self.assertFalse(gfile.Exists(f))
        self.assertTrue(gfile.Exists(f2))

        f3 = get_oss_path("test_file_3")
        gfile.Copy(f2, f3, overwrite=True)
        self.assertTrue(gfile.Exists(f3))
 def finish_example_id_dumper(self):
     self._tf_record_writer.close()
     self._tf_record_writer = None
     if self.dumped_example_number() > 0:
         fpath = self._get_dumped_fpath()
         gfile.Rename(self._tmp_fpath, fpath)
         return ExampleIdMeta(self._start_index, self._end_index, fpath)
     assert self._start_index == self._end_index
     gfile.Remove(self._tmp_fpath)
     return None
Esempio n. 6
0
 def finish_data_block(self):
     assert self._example_num == len(self._data_block_meta.example_ids)
     self._tf_record_writer.close()
     self._tf_record_writer = None
     if len(self._data_block_meta.example_ids) > 0:
         data_block_id = self._generate_data_block_id()
         data_block_path = os.path.join(self._get_data_block_dir(),
                                        data_block_id + DataBlockSuffix)
         gfile.Rename(self._tmp_fpath, data_block_path)
         self._data_block_meta.start_time = self._start_time
         self._data_block_meta.end_time = self._end_time
         self._data_block_meta.block_id = data_block_id
         meta_tmp_fpath = self._get_tmp_fpath()
         with tf.io.TFRecordWriter(meta_tmp_fpath) as meta_writer:
             meta_writer.write(self._data_block_meta.SerializeToString())
         meta_path = os.path.join(self._get_data_block_dir(),
                                  data_block_id + DataBlockMetaSuffix)
         gfile.Rename(meta_tmp_fpath, meta_path)
     else:
         gfile.Remove(self._tmp_fpath)
Esempio n. 7
0
  def testRename(self):
    gfile.MkDir(self.tmp + "dir1")
    gfile.MkDir(self.tmp + "dir2")
    with gfile.GFile(self.tmp + "file1", "w"):
      pass  # Create file
    with gfile.GFile(self.tmp + "file2", "w"):
      pass  # Create file

    # Dest file already exists, overwrite=False (default).
    self.assertRaises(
        OSError, lambda: gfile.Rename(self.tmp + "file1", self.tmp + "file2"))
    gfile.Rename(self.tmp + "file1", self.tmp + "file2", overwrite=True)
    self.assertFalse(gfile.Exists(self.tmp + "file1"))
    gfile.Rename(self.tmp + "file2", self.tmp + "newfile")
    self.assertTrue(gfile.Exists(self.tmp + "newfile"))

    gfile.Rename(self.tmp + "dir1", self.tmp + "dir2")
    self.assertFalse(gfile.Exists(self.tmp + "dir1"))
    gfile.Rename(self.tmp + "dir2", self.tmp + "newdir")
    self.assertTrue(gfile.Exists(self.tmp + "newdir"))
Esempio n. 8
0
    def test_dir_operations(self):
        """ Test directory operations"""

        d = get_oss_path("d1/d2/d3/d4")
        gfile.MakeDirs(d)
        self.assertTrue(gfile.Stat(d).is_directory)

        # Test listing bucket directory with and without trailing '/'
        content = gfile.ListDirectory(
            "oss://%s\x01id=%s\x02key=%s\x02host=%s" %
            (bucket, access_id, access_key, host))
        content_s = gfile.ListDirectory(
            "oss://%s\x01id=%s\x02key=%s\x02host=%s/" %
            (bucket, access_id, access_key, host))
        self.assertEqual(content, content_s)
        self.assertIn("oss_fs_test", content)
        self.assertIn("oss_fs_test/d1", content)
        self.assertIn("oss_fs_test/d1/d2", content)

        # Test listing test directory with and without trailing '/'
        content = gfile.ListDirectory(
            "oss://%s\x01id=%s\x02key=%s\x02host=%s" %
            (bucket, access_id, access_key, host) + "/oss_fs_test")
        content_s = gfile.ListDirectory(
            "oss://%s\x01id=%s\x02key=%s\x02host=%s" %
            (bucket, access_id, access_key, host) + "/oss_fs_test/")
        self.assertEqual(content, content_s)
        self.assertIn("d1", content)
        self.assertIn("d1/d2", content)

        # Test listing sub directories.
        content = gfile.ListDirectory(get_oss_path("d1"))
        content_s = gfile.ListDirectory(get_oss_path("d1/"))
        self.assertEqual(content, content_s)
        self.assertIn("d2", content)

        content = gfile.ListDirectory(get_oss_path("d1/d2/d3/d4"))
        content_s = gfile.ListDirectory(get_oss_path("d1/d2/d3/d4"))
        self.assertEqual(content, content_s)
        self.assertEqual([], content)

        # Test Rename directories
        self.assertTrue(gfile.Exists(get_oss_path("d1")))
        gfile.Rename(get_oss_path("d1"),
                     get_oss_path("rename_d1"),
                     overwrite=True)
        self.assertTrue(gfile.Exists(get_oss_path("rename_d1")))
        self.assertFalse(gfile.Exists(get_oss_path("d1")))

        content = gfile.ListDirectory(get_oss_path("rename_d1"))
        content_s = gfile.ListDirectory(get_oss_path("rename_d1/"))
        self.assertEqual(content, content_s)
        self.assertIn("d2", content)
def encode_and_save_files(subtokenizer, data_dir, raw_files, tag,
                          total_shards):
    """Save data from files as encoded Examples in TFrecord format.
  Args:
    subtokenizer: Subtokenizer object that will be used to encode the strings.
    data_dir: The directory in which to write the examples
    raw_files: A tuple of (input, target) data files. Each line in the input and
      the corresponding line in target file will be saved in a tf.Example.
    tag: String that will be added onto the file names.
    total_shards: Number of files to divide the data into.
  Returns:
    List of all files produced.
  """
    # Create a file for each shard.
    filepaths = [
        shard_filename(data_dir, tag, n + 1, total_shards)
        for n in range(total_shards)
    ]

    if all_exist(filepaths):
        print("Files with tag %s already exist." % tag)
        return filepaths

    print("Saving files with tag %s." % tag)
    input_file = raw_files[0]
    target_file = raw_files[1]

    # Write examples to each shard in round robin order.
    tmp_filepaths = [fname + ".incomplete" for fname in filepaths]
    writers = [python_io.TFRecordWriter(fname) for fname in tmp_filepaths]
    counter, shard = 0, 0
    for counter, (input_line, target_line) in enumerate(
            zip(txt_line_iterator(input_file),
                txt_line_iterator(target_file))):
        if counter > 0 and counter % 100000 == 0:
            print("\tSaving case %d." % counter)
        example = dict_to_example({
            "inputs":
            subtokenizer.encode(input_line, add_eos=True),
            "targets":
            subtokenizer.encode(target_line, add_eos=True)
        })
        writers[shard].write(example.SerializeToString())
        shard = (shard + 1) % total_shards
    for writer in writers:
        writer.close()

    for tmp_name, final_name in zip(tmp_filepaths, filepaths):
        gfile.Rename(tmp_name, final_name)

    print("Saved %d Examples", counter)
    return filepaths
    def export_fn(estimator, export_dir_base, checkpoint_path=None):
        """Exports the given Estimator as a SavedModel and invokes post_export_fn.

    Args:
      estimator: the Estimator to export.
      export_dir_base: A string containing a directory to write the exported
        graphs and checkpoint.
      checkpoint_path: The checkpoint path to export. If None (the default),
        the most recent checkpoint found within the model directory is chosen.

    Returns:
      The string path to the SavedModel indicated by post_export_fn.

    Raises:
      ValueError: If `estimator` is a ${tf.estimator.Estimator} instance
        and `default_output_alternative_key` was specified or if post_export_fn
        does not return a valid directory.
      RuntimeError: If unable to create temporary or final export directory.
    """
        tmp_base_export_folder = 'temp-base-export-' + str(int(time.time()))
        tmp_base_export_dir = os.path.join(export_dir_base,
                                           tmp_base_export_folder)
        if gfile.Exists(tmp_base_export_dir):
            raise RuntimeError('Failed to obtain base export directory')
        gfile.MakeDirs(tmp_base_export_dir)
        tmp_base_export = base_export_strategy.export(estimator,
                                                      tmp_base_export_dir,
                                                      checkpoint_path)

        tmp_post_export_folder = 'temp-post-export-' + str(int(time.time()))
        tmp_post_export_dir = os.path.join(export_dir_base,
                                           tmp_post_export_folder)
        if gfile.Exists(tmp_post_export_dir):
            raise RuntimeError('Failed to obtain temp export directory')

        gfile.MakeDirs(tmp_post_export_dir)
        tmp_post_export = post_export_fn(tmp_base_export, tmp_post_export_dir)

        if not tmp_post_export.startswith(tmp_post_export_dir):
            raise ValueError(
                'post_export_fn must return a sub-directory of {}'.format(
                    tmp_post_export_dir))
        post_export_relpath = os.path.relpath(tmp_post_export,
                                              tmp_post_export_dir)
        post_export = os.path.join(export_dir_base, post_export_relpath)
        if gfile.Exists(post_export):
            raise RuntimeError('Failed to obtain final export directory')
        gfile.Rename(tmp_post_export, post_export)

        gfile.DeleteRecursively(tmp_base_export_dir)
        gfile.DeleteRecursively(tmp_post_export_dir)
        return post_export
Esempio n. 11
0
    def test_rename_dir(self):
        """Test rename dir.

    """
        # Setup and check preconditions.
        src_dir_name = "igfs:///test_rename_dir/1"
        dst_dir_name = "igfs:///test_rename_dir/2"
        gfile.MkDir(src_dir_name)
        # Rename directory.
        gfile.Rename(src_dir_name, dst_dir_name)
        # Check that only new name of directory is available.
        self.assertFalse(gfile.Exists(src_dir_name))
        self.assertTrue(gfile.Exists(dst_dir_name))
        self.assertTrue(gfile.IsDirectory(dst_dir_name))
Esempio n. 12
0
    def test_rename_file(self):
        """Test rename file.

    """
        # Setup and check preconditions.
        src_file_name = "igfs:///test_rename_file/1"
        dst_file_name = "igfs:///test_rename_file/2"
        with gfile.Open(src_file_name, mode="w") as w:
            w.write("42")
        self.assertTrue(gfile.Exists(src_file_name))
        # Rename file.
        gfile.Rename(src_file_name, dst_file_name)
        # Check that only new name of file is available.
        self.assertFalse(gfile.Exists(src_file_name))
        self.assertTrue(gfile.Exists(dst_file_name))
        with gfile.Open(dst_file_name, mode="r") as r:
            data = r.read()
        self.assertEqual("42", data)
Esempio n. 13
0
def export_saved_model(estimator,
                       export_dir_base,
                       checkpoint_path,
                       serving_input_receiver_fn,
                       as_text=False):
    with context.graph_mode():
        export_dir = export_helpers.get_timestamped_export_dir(export_dir_base)
        temp_export_dir = export_helpers.get_temp_export_dir(export_dir)

        builder = saved_model_builder.SavedModelBuilder(temp_export_dir)

        save_variables = True
        _add_meta_graph_for_mode(estimator, builder, serving_input_receiver_fn,
                                 checkpoint_path, save_variables)
        save_variables = False

        builder.save(as_text)
        if save_variables:
            raise ValueError('No valid modes for exporting found.')

    gfile.Rename(temp_export_dir, export_dir)
    return export_dir
def download_from_url(path, url):
    """Download content from a url.
  Args:
    path: string directory where file will be downloaded
    url: string url
  Returns:
    Full path to downloaded file
  """
    filename = url.split("/")[-1]
    found_file = find_file(path, filename, max_depth=0)
    if found_file is None:
        filename = os.path.join(path, filename)
        print("Downloading from %s to %s." % (url, filename))
        inprogress_filepath = filename + ".incomplete"
        inprogress_filepath, _ = urlretrieve(url,
                                             inprogress_filepath,
                                             reporthook=download_report_hook)
        # Print newline to clear the carriage return from the download progress.
        print()
        gfile.Rename(inprogress_filepath, filename)
        return filename
    else:
        print("Already downloaded: %s (at %s)." % (url, found_file))
        return found_file
Esempio n. 15
0
  def export(self,
             export_dir_base,
             global_step_tensor,
             sess=None,
             exports_to_keep=None):
    """Exports the model.

    Args:
      export_dir_base: A string path to the base export dir.
      global_step_tensor: An Tensor or tensor name providing the
        global step counter to append to the export directory path and set
        in the manifest version.
      sess: A Session to use to save the parameters.
      exports_to_keep: a gc.Path filter function used to determine the set of
        exports to keep. If set to None, all versions will be kept.

    Raises:
      RuntimeError: if init is not called.
      RuntimeError: if the export would overwrite an existing directory.
    """
    if not self._has_init:
      raise RuntimeError("init must be called first")

    global_step = training_util.global_step(sess, global_step_tensor)
    export_dir = os.path.join(export_dir_base,
                              VERSION_FORMAT_SPECIFIER % global_step)

    # Prevent overwriting on existing exports which could lead to bad/corrupt
    # storage and loading of models. This is an important check that must be
    # done before any output files or directories are created.
    if gfile.Exists(export_dir):
      raise RuntimeError("Overwriting exports can cause corruption and are "
                         "not allowed. Duplicate export dir: %s" % export_dir)

    # Output to a temporary directory which is atomically renamed to the final
    # directory when complete.
    tmp_export_dir = export_dir + "-tmp"
    gfile.MakeDirs(tmp_export_dir)

    self._saver.save(sess,
                     os.path.join(tmp_export_dir, EXPORT_BASE_NAME),
                     meta_graph_suffix=EXPORT_SUFFIX_NAME)

    # Run the asset callback.
    if self._assets_callback:
      assets_dir = os.path.join(tmp_export_dir, ASSETS_DIRECTORY)
      gfile.MakeDirs(assets_dir)
      self._assets_callback(assets_dir)

    # TODO(b/27794910): Delete *checkpoint* file before rename.
    gfile.Rename(tmp_export_dir, export_dir)

    if exports_to_keep:
      # create a simple parser that pulls the export_version from the directory.
      def parser(path):
        match = re.match("^" + export_dir_base + "/(\\d{8})$", path.path)
        if not match:
          return None
        return path._replace(export_version=int(match.group(1)))

      paths_to_delete = gc.negation(exports_to_keep)
      for p in paths_to_delete(gc.get_paths(export_dir_base, parser=parser)):
        gfile.DeleteRecursively(p.path)
Esempio n. 16
0
  def export_savedmodel(
      self, export_dir_base, serving_input_receiver_fn,
      assets_extra=None,
      as_text=False,
      checkpoint_path=None):
    """Exports inference graph as a SavedModel into given dir.

    This method builds a new graph by first calling the
    serving_input_receiver_fn to obtain feature `Tensor`s, and then calling
    this `Estimator`'s model_fn to generate the model graph based on those
    features. It restores the given checkpoint (or, lacking that, the most
    recent checkpoint) into this graph in a fresh session.  Finally it creates
    a timestamped export directory below the given export_dir_base, and writes
    a `SavedModel` into it containing a single `MetaGraphDef` saved from this
    session.

    The exported `MetaGraphDef` will provide one `SignatureDef` for each
    element of the export_outputs dict returned from the model_fn, named using
    the same keys.  One of these keys is always
    signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which
    signature will be served when a serving request does not specify one.
    For each signature, the outputs are provided by the corresponding
    `ExportOutput`s, and the inputs are always the input receivers provided by
    the serving_input_receiver_fn.

    Extra assets may be written into the SavedModel via the extra_assets
    argument.  This should be a dict, where each key gives a destination path
    (including the filename) relative to the assets.extra directory.  The
    corresponding value gives the full path of the source file to be copied.
    For example, the simple case of copying a single file without renaming it
    is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.

    Args:
      export_dir_base: A string containing a directory in which to create
        timestamped subdirectories containing exported SavedModels.
      serving_input_receiver_fn: A function that takes no argument and
        returns a `ServingInputReceiver`.
      assets_extra: A dict specifying how to populate the assets.extra directory
        within the exported SavedModel, or `None` if no extra assets are needed.
      as_text: whether to write the SavedModel proto in text format.
      checkpoint_path: The checkpoint path to export.  If `None` (the default),
        the most recent checkpoint found within the model directory is chosen.

    Returns:
      The string path to the exported directory.

    Raises:
      ValueError: if no serving_input_receiver_fn is provided, no export_outputs
          are provided, or no checkpoint can be found.
    """
    if serving_input_receiver_fn is None:
      raise ValueError('serving_input_receiver_fn must be defined.')

    with ops.Graph().as_default() as g:
      self._create_and_assert_global_step(g)
      random_seed.set_random_seed(self._config.tf_random_seed)
      serving_input_receiver = serving_input_receiver_fn()

      # Call the model_fn and collect the export_outputs.
      estimator_spec = self._call_model_fn(
          features=serving_input_receiver.features,
          labels=None,
          mode=model_fn_lib.ModeKeys.PREDICT,
          config=self.config)

      # Build the SignatureDefs from receivers and all outputs
      signature_def_map = build_all_signature_defs(
          serving_input_receiver.receiver_tensors,
          estimator_spec.export_outputs,
          serving_input_receiver.receiver_tensors_alternatives)

      if not checkpoint_path:
        # Locate the latest checkpoint
        checkpoint_path = saver.latest_checkpoint(self._model_dir)
      if not checkpoint_path:
        raise ValueError("Couldn't find trained model at %s." % self._model_dir)

      export_dir = get_timestamped_export_dir(export_dir_base)
      temp_export_dir = get_temp_export_dir(export_dir)

      # TODO(soergel): Consider whether MonitoredSession makes sense here
      with tf_session.Session() as session:

        saver_for_restore = estimator_spec.scaffold.saver or saver.Saver(
            sharded=True)
        saver_for_restore.restore(session, checkpoint_path)

        # TODO(b/36111876): replace legacy_init_op with main_op mechanism
        # pylint: disable=protected-access
        local_init_op = (
            estimator_spec.scaffold.local_init_op or
            monitored_session.Scaffold._default_local_init_op())
        # pylint: enable=protected-access

        # Perform the export
        builder = saved_model_builder.SavedModelBuilder(temp_export_dir)
        builder.add_meta_graph_and_variables(
            session, [tag_constants.SERVING],
            signature_def_map=signature_def_map,
            assets_collection=ops.get_collection(
                ops.GraphKeys.ASSET_FILEPATHS),
            legacy_init_op=local_init_op)
        builder.save(as_text)

      # Add the extra assets
      if assets_extra:
        assets_extra_path = os.path.join(compat.as_bytes(temp_export_dir),
                                         compat.as_bytes('assets.extra'))
        for dest_relative, source in assets_extra.items():
          dest_absolute = os.path.join(compat.as_bytes(assets_extra_path),
                                       compat.as_bytes(dest_relative))
          dest_path = os.path.dirname(dest_absolute)
          gfile.MakeDirs(dest_path)
          gfile.Copy(source, dest_absolute)

      gfile.Rename(temp_export_dir, export_dir)
      return export_dir
Esempio n. 17
0
def export_eval_savedmodel(
    estimator,
    export_dir_base,
    eval_input_receiver_fn,
    checkpoint_path = None):
  """Export a EvalSavedModel for the given estimator.

  Args:
    estimator: Estimator to export the graph for.
    export_dir_base: Base path for export. Graph will be exported into a
      subdirectory of this base path.
    eval_input_receiver_fn: Eval input receiver function.
    checkpoint_path: Path to a specific checkpoint to export. If set to None,
      exports the latest checkpoint.

  Returns:
    Path to the directory where the eval graph was exported.

  Raises:
    ValueError: Could not find a checkpoint to export.
  """
  with tf.Graph().as_default() as g:
    eval_input_receiver = eval_input_receiver_fn()
    tf.train.create_global_step(g)
    tf.set_random_seed(estimator.config.tf_random_seed)

    # Workaround for TensorFlow issue #17568. Note that we pass the
    # identity-wrapped features and labels to model_fn, but we have to feed
    # the non-identity wrapped Tensors during evaluation.
    #
    # Also note that we can't wrap predictions, so metrics that have control
    # dependencies on predictions will cause the predictions to be recomputed
    # during their evaluation.
    wrapped_features = util.wrap_tensor_or_dict_of_tensors_in_identity(
        eval_input_receiver.features)
    wrapped_labels = util.wrap_tensor_or_dict_of_tensors_in_identity(
        eval_input_receiver.labels)

    if isinstance(estimator, tf.estimator.Estimator):
      # This is a core estimator
      estimator_spec = estimator.model_fn(
          features=wrapped_features,
          labels=wrapped_labels,
          mode=tf.estimator.ModeKeys.EVAL,
          config=estimator.config)
    else:
      # This is a contrib estimator
      model_fn_ops = estimator._call_model_fn(  # pylint: disable=protected-access
          features=wrapped_features,
          labels=wrapped_labels,
          mode=tf.estimator.ModeKeys.EVAL)
      estimator_spec = lambda x: None
      estimator_spec.predictions = model_fn_ops.predictions
      estimator_spec.eval_metric_ops = model_fn_ops.eval_metric_ops
      estimator_spec.scaffold = model_fn_ops.scaffold

    # Save metric using eval_metric_ops.
    for user_metric_key, (value_op, update_op) in (
        estimator_spec.eval_metric_ops.items()):
      tf.add_to_collection('%s/%s' % (encoding.METRICS_COLLECTION,
                                      encoding.KEY_SUFFIX),
                           encoding.encode_key(user_metric_key))
      tf.add_to_collection('%s/%s' % (encoding.METRICS_COLLECTION,
                                      encoding.VALUE_OP_SUFFIX),
                           encoding.encode_tensor_node(value_op))
      tf.add_to_collection('%s/%s' % (encoding.METRICS_COLLECTION,
                                      encoding.UPDATE_OP_SUFFIX),
                           encoding.encode_tensor_node(update_op))

    # Save all prediction nodes.
    # Predictions can either be a Tensor, or a dict of Tensors.
    predictions = estimator_spec.predictions
    if not isinstance(predictions, dict):
      predictions = {encoding.DEFAULT_PREDICTIONS_DICT_KEY: predictions}

    for prediction_key, prediction_node in predictions.items():
      _encode_and_add_to_node_collection(encoding.PREDICTIONS_COLLECTION,
                                         prediction_key, prediction_node)

    ############################################################
    ## Features, label (and weight) graph

    # Placeholder for input example to label graph.
    tf.add_to_collection(encoding.INPUT_EXAMPLE_COLLECTION,
                         encoding.encode_tensor_node(
                             eval_input_receiver.receiver_tensors['examples']))

    # Save all label nodes.
    # Labels can either be a Tensor, or a dict of Tensors.
    labels = eval_input_receiver.labels
    if not isinstance(labels, dict):
      labels = {encoding.DEFAULT_LABELS_DICT_KEY: labels}

    for label_key, label_node in labels.items():
      _encode_and_add_to_node_collection(encoding.LABELS_COLLECTION, label_key,
                                         label_node)

    # Save features.
    for feature_name, feature_node in eval_input_receiver.features.items():
      _encode_and_add_to_node_collection(encoding.FEATURES_COLLECTION,
                                         feature_name, feature_node)

    ############################################################
    ## Export as normal

    if not checkpoint_path:
      checkpoint_path = tf.train.latest_checkpoint(estimator.model_dir)
      if not checkpoint_path:
        raise ValueError(
            'Could not find trained model at %s.' % estimator.model_dir)

    export_dir = _get_timestamped_export_dir(export_dir_base)
    temp_export_dir = _get_temp_export_dir(export_dir)

    if estimator.config.session_config is None:
      session_config = config_pb2.ConfigProto(allow_soft_placement=True)
    else:
      session_config = estimator.config.session_config

    with tf.Session(config=session_config) as session:
      if estimator_spec.scaffold and estimator_spec.scaffold.saver:
        saver_for_restore = estimator_spec.scaffold.saver
      else:
        saver_for_restore = tf.train.Saver(sharded=True)
      saver_for_restore.restore(session, checkpoint_path)

      if estimator_spec.scaffold and estimator_spec.scaffold.local_init_op:
        local_init_op = estimator_spec.scaffold.local_init_op
      else:
        local_init_op = tf.train.Scaffold._default_local_init_op()
      # pylint: enable=protected-access

      # Perform the export
      builder = tf.saved_model.builder.SavedModelBuilder(temp_export_dir)
      builder.add_meta_graph_and_variables(
          session,
          [tf.saved_model.tag_constants.SERVING],
          # Don't export any signatures, since this graph is not actually
          # meant for serving.
          signature_def_map=None,
          assets_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS),
          legacy_init_op=local_init_op)
      builder.save(False)

      gfile.Rename(temp_export_dir, export_dir)
      return export_dir
Esempio n. 18
0
def save_keras_model(model,
                     saved_model_path,
                     custom_objects=None,
                     as_text=None):
    """Save a `tf.keras.Model` into Tensorflow SavedModel format.

  `save_model` generates new files/folders under the `saved_model_path` folder:
  1) an asset folder containing the json string of the model's
     configuration (topology).
  2) a checkpoint containing the model weights.
  3) a saved_model.pb file containing the model's MetaGraphs. The prediction
     graph is always exported. The evaluaton and training graphs are exported
     if the following conditions are met:
     - Evaluation: model loss is defined.
     - Training: model is compiled with an optimizer defined under `tf.train`.
       This is because `tf.keras.optimizers.Optimizer` instances cannot be
       saved to checkpoints.

  Model Requirements:
  - Model must be a sequential model or functional model. Subclassed models can
    not be saved via this function, unless you provide an implementation for
    get_config() and from_config().
  - All variables must be saveable by the model. In general, this condition is
    met through the use of layers defined in the keras library. However,
    there is currently a bug with variables created in Lambda layer functions
    not being saved correctly (see
    https://github.com/keras-team/keras/issues/9740).

  Note that each mode is exported in separate graphs, so different modes do not
  share variables. To use the train graph with evaluation or prediction graphs,
  create a new checkpoint if variable values have been updated.

  Args:
    model: A `tf.keras.Model` to be saved.
    saved_model_path: a string specifying the path to the SavedModel directory.
      The SavedModel will be saved to a timestamped folder created within this
      directory.
    custom_objects: Optional dictionary mapping string names to custom classes
      or functions (e.g. custom loss functions).
    as_text: whether to write the `SavedModel` proto in text format.

  Returns:
    String path to the SavedModel folder, a subdirectory of `saved_model_path`.

  Raises:
    NotImplementedError: If the model is a subclassed model.
    ValueError: If a Sequential model does not have input shapes defined by the
      user, and is not built.
  """
    if not model._is_graph_network:
        if isinstance(model, sequential.Sequential):
            # If input shape is not directly set in the model, the exported model
            # will assume that the inputs have the same shape as the shape the model
            # was built model with.
            if not model.built:
                raise ValueError(
                    'Sequential model must be built before it can be exported.'
                )
        else:
            raise NotImplementedError(
                'Exporting subclassed models is not yet supported.')

    export_dir = export_helpers.get_timestamped_export_dir(saved_model_path)
    temp_export_dir = export_helpers.get_temp_export_dir(export_dir)

    builder = saved_model_builder.SavedModelBuilder(temp_export_dir)

    # Manually save variables to export them in an object-based checkpoint. This
    # skips the `builder.add_meta_graph_and_variables()` step, which saves a
    # named-based checkpoint.
    # TODO(b/113134168): Add fn to Builder to save with object-based saver.
    # TODO(b/113178242): This should only export the model json structure. Only
    # one save is needed once the weights can be copied from the model to clone.
    checkpoint_path = _export_model_json_and_variables(model, temp_export_dir)

    # Export each mode. Use ModeKeys enums defined for `Estimator` to ensure that
    # Keras models and `Estimator`s are exported with the same format.
    # Every time a mode is exported, the code checks to see if new variables have
    # been created (e.g. optimizer slot variables). If that is the case, the
    # checkpoint is re-saved to include the new variables.
    export_args = {
        'builder': builder,
        'model': model,
        'custom_objects': custom_objects,
        'checkpoint_path': checkpoint_path
    }

    has_saved_vars = False
    if model.optimizer:
        if isinstance(model.optimizer, optimizers.TFOptimizer):
            _export_mode(model_fn_lib.ModeKeys.TRAIN, has_saved_vars,
                         **export_args)
            has_saved_vars = True
            _export_mode(model_fn_lib.ModeKeys.EVAL, has_saved_vars,
                         **export_args)
        else:
            logging.warning(
                'Model was compiled with an optimizer, but the optimizer is not from '
                '`tf.train` (e.g. `tf.train.AdagradOptimizer`). Only the serving '
                'graph was exported. The train and evaluate graphs were not added to '
                'the SavedModel.')
    _export_mode(model_fn_lib.ModeKeys.PREDICT, has_saved_vars, **export_args)

    builder.save(as_text)

    gfile.Rename(temp_export_dir, export_dir)
    return export_dir
Esempio n. 19
0
def _write_with_backup(filename, content):
    if gfile.Exists(filename):
        gfile.Rename(filename, filename + '.old', overwrite=True)
    with gfile.Open(filename, 'w') as f:
        f.write(content)
Esempio n. 20
0
    def test_dir_operations(self):
        """Test directory operations"""

        d = get_oss_path("d1/d2/d3/d4")
        gfile.MakeDirs(d)
        self.assertTrue(gfile.Stat(d).is_directory)

        # Test listing bucket directory with and without trailing '/'
        content = gfile.ListDirectory(
            "oss://%s\x01id=%s\x02key=%s\x02host=%s"
            % (bucket, access_id, access_key, host)
        )
        content_s = gfile.ListDirectory(
            "oss://%s\x01id=%s\x02key=%s\x02host=%s/"
            % (bucket, access_id, access_key, host)
        )
        self.assertEqual(content, content_s)
        self.assertIn("oss_fs_test", content)
        self.assertIn("oss_fs_test/d1", content)
        self.assertIn("oss_fs_test/d1/d2", content)

        # Test listing test directory with and without trailing '/'
        content = gfile.ListDirectory(
            "oss://%s\x01id=%s\x02key=%s\x02host=%s"
            % (bucket, access_id, access_key, host)
            + "/oss_fs_test"
        )
        content_s = gfile.ListDirectory(
            "oss://%s\x01id=%s\x02key=%s\x02host=%s"
            % (bucket, access_id, access_key, host)
            + "/oss_fs_test/"
        )
        self.assertEqual(content, content_s)
        self.assertIn("d1", content)
        self.assertIn("d1/d2", content)

        # Test listing sub directories.
        content = gfile.ListDirectory(get_oss_path("d1"))
        content_s = gfile.ListDirectory(get_oss_path("d1/"))
        self.assertEqual(content, content_s)
        self.assertIn("d2", content)

        content = gfile.ListDirectory(get_oss_path("d1/d2/d3/d4"))
        content_s = gfile.ListDirectory(get_oss_path("d1/d2/d3/d4"))
        self.assertEqual(content, content_s)
        self.assertEqual([], content)

        # Test Rename directories
        self.assertTrue(gfile.Exists(get_oss_path("d1")))
        gfile.Rename(get_oss_path("d1"), get_oss_path("rename_d1"), overwrite=True)
        self.assertTrue(gfile.Exists(get_oss_path("rename_d1")))
        self.assertFalse(gfile.Exists(get_oss_path("d1")))

        content = gfile.ListDirectory(get_oss_path("rename_d1"))
        content_s = gfile.ListDirectory(get_oss_path("rename_d1/"))
        self.assertEqual(content, content_s)
        self.assertIn("d2", content)

        # Test Rename non-empty directories
        not_empty_dir = get_oss_path("not_empty_dir/")
        rename_not_empty_dir = get_oss_path("rename_not_empty_dir/")
        gfile.MakeDirs(not_empty_dir)
        not_empty_file = get_oss_path("not_empty_dir/not_empty_file")
        rename_not_empty_file = get_oss_path("rename_not_empty_dir/not_empty_file")
        with gfile.Open(not_empty_file, mode="w") as fh:
            content = "file content"
            fh.write(content)
        self.assertTrue(gfile.Exists(not_empty_dir))
        self.assertTrue(gfile.Exists(not_empty_file))
        gfile.Rename(not_empty_dir, rename_not_empty_dir, overwrite=True)
        self.assertFalse(gfile.Exists(not_empty_dir))
        self.assertFalse(gfile.Exists(not_empty_file))
        self.assertTrue(gfile.Exists(rename_not_empty_dir))
        self.assertTrue(gfile.Exists(rename_not_empty_file))