def load(self, path_to_checkpoint: str, optimizer: Optimizer = None, scheduler: _LRScheduler = None) -> 'Model':
     checkpoint = torch.load(path_to_checkpoint)
     self.load_state_dict(checkpoint['state_dict'])
     step = checkpoint['step']
     if optimizer is not None:
         optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
     if scheduler is not None:
         scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
     return step
示例#2
0
    def simulate_values(  # type: ignore[override]
            cls, num_events: int, lr_scheduler: _LRScheduler,
            **kwargs: Any) -> List[List[int]]:
        """Method to simulate scheduled values during num_events events.

        Args:
            num_events (int): number of events during the simulation.
            lr_scheduler (subclass of `torch.optim.lr_scheduler._LRScheduler`): lr_scheduler object to wrap.

        Returns:
            list of pairs: [event_index, value]

        """

        if not isinstance(lr_scheduler, _LRScheduler):
            raise TypeError(
                "Argument lr_scheduler should be a subclass of torch.optim.lr_scheduler._LRScheduler, "
                f"but given {type(lr_scheduler)}")

        # This scheduler uses `torch.optim.lr_scheduler._LRScheduler` which
        # should be replicated in order to simulate LR values and
        # not perturb original scheduler.
        with tempfile.TemporaryDirectory() as tmpdirname:
            cache_filepath = Path(tmpdirname) / "ignite_lr_scheduler_cache.pt"
            obj = {
                "lr_scheduler": lr_scheduler.state_dict(),
                "optimizer": lr_scheduler.optimizer.state_dict(
                ),  # type: ignore[attr-defined]
            }
            torch.save(obj, cache_filepath.as_posix())

            values = []
            scheduler = cls(save_history=False,
                            lr_scheduler=lr_scheduler,
                            **kwargs)  # type: ignore[call-arg]
            for i in range(num_events):
                params = [
                    p[scheduler.param_name]
                    for p in scheduler.optimizer_param_groups
                ]
                values.append([i] + params)
                scheduler(engine=None)

            obj = torch.load(cache_filepath.as_posix())
            lr_scheduler.load_state_dict(obj["lr_scheduler"])
            lr_scheduler.optimizer.load_state_dict(
                obj["optimizer"])  # type: ignore[attr-defined]

            return values
示例#3
0
def resume_checkpoint(
    model: nn.Module = None,
    optimizer: optim.Optimizer = None,
    scheduler: sche._LRScheduler = None,
    exp_name: str = "",
    load_path: str = "",
    mode: str = "all",
):
    """
    从保存节点恢复模型

    Args:
        model (nn.Module): model object
        optimizer (optim.Optimizer): optimizer object
        scheduler (sche._LRScheduler): scheduler object
        exp_name (str): exp_name
        load_path (str): 模型存放路径
        mode (str): 选择哪种模型恢复模式:
            - 'all': 回复完整模型,包括训练中的的参数;
            - 'onlynet': 仅恢复模型权重参数
            
    Returns mode: 'all' start_epoch; 'onlynet' None
    """
    if os.path.exists(load_path) and os.path.isfile(load_path):
        construct_print(f"Loading checkpoint '{load_path}'")
        checkpoint = torch.load(load_path)
        if mode == "all":
            if exp_name == checkpoint["arch"]:
                start_epoch = checkpoint["epoch"]
                model.load_state_dict(checkpoint["net_state"])
                optimizer.load_state_dict(checkpoint["opti_state"])
                scheduler.load_state_dict(checkpoint["sche_state"])
                construct_print(f"Loaded '{load_path}' "
                                f"(will train at epoch"
                                f" {checkpoint['epoch']})")
                return start_epoch
            else:
                raise Exception(f"{load_path} does not match.")
        elif mode == "onlynet":
            model.load_state_dict(checkpoint)
            construct_print(f"Loaded checkpoint '{load_path}' "
                            f"(only has the model's weight params)")
        else:
            raise NotImplementedError
    else:
        raise Exception(f"{load_path}路径不正常,请检查")
示例#4
0
    def fit_support(
        self,
        model,
        tasks: List[Task],
        dataloader: DataLoader,
        optimizer: Optimizer,
        scheduler: _LRScheduler,
        training_logger: ResultLogger,
    ):
        support_loss = 1.0
        support_epoch = 0

        # Don't change default optimizer and scheduler states
        optimizer_state_dict = deepcopy(optimizer.state_dict())
        scheduler_state_dict = deepcopy(scheduler.state_dict())

        # Reset tasks states
        for task in tasks:
            task.reset()

        model.freeze_weights()

        while (support_loss > self.support_min_loss
               and support_epoch < self.support_max_epochs):
            support_epoch += 1
            support_loss = self.fit_one(
                model,
                tasks,
                dataloader,
                optimizer,
                scheduler,
                training_logger.epoch(support_epoch, self.support_max_epochs),
                train_model=False,
            )

        optimizer.load_state_dict(optimizer_state_dict)
        scheduler.load_state_dict(scheduler_state_dict)
        model.defreeze_weights()
示例#5
0
def resume_checkpoint(
    model: nn.Module = None,
    optimizer: optim.Optimizer = None,
    scheduler: sche._LRScheduler = None,
    amp=None,
    exp_name: str = "",
    load_path: str = "",
    mode: str = "all",
):
    """
    从保存节点恢复模型

    Args:
        model (nn.Module): model object
        optimizer (optim.Optimizer): optimizer object
        scheduler (sche._LRScheduler): scheduler object
        amp (): apex.amp
        exp_name (str): exp_name
        load_path (str): 模型存放路径
        mode (str): 选择哪种模型恢复模式:
            - 'all': 回复完整模型,包括训练中的的参数;
            - 'onlynet': 仅恢复模型权重参数

    Returns mode: 'all' start_epoch; 'onlynet' None
    """
    if os.path.exists(load_path) and os.path.isfile(load_path):
        construct_print(f"Loading checkpoint '{load_path}'")
        checkpoint = torch.load(load_path)
        if mode == "all":
            if exp_name and exp_name != checkpoint["arch"]:
                # 如果给定了exp_name,那么就必须匹配对应的checkpoint["arch"],否则不作要求
                raise Exception(
                    f"We can not match {exp_name} with {load_path}.")

            start_epoch = checkpoint["epoch"]
            if hasattr(model, "module"):
                model.module.load_state_dict(checkpoint["net_state"])
            else:
                model.load_state_dict(checkpoint["net_state"])
            optimizer.load_state_dict(checkpoint["opti_state"])
            scheduler.load_state_dict(checkpoint["sche_state"])
            if checkpoint.get("amp_state", None):
                if amp:
                    amp.load_state_dict(checkpoint["amp_state"])
                else:
                    construct_print("You are not using amp.")
            else:
                construct_print("The state_dict of amp is None.")
            construct_print(f"Loaded '{load_path}' "
                            f"(will train at epoch"
                            f" {checkpoint['epoch']})")
            return start_epoch
        elif mode == "onlynet":
            if hasattr(model, "module"):
                model.module.load_state_dict(checkpoint)
            else:
                model.load_state_dict(checkpoint)
            construct_print(f"Loaded checkpoint '{load_path}' "
                            f"(only has the model's weight params)")
        else:
            raise NotImplementedError
    else:
        raise Exception(f"{load_path}路径不正常,请检查")