def init_optimizer(optimizer: optim.Optimizer, configs):
    if not configs.resume_optim:
        return
    else:
        print("Resuming optimizer from %s" % configs.ockpt_file)
        o = torch.load(configs.ockpt_file)['optim']
        optimizer.load_state_dict(o)
Beispiel #2
0
def update_optimizer_params(optimizer: Optimizer, new_state) -> Optimizer:
    optim_state = optimizer.state_dict()
    if "params" in new_state["param_groups"][0].keys():
        del new_state["param_groups"][0]["params"]
    optim_state["param_groups"][0].update(new_state["param_groups"][0])
    optimizer.load_state_dict(optim_state)
    return optimizer
Beispiel #3
0
def load_optim(optimizer: Optimizer, checkpoint_path: str,
               device: torch.device) -> Optimizer:
    """
    Load optimizer to continuer training
        Args:
            optimizer      : initialized optimizer
            checkpoint_path: path to the checkpoint
            device         : device to send optimizer to (must be the same as in the model)

        Note: must be called after initializing the model

        Output: optimizer with the loaded state
    """
    checkpoint = torch.load(checkpoint_path)
    optimizer.load_state_dict(checkpoint["optimizer"])
    for state in optimizer.state.values():
        for k, v in state.items():
            if torch.is_tensor(v):
                state[k] = v.to(device)

    for param_group in optimizer.param_groups:
        print("learning_rate: {}".format(param_group["lr"]))
    print("Loaded optimizer {} state from {}".format(optimizer,
                                                     checkpoint_path))
    return optimizer
Beispiel #4
0
    def load(self, model: Module, optimizer: Optimizer) -> dict:
        """
        Loads a PyTorch model from a TileDB array.
        :param model: Pytorch Module. A defined PyTorch model.
        :param optimizer: PyTorch Optimizer. A defined PyTorch optimizer.
        :return: Dict. A dictionary with attributes other than model or optimizer
        state_dict.
        """

        model_array = tiledb.open(self.uri)
        model_array_results = model_array[:]
        schema = model_array.schema

        model_state_dict = pickle.loads(
            model_array_results["model_state_dict"].item(0))
        optimizer_state_dict = pickle.loads(
            model_array_results["optimizer_state_dict"].item(0))

        # Load model's state and optimizer dictionaries
        model.load_state_dict(model_state_dict)
        optimizer.load_state_dict(optimizer_state_dict)

        # Get the rest of the attributes
        out_dict = {}
        for idx in range(schema.nattr):
            attr_name = schema.attr(idx).name
            if (schema.attr(idx).name != "model_state_dict"
                    and schema.attr(idx).name != "optimizer_state_dict"):
                out_dict[attr_name] = pickle.loads(
                    model_array_results[attr_name].item(0))
        return out_dict
Beispiel #5
0
    def Load(self, filename, model: nn.Module, optimizer: optim.Optimizer, scheduler: optim.lr_scheduler):
        checkpoint = th.load(os.path.join(self.directory, filename))
        model.load_state_dict(checkpoint["model_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_dict"])
        scheduler.load_state_dict(checkpoint["scheduler_dict"])

        return model, optimizer, scheduler
Beispiel #6
0
def load_state(net: Connect4Network,
               optimizer: optim.Optimizer,
               scheduler,
               args: AlphaZeroArgs,
               iteration: np.int,
               new_optim_state=True) -> np.int:
    """ Loads saved model and optimizer states if exists
    @param net: Neural Network object
    @param optimizer: pytorch optimizer
    @param scheduler: pytorch scheduler
    @param args: AlphaZeroArgs
    @param iteration: current iteration
    @param new_optim_state:
    @return:
    """
    checkpoint_path = util.get_model_file_path(args.neural_net_name, iteration)
    start_epoch, checkpoint = 0, None
    if os.path.isfile(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
    if checkpoint is not None:
        if len(checkpoint) == 1 or new_optim_state:
            net.load_state_dict(checkpoint['state_dict'])
            print("Loaded checkpoint model {checkpoint_path}")
        else:
            start_epoch = checkpoint['epoch']
            net.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            print(
                f"Loaded checkpoint model {checkpoint_path}, and optimizer, scheduler."
            )
    return start_epoch
 def load_checkpoint(checkpoint_path: str, model: nn.Module, optimizer: optim.Optimizer=None):
     if not os.path.exists(checkpoint_path):
         raise ValueError('missing checkpoint file {}'.format(checkpoint_path))
     checkpoint = torch.load(checkpoint_path)
     model.load_state_dict(checkpoint.pop('state_dict'))
     if optimizer is not None:
         optimizer.load_state_dict(checkpoint.pop('optimizer'))
     return checkpoint
def load_checkpoint(filename: str, model: nn.Module,
                    optimizer: optim.Optimizer):
    if os.path.isfile(filename):
        checkpoint = torch.load(filename)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])

        return start_epoch, model, optim
