def get_checkpoint_state(checkpoint_dir, latest_filename=None):
  """Returns CheckpointState proto from the "checkpoint" file.

  If the "checkpoint" file contains a valid CheckpointState
  proto, returns it.

  Args:
    checkpoint_dir: The directory of checkpoints.
    latest_filename: Optional name of the checkpoint file.  Default to
      'checkpoint'.

  Returns:
    A CheckpointState if the state was available, None
    otherwise.

  Raises:
    ValueError: if the checkpoint read doesn't have model_checkpoint_path set.
  """
  ckpt = None
  coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir,
                                                     latest_filename)
  f = None
  try:
    # Check that the file exists before opening it to avoid
    # many lines of errors from colossus in the logs.
    if file_io.file_exists(coord_checkpoint_filename):
      file_content = file_io.read_file_to_string(
          coord_checkpoint_filename)
      ckpt = CheckpointState()
      text_format.Merge(file_content, ckpt)
      if not ckpt.model_checkpoint_path:
        raise ValueError("Invalid checkpoint state loaded from "
                         + checkpoint_dir)
      # For relative model_checkpoint_path and all_model_checkpoint_paths,
      # prepend checkpoint_dir.
      if not os.path.isabs(ckpt.model_checkpoint_path):
        ckpt.model_checkpoint_path = os.path.join(checkpoint_dir,
                                                  ckpt.model_checkpoint_path)
      for i in range(len(ckpt.all_model_checkpoint_paths)):
        p = ckpt.all_model_checkpoint_paths[i]
        if not os.path.isabs(p):
          ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p)
  except errors.OpError as e:
    # It's ok if the file cannot be read
    logging.warning("%s: %s", type(e).__name__, e)
    logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
    return None
  except text_format.ParseError as e:
    logging.warning("%s: %s", type(e).__name__, e)
    logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
    return None
  finally:
    if f:
      f.close()
  return ckpt
Beispiel #2
0
def get_checkpoint_state(checkpoint_dir, latest_filename=None):
  """Returns CheckpointState proto from the "checkpoint" file.

  If the "checkpoint" file contains a valid CheckpointState
  proto, returns it.

  Args:
    checkpoint_dir: The directory of checkpoints.
    latest_filename: Optional name of the checkpoint file.  Default to
      'checkpoint'.

  Returns:
    A CheckpointState if the state was available, None
    otherwise.
  """
  ckpt = None
  coord_checkpoint_filename = _GetCheckpointFilename(
      checkpoint_dir, latest_filename)
  f = None
  try:
    # Check that the file exists before opening it to avoid
    # many lines of errors from colossus in the logs.
    if pywrap_tensorflow.file_exists(coord_checkpoint_filename):
      file_content = pywrap_tensorflow.read_file_to_string(
          coord_checkpoint_filename).decode("utf-8")
      ckpt = CheckpointState()
      text_format.Merge(file_content, ckpt)
      # For relative model_checkpoint_path and all_model_checkpoint_paths,
      # prepend checkpoint_dir.
      if not os.path.isabs(checkpoint_dir):
        if not os.path.isabs(ckpt.model_checkpoint_path):
          ckpt.model_checkpoint_path = os.path.join(checkpoint_dir,
                                                    ckpt.model_checkpoint_path)
        for i in range(len(ckpt.all_model_checkpoint_paths)):
          p = ckpt.all_model_checkpoint_paths[i]
          if not os.path.isabs(p):
            ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p)
  except IOError:
    # It's ok if the file cannot be read
    return None
  except text_format.ParseError as e:
    logging.warning(str(e))
    logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
    return None
  finally:
    if f:
      f.close()
  return ckpt
