Example #1
0
 def snapshot(self,
              net: torch.nn.Module,
              opt: Optimizer,
              sched: _LRScheduler = None,
              epoch: int = None,
              subdir='.'):
     """
     Writes a snapshot of the training, i.e. network weights, optimizer state and scheduler state to a file
     in the log directory.
     :param net: the neural network
     :param opt: the optimizer used
     :param sched: the learning rate scheduler used
     :param epoch: the current epoch
     :param subdir: if given, creates a subdirectory in the log directory. The data is written to a file
         in this subdirectory instead.
     :return:
     """
     outfile = pt.join(self.dir, subdir, 'snapshot.pt')
     if not pt.exists(os.path.dirname(outfile)):
         os.makedirs(os.path.dirname(outfile))
     torch.save(
         {
             'net': net.state_dict(),
             'opt': opt.state_dict(),
             'sched': sched.state_dict(),
             'epoch': epoch
         }, outfile)
     return outfile
Example #2
0
def save_ckpt(model: nn.Module, optim: optimizer.Optimizer, epoch_id: int, best_score: float, model_save_path: str):
    torch.save({
        _MODEL_STATE_DICT: model.state_dict(),
        _OPTIMIZER_STATE_DICT: optim.state_dict(),
        _EPOCH: epoch_id,
        _BEST_SCORE: best_score
    }, model_save_path)
Example #3
0
def log_checkpoints(
    checkpoint_dir: Path,
    model: Union[nn.Module, nn.DataParallel],
    optimizer: Optimizer,
    scheduler: optim.lr_scheduler._LRScheduler,
    epoch: int,
) -> None:
    """
    Serialize a PyTorch model in the `checkpoint_dir`.

    Args:
        checkpoint_dir: the directory to store checkpoints
        model: the model to serialize
        optimizer: the optimizer to be saved
        scheduler: the LR scheduler to be saved
        epoch: the epoch number
    """
    checkpoint_file = 'checkpoint_{}.pt'.format(epoch)
    checkpoint_dir.mkdir(exist_ok=True, parents=True)
    file_path = checkpoint_dir / checkpoint_file

    if isinstance(model, nn.DataParallel):
        model_state_dict = model.module.state_dict()
    else:
        model_state_dict = model.state_dict()

    torch.save(  # type: ignore
        {
            'epoch': epoch,
            'model_state_dict': model_state_dict,
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
        },
        file_path,
    )
def save_checkpoint(model: torch.nn.Module,
                    optimizer: Optimizer,
                    epoch: int,
                    args: ModelConfigBase,
                    mean_teacher_model: bool = False) -> None:
    """
    Saves a checkpoint of the current model and optimizer_type parameters in the specified folder
    and uploads it to the output blob storage of the current run context.
    The checkpoint's name for epoch 123 would be 123_checkpoint.pth.tar.

    :param model: A DataParallel object representing the model.
    :param optimizer: The optimizer_type used for training.
    :param epoch: The last epoch used to train the model.
    :param args:
    :param mean_teacher_model: If True save to the mean teacher model checkpoint path.
    """
    logging.getLogger().disabled = True

    model_state_dict = model.module.state_dict() if isinstance(
        model, torch.nn.DataParallel) else model.state_dict()
    checkpoint_file_path = args.get_path_to_checkpoint(epoch,
                                                       mean_teacher_model)
    info_to_store = {
        'epoch': epoch,
        'state_dict': model_state_dict,
        'opt_dict': optimizer.state_dict()
    }
    torch.save(info_to_store, checkpoint_file_path)
    logging.getLogger().disabled = False
    logging.info("Saved model checkpoint for epoch {} to {}".format(
        epoch, checkpoint_file_path))
