Esempio n. 1
0
def load_optim(optimizer: torch.optim, checkpoint_path: str,
               device: torch.device) -> torch.optim:
    """
    Load optimizer to continuer training
        Args:
            optimizer      : initialized optimizer
            checkpoint_path: path to the checkpoint
            device         : device to send optimizer to (must be the same as in the model)
            
        Note: must be called after initializing the model    

        Output: optimizer with the loaded state
    """
    checkpoint = torch.load(checkpoint_path)
    optimizer.load_state_dict(checkpoint['optimizer'])
    for state in optimizer.state.values():
        for k, v in state.items():
            if torch.is_tensor(v):
                state[k] = v.to(device)

    for param_group in optimizer.param_groups:
        print('learning_rate: {}'.format(param_group['lr']))

    print('Loaded optimizer {} state from {}'.format(optimizer,
                                                     checkpoint_path))

    return optimizer
Esempio n. 2
0
def load_model_checkpoint(model: torch.nn.Module,
                          filename: str,
                          inference: bool,
                          map_location=None,
                          optimizer: torch.optim = None):
    """
    Load a model checkpoint
    :param model:
    :param filename:
    :param inference:
    :param optimizer:
    :return:
    """
    checkpoint = torch.load(filename, map_location=map_location)

    if optimizer:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    # epoch = checkpoint['epoch']
    # loss = checkpoint['loss']

    if inference:
        model.eval()
    else:
        model.train()
    return model.load_state_dict(checkpoint['model_state_dict'])
Esempio n. 3
0
def loadCheckpoint(checkpoint_path: str, model: nn.Module, optimizer: optim, scheduler: optim.lr_scheduler.MultiStepLR):
    """
    Load the training instance to .pth file

    Parameters
    ----------
    checkpoint_path : str
        the directory of the model parameter

    model, optimizer, scheduler : 
        the neural network to save

    Return
    ------
    model, optimizer, resume_epoch, resume_iteration, scheduler
    """
    state = torch.load(checkpoint_path)

    resume_epoch = state['epoch']
    resume_iteration = state['iteration']
    model.load_state_dict(state['state_dict'])
    optimizer.load_state_dict(state['optimizer'])
    scheduler.load_state_dict(state['scheduler'])

    return model, optimizer, resume_epoch, resume_iteration, scheduler
def loadCheckpoint(checkpoint_path: str, model: nn.Module, optimizer: optim,
                   scheduler: optim.lr_scheduler.MultiStepLR):
    state = torch.load(checkpoint_path)
    resume_epoch = state['epoch']
    model.load_state_dict(state['state_dict'])
    optimizer.load_state_dict(state['optimizer'])
    scheduler.load_state_dict(state['scheduler'])

    return model, optimizer, resume_epoch, scheduler
Esempio n. 5
0
 def load_checkpoint(self, model: torch.nn.Module, optimizer: torch.optim):
     state = torch.load(self.state_dir)
     try:
         model.load_state_dict(state['model_state_dict'])
     except RuntimeError:
         new_state_dict = OrderedDict()
         for k, v in state['model_state_dict'].items():
             name = k[7:]
             new_state_dict[name] = v
         model.load_state_dict(new_state_dict)
     optimizer.load_state_dict(state['optimizer_state_dict'])
     return model, optimizer
Esempio n. 6
0
def load_model_checkpoint(model: torch.nn.Module,
                          filename: str,
                          inference: bool,
                          map_location=None,
                          optimizer: torch.optim = None):
    checkpoint = torch.load(filename, map_location=map_location)
    if optimizer:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    if inference:
        model.eval()
    else:
        model.train()
    return model.load_state_dict(checkpoint['model_state_dict'])
Esempio n. 7
0
 def find_lr(dataloader,
             model,
             optimizer: torch.optim,
             criterion,
             device,
             num_steps,
             lr_min: float = 1e-7,
             lr_max: float = 10,
             beta: float = 0.98):
     model.to(device)
     optim_dict = optimizer.state_dict().copy()
     optimizer.param_groups[0]['lr'] = lr_min
     #     num_steps = len(dataloader) - 1
     scheduler = LrSchedulerFinder(optimizer, lr_min, lr_max, num_steps)
     model_dict = model.state_dict().copy()
     losses = list()
     lrs = list()
     avg_loss = 0
     best_loss = 0
     for idx_batch, (data, label) in tqdm(enumerate(dataloader, 1),
                                          total=num_steps):
         print("here")
         if idx_batch == num_steps:
             break
         y, kl = model(data.to(device))
         print(y, kl)
         loss = criterion(y, label, kl, 0)
         if np.isnan(loss.item()):
             print(loss.item())
         avg_loss = beta * avg_loss + (1 - beta) * loss.item()
         smooth_loss = avg_loss / (1 - beta**idx_batch)
         if idx_batch > 1 and smooth_loss > 4 * best_loss:
             break
         if smooth_loss < best_loss or idx_batch == 1:
             best_loss = smooth_loss
         losses.append(smooth_loss)
         lrs.append(scheduler.get_lr()[0])
         optimizer.zero_grad()
         loss.backward()
         optimizer.step()
         scheduler.step()
     model.load_state_dict(model_dict)
     optimizer.load_state_dict(optim_dict)
     return np.array(lrs), np.array(losses)