Beispiel #3
0
def generate_checkpoint_state_proto(save_dir,
                                    model_checkpoint_path,
                                    all_model_checkpoint_paths=None,
                                    all_model_checkpoint_timestamps=None,
                                    last_preserved_timestamp=None):
    """Generates a checkpoint state proto.

  Args:
    save_dir: Directory where the model was saved.
    model_checkpoint_path: The checkpoint file.
    all_model_checkpoint_paths: List of strings.  Paths to all not-yet-deleted
      checkpoints, sorted from oldest to newest.  If this is a non-empty list,
      the last element must be equal to model_checkpoint_path.  These paths
      are also saved in the CheckpointState proto.
    all_model_checkpoint_timestamps: A list of floats, indicating the number of
      seconds since the Epoch when each checkpoint was generated.
    last_preserved_timestamp: A float, indicating the number of seconds since
      the Epoch when the last preserved checkpoint was written, e.g. due to a
      `keep_checkpoint_every_n_hours` parameter (see
      `tf.train.CheckpointManager` for an implementation).
  Returns:
    CheckpointState proto with model_checkpoint_path and
    all_model_checkpoint_paths updated to either absolute paths or
    relative paths to the current save_dir.

  Raises:
    ValueError: If `all_model_checkpoint_timestamps` was provided but its length
      does not match `all_model_checkpoint_paths`.
  """
    if all_model_checkpoint_paths is None:
        all_model_checkpoint_paths = []

    if (not all_model_checkpoint_paths
            or all_model_checkpoint_paths[-1] != model_checkpoint_path):
        logging.info(
            "%s is not in all_model_checkpoint_paths. Manually adding it.",
            model_checkpoint_path)
        all_model_checkpoint_paths.append(model_checkpoint_path)

    if (all_model_checkpoint_timestamps
            and (len(all_model_checkpoint_timestamps) !=
                 len(all_model_checkpoint_paths))):
        raise ValueError((
            "Checkpoint timestamps, if provided, must match checkpoint paths (got "
            "paths %s and timestamps %s)") % (all_model_checkpoint_paths,
                                              all_model_checkpoint_timestamps))

    # Relative paths need to be rewritten to be relative to the "save_dir"
    # if model_checkpoint_path already contains "save_dir".
    if not os.path.isabs(save_dir):
        if not os.path.isabs(model_checkpoint_path):
            model_checkpoint_path = os.path.relpath(model_checkpoint_path,
                                                    save_dir)
        for i, p in enumerate(all_model_checkpoint_paths):
            if not os.path.isabs(p):
                all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir)

    coord_checkpoint_proto = CheckpointState(
        model_checkpoint_path=model_checkpoint_path,
        all_model_checkpoint_paths=all_model_checkpoint_paths,
        all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
        last_preserved_timestamp=last_preserved_timestamp)

    return coord_checkpoint_proto
Beispiel #4
0
        evaluator_task.evaluate(mock_experiment)

        assert len(mock_exporter.export.call_args_list) == len(ckpt_to_export)
        assert len(mock_experiment.estimator.evaluate.call_args_list) == len(
            ckpt_to_export)
        export_path = os.path.join(mock_experiment.estimator.model_dir,
                                   mock_exporter.name)
        if len(ckpt_to_export) > 0:
            for ckpt in ckpt_to_export:
                mock_exporter.export.assert_any_call(ANY, export_path, ckpt,
                                                     ANY, ANY)
                mock_experiment.estimator.evaluate(ANY,
                                                   steps=ANY,
                                                   hooks=ANY,
                                                   name=ANY,
                                                   checkpoint_path=ckpt)


@pytest.mark.parametrize("checkpoint_state,checkpoints", [
    (CheckpointState(
        all_model_checkpoint_paths=["/path/to/model/dir/model.ckpt-300"]),
     ["/path/to/model/dir/model.ckpt-300"]),
    (None, []),
])
def test__get_all_checkpoints(checkpoint_state, checkpoints):
    with mock.patch(
            "tf_yarn.tensorflow.tasks.evaluator_task.tf.train.get_checkpoint_state"
    ) as get_checkpoint_state_mock:
        get_checkpoint_state_mock.side_effect = lambda *args, **kwargs: checkpoint_state
        assert evaluator_task._get_all_checkpoints("dir") == checkpoints