Example #5
0
def save_checkpoint(
    model: nn.Module = None,
    optimizer: optim.Optimizer = None,
    scheduler: sche._LRScheduler = None,
    amp=None,
    exp_name: str = "",
    current_epoch: int = 1,
    full_net_path: str = "",
    state_net_path: str = "",
):
    """
    保存完整参数模型(大)和状态参数模型(小)

    Args:
        model (nn.Module): model object
        optimizer (optim.Optimizer): optimizer object
        scheduler (sche._LRScheduler): scheduler object
        amp (): apex.amp
        exp_name (str): exp_name
        current_epoch (int): in the epoch, model **will** be trained
        full_net_path (str): the path for saving the full model parameters
        state_net_path (str): the path for saving the state dict.
    """

    state_dict = {
        "arch": exp_name,
        "epoch": current_epoch,
        "net_state": model.state_dict(),
        "opti_state": optimizer.state_dict(),
        "sche_state": scheduler.state_dict(),
        "amp_state": amp.state_dict() if amp else None,
    }
    torch.save(state_dict, full_net_path)
    torch.save(model.state_dict(), state_net_path)
Example #6
0
    def __init__(
            self,
            network: torch.nn.Module,
            criterion: torch.nn.modules.loss,
            optimizer: Optimizer,
            data_type: torch.dtype = torch.float32,
            batch_size: int = 256,
            shuffle_training_examples: bool = False,
            scheduler: Optional[lr._LRScheduler] = None,
    ):
        self.network = network
        if network:
            self.network_initial_state_dict = network.state_dict()
        self.criterion = criterion
        self.optimizer = optimizer
        if optimizer:
            self.optimizer_initial_state_dict = optimizer.state_dict()
        self.data_type = data_type
        self.batch_size = batch_size
        self.shuffle_training_examples = shuffle_training_examples
        if scheduler:
            self.scheduler = scheduler
        else:
            self.scheduler = lr.StepLR(self.optimizer, step_size=1, gamma=0.99)
        self.scheduler_initial_state_dict = self.scheduler.state_dict()

        # placeholders until functions which assign them are called
        self.network_output = None
        self.training_average_loss = None
        self.validation_average_loss = None
        self.epochs_trained = 0
        self.maximum_epochs_allowed = None
        self.training_dataloader = None
        self.validation_dataloader = None
        self.learned_params = None
Example #7
0
def lr_find(model: UNet,
            data_loader,
            optimizer: Optimizer,
            criterion,
            use_gpu,
            min_lr=0.0001,
            max_lr=0.1):
    # Save model and optimizer states to revert
    model_state = model.state_dict()
    optimizer_state = optimizer.state_dict()

    losses = []
    lrs = []
    scheduler = CyclicExpLR(optimizer,
                            min_lr,
                            max_lr,
                            step_size_up=100,
                            mode='triangular',
                            cycle_momentum=True)
    model.train()
    for i, (data, target, class_ids) in enumerate(data_loader):
        data, target = data, target

        if use_gpu:
            data = data.cuda()
            target = target.cuda()

        optimizer.zero_grad()
        output_raw = model(data)
        # This step is specific for this project
        output = torch.zeros(output_raw.shape[0], 1, output_raw.shape[2],
                             output_raw.shape[3])

        if use_gpu:
            output = output.cuda()

        # This step is specific for this project
        for idx, (raw_o, class_id) in enumerate(zip(output_raw, class_ids)):
            output[idx] = raw_o[class_id - 1]

        loss = criterion(output, target)
        loss.backward()
        current_lr = optimizer.param_groups[0]['lr']
        # Stop if lr stopped increasing
        if len(lrs) > 0 and current_lr < lrs[-1]:
            break
        lrs.append(current_lr)
        losses.append(loss.item())
        optimizer.step()
        scheduler.step()

    # Plot in log scale
    plt.plot(lrs, losses)
    plt.xscale('log')

    plt.show()

    model.load_state_dict(model_state)
    optimizer.load_state_dict(optimizer_state)
Example #8
0
def _get_optimizer_kwargs(optimizer: Optimizer) -> Mapping[str, Any]:
    optimizer_kwargs = optimizer.state_dict()
    optimizer_kwargs = {
        key: value
        for key, value in optimizer_kwargs['param_groups'][0].items()
        if key != 'params'
    }
    return optimizer_kwargs
