Example #1
0
def save_checkpoint(
    model: nn.Module = None,
    optimizer: optim.Optimizer = None,
    scheduler: sche._LRScheduler = None,
    amp=None,
    exp_name: str = "",
    current_epoch: int = 1,
    full_net_path: str = "",
    state_net_path: str = "",
):
    """
    保存完整参数模型(大)和状态参数模型(小)

    Args:
        model (nn.Module): model object
        optimizer (optim.Optimizer): optimizer object
        scheduler (sche._LRScheduler): scheduler object
        amp (): apex.amp
        exp_name (str): exp_name
        current_epoch (int): in the epoch, model **will** be trained
        full_net_path (str): the path for saving the full model parameters
        state_net_path (str): the path for saving the state dict.
    """

    state_dict = {
        "arch": exp_name,
        "epoch": current_epoch,
        "net_state": model.state_dict(),
        "opti_state": optimizer.state_dict(),
        "sche_state": scheduler.state_dict(),
        "amp_state": amp.state_dict() if amp else None,
    }
    torch.save(state_dict, full_net_path)
    torch.save(model.state_dict(), state_net_path)
Example #2
0
 def snapshot(self,
              net: torch.nn.Module,
              opt: Optimizer,
              sched: _LRScheduler = None,
              epoch: int = None,
              subdir='.'):
     """
     Writes a snapshot of the training, i.e. network weights, optimizer state and scheduler state to a file
     in the log directory.
     :param net: the neural network
     :param opt: the optimizer used
     :param sched: the learning rate scheduler used
     :param epoch: the current epoch
     :param subdir: if given, creates a subdirectory in the log directory. The data is written to a file
         in this subdirectory instead.
     :return:
     """
     outfile = pt.join(self.dir, subdir, 'snapshot.pt')
     if not pt.exists(os.path.dirname(outfile)):
         os.makedirs(os.path.dirname(outfile))
     torch.save(
         {
             'net': net.state_dict(),
             'opt': opt.state_dict(),
             'sched': sched.state_dict(),
             'epoch': epoch
         }, outfile)
     return outfile
 def save(self, path_to_checkpoints_dir: str, step: int, optimizer: Optimizer, scheduler: _LRScheduler) -> str:
     path_to_checkpoint = os.path.join(path_to_checkpoints_dir, f'model-{step}.pth')
     checkpoint = {
         'state_dict': self.state_dict(),
         'step': step,
         'optimizer_state_dict': optimizer.state_dict(),
         'scheduler_state_dict': scheduler.state_dict()
     }
     torch.save(checkpoint, path_to_checkpoint)
     return path_to_checkpoint
Example #4
0
 def _better_lr_sched_repr(lr_sched: _LRScheduler) -> str:
     return (
         lr_sched.__class__.__name__
         + "(\n    "
         + "\n    ".join(
             f"{k}: {v}"
             for k, v in lr_sched.state_dict().items()
             if not k.startswith("_")
         )
         + "\n)"
     )
Example #5
0
    def simulate_values(  # type: ignore[override]
            cls, num_events: int, lr_scheduler: _LRScheduler,
            **kwargs: Any) -> List[List[int]]:
        """Method to simulate scheduled values during num_events events.

        Args:
            num_events (int): number of events during the simulation.
            lr_scheduler (subclass of `torch.optim.lr_scheduler._LRScheduler`): lr_scheduler object to wrap.

        Returns:
            list of pairs: [event_index, value]

        """

        if not isinstance(lr_scheduler, _LRScheduler):
            raise TypeError(
                "Argument lr_scheduler should be a subclass of torch.optim.lr_scheduler._LRScheduler, "
                f"but given {type(lr_scheduler)}")

        # This scheduler uses `torch.optim.lr_scheduler._LRScheduler` which
        # should be replicated in order to simulate LR values and
        # not perturb original scheduler.
        with tempfile.TemporaryDirectory() as tmpdirname:
            cache_filepath = Path(tmpdirname) / "ignite_lr_scheduler_cache.pt"
            obj = {
                "lr_scheduler": lr_scheduler.state_dict(),
                "optimizer": lr_scheduler.optimizer.state_dict(
                ),  # type: ignore[attr-defined]
            }
            torch.save(obj, cache_filepath.as_posix())

            values = []
            scheduler = cls(save_history=False,
                            lr_scheduler=lr_scheduler,
                            **kwargs)  # type: ignore[call-arg]
            for i in range(num_events):
                params = [
                    p[scheduler.param_name]
                    for p in scheduler.optimizer_param_groups
                ]
                values.append([i] + params)
                scheduler(engine=None)

            obj = torch.load(cache_filepath.as_posix())
            lr_scheduler.load_state_dict(obj["lr_scheduler"])
            lr_scheduler.optimizer.load_state_dict(
                obj["optimizer"])  # type: ignore[attr-defined]

            return values
