예제 #1
0
def get_checkpoint_state(checkpoint_dir, latest_filename=None):
    """Returns CheckpointState proto from the "checkpoint" file.

  If the "checkpoint" file contains a valid CheckpointState
  proto, returns it.

  Args:
    checkpoint_dir: The directory of checkpoints.
    latest_filename: Optional name of the checkpoint file.  Default to
      'checkpoint'.

  Returns:
    A CheckpointState if the state was available, None
    otherwise.

  Raises:
    ValueError: if the checkpoint read doesn't have model_checkpoint_path set.
  """
    if isinstance(checkpoint_dir, os.PathLike):
        checkpoint_dir = os.fspath(checkpoint_dir)
    ckpt = None
    coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir,
                                                       latest_filename)
    f = None
    try:
        # Check that the file exists before opening it to avoid
        # many lines of errors from colossus in the logs.
        if file_io.file_exists(coord_checkpoint_filename):
            file_content = file_io.read_file_to_string(
                coord_checkpoint_filename)
            ckpt = CheckpointState()
            text_format.Merge(file_content, ckpt)
            if not ckpt.model_checkpoint_path:
                raise ValueError("Invalid checkpoint state loaded from " +
                                 checkpoint_dir)
            # For relative model_checkpoint_path and all_model_checkpoint_paths,
            # prepend checkpoint_dir.
            if not os.path.isabs(ckpt.model_checkpoint_path):
                ckpt.model_checkpoint_path = os.path.join(
                    checkpoint_dir, ckpt.model_checkpoint_path)
            for i, p in enumerate(ckpt.all_model_checkpoint_paths):
                if not os.path.isabs(p):
                    ckpt.all_model_checkpoint_paths[i] = os.path.join(
                        checkpoint_dir, p)
    except errors.OpError as e:
        # It's ok if the file cannot be read
        logging.warning("%s: %s", type(e).__name__, e)
        logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
        return None
    except text_format.ParseError as e:
        logging.warning("%s: %s", type(e).__name__, e)
        logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
        return None
    finally:
        if f:
            f.close()
    return ckpt
def get_checkpoint_state(checkpoint_dir, latest_filename=None):
  """Returns CheckpointState proto from the "checkpoint" file.

  If the "checkpoint" file contains a valid CheckpointState
  proto, returns it.

  Args:
    checkpoint_dir: The directory of checkpoints.
    latest_filename: Optional name of the checkpoint file.  Default to
      'checkpoint'.

  Returns:
    A CheckpointState if the state was available, None
    otherwise.

  Raises:
    ValueError: if the checkpoint read doesn't have model_checkpoint_path set.
  """
  ckpt = None
  coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir,
                                                     latest_filename)
  f = None
  try:
    # Check that the file exists before opening it to avoid
    # many lines of errors from colossus in the logs.
    if file_io.file_exists(coord_checkpoint_filename):
      file_content = file_io.read_file_to_string(
          coord_checkpoint_filename)
      ckpt = CheckpointState()
      text_format.Merge(file_content, ckpt)
      if not ckpt.model_checkpoint_path:
        raise ValueError("Invalid checkpoint state loaded from "
                         + checkpoint_dir)
      # For relative model_checkpoint_path and all_model_checkpoint_paths,
      # prepend checkpoint_dir.
      if not os.path.isabs(ckpt.model_checkpoint_path):
        ckpt.model_checkpoint_path = os.path.join(checkpoint_dir,
                                                  ckpt.model_checkpoint_path)
      for i in range(len(ckpt.all_model_checkpoint_paths)):
        p = ckpt.all_model_checkpoint_paths[i]
        if not os.path.isabs(p):
          ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p)
  except errors.OpError as e:
    # It's ok if the file cannot be read
    logging.warning("%s: %s", type(e).__name__, e)
    logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
    return None
  except text_format.ParseError as e:
    logging.warning("%s: %s", type(e).__name__, e)
    logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
    return None
  finally:
    if f:
      f.close()
  return ckpt
