Пример #1
0
    def from_checkpoint(
        cls,
        checkpoint: Checkpoint,
        model_definition: Union[Callable[[], tf.keras.Model],
                                Type[tf.keras.Model]],
    ) -> "TensorflowPredictor":
        """Instantiate the predictor from a Checkpoint.

        The checkpoint is expected to be a result of ``TensorflowTrainer``.

        Args:
            checkpoint: The checkpoint to load the model and
                preprocessor from. It is expected to be from the result of a
                ``TensorflowTrainer`` run.
            model_definition: A callable that returns a TensorFlow Keras model
                to use. Model weights will be loaded from the checkpoint.
        """
        checkpoint_dict = checkpoint.to_dict()
        preprocessor = checkpoint_dict.get(PREPROCESSOR_KEY, None)
        if MODEL_KEY not in checkpoint_dict:
            raise RuntimeError(
                f"No item with key: {MODEL_KEY} is found in the "
                f"Checkpoint. Make sure this key exists when saving the "
                f"checkpoint in ``TensorflowTrainer``.")
        model_weights = checkpoint_dict[MODEL_KEY]
        return TensorflowPredictor(
            model_definition=model_definition,
            model_weights=model_weights,
            preprocessor=preprocessor,
        )
Пример #2
0
    def _convert_directory_checkpoint_to_sync_if_needed(
            self, checkpoint: Checkpoint) -> Checkpoint:
        """Replace the directory checkpoint with a node ip & path dict checkpoint.

        This dict checkpoint will be used used to sync the directory.
        If we were to use a directory checkpoint directly, it would get deepcopied &
        serialized unnecessarily."""
        with checkpoint.as_directory() as checkpoint_path:
            # Load checkpoint from path.
            checkpoint_path = Path(checkpoint_path).expanduser().absolute()
            if not checkpoint_path.joinpath(TUNE_CHECKPOINT_ID).exists():
                # If the ID file is missing, we assume that this is already
                # a sync checkpoint
                dict_checkpoint = checkpoint.to_dict()
                if (NODE_IP_KEY not in dict_checkpoint
                        or CHECKPOINT_PATH_ON_NODE_KEY not in dict_checkpoint):
                    raise ValueError(
                        "Wrong checkpoint format. Ensure the checkpoint is a "
                        "result of `HuggingFaceTrainer`.")
                return checkpoint
            with open(checkpoint_path.joinpath(TUNE_CHECKPOINT_ID), "r") as f:
                tune_checkpoint_id = int(f.read())

            return Checkpoint.from_dict({
                NODE_IP_KEY:
                get_node_ip_address(),
                CHECKPOINT_PATH_ON_NODE_KEY:
                str(checkpoint_path),
                TUNE_CHECKPOINT_ID:
                tune_checkpoint_id,
            })
Пример #3
0
    def from_checkpoint(
        cls, checkpoint: Checkpoint, model: Optional[torch.nn.Module] = None
    ) -> "TorchPredictor":
        """Instantiate the predictor from a Checkpoint.

        The checkpoint is expected to be a result of ``TorchTrainer``.

        Args:
            checkpoint: The checkpoint to load the model and
                preprocessor from. It is expected to be from the result of a
                ``TorchTrainer`` run.
            model: If the checkpoint contains a model state dict, and not
                the model itself, then the state dict will be loaded to this
                ``model``.
        """
        checkpoint_dict = checkpoint.to_dict()
        preprocessor = checkpoint_dict.get(PREPROCESSOR_KEY, None)
        if MODEL_KEY not in checkpoint_dict:
            raise RuntimeError(
                f"No item with key: {MODEL_KEY} is found in the "
                f"Checkpoint. Make sure this key exists when saving the "
                f"checkpoint in ``TorchTrainer``."
            )
        model = load_torch_model(
            saved_model=checkpoint_dict[MODEL_KEY], model_definition=model
        )
        return TorchPredictor(model=model, preprocessor=preprocessor)
Пример #4
0
def _load_checkpoint(
    checkpoint: Checkpoint, trainer_name: str
) -> Tuple[Any, Optional[Preprocessor]]:
    """Load a Ray Train Checkpoint.

    This is a private API.

    Args:
        checkpoint: The checkpoint to load the weights and
            preprocessor from.
        trainer_name: Trainer class name to use in error
            message.

    Returns:
        The model or weights and AIR preprocessor contained within.
    """
    checkpoint_dict = checkpoint.to_dict()
    preprocessor = checkpoint_dict.get(PREPROCESSOR_KEY, None)
    if MODEL_KEY not in checkpoint_dict:
        raise RuntimeError(
            f"No item with key: {MODEL_KEY} is found in the "
            f"Checkpoint. Make sure this key exists when saving the "
            f"checkpoint in ``{trainer_name}``."
        )
    model = checkpoint_dict[MODEL_KEY]
    return model, preprocessor
Пример #5
0
 def from_checkpoint(cls, checkpoint: Checkpoint, **kwargs) -> "DummyPredictor":
     checkpoint_data = checkpoint.to_dict()
     return DummyPredictor(**checkpoint_data)