Esempio n. 8
0
def resume_checkpoint(resume_path: str, model: nn.Module,
                      optimizer: module_optimizer,
                      config: dict) -> (nn.Module, module_optimizer, int):
    """ Resume from saved checkpoint. """
    if not resume_path:
        return model, optimizer, 0

    log.info(f'Loading checkpoint: {resume_path}')
    checkpoint = torch.load(resume_path)
    model.load_state_dict(checkpoint['state_dict'])

    # load optimizer state from checkpoint only when optimizer type is not changed.
    if checkpoint['config']['optimizer']['type'] != config['optimizer']['type']:
        log.warning(
            "Warning: Optimizer type given in config file is different from "
            "that of checkpoint. Optimizer parameters not being resumed.")
    else:
        optimizer.load_state_dict(checkpoint['optimizer'])

    log.info(f'Checkpoint "{resume_path}" loaded')
    return model, optimizer, checkpoint['epoch']
Esempio n. 9
0
def run(data_loader: BaseDataLoader, encoder: Encoder, criterion, optimizer: optim, scheduler: optim.lr_scheduler,
          similarity_measure: Similarity, save_model,
          args:argparse.Namespace):
    encoder.to(device)
    best_accs = {"encode_acc": float('-inf'), "pivot_acc": float('-inf')}
    last_update = 0
    dev_arg_dict = {
        "use_mid": args.use_mid,
        "topk": args.val_topk,
        "trg_encoding_num": args.trg_encoding_num,
        "mid_encoding_num": args.mid_encoding_num
    }
    # lr_decay = scheduler is not None
    # if lr_decay:
    #     print("[INFO] using learning rate decay")
    for ep in range(args.max_epoch):
        encoder.train()
        train_loss = 0.0
        start_time = time.time()
        # if not args.mega:
        train_batches = data_loader.create_batches("train")
        # else:
        #     if ep <= 30:
        #         train_batches = data_loader.create_batches("train")
        #     else:
        #         train_batches = data_loader.create_megabatch(encoder)
        batch_num = 0
        t = 0
        for idx, batch in enumerate(train_batches):
            optimizer.zero_grad()
            cur_loss = calc_batch_loss(encoder, criterion, batch, args.mid_proportion, args.trg_encoding_num, args.mid_encoding_num)
            train_loss += cur_loss.item()
            cur_loss.backward()
            # optimizer.step()

            for p in list(filter(lambda p: p.grad is not None, encoder.parameters())):
                t += p.grad.data.norm(2).item()

            torch.nn.utils.clip_grad_norm_(encoder.parameters(), max_norm=5)
            optimizer.step()

            if encoder.name == "bilstm":
                # set all but forget gate bias to 0
                reset_bias(encoder.src_lstm)
                reset_bias(encoder.trg_lstm)
                # pass
            batch_num += 1
        print("[INFO] epoch {:d}: train loss={:.8f}, time={:.2f}".format(ep, train_loss / batch_num,
                                                                         time.time()-start_time))
        # print(t)

        if (ep + 1) % EPOCH_CHECK == 0:
            with torch.no_grad():
                encoder.eval()
                # eval
                train_batches = data_loader.create_batches("train")
                dev_batches = data_loader.create_batches("dev")
                start_time = time.time()

                recall, tot = eval_data(encoder, train_batches, dev_batches, similarity_measure, dev_arg_dict)
                dev_pivot_acc = recall[0] / float(tot)
                dev_encode_acc = recall[1] / float(tot)
                if dev_encode_acc > best_accs["encode_acc"]:
                    best_accs["encode_acc"] = dev_encode_acc
                    best_accs["pivot_acc"] = dev_pivot_acc
                    last_update = ep + 1
                    save_model(encoder, ep + 1, train_loss / batch_num, optimizer, args.model_path + "_" + "best" + ".tar")
                save_model(encoder, ep + 1, train_loss / batch_num, optimizer, args.model_path + "_" + "last" + ".tar")
                print("[INFO] epoch {:d}: encoding/pivoting dev acc={:.4f}/{:.4f}, time={:.2f}".format(
                                                                                            ep, dev_encode_acc, dev_pivot_acc,
                                                                                            time.time()-start_time))
                if args.lr_decay and ep + 1 - last_update > UPDATE_PATIENT:
                    new_lr = optimizer.param_groups[0]['lr'] * args.lr_scaler
                    best_info  = torch.load(args.model_path + "_" + "best" + ".tar")
                    encoder.load_state_dict(best_info["model_state_dict"])
                    optimizer.load_state_dict(best_info["optimizer_state_dict"])
                    optimizer.param_groups[0]['lr'] = new_lr
                    print("[INFO] reload best model ..")

                if ep + 1 - last_update > PATIENT:
                    print("[FINAL] in epoch {}, the best develop encoding/pivoting accuracy = {:.4f}/{:.4f}".format(ep + 1,
                                                                                                                    best_accs["encode_acc"],
                                                                                                                    best_accs["pivot_acc"]))
                    break