def from_checkpoint( cls, checkpoint: Checkpoint, model_definition: Union[Callable[[], tf.keras.Model], Type[tf.keras.Model]], ) -> "TensorflowPredictor": """Instantiate the predictor from a Checkpoint. The checkpoint is expected to be a result of ``TensorflowTrainer``. Args: checkpoint: The checkpoint to load the model and preprocessor from. It is expected to be from the result of a ``TensorflowTrainer`` run. model_definition: A callable that returns a TensorFlow Keras model to use. Model weights will be loaded from the checkpoint. """ checkpoint_dict = checkpoint.to_dict() preprocessor = checkpoint_dict.get(PREPROCESSOR_KEY, None) if MODEL_KEY not in checkpoint_dict: raise RuntimeError( f"No item with key: {MODEL_KEY} is found in the " f"Checkpoint. Make sure this key exists when saving the " f"checkpoint in ``TensorflowTrainer``.") model_weights = checkpoint_dict[MODEL_KEY] return TensorflowPredictor( model_definition=model_definition, model_weights=model_weights, preprocessor=preprocessor, )
def _convert_directory_checkpoint_to_sync_if_needed( self, checkpoint: Checkpoint) -> Checkpoint: """Replace the directory checkpoint with a node ip & path dict checkpoint. This dict checkpoint will be used used to sync the directory. If we were to use a directory checkpoint directly, it would get deepcopied & serialized unnecessarily.""" with checkpoint.as_directory() as checkpoint_path: # Load checkpoint from path. checkpoint_path = Path(checkpoint_path).expanduser().absolute() if not checkpoint_path.joinpath(TUNE_CHECKPOINT_ID).exists(): # If the ID file is missing, we assume that this is already # a sync checkpoint dict_checkpoint = checkpoint.to_dict() if (NODE_IP_KEY not in dict_checkpoint or CHECKPOINT_PATH_ON_NODE_KEY not in dict_checkpoint): raise ValueError( "Wrong checkpoint format. Ensure the checkpoint is a " "result of `HuggingFaceTrainer`.") return checkpoint with open(checkpoint_path.joinpath(TUNE_CHECKPOINT_ID), "r") as f: tune_checkpoint_id = int(f.read()) return Checkpoint.from_dict({ NODE_IP_KEY: get_node_ip_address(), CHECKPOINT_PATH_ON_NODE_KEY: str(checkpoint_path), TUNE_CHECKPOINT_ID: tune_checkpoint_id, })
def from_checkpoint( cls, checkpoint: Checkpoint, model: Optional[torch.nn.Module] = None ) -> "TorchPredictor": """Instantiate the predictor from a Checkpoint. The checkpoint is expected to be a result of ``TorchTrainer``. Args: checkpoint: The checkpoint to load the model and preprocessor from. It is expected to be from the result of a ``TorchTrainer`` run. model: If the checkpoint contains a model state dict, and not the model itself, then the state dict will be loaded to this ``model``. """ checkpoint_dict = checkpoint.to_dict() preprocessor = checkpoint_dict.get(PREPROCESSOR_KEY, None) if MODEL_KEY not in checkpoint_dict: raise RuntimeError( f"No item with key: {MODEL_KEY} is found in the " f"Checkpoint. Make sure this key exists when saving the " f"checkpoint in ``TorchTrainer``." ) model = load_torch_model( saved_model=checkpoint_dict[MODEL_KEY], model_definition=model ) return TorchPredictor(model=model, preprocessor=preprocessor)
def _load_checkpoint( checkpoint: Checkpoint, trainer_name: str ) -> Tuple[Any, Optional[Preprocessor]]: """Load a Ray Train Checkpoint. This is a private API. Args: checkpoint: The checkpoint to load the weights and preprocessor from. trainer_name: Trainer class name to use in error message. Returns: The model or weights and AIR preprocessor contained within. """ checkpoint_dict = checkpoint.to_dict() preprocessor = checkpoint_dict.get(PREPROCESSOR_KEY, None) if MODEL_KEY not in checkpoint_dict: raise RuntimeError( f"No item with key: {MODEL_KEY} is found in the " f"Checkpoint. Make sure this key exists when saving the " f"checkpoint in ``{trainer_name}``." ) model = checkpoint_dict[MODEL_KEY] return model, preprocessor
def from_checkpoint(cls, checkpoint: Checkpoint, **kwargs) -> "DummyPredictor": checkpoint_data = checkpoint.to_dict() return DummyPredictor(**checkpoint_data)