Example #9
0
def save_checkpoint(model: nn.Module, optim: optimizer.Optimizer,
                    epoch_id: int, step: int, best_score: float):
    torch.save(
        {
            _MODEL_STATE_DICT: model.state_dict(),
            _OPTIMIZER_STATE_DICT: optim.state_dict(),
            _EPOCH: epoch_id,
            _STEP: step,
            _BEST_SCORE: best_score
        }, "checkpoint.tar")
Example #10
0
def save_checkpoint(model: nn.Module, optim: optimizer.Optimizer,
                    epoch_id: int, step: int, best_score: float, loss: float,
                    save_path="./result/fr_en/checkpoint.tar"):
    torch.save({
        _MODEL_STATE_DICT: model.state_dict(),
        _OPTIMIZER_STATE_DICT: optim.state_dict(),
        _EPOCH: epoch_id,
        _STEP: step,
        _BEST_SCORE: best_score,
        _LOSS: loss,
    }, save_path)
Example #11
0
def save_model(
    path: str,
    model: Module,
    optimizer: Optimizer = None,
    epoch: Union[int, None] = None,
    use_zipfile_serialization_if_available: bool = True,
    include_modifiers: bool = False,
):
    """
    Save a model's state dict into a file at the given path.
    Additionally can save an optimizer's state and the current epoch.

    :param path: the path to save the file the states to
    :param model: the model to save state for
    :param optimizer: the optimizer, if any, to save state for
    :param epoch: the epoch to save
    :param use_zipfile_serialization_if_available: for torch >= 1.6.0 only
        exports the model's state dict using the new zipfile serialization
    :param include_modifiers: if True, and a ScheduledOptimizer is provided
        as the optimizer, the associated ScheduledModifierManager and its
        Modifiers will be exported under the 'manager' key. Default is False
    """
    create_parent_dirs(path)

    if is_parallel_model(model):
        model = model.module

    save_dict = {"state_dict": OrderedDict()}

    # make sure we have the model state_dict on cpu
    for key, state in model.state_dict().items():
        copy = torch.zeros(state.shape)
        copy.copy_(state)
        save_dict["state_dict"][key] = copy

    if optimizer:
        save_dict["optimizer"] = optimizer.state_dict()

    if epoch:
        save_dict["epoch"] = epoch

    if include_modifiers and optimizer and hasattr(optimizer,
                                                   "manager_state_dict"):
        save_dict["manager"] = optimizer.manager_state_dict()

    if torch.__version__ < "1.6":
        torch.save(save_dict, path)
    else:
        torch.save(
            save_dict,
            path,
            _use_new_zipfile_serialization=
            use_zipfile_serialization_if_available,
        )
Example #12
0
    def get_save_info(self, gen_optim: Optimizer,
                      dis_optim: Optimizer) -> Dict[str, Any]:

        if self.device == torch.device("cpu"):
            generator_save_info = self.gen.get_save_info()
            discriminator_save_info = self.dis.get_save_info()
        else:
            generator_save_info = self.gen.module.get_save_info()
            discriminator_save_info = self.dis.module.get_save_info()
        save_info = {
            "generator": generator_save_info,
            "discriminator": discriminator_save_info,
            "gen_optim": gen_optim.state_dict(),
            "dis_optim": dis_optim.state_dict(),
        }
        if self.use_ema:
            save_info["shadow_generator"] = (
                self.gen_shadow.get_save_info()
                if self.device == torch.device("cpu") else
                self.gen_shadow.module.get_save_info())
        return save_info
Example #13
0
def save_checkpoint(model: nn.Module, optim: Optimizer, best_top1: float,
                    epoch: int, is_best: bool, ckpt_dir: str) -> None:
    state = {
        'epoch': epoch + 1,
        'model': model.state_dict(),
        'best_top1': best_top1,
        'optim': optim.state_dict(),
    }
    filename = os.path.join(ckpt_dir, 'checkpoint.pth.tar')
    torch.save(state, filename)
    if is_best:
        best_filename = os.path.join(ckpt_dir, 'model_best.pth.tar')
        shutil.copyfile(filename, best_filename)