Beispiel #5
0
    def download_model(self, checkpoint_dir):
        s3_client = self.get_client()
        logger.info(
            "Downloading pretrained model from %s/%s %s" %
            (self.bucket, self.model_checkpoints_prefix, checkpoint_dir))
        filename = "None"
        try:
            filename = os.path.abspath(
                os.path.join(checkpoint_dir, "checkpoint"))
            if not os.path.exists(checkpoint_dir):
                logger.info("Model folder %s does not exist, creating" %
                            filename)
                os.makedirs(checkpoint_dir)

            while True:
                response = s3_client.list_objects_v2(Bucket=self.bucket,
                                                     Prefix=self._get_s3_key(
                                                         self.lock_file))

                if "Contents" not in response:
                    # If no lock is found, try getting the checkpoint
                    try:
                        key = self._get_s3_key("checkpoint")
                        logger.info("Downloading %s" % key)
                        s3_client.download_file(Bucket=self.bucket,
                                                Key=key,
                                                Filename=filename)
                    except Exception as e:
                        logger.info(
                            "Something went wrong, will retry in 2 seconds %s"
                            % e)
                        time.sleep(2)
                        continue
                else:
                    logger.info("Found a lock file %s , waiting" %
                                self._get_s3_key(self.lock_file))
                    time.sleep(2)
                    continue

                ckpt = CheckpointState()
                if os.path.exists(filename):
                    contents = open(filename, 'r').read()
                    text_format.Merge(contents, ckpt)
                    rel_path = ckpt.model_checkpoint_path
                    checkpoint = int(rel_path.split('_Step')[0])

                    response = s3_client.list_objects_v2(
                        Bucket=self.bucket, Prefix=self._get_s3_key(rel_path))
                    if "Contents" in response:
                        num_files = 0
                        for obj in response["Contents"]:
                            filename = os.path.abspath(
                                os.path.join(
                                    checkpoint_dir, obj["Key"].replace(
                                        self.model_checkpoints_prefix, "")))

                            logger.info("Downloading model file %s" % filename)
                            s3_client.download_file(Bucket=self.bucket,
                                                    Key=obj["Key"],
                                                    Filename=filename)
                            num_files += 1
                        return True

        except Exception as e:
            util.json_format_logger(
                "{} while downloading the model {} from S3".format(
                    e, filename),
                **util.build_system_error_dict(
                    utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                    utils.SIMAPP_EVENT_ERROR_CODE_500))
            return False
    def download_model(self, checkpoint_dir):
        s3_client = self.get_client()
        filename = "None"
        try:
            filename = os.path.abspath(
                os.path.join(checkpoint_dir, "checkpoint"))
            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)

            while True:
                response = s3_client.list_objects_v2(Bucket=self.bucket,
                                                     Prefix=self._get_s3_key(
                                                         self.lock_file))

                if "Contents" not in response:
                    # If no lock is found, try getting the checkpoint
                    try:
                        s3_client.download_file(
                            Bucket=self.bucket,
                            Key=self._get_s3_key("checkpoint"),
                            Filename=filename)
                    except Exception as e:
                        time.sleep(2)
                        continue
                else:
                    time.sleep(2)
                    continue

                ckpt = CheckpointState()
                if os.path.exists(filename):
                    contents = open(filename, 'r').read()
                    text_format.Merge(contents, ckpt)
                    rel_path = ckpt.model_checkpoint_path
                    checkpoint = int(rel_path.split('_Step')[0])

                    response = s3_client.list_objects_v2(
                        Bucket=self.bucket, Prefix=self._get_s3_key(rel_path))
                    if "Contents" in response:
                        num_files = 0
                        for obj in response["Contents"]:
                            filename = os.path.abspath(
                                os.path.join(
                                    checkpoint_dir, obj["Key"].replace(
                                        self.model_checkpoints_prefix, "")))
                            s3_client.download_file(Bucket=self.bucket,
                                                    Key=obj["Key"],
                                                    Filename=filename)
                            num_files += 1
                        return
        except botocore.exceptions.ClientError as e:
            utils.json_format_logger(
                "Unable to download model {} from {}: {}".format(
                    filename, self.bucket, e.response['Error']['Code']),
                **utils.build_user_error_dict(
                    utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                    utils.SIMAPP_EVENT_ERROR_CODE_400))
            utils.simapp_exit_gracefully()
        except Exception as e:
            utils.json_format_logger(
                "Unable to download model {} from {}: {}".format(
                    filename, self.bucket, e),
                **utils.build_system_error_dict(
                    utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                    utils.SIMAPP_EVENT_ERROR_CODE_500))
            utils.simapp_exit_gracefully()