Beispiel #9
0
def load_optimizer(optimizer: Optimizer, path: str):
    """
    Load optimizer state for resuming training

    :param optimizer:
    :param path:
    """
    optimizer.load_state_dict(torch.load(path))
    print("Optimizer state loaded.")
 def load(self, path_to_checkpoint: str, optimizer: Optimizer = None, scheduler: _LRScheduler = None) -> 'Model':
     checkpoint = torch.load(path_to_checkpoint)
     self.load_state_dict(checkpoint['state_dict'])
     step = checkpoint['step']
     if optimizer is not None:
         optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
     if scheduler is not None:
         scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
     return step
Beispiel #11
0
def resume(weights, optimizer: optim.Optimizer, ema: ModelEMA, device,
           results):
    start_epoch = 0
    pretrained = weights.endswith(".pth")
    if pretrained:
        checkpoint = torch.load(weights, map_location=device)
        if (ckpt_optim := checkpoint.get("optimizer", None)) is not None:
            optimizer.load_state_dict(ckpt_optim)

        if ema and (ckpt_ema := checkpoint.get("ema", False)):
            ema.ema.load_state_dict(ckpt_ema.state_dict())
            ema.updates = ckpt_ema["updates"]
def resume_checkpoint(config: Config,
                      model: Module,
                      optimizer: Optimizer = None) -> int:
    """
    resume training process data from config.logs which generated by make_checkpoint()
    :return number of last epoch
    """
    last_epoch = -1
    temp_weight_path = config.temp_weight_path
    temp_optim_path = config.temp_optim_path
    if os.path.exists(config.train_record_file):
        try:
            with open(config.train_record_file, 'r') as f:
                last = f.readlines()[-1]
                import json
                info = json.loads(last)
                last_epoch = int(info["epoch"])
                last_init = str(info["init"])
                if not os.path.exists(temp_weight_path):
                    temp_weight_path = temp_weight_path.replace(
                        config.init_time, last_init)
                if not os.path.exists(temp_optim_path):
                    temp_optim_path = temp_optim_path.replace(
                        config.init_time, last_init)
            print("Continue train from last epoch %d" % last_epoch)
        except:
            warn("Rename invalid train record file from {} to {}".format(
                config.train_record_file,
                config.train_record_file + '.badfile'))
            warn("Can't get last_epoch value, {} will be returned".format(
                last_epoch))
            os.rename(config.train_record_file,
                      config.train_record_file + '.badfile')
    if os.path.exists(temp_weight_path):
        try:
            model.load_state_dict(load(temp_weight_path))
            print("Resumed weight checkpoint from {}".format(temp_weight_path))
        except:
            warn("Move invalid temp {} weights file from {} to {}".format(
                type(model), temp_weight_path, temp_weight_path + '.badfile'))
            os.rename(temp_weight_path, temp_weight_path + '.badfile')
    if optimizer is not None and os.path.exists(temp_optim_path):
        try:
            optimizer.load_state_dict(load(temp_optim_path))
            print(
                "Resumed optimizer checkpoint from {}".format(temp_optim_path))
        except:
            warn("Move invalid temp {} weights file from {} to {}".format(
                type(optimizer), temp_optim_path,
                temp_optim_path + '.badfile'))
            os.rename(temp_optim_path, temp_optim_path + '.badfile')

    return last_epoch
Beispiel #13
0
def load_checkpoint(model: nn.Module, optimizer: optim.Optimizer,
                    filepath: str) -> Tuple[int, float]:
    """Loads model and optimizer state from the provided .model file."""
    if not os.path.exists(filepath):
        raise ValueError(f'Filepath: {filepath} does not exist!')

    logging.info(f'Loading checkpoint: {filepath}...')
    checkpoint = torch.load(filepath)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    return (checkpoint['epoch'], checkpoint['loss'])
Beispiel #14
0
def resume_training(checkpoint_path, model, model_optim: optim.Optimizer,
                    proxy_optim: optim.Optimizer):
    try:
        checkpoint = torch.load(checkpoint_path)
        model.module.load_state_dict(checkpoint['model'])
        model_optim.load_state_dict(checkpoint['model_optim'])
        proxy_optim.load_state_dict(checkpoint['proxy_optim'])
    except BaseException as err:
        print("? Failed to load models. Quitting...")
        print(err)
        quit()

    return model, model_optim, proxy_optim, checkpoint['epoch']