Example #14
0
def save_checkpoint(checkpoint_dir: str, model: nn.Module, optim: optimizer.Optimizer, epoch_id: int, step: int, best_score: float):
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint.tar') 
    
    torch.save({
        _MODEL_STATE_DICT: model.state_dict(),
        _OPTIMIZER_STATE_DICT: optim.state_dict(),
        _EPOCH: epoch_id,
        _STEP: step,
        _BEST_SCORE: best_score
    }, checkpoint_path)
Example #15
0
    def collect_state_dict(
        self,
        iteration: Union[float, int],
        model: EmmentalModel,
        optimizer: Optimizer,
        lr_scheduler: _LRScheduler,
        metric_dict: Dict[str, float],
    ) -> Dict[str, Any]:
        r"""Collect the state dict of the model.

        Args:
          iteration(float or int): The current iteration.
          model(EmmentalModel): The model to checkpoint.
          optimizer(Optimizer): The optimizer used during training process.
          lr_scheduler(_LRScheduler): Learning rate scheduler.
          metric_dict(dict): the metric dict.

        Returns:
          dict: The state dict.
        """

        model_params = {
            "name": model.name,
            "module_pool": model.collect_state_dict(),
            # "task_names": model.task_names,
            # "task_flows": model.task_flows,
            # "loss_funcs": model.loss_funcs,
            # "output_funcs": model.output_funcs,
            # "scorers": model.scorers,
        }

        state_dict = {
            "iteration": iteration,
            "model": model_params,
            "optimizer": optimizer.state_dict(),
            "lr_scheduler":
            lr_scheduler.state_dict() if lr_scheduler else None,
            "metric_dict": metric_dict,
        }

        return state_dict
Example #16
0
    def checkpoint(
        self,
        iteration: Union[float, int],
        model: EmmentalModel,
        optimizer: Optimizer,
        lr_scheduler: _LRScheduler,
        metric_dict: Dict[str, float],
    ) -> None:
        """Checkpointing the checkpoint.

        Args:
          iteration: The current iteration.
          model: The model to checkpoint.
          optimizer: The optimizer used during training process.
          lr_scheduler: Learning rate scheduler.
          metric_dict: The metric dict.
        """
        # Check the checkpoint_runway condition is met
        if iteration < self.checkpoint_runway:
            return
        elif not self.checkpoint_condition_met and iteration >= self.checkpoint_runway:
            self.checkpoint_condition_met = True
            logger.info(
                "checkpoint_runway condition has been met. Start checkpoining."
            )

        # Save model state
        model_path = f"{self.checkpoint_path}/checkpoint_{iteration}.model.pth"
        model.save(model_path, verbose=False)
        logger.info(f"Save checkpoint of {iteration} {self.checkpoint_unit} "
                    f"at {model_path}.")

        # Save optimizer state
        optimizer_path = f"{self.checkpoint_path}/checkpoint_{iteration}.optimizer.pth"
        optimizer_dict = {
            "optimizer": optimizer.state_dict(),
        }
        torch.save(optimizer_dict, optimizer_path)

        # Save lr_scheduler state
        scheduler_path = f"{self.checkpoint_path}/checkpoint_{iteration}.scheduler.pth"
        scheduler_dict = {
            "lr_scheduler": lr_scheduler.state_dict() if lr_scheduler else None
        }
        torch.save(scheduler_dict, scheduler_path)

        if self.checkpoint_all is False:
            for path in self.checkpoint_paths:
                if os.path.exists(path):
                    os.remove(path)

        self.checkpoint_paths.extend(
            [model_path, optimizer_path, scheduler_path])

        if not set(self.checkpoint_all_metrics.keys()).isdisjoint(
                set(metric_dict.keys())):
            new_best_metrics = self.is_new_best(metric_dict)
            for metric in new_best_metrics:
                best_metric_model_path = (
                    f"{self.checkpoint_path}/best_model_"
                    f"{metric.replace('/', '_')}.model.pth")
                copyfile(
                    model_path,
                    best_metric_model_path,
                )
                logger.info(
                    f"Save best model of metric {metric} to {best_metric_model_path}"
                )

                best_metric_optimizer_path = (
                    f"{self.checkpoint_path}/best_model_"
                    f"{metric.replace('/', '_')}.optimizer.pth")
                copyfile(optimizer_path, best_metric_optimizer_path)

                best_metric_scheduler_path = (
                    f"{self.checkpoint_path}/best_model_"
                    f"{metric.replace('/', '_')}.scheduler.pth")
                copyfile(scheduler_path, best_metric_scheduler_path)
