コード例 #1
0
    def _load_checkpoint(cls,
                         path,
                         data_silo,
                         model,
                         optimizer,
                         local_rank=-1):
        """
        Load the train checkpoint at given path.

        :param path: The checkpoint path is subdirectory under checkpoint_root_dir. The individual checkpoint dirs have
               a default naming convention of "epoch_{epoch_num}_step_{step_num}".
        :type path: Path
        :param data_silo: A DataSilo object that will contain the train, dev and test datasets as PyTorch DataLoaders
        :type data_silo: DataSilo
        """
        if not path.exists():
            raise Exception(f"The checkpoint path {path} does not exists.")

        # In distributed mode, we save the model only once from process 0 (using cuda:0)
        # At loading time, we need to load the model to the current cuda device (instead of back to cuda:0)
        # Note: This assumes exactly one GPU per process (as recommended by PyTorch)
        if local_rank == -1:
            map_location = None
        else:
            device = torch.device(f"cuda:{local_rank}")
            map_location = {'cuda:0': f'cuda:{local_rank}'}

        trainer_checkpoint = torch.load(path / "trainer",
                                        map_location=map_location)
        trainer_state_dict = trainer_checkpoint["trainer_state"]
        if local_rank != -1:
            trainer_state_dict["device"] = device
            trainer_state_dict["local_rank"] = local_rank

        # Just setting seeds is not sufficient to have deterministic results when resuming
        # training from a checkpoint. Additionally, the previous states of Random Number
        # Generators also need to be restored from the saved checkpoint.
        numpy_rng_state = trainer_checkpoint["numpy_rng_state"]
        numpy.random.set_state(numpy_rng_state)
        rng_state = trainer_checkpoint["rng_state"]
        cuda_rng_state = trainer_checkpoint["cuda_rng_state"]
        torch.set_rng_state(rng_state)
        torch.cuda.set_rng_state(cuda_rng_state)

        model.load_state_dict(trainer_checkpoint["model_state"], strict=True)
        optimizer.load_state_dict(trainer_checkpoint["optimizer_state"])

        scheduler_state_dict = trainer_checkpoint["scheduler_state"]
        scheduler_opts = trainer_checkpoint["scheduler_opts"]
        scheduler = get_scheduler(optimizer, scheduler_opts)
        scheduler.load_state_dict(scheduler_state_dict)

        trainer = Trainer(data_silo=data_silo,
                          model=model,
                          optimizer=optimizer,
                          lr_schedule=scheduler,
                          **trainer_state_dict)

        logger.info(f"Loaded a train checkpoint from {path}")
        return trainer
コード例 #2
0
    def _load_checkpoint(cls, path, data_silo):
        """
        Load the train checkpoint at given path.

        :param path: The checkpoint path is subdirectory under checkpoint_root_dir. The individual checkpoint dirs have
               a default naming convention of "epoch_{epoch_num}_step_{step_num}".
        :type path: Path
        :param data_silo: A DataSilo object that will contain the train, dev and test datasets as PyTorch DataLoaders
        :type data_silo: DataSilo
        """
        if not path.exists():
            raise Exception(f"The checkpoint path {path} does not exists.")

        trainer_checkpoint = torch.load(path / "trainer")
        trainer_state_dict = trainer_checkpoint["trainer_state_dict"]

        # Just setting seeds is not sufficient to have deterministic results when resuming
        # training from a checkpoint. Additionally, the previous states of Random Number
        # Generators also need to be restored from the saved checkpoint.
        numpy_rng_state = trainer_checkpoint["numpy_rng_state"]
        numpy.random.set_state(numpy_rng_state)
        rng_state = trainer_checkpoint["rng_state"]
        cuda_rng_state = trainer_checkpoint["cuda_rng_state"]
        torch.set_rng_state(rng_state)
        torch.cuda.set_rng_state(cuda_rng_state)

        model = trainer_checkpoint["model"]

        optimizer = trainer_checkpoint["optimizer"]

        scheduler_state_dict = trainer_checkpoint["scheduler_state"]
        scheduler_opts = trainer_checkpoint["scheduler_opts"]
        scheduler_opts["last_epoch"] = scheduler_state_dict["last_epoch"]
        scheduler = get_scheduler(optimizer, scheduler_opts)
        scheduler.load_state_dict(scheduler_state_dict)

        trainer = Trainer(
            data_silo=data_silo,
            model=model,
            optimizer=optimizer,
            lr_schedule=scheduler,
            **trainer_state_dict
        )

        logger.info(f"Loaded a train checkpoint from {path}")
        return trainer