Beispiel #15
0
def restore_snapshot(model: nn.Module, optimizer: Optimizer,
                     snapshot_file: str):
    checkpoint = torch.load(snapshot_file)
    start_epoch = checkpoint['epoch'] + 1
    best_loss = checkpoint['loss']
    model.load_state_dict(checkpoint['model'])

    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer'])

    train_history = pd.DataFrame.from_dict(checkpoint['train_history'])

    return start_epoch, train_history, best_loss
Beispiel #16
0
def load_state(model: nn.Module, optimizer: Optimizer, path: pathlib.Path,
               callbacks: List[BaseCallback]):
    """ Loads model, optimizer and epoch states from path """
    checkpoint = torch.load(
        str(path), map_location={'cuda:0': f'cuda:{get_local_rank()}'})
    if isinstance(model, DistributedDataParallel):
        model.module.load_state_dict(checkpoint['state_dict'])
    else:
        model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    for callback in callbacks:
        callback.on_checkpoint_load(checkpoint)

    logging.info(f'Loaded checkpoint from {str(path)}')
    return checkpoint['epoch']
Beispiel #17
0
def _load_checkpoint(model: nn.Module, optimizer: optim.Optimizer,
                     config: Config):
    best_statistic = 0
    start = datetime.now()
    if os.path.isfile(ct.L2R_TRAIN_PROGRESS.format(config.name)):
        with open(ct.L2R_MODEL.format(config.name), 'rb') as file:
            checkpoint = torch.load(file, map_location=ct.DEVICE)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        model.epochs_trained = checkpoint['epoch']

        best_statistic = checkpoint['best_statistic']
        helpers.log(
            f'Loaded checkpoint from {ct.L2R_MODEL.format(config.name)} in {datetime.now() - start}.'
        )
    return best_statistic
Beispiel #18
0
    def load_checkpoint(self,
                        model: nn.Module,
                        optimizer: optim.Optimizer = None,
                        file_name: str = 'pytorch_model.pt'):
        filepath = os.path.join(self.directory, file_name)
        checkpoint = th.load(filepath)
        model.load_state_dict(checkpoint['model_dict'])
        if optimizer is not None:
            optimizer.load_state_dict(checkpoint['optimizer_dict'])

        hyperparam_dict = {
            k: v
            for k, v in checkpoint.items()
            if k != 'model_dict' or k != 'optimizer_dict'
        }

        return model, optimizer, hyperparam_dict
Beispiel #19
0
def load_latest_model_parameters(
    model: torch.nn.Module,
    *,
    optimiser: Optimizer = None,
    model_name: str,
    model_directory: Path,
) -> Tuple[Union[torch.nn.Module, Tuple[torch.nn.Module, Optimizer]], bool]:
    """

    inplace but returns model

    :param optimiser:
    :param model:
    :type model:
    :param model_directory:
    :param model_name:
    :return:"""
    if model:
        model_path = model_directory / model_name
        list_of_files = list(model_path.glob(f"*{parameter_extension}"))
        if len(list_of_files) == 0:
            print(
                f"Found no previous models with extension {parameter_extension} in {model_path}"
            )
        else:
            latest_model_parameter_file = max(list_of_files,
                                              key=os.path.getctime)
            print(
                f"loading previous model parameters: {latest_model_parameter_file}"
            )

            model.load_state_dict(torch.load(str(latest_model_parameter_file)))

            if optimiser:
                opt_st_d_file = latest_model_parameter_file.with_suffix(
                    optimiser_extension)
                if opt_st_d_file.exists():
                    optimiser.load_state_dict(torch.load(str(opt_st_d_file)))
                    print(f"loading previous optimiser state: {opt_st_d_file}")
                return (model, optimiser), True
            else:
                return model, True
    if optimiser:
        return (model, optimiser), False
    return model, False
Beispiel #20
0
def load_training_model(model: nn.Module,
                        optimizer: optim.Optimizer,
                        model_name="model",
                        path_state=PATH_STATE,
                        device=None):
    if device is None:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    state_file = model_name + STATE_EXT
    state_file = os.path.join(path_state, state_file)
    checkpoint = torch.load(state_file)
    model.load_state_dict(checkpoint[DEFAULT_STATE_DICT_MODEL])
    optimizer.load_state_dict(checkpoint[DEFAULT_STATE_DICT_OPT])
    epoch = checkpoint['epoch']
    model.to(device)
    optimizer.zero_grad()
    model.train()
    return model, optimizer, epoch
    pass
