示例#1
0
    def _convert_directory_checkpoint_to_sync_if_needed(
            self, checkpoint: Checkpoint) -> Checkpoint:
        """Replace the directory checkpoint with a node ip & path dict checkpoint.

        This dict checkpoint will be used to sync the directory.
        If we were to use a directory checkpoint directly, it would get deepcopied &
        serialized unnecessarily."""
        with checkpoint.as_directory() as checkpoint_path:
            # Load checkpoint from path.
            checkpoint_path = Path(checkpoint_path).expanduser().absolute()
            if not checkpoint_path.joinpath(TUNE_CHECKPOINT_ID).exists():
                # If the ID file is missing, we assume that this is already
                # a sync checkpoint
                dict_checkpoint = checkpoint.to_dict()
                if (NODE_IP_KEY not in dict_checkpoint
                        or CHECKPOINT_PATH_ON_NODE_KEY not in dict_checkpoint):
                    raise ValueError(
                        "Wrong checkpoint format. Ensure the checkpoint is a "
                        "result of `HuggingFaceTrainer`.")
                return checkpoint
            with open(checkpoint_path.joinpath(TUNE_CHECKPOINT_ID), "r") as f:
                tune_checkpoint_id = int(f.read())

            return Checkpoint.from_dict({
                NODE_IP_KEY:
                get_node_ip_address(),
                CHECKPOINT_PATH_ON_NODE_KEY:
                str(checkpoint_path),
                TUNE_CHECKPOINT_ID:
                tune_checkpoint_id,
            })
示例#2
0
 def from_checkpoint(cls, checkpoint: Checkpoint,
                     **kwargs) -> "DummyPredictor":
     with checkpoint.as_directory():
         # simulate reading
         time.sleep(1)
     checkpoint_data = checkpoint.to_dict()
     return DummyPredictor(**checkpoint_data)
def _load_checkpoint(
    checkpoint: Checkpoint, trainer_name: str
) -> Tuple[Any, Optional["Preprocessor"]]:
    """Load a Ray Train Checkpoint.

    This is a private API.

    Args:
        checkpoint: The checkpoint to load the weights and
            preprocessor from.
        trainer_name: Trainer class name to use in error
            message.

    Returns:
        The model or weights and AIR preprocessor contained within.
    """
    checkpoint_dict = checkpoint.to_dict()
    preprocessor = checkpoint_dict.get(PREPROCESSOR_KEY, None)
    if MODEL_KEY not in checkpoint_dict:
        raise RuntimeError(
            f"No item with key: {MODEL_KEY} is found in the "
            f"Checkpoint. Make sure this key exists when saving the "
            f"checkpoint in ``{trainer_name}``."
        )
    model = checkpoint_dict[MODEL_KEY]
    return model, preprocessor
示例#4
0
 def from_checkpoint(cls, checkpoint: Checkpoint,
                     **kwargs) -> "DummyPredictor":
     with checkpoint.as_directory():
         # simulate reading
         time.sleep(1)
     checkpoint_data = checkpoint.to_dict()
     preprocessor = checkpoint.get_preprocessor()
     return cls(checkpoint_data["factor"], preprocessor=preprocessor)
示例#5
0
 def from_checkpoint(cls,
                     checkpoint: Checkpoint,
                     use_gpu: bool = False,
                     **kwargs) -> "DummyPredictor":
     checkpoint_data = checkpoint.to_dict()
     preprocessor = checkpoint.get_preprocessor()
     return cls(checkpoint_data["factor"],
                preprocessor=preprocessor,
                use_gpu=use_gpu)
示例#6
0
 def from_checkpoint(cls,
                     checkpoint: Checkpoint,
                     do_double: bool = False) -> "AdderPredictor":
     return cls(checkpoint.to_dict()["increment"], do_double)
示例#7
0
 def from_checkpoint(
         cls,
         checkpoint: Checkpoint) -> "TakeArrayReturnDataFramePredictor":
     return cls(checkpoint.to_dict()["increment"])
示例#8
0
 def from_checkpoint(cls, checkpoint: Checkpoint, **kwargs) -> "DummyPredictor":
     checkpoint_data = checkpoint.to_dict()
     return DummyPredictor(**checkpoint_data)
示例#9
0
 def from_checkpoint(cls, checkpoint: Checkpoint,
                     **kwargs) -> "DummyPredictor":
     checkpoint_data = checkpoint.to_dict()
     preprocessor = checkpoint.get_preprocessor()
     return cls(checkpoint_data["factor"], preprocessor)
示例#10
0
 def from_checkpoint(cls, checkpoint: Checkpoint) -> "AdderPredictor":
     return cls(checkpoint.to_dict()["increment"])