예제 #3
0
    def _run():
        from google.protobuf import text_format
        from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState

        with open(ckpt_path, "rb") as fp:
            ckpt_str = fp.read()
        
        ckpt = CheckpointState()
        text_format.Merge(ckpt_str, ckpt)

        new_ckpt = CheckpointState()
        new_ckpt.model_checkpoint_path = "./{}".format(os.path.basename(ckpt.model_checkpoint_path))
        
        with open(ckpt_path, "wb") as fp:
            fp.write(text_format.MessageToBytes(new_ckpt))
예제 #4
0
def get_checkpoint_state(checkpoint_dir, latest_filename=None):
    """Returns CheckpointState proto from the "checkpoint" file.

  If the "checkpoint" file contains a valid CheckpointState
  proto, returns it.

  Args:
    checkpoint_dir: The directory of checkpoints.
    latest_filename: Optional name of the checkpoint file.  Default to
      'checkpoint'.

  Returns:
    A CheckpointState if the state was available, None
    otherwise.
  """
    ckpt = None
    coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir,
                                                       latest_filename)
    f = None
    try:
        # Check that the file exists before opening it to avoid
        # many lines of errors from colossus in the logs.
        if gfile.Exists(coord_checkpoint_filename):
            f = gfile.FastGFile(coord_checkpoint_filename, mode="r")
            ckpt = CheckpointState()
            text_format.Merge(f.read(), ckpt)
            # For relative model_checkpoint_path and all_model_checkpoint_paths,
            # prepend checkpoint_dir.
            if not os.path.isabs(checkpoint_dir):
                if not os.path.isabs(ckpt.model_checkpoint_path):
                    ckpt.model_checkpoint_path = os.path.join(
                        checkpoint_dir, ckpt.model_checkpoint_path)
                for i in range(len(ckpt.all_model_checkpoint_paths)):
                    p = ckpt.all_model_checkpoint_paths[i]
                    if not os.path.isabs(p):
                        ckpt.all_model_checkpoint_paths[i] = os.path.join(
                            checkpoint_dir, p)
    except IOError:
        # It's ok if the file cannot be read
        return None
    except text_format.ParseError as e:
        logging.warning(str(e))
        logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
        return None
    finally:
        if f:
            f.close()
    return ckpt
예제 #5
0
파일: saver.py 프로젝트: smrowe4/tensorflow
def get_checkpoint_state(checkpoint_dir, latest_filename=None):
  """Returns CheckpointState proto from the "checkpoint" file.

  If the "checkpoint" file contains a valid CheckpointState
  proto, returns it.

  Args:
    checkpoint_dir: The directory of checkpoints.
    latest_filename: Optional name of the checkpoint file.  Default to
      'checkpoint'.

  Returns:
    A CheckpointState if the state was available, None
    otherwise.
  """
  ckpt = None
  coord_checkpoint_filename = _GetCheckpointFilename(
      checkpoint_dir, latest_filename)
  f = None
  try:
    # Check that the file exists before opening it to avoid
    # many lines of errors from colossus in the logs.
    if pywrap_tensorflow.file_exists(coord_checkpoint_filename):
      file_content = pywrap_tensorflow.read_file_to_string(
          coord_checkpoint_filename).decode("utf-8")
      ckpt = CheckpointState()
      text_format.Merge(file_content, ckpt)
      # For relative model_checkpoint_path and all_model_checkpoint_paths,
      # prepend checkpoint_dir.
      if not os.path.isabs(checkpoint_dir):
        if not os.path.isabs(ckpt.model_checkpoint_path):
          ckpt.model_checkpoint_path = os.path.join(checkpoint_dir,
                                                    ckpt.model_checkpoint_path)
        for i in range(len(ckpt.all_model_checkpoint_paths)):
          p = ckpt.all_model_checkpoint_paths[i]
          if not os.path.isabs(p):
            ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p)
  except IOError:
    # It's ok if the file cannot be read
    return None
  except text_format.ParseError as e:
    logging.warning(str(e))
    logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
    return None
  finally:
    if f:
      f.close()
  return ckpt