Beispiel #21
0
def resume(
    resume_path: str, model: nn.Module, optimizer: optim.Optimizer
) -> Tuple[int, nn.Module, optim.Optimizer, float]:

    assert os.path.exists(resume_path), "there is no checkpoint at the result folder"

    print("loading checkpoint {}".format(resume_path))
    checkpoint = torch.load(resume_path, map_location=lambda storage, loc: storage)

    begin_epoch = checkpoint["epoch"]
    best_loss = checkpoint["best_loss"]
    model.load_state_dict(checkpoint["state_dict"])

    optimizer.load_state_dict(checkpoint["optimizer"])

    print("training will start from {} epoch".format(begin_epoch))

    return begin_epoch, model, optimizer, best_loss
Beispiel #22
0
    def load_state_dict(self, state_dict):
        r"""Loads the optimizer state.

        Args:
            state_dict (dict): SWA optimizer state. Should be an object returned
                from a call to `state_dict`.
        """
        swa_state_dict = {
            "state": state_dict["swa_state"],
            "param_groups": state_dict["param_groups"],
        }
        opt_state_dict = {
            "state": state_dict["opt_state"],
            "param_groups": state_dict["param_groups"],
        }
        PT_Optimizer.load_state_dict(self, swa_state_dict)
        self.optimizer.load_state_dict(opt_state_dict)
        self.opt_state = self.optimizer.state
Beispiel #23
0
def _restore(
    mdl: nn.Module, optimizer: optim.Optimizer,
    scheduler: optim.lr_scheduler._LRScheduler, ckpt_loc: str
) -> t.Tuple[nn.Module, optim.Optimizer, optim.lr_scheduler._LRScheduler, int,
             float]:
    """Restore model training state

    Args:
        mdl (nn.Module):
            The randomly initialized model
        optimizer (optim.Optimizer):
            The optimizer
        scheduler (optim.lr_scheduler._LRScheduler):
            The scheduler for learning rate
        ckpt_loc (str):
            Location to store model checkpoints

    Returns:
        t.Tuple[nn.Module,
                optim.Optimizer,
                optim.lr_scheduler._LRScheduler,
                int, float]:
            The restored status
    """
    # Restore model checkpoint
    mdl.load_state_dict(torch.load(os.path.join(ckpt_loc, 'mdl.ckpt')))
    optimizer.load_state_dict(
        torch.load(os.path.join(ckpt_loc, 'optimizer.ckpt')))
    scheduler.load_state_dict(
        torch.load(os.path.join(ckpt_loc, 'scheduler.ckpt')))

    # Restore timer and step counter
    with open(os.path.join(ckpt_loc, 'log.out')) as f:
        records = f.readlines()
        if records[-1] != 'Training finished\n':
            final_record = records[-1]
        else:
            final_record = records[-2]
    global_counter, t_final = final_record.split('\t')[:2]
    global_counter = int(global_counter)
    t_final = float(t_final)
    t0 = time.time() - t_final * 60

    return mdl, optimizer, scheduler, global_counter, t0
Beispiel #24
0
def resume(
    result_path: str,
    model: nn.Module,
    optimizer: optim.Optimizer,
) -> Tuple[Any]:

    resume_path = os.path.join(result_path, "checkpoint.pth")
    print("loading checkpoint {}".format(resume_path))

    checkpoint = torch.load(resume_path, map_location=lambda storage, loc: storage)

    begin_epoch = checkpoint["epoch"]
    best_loss = checkpoint["best_loss"]
    model.load_state_dict(checkpoint["state_dict"])

    # confirm whether the optimizer matches that of checkpoints
    optimizer.load_state_dict(checkpoint["optimizer"])

    return begin_epoch, model, optimizer, best_loss
Beispiel #25
0
    def load_checkpoint_reconstructions(
        self,
        model: nn.Module,
        optimizer: optim.Optimizer = None,
        file_name: str = "pytorch_model.pt",
    ):
        filepath = os.path.join(self.directory, file_name)
        checkpoint = th.load(filepath)
        model.load_state_dict(checkpoint["model_dict"])
        if optimizer is not None:
            optimizer.load_state_dict(checkpoint["optimizer_dict"])

        hyperparam_dict = {
            k: v
            for k, v in checkpoint.items()
            if k != "model_dict" or k != "optimizer_dict"
        }

        return model, optimizer, hyperparam_dict