Example #6
0
    def collect_state_dict(
        self,
        iteration: Union[float, int],
        model: EmmentalModel,
        optimizer: Optimizer,
        lr_scheduler: _LRScheduler,
        metric_dict: Dict[str, float],
    ) -> Dict[str, Any]:
        r"""Collect the state dict of the model.

        Args:
          iteration(float or int): The current iteration.
          model(EmmentalModel): The model to checkpoint.
          optimizer(Optimizer): The optimizer used during training process.
          lr_scheduler(_LRScheduler): Learning rate scheduler.
          metric_dict(dict): the metric dict.

        Returns:
          dict: The state dict.
        """

        model_params = {
            "name": model.name,
            "module_pool": model.collect_state_dict(),
            # "task_names": model.task_names,
            # "task_flows": model.task_flows,
            # "loss_funcs": model.loss_funcs,
            # "output_funcs": model.output_funcs,
            # "scorers": model.scorers,
        }

        state_dict = {
            "iteration": iteration,
            "model": model_params,
            "optimizer": optimizer.state_dict(),
            "lr_scheduler":
            lr_scheduler.state_dict() if lr_scheduler else None,
            "metric_dict": metric_dict,
        }

        return state_dict
Example #7
0
    def fit_support(
        self,
        model,
        tasks: List[Task],
        dataloader: DataLoader,
        optimizer: Optimizer,
        scheduler: _LRScheduler,
        training_logger: ResultLogger,
    ):
        support_loss = 1.0
        support_epoch = 0

        # Don't change default optimizer and scheduler states
        optimizer_state_dict = deepcopy(optimizer.state_dict())
        scheduler_state_dict = deepcopy(scheduler.state_dict())

        # Reset tasks states
        for task in tasks:
            task.reset()

        model.freeze_weights()

        while (support_loss > self.support_min_loss
               and support_epoch < self.support_max_epochs):
            support_epoch += 1
            support_loss = self.fit_one(
                model,
                tasks,
                dataloader,
                optimizer,
                scheduler,
                training_logger.epoch(support_epoch, self.support_max_epochs),
                train_model=False,
            )

        optimizer.load_state_dict(optimizer_state_dict)
        scheduler.load_state_dict(scheduler_state_dict)
        model.defreeze_weights()
Example #8
0
    def checkpoint(
        self,
        iteration: Union[float, int],
        model: EmmentalModel,
        optimizer: Optimizer,
        lr_scheduler: _LRScheduler,
        metric_dict: Dict[str, float],
    ) -> None:
        """Checkpointing the checkpoint.

        Args:
          iteration: The current iteration.
          model: The model to checkpoint.
          optimizer: The optimizer used during training process.
          lr_scheduler: Learning rate scheduler.
          metric_dict: The metric dict.
        """
        # Check the checkpoint_runway condition is met
        if iteration < self.checkpoint_runway:
            return
        elif not self.checkpoint_condition_met and iteration >= self.checkpoint_runway:
            self.checkpoint_condition_met = True
            logger.info(
                "checkpoint_runway condition has been met. Start checkpoining."
            )

        # Save model state
        model_path = f"{self.checkpoint_path}/checkpoint_{iteration}.model.pth"
        model.save(model_path, verbose=False)
        logger.info(f"Save checkpoint of {iteration} {self.checkpoint_unit} "
                    f"at {model_path}.")

        # Save optimizer state
        optimizer_path = f"{self.checkpoint_path}/checkpoint_{iteration}.optimizer.pth"
        optimizer_dict = {
            "optimizer": optimizer.state_dict(),
        }
        torch.save(optimizer_dict, optimizer_path)

        # Save lr_scheduler state
        scheduler_path = f"{self.checkpoint_path}/checkpoint_{iteration}.scheduler.pth"
        scheduler_dict = {
            "lr_scheduler": lr_scheduler.state_dict() if lr_scheduler else None
        }
        torch.save(scheduler_dict, scheduler_path)

        if self.checkpoint_all is False:
            for path in self.checkpoint_paths:
                if os.path.exists(path):
                    os.remove(path)

        self.checkpoint_paths.extend(
            [model_path, optimizer_path, scheduler_path])

        if not set(self.checkpoint_all_metrics.keys()).isdisjoint(
                set(metric_dict.keys())):
            new_best_metrics = self.is_new_best(metric_dict)
            for metric in new_best_metrics:
                best_metric_model_path = (
                    f"{self.checkpoint_path}/best_model_"
                    f"{metric.replace('/', '_')}.model.pth")
                copyfile(
                    model_path,
                    best_metric_model_path,
                )
                logger.info(
                    f"Save best model of metric {metric} to {best_metric_model_path}"
                )

                best_metric_optimizer_path = (
                    f"{self.checkpoint_path}/best_model_"
                    f"{metric.replace('/', '_')}.optimizer.pth")
                copyfile(optimizer_path, best_metric_optimizer_path)

                best_metric_scheduler_path = (
                    f"{self.checkpoint_path}/best_model_"
                    f"{metric.replace('/', '_')}.scheduler.pth")
                copyfile(scheduler_path, best_metric_scheduler_path)