Example #17
0
def save_torch_state(model: nn.Module, optimizer: Optimizer, path):
    torch.save(
        {
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, path)
Example #18
0
def train(model: nn.Module,
          train_dataset: Dataset,
          metrics: TrainingMetrics,
          config: TrainingConfig,
          val_dataset: Dataset = None,
          optimizer: Optimizer = None,
          lr_scheduler=None) -> nn.Module:
    if optimizer is None:
        optimizer = get_optimizer_from_model_and_config(model, config)

    lr_scheduler_interval = None
    if config.lr_scheduler is not None:
        lr_scheduler_interval = config.lr_scheduler[1]

    train_loader = DataLoader(train_dataset,
                              batch_size=config.batch_size,
                              shuffle=config.shuffle_batches,
                              drop_last=config.drop_last_batch)

    for epoch in range(config.epochs):
        if (isinstance(config, TeleportationTrainingConfig)
                and epoch in get_teleportation_epochs(config)):
            model = config.teleport_fn(model=model,
                                       train_dataset=train_dataset,
                                       metrics=metrics,
                                       config=config)
            # Force a new optimizer in case the model was swapped as a result of the teleportations
            # We need to recreate the optimizer with the new model's parameters and update it
            # with the previous optimizer's parameters otherwise any changes to the old optimizer will be lost
            old_optimizer_state = optimizer.state_dict()
            optimizer = get_optimizer_from_model_and_config(model, config)
            if lr_scheduler:
                # Similar to the optimizer, the lr scheduler needs to be updated after its recreation.
                old_scheduler_state = lr_scheduler.state_dict()
                lr_scheduler = get_lr_scheduler_from_optimizer_and_config(
                    optimizer, config)
                lr_scheduler.load_state_dict(old_scheduler_state)
            # update the optimizer, because for certain LrSchedulers, when they are recreated,
            # they overwrite the previous parameters set in the optimizer (c.f OneCycleLR)
            optimizer = update_optimizer_params(optimizer, old_optimizer_state)
        if lr_scheduler:
            print("Current LR: ", get_optimizer_lr(optimizer))
        train_epoch(model,
                    metrics,
                    optimizer,
                    train_loader,
                    epoch,
                    device=config.device,
                    config=config,
                    lr_scheduler=lr_scheduler)

        if val_dataset:
            if config.logger:
                with config.logger.validate():
                    val_res = test(model, val_dataset, metrics, config)
            else:
                val_res = test(model, val_dataset, metrics, config)

            print("Validation: {}".format(val_res))
            if np.isnan(val_res["loss"]) or np.isnan(val_res["accuracy"]):
                print("Stopping: Loss NaN!")
                if config.logger:
                    config.logger.add_text("Info", "Stopped due to Loss NaN.")
                break
            if config.logger is not None:
                config.logger.add_scalar("val_loss", val_res["loss"], epoch)
                config.logger.add_scalar("val_accuracy", val_res["accuracy"],
                                         epoch)
        if lr_scheduler and lr_scheduler_interval == "epoch":
            if isinstance(lr_scheduler, ReduceLROnPlateau):
                lr_scheduler.step(metrics=val_res["accuracy"])
            else:
                lr_scheduler.step()

    if config.logger is not None:
        config.logger.flush()

    return model
Example #19
0
 def update(self, network: nn.Module, optimizer: Optimizer, loss: float,
            count: int) -> None:
     self.losses.append(loss)
     self.counter.append(count)
     torch.save(network.state_dict(), self.model_pth)
     torch.save(optimizer.state_dict(), self.optimizer_pth)
Example #20
0
 def set_states(self, model: Module, optimizer: Optimizer):
     self.model_state = (model.module if isinstance(
         model, DistributedDataParallel) else model).state_dict()
     self.optimizer_state = optimizer.state_dict()