Beispiel #26
0
    def restore(self, module: nn.Module, optimizer: optim.Optimizer) -> Union[Tuple[None, None], Tuple[int, float]]:
        # getting files in the checkpoints folder
        files = self._get_files_in_dir()

        # if there are no files, return False
        if not files:
            return None, None

        # getting the latest checkpoint file
        checkpoint_file = files[0]

        # loading the data
        data = torch.load(checkpoint_file)

        # loading data to module and optimizer
        module.load_state_dict(data['state_dict'])
        optimizer.load_state_dict(data['optimizer'])

        # returning frame and epsilon
        return data['frame'], data['epsilon']
Beispiel #27
0
def restore_snapshot(model: torch.nn.Module, optimizer: Optimizer, snapshot_file: str, multi_gpu=False):
    checkpoint = torch.load(snapshot_file)
    start_epoch = checkpoint['epoch'] + 1
    best_loss = checkpoint['loss']

    try:
        model.load_state_dict(checkpoint['model'])
    except RuntimeError:
        model.load_state_dict(checkpoint['model'], strict=False)
        print('Loaded model with strict=False mode')

    try:
        if optimizer is not None:
            optimizer.load_state_dict(checkpoint['optimizer'])
    except:
        print('Optimizer state not loaded')

    train_history = pd.DataFrame.from_dict(checkpoint['train_history'])

    return start_epoch, train_history, best_loss
Beispiel #28
0
    def _load_checkpoint(self, model: nn.Module, optimizer: optim.Optimizer,
                         state):
        # from .. import __version__

        checkpoint = torch.load(
            os.path.join(self.save_directory, self.filename))

        # checks
        if checkpoint['check_model_class'] != str(model.__class__):
            raise TypeError("Models do not match: %s and %s" %
                            (checkpoint['check_model_class'], model.__class__))
        if checkpoint['check_optimizer_class'] != str(optimizer.__class__):
            raise TypeError(
                "Optimizers do not match: %s and %s" %
                (checkpoint['check_optimizer_class'], optimizer.__class__))

        # checkpoint['check_trainer_version']   # TODO

        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        state.set(checkpoint['trainer_state'])
Beispiel #29
0
    def load(  # type: ignore
        self,
        *,
        timestamp: Optional[Timestamp] = None,
        model: torch.nn.Module,
        optimizer: Optimizer,
    ) -> Optional[Mapping[str, Any]]:
        """
        Load a PyTorch model from a TileDB array.

        :param timestamp: Range of timestamps to load fragments of the array which live
            in the specified time range.
        :param model: A defined PyTorch model.
        :param optimizer: A defined PyTorch optimizer.
        :return: A dictionary with attributes other than model or optimizer state_dict.
        """

        # TODO: Change timestamp when issue in core is resolved
        model_array = tiledb.open(self.uri, ctx=self.ctx, timestamp=timestamp)
        model_array_results = model_array[:]
        schema = model_array.schema

        model_state_dict = pickle.loads(
            model_array_results["model_state_dict"].item(0))
        optimizer_state_dict = pickle.loads(
            model_array_results["optimizer_state_dict"].item(0))

        # Load model's state and optimizer dictionaries
        model.load_state_dict(model_state_dict)
        optimizer.load_state_dict(optimizer_state_dict)

        # Get the rest of the attributes
        out_dict = {}
        for idx in range(schema.nattr):
            attr_name = schema.attr(idx).name
            if (schema.attr(idx).name != "model_state_dict"
                    and schema.attr(idx).name != "optimizer_state_dict"):
                out_dict[attr_name] = pickle.loads(
                    model_array_results[attr_name].item(0))
        return out_dict
Beispiel #30
0
    def load_checkpoint(
        self,
        model_G: nn.Module,
        model_D: nn.Module,
        optimizer_G: optim.Optimizer = None,
        optimizer_R: optim.Optimizer = None,
        optimizer_D: optim.Optimizer = None,
        file_name: str = "pytorch_model.pt",
    ):
        filepath = os.path.join(self.directory, file_name)
        checkpoint = th.load(filepath)
        model_G.load_state_dict(checkpoint["model_G_dict"])
        model_D.load_state_dict(checkpoint["model_D_dict"])
        if optimizer_G is not None:
            optimizer_G.load_state_dict(checkpoint["optimizer_G_dict"])
            optimizer_R.load_state_dict(checkpoint["optimizer_R_dict"])
            optimizer_D.load_state_dict(checkpoint["optimizer_D_dict"])

        hyperparam_dict = {
            k: v
            for k, v in checkpoint.items()
            if k != "model_G_dict" or k != "model_D_dict"
            or k != "optimizer_G_dict" or k != "optimizer_D_dict"
        }

        return model_G, model_D, optimizer_G, optimizer_R, optimizer_D, hyperparam_dict