Example #1
0
def Train_with_early_stop(loss_func, opt, model, train_dl, val_dl, target_one_hot=False):   
    
    model.train()
    
    # Accumulate loss values for ploting 
    loss_values = []
    
    # Initialize scheduler
    es = EarlyStopping(patience=1)
    
    epoch = 0
    
    # Run the training loop for defined number of epochs
    while True:

        # Print epoch
        print(f'Starting epoch {epoch+1}')

        # Set current loss value
        current_loss = 0.0

        # Iterate over the DataLoader for training data
        for i, (inputs, targets) in enumerate(train_dl, 0):

            # Zero the gradients
            opt.zero_grad()

            # Perform forward pass
            outputs = model(inputs.float())

            # Turn target batch to one-hot for MSELoss
            if target_one_hot:
                targets = F.one_hot(targets, num_classes=10).float()
            
            # Compute loss
            loss = loss_func(outputs.float(), targets)

            # Perform backward pass
            loss.backward()

            # Perform optimization
            opt.step()

            # Print statistics
            current_loss += loss.item()
            
            if i % 100 == 99:
                #print(f'Loss after {i + 1} mini-batches: {current_loss / 500}')
                loss_values.append(current_loss / 500)
                current_loss = 0.0
                
        epoch += 1
        # evalution on dev set (i.e., holdout from training)
        metric = eval(model, val_dl, loss_func, target_one_hot)  
        if es.step(metric): break  # early stop criterion is met
            

    return loss_values
Example #2
0
          'loss_train: {:.4f}'.format(loss_train.item()),
          'acc_train: {:.4f}'.format(acc_train.item()),
          'loss_val: {:.4f}'.format(loss_val.item()),
          'acc_val: {:.4f}'.format(acc_val.item()))

    return loss_val.item(), acc_val.item()


def test():
    model.eval()
    output = model(features, adj)
    loss_test = F.nll_loss(output[idx_test], labels[idx_test])
    acc_test = accuracy(output[idx_test], labels[idx_test])
    print("Test set results:", "loss= {:.4f}".format(loss_test.item()),
          "accuracy= {:.4f}".format(acc_test.item()))


stopping_args = Stop_args(patience=args.patience, max_epochs=args.epochs)
early_stopping = EarlyStopping(model, **stopping_args)
for epoch in range(args.epochs):
    loss_val, acc_val = train(epoch)
    if early_stopping.check([acc_val, loss_val], epoch):
        break

print("Optimization Finished!")

# Restore best model
print('Loading {}th epoch'.format(early_stopping.best_epoch))
model.load_state_dict(early_stopping.best_state)
test()
Example #3
0
def main(args):

    ts = time.strftime('%Y-%b-%d-%H:%M:%S', time.gmtime())

    splits = ['train', 'valid'] + (['test'] if args.test else [])

    datasets = OrderedDict()
    for split in splits:
        datasets[split] = PTB(data_dir=args.data_dir,
                              split=split,
                              create_data=args.create_data,
                              max_sequence_length=args.max_sequence_length,
                              min_occ=args.min_occ)

    model = SentenceRNN(vocab_size=datasets['train'].vocab_size,
                        sos_idx=datasets['train'].sos_idx,
                        eos_idx=datasets['train'].eos_idx,
                        pad_idx=datasets['train'].pad_idx,
                        unk_idx=datasets['train'].unk_idx,
                        max_sequence_length=args.max_sequence_length,
                        embedding_size=args.embedding_size,
                        rnn_type=args.rnn_type,
                        hidden_size=args.hidden_size,
                        word_dropout=args.word_dropout,
                        embedding_dropout=args.embedding_dropout,
                        latent_size=args.latent_size,
                        num_layers=args.num_layers,
                        bidirectional=args.bidirectional)

    if torch.cuda.is_available():
        model = model.cuda()

    print(model)

    if args.tensorboard_logging:
        writer = SummaryWriter(
            os.path.join(args.logdir, experiment_name_rnn(args, ts)))
        writer.add_text("model", str(model))
        writer.add_text("args", str(args))
        writer.add_text("ts", ts)

    save_model_path = os.path.join(args.save_model_path, ts)
    os.makedirs(save_model_path)

    NLL = torch.nn.NLLLoss(size_average=False,
                           ignore_index=datasets['train'].pad_idx)

    def loss_fn(logp, target, length):

        # cut-off unnecessary padding from target, and flatten
        target = target[:, :torch.max(length).item()].contiguous().view(-1)
        logp = logp.view(-1, logp.size(2))

        # Negative Log Likelihood
        NLL_loss = NLL(logp, target)

        return NLL_loss

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    tensor = torch.cuda.FloatTensor if torch.cuda.is_available(
    ) else torch.Tensor
    step = 0
    early_stopping = EarlyStopping(history=10)
    for epoch in range(args.epochs):
        early_stopping_flag = False
        for split in splits:

            data_loader = DataLoader(dataset=datasets[split],
                                     batch_size=args.batch_size,
                                     shuffle=split == 'train',
                                     num_workers=cpu_count(),
                                     pin_memory=torch.cuda.is_available())

            # tracker = defaultdict(tensor)
            tracker = defaultdict(list)

            # Enable/Disable Dropout
            if split == 'train':
                model.train()
            else:
                model.eval()

            for iteration, batch in enumerate(data_loader):

                batch_size = batch['input'].size(0)

                for k, v in batch.items():
                    if torch.is_tensor(v):
                        batch[k] = to_var(v)

                # Forward pass
                logp = model(batch['input'], batch['length'])

                # loss calculation
                NLL_loss = loss_fn(logp, batch['target'], batch['length'])
                loss = (NLL_loss) / batch_size

                # backward + optimization
                if split == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    step += 1

                # bookkeepeing
                # tracker['Loss'] = torch.cat((tracker['Loss'], loss.data))
                tracker['Loss'].append(loss.item())

                if args.tensorboard_logging:
                    writer.add_scalar("%s/NLL_Loss" % split.upper(),
                                      NLL_loss.item() / batch_size,
                                      epoch * len(data_loader) + iteration)

                if iteration % args.print_every == 0 or iteration + 1 == len(
                        data_loader):
                    logger.info("%s Batch %04d/%i, Loss %9.4f" %
                                (split.upper(), iteration,
                                 len(data_loader) - 1, loss.item()))
            mean_loss = sum(tracker['Loss']) / len(tracker['Loss'])

            logger.info("%s Epoch %02d/%i, Mean Loss %9.4f" %
                        (split.upper(), epoch, args.epochs, mean_loss))

            if args.tensorboard_logging:
                writer.add_scalar("%s-Epoch/Loss" % split.upper(), mean_loss,
                                  epoch)

            if split == 'valid':
                if (args.early_stopping):
                    if (early_stopping.check(mean_loss)):
                        early_stopping_flag = True

            # save checkpoint
            if split == 'train':
                checkpoint_path = os.path.join(save_model_path,
                                               "E%i.pytorch" % (epoch))
                torch.save(model.state_dict(), checkpoint_path)
                logger.info("Model saved at %s" % checkpoint_path)

        if (early_stopping_flag):
            print("Early stopping trigerred. Training stopped...")
            break
def train():
    train_set=make_dataset('C1-P1_Train')
    train_loader=Dataloader(dataset=train_set,batch_size=opt.train_batch_size,shuffle=True,num_workers=opt.num_workers)

    dev_set=make_dataset('C1-P1_Dev')
    dev_loader=Dataloader(dataset=dev_set,batch_size=opt.dev_batch_size,shuffle=True,num_workers=opt.num_workers)

    net=get_net(opt.model)
    if (opt.model[0:9] == 'resnest'):
        model=net(opt.num_classes)
    model = net
    model=model.cuda(opt.cuda_devices)

    best_model_params_acc = copy.deepcopy(model.state_dict())
    best_model_params_loss = copy.deepcopy(model.state_dict())

    best_acc=0.0
    best_loss = float('inf')

    training_loss_list = []
    training_acc_list = []
    dev_loss_list = []
    dev_acc_list = []

    criterion = nn.CrossEntropyLoss()
    
    # optimizer = adabound.AdaBound(model.parameters(), lr=opt.lr, final_lr=0.1)
    # optimizer = torch.optim.Adam(params=model.parameters(), lr=opt.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0001, amsgrad=True)
    optimizer = torch.optim.SGD(model.parameters(), lr=opt.lr, momentum=0.9, weight_decay=5e-4, nesterov=True)
    # optimizer = torch.optim.RMSprop(params=model.parameters(), lr=opt.lr, alpha=0.99, eps=1e-08, weight_decay=5e-4, momentum=0.9, centered=False)
    # optimizer = torch.optim.AdamW(params=model.parameters(), lr=opt.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=5e-4, amsgrad=True)
    step = 0

    # scheduler = scheduler = StepLR(optimizer, step_size=10, gamma=0.5, last_epoch=-1)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=4, verbose=True, cooldown=1)
    record=open('record.txt','w')

    early_stopping = EarlyStopping(patience=20, verbose=True)

    for epoch in range(opt.epochs):
        print(f'Epoch: {epoch+1}/{opt.epochs}')
        print('-'*len(f'Epoch: {epoch+1}/{opt.epochs}'))

        training_loss = 0.0
        training_corrects = 0

        model.train()

        for i, (inputs,labels) in enumerate(tqdm(train_loader)):
            inputs=Variable(inputs.cuda(opt.cuda_devices))
            labels=Variable(labels.cuda(opt.cuda_devices))

            optimizer.zero_grad()
            outputs=model(inputs)

            _, preds = torch.max(outputs.data, 1)
            loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()

            training_loss += loss.item() * inputs.size(0)
            training_corrects += torch.sum(preds == labels.data)
        
        training_loss = training_loss / len(train_set)
        training_acc = float(training_corrects) / len(train_set)

        training_loss_list.append(training_loss)
        training_acc_list.append(training_acc)

        print(f'Training loss: {training_loss:.4f}\taccuracy: {training_acc:.4f}')

        model.eval()

        dev_loss=0.0
        dev_corrects=0

        for i,(inputs,labels) in enumerate(tqdm(dev_loader)):
            inputs=Variable(inputs.cuda(opt.cuda_devices))
            labels=Variable(labels.cuda(opt.cuda_devices))

            outputs=model(inputs)

            _,preds=torch.max(outputs.data,1)
            loss=criterion(outputs,labels)

            dev_loss+=loss.item()*inputs.size(0)
            dev_corrects+=torch.sum(preds==labels.data)

        dev_loss=dev_loss/len(dev_set)
        dev_acc=float(dev_corrects)/len(dev_set)

        dev_loss_list.append(dev_loss)
        dev_acc_list.append(dev_acc)

        print(f'Dev loss: {dev_loss:.4f}\taccuracy: {dev_acc:.4f}\n')

        scheduler.step(dev_loss)
        early_stopping(dev_loss, model)
        if early_stopping.early_stop:
            print("Early Stopping")
            break

        if dev_acc > best_acc:
            best_acc = dev_acc
            best_acc_dev_loss = dev_loss

            best_train_acc=training_acc
            best_train_loss=training_loss 

            best_model_params_acc = copy.deepcopy(model.state_dict())
        
        if dev_loss < best_loss:
            the_acc = dev_acc
            best_loss = dev_loss

            the_train_acc = training_acc
            the_train_loss =  training_loss

            best_model_params_loss = copy.deepcopy(model.state_dict())

        if (epoch+1)%50==0:
            model.load_state_dict(best_model_params_loss)
            weight_path=Path(opt.checkpoint_dir).joinpath(f'model-{epoch+1}epoch-{best_loss:.02f}-loss-{the_acc:.02f}-acc.pth')
            torch.save(model,str(weight_path))

            model.load_state_dict(best_model_params_acc)
            weight_path=Path(opt.checkpoint_dir).joinpath(f'model-{epoch+1}epoch-{best_acc:.02f}-acc.pth')
            torch.save(model,str(weight_path))

            record.write(f'{epoch+1}\n')
            record.write(f'Best training loss: {best_train_loss:.4f}\tBest training accuracy: {best_train_acc:.4f}\n')
            record.write(f'Best dev loss: {best_acc_dev_loss:.4f}\tBest dev accuracy: {best_acc:.4f}\n\n')
            visualization(training_loss_list, training_acc_list, dev_loss_list, dev_acc_list, epoch+1)


        """
        if (epoch+1) == 100:
            scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=4, verbose=True, cooldown=1)
            early_stopping = EarlyStopping(patience=10, verbose=True)"""
        
        """
        if (epoch+1) >= 50:
            early_stopping(dev_loss,model)
            if early_stopping.early_stop:
                print("Early Stoppping")
                break"""

    print('Based on best accuracy:')
    print(f'Best training loss: {best_train_loss:.4f}\t Best training accuracy: {best_train_acc:.4f}')
    print(f'Best dev loss: { best_acc_dev_loss:.4f}\t Best dev accuracy: {best_acc:.4f}\n')
        
    model.load_state_dict(best_model_params_acc)
    weight_path=Path(opt.checkpoint_dir).joinpath(f'model-{best_acc:.02f}-best_acc.pth')
    torch.save(model, str(weight_path))

    print('Based on best loss:')
    print(f'Best training loss: {the_train_loss:.4f}\t Best training accuracy: {the_train_acc:.4f}')
    print(f'Best dev loss: {best_loss:.4f}\t Best dev accuracy: {the_acc:.4f}\n')
        
    model.load_state_dict(best_model_params_loss)
    weight_path=Path(opt.checkpoint_dir).joinpath(f'model-{best_loss:.02f}-best_loss-{the_acc:.02f}-acc.pth')
    torch.save(model, str(weight_path))

    visualization(training_loss_list, training_acc_list, dev_loss_list, dev_acc_list, epoch+1)
Example #5
0
def main(args):

    ts = time.strftime('%Y-%b-%d-%H:%M:%S', time.gmtime())

    splits = ['train', 'valid'] + (['test'] if args.test else [])

    datasets = OrderedDict()
    for split in splits:
        datasets[split] = PTB(data_dir=args.data_dir,
                              split=split,
                              create_data=args.create_data,
                              max_sequence_length=args.max_sequence_length,
                              min_occ=args.min_occ)

    model = SentenceVAE(vocab_size=datasets['train'].vocab_size,
                        sos_idx=datasets['train'].sos_idx,
                        eos_idx=datasets['train'].eos_idx,
                        pad_idx=datasets['train'].pad_idx,
                        unk_idx=datasets['train'].unk_idx,
                        max_sequence_length=args.max_sequence_length,
                        embedding_size=args.embedding_size,
                        rnn_type=args.rnn_type,
                        hidden_size=args.hidden_size,
                        word_dropout=args.word_dropout,
                        embedding_dropout=args.embedding_dropout,
                        latent_size=args.latent_size,
                        num_layers=args.num_layers,
                        bidirectional=args.bidirectional)

    if torch.cuda.is_available():
        model = model.cuda()

    print(model)

    if args.tensorboard_logging:
        writer = SummaryWriter(
            os.path.join(args.logdir, experiment_name(args, ts)))
        writer.add_text("model", str(model))
        writer.add_text("args", str(args))
        writer.add_text("ts", ts)

    save_model_path = os.path.join(args.save_model_path, ts)
    os.makedirs(save_model_path)

    def kl_anneal_function(anneal_function, step, x1, x2):
        if anneal_function == 'identity':
            return 1
        elif anneal_function == 'linear':
            return min(1, step / x1)
        elif anneal_function == 'logistic':
            return float(1 / (1 + np.exp(-x2 * (step - x1))))
        elif anneal_function == 'cyclic_log':
            return float(1 / (1 + np.exp(-x2 * ((step % (3 * x1)) - x1))))
        elif anneal_function == 'cyclic_lin':
            return min(1, (step % (3 * x1)) / x1)

    ReconLoss = torch.nn.NLLLoss(size_average=False,
                                 ignore_index=datasets['train'].pad_idx)

    def loss_fn(logp, target, length, mean, logv, anneal_function, step, x1,
                x2):

        # cut-off unnecessary padding from target, and flatten
        target = target[:, :torch.max(length).item()].contiguous().view(-1)
        logp = logp.view(-1, logp.size(2))

        # Negative Log Likelihood
        recon_loss = ReconLoss(logp, target)

        # KL Divergence
        KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
        KL_weight = kl_anneal_function(anneal_function, step, x1, x2)

        return recon_loss, KL_loss, KL_weight

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    tensor = torch.cuda.FloatTensor if torch.cuda.is_available(
    ) else torch.Tensor
    step = 0

    early_stopping = EarlyStopping(history=10)
    for epoch in range(args.epochs):

        early_stopping_flag = False
        for split in splits:

            data_loader = DataLoader(dataset=datasets[split],
                                     batch_size=args.batch_size,
                                     shuffle=split == 'train',
                                     num_workers=cpu_count(),
                                     pin_memory=torch.cuda.is_available())

            # tracker = defaultdict(tensor)
            tracker = defaultdict(list)

            # Enable/Disable Dropout
            if split == 'train':
                model.train()
            else:
                model.eval()

            for iteration, batch in enumerate(data_loader):

                batch_size = batch['input'].size(0)

                for k, v in batch.items():
                    if torch.is_tensor(v):
                        batch[k] = to_var(v)

                # Forward pass
                logp, mean, logv, z = model(batch['input'], batch['length'])

                # loss calculation
                recon_loss, KL_loss, KL_weight = loss_fn(
                    logp, batch['target'], batch['length'], mean, logv,
                    args.anneal_function, step, args.x1, args.x2)

                if split == 'train':
                    loss = (recon_loss + KL_weight * KL_loss) / batch_size
                else:
                    # report complete elbo when validation
                    loss = (recon_loss + KL_loss) / batch_size

                # backward + optimization
                if split == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    step += 1

                # bookkeepeing
                tracker['negELBO'].append(loss.item())

                if args.tensorboard_logging:
                    writer.add_scalar("%s/Negative_ELBO" % split.upper(),
                                      loss.item(),
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/Recon_Loss" % split.upper(),
                                      recon_loss.item() / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL_Loss" % split.upper(),
                                      KL_loss.item() / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL_Weight" % split.upper(),
                                      KL_weight,
                                      epoch * len(data_loader) + iteration)

                if iteration % args.print_every == 0 or iteration + 1 == len(
                        data_loader):
                    # print(step)
                    # logger.info("Step = %d"%step)
                    logger.info(
                        "%s Batch %04d/%i, Loss %9.4f, Recon-Loss %9.4f, KL-Loss %9.4f, KL-Weight %6.3f"
                        % (split.upper(), iteration, len(data_loader) - 1,
                           loss.item(), recon_loss.item() / batch_size,
                           KL_loss.item() / batch_size, KL_weight))

                if split == 'valid':
                    if 'target_sents' not in tracker:
                        tracker['target_sents'] = list()
                    tracker['target_sents'] += idx2word(
                        batch['target'].data,
                        i2w=datasets['train'].get_i2w(),
                        pad_idx=datasets['train'].pad_idx)
                    # tracker['z'] = torch.cat((tracker['z'], z.data), dim=0)
                    # print(z.data.shape)
                    tracker['z'].append(z.data.tolist())
            mean_loss = sum(tracker['negELBO']) / len(tracker['negELBO'])

            logger.info("%s Epoch %02d/%i, Mean Negative ELBO %9.4f" %
                        (split.upper(), epoch, args.epochs, mean_loss))
            # print(mean_loss)

            if args.tensorboard_logging:
                writer.add_scalar("%s-Epoch/NegELBO" % split.upper(),
                                  mean_loss, epoch)

            # save a dump of all sentences and the encoded latent space
            if split == 'valid':
                dump = {
                    'target_sents': tracker['target_sents'],
                    'z': tracker['z']
                }
                if not os.path.exists(os.path.join('dumps', ts)):
                    os.makedirs('dumps/' + ts)
                with open(
                        os.path.join('dumps/' + ts +
                                     '/valid_E%i.json' % epoch),
                        'w') as dump_file:
                    json.dump(dump, dump_file)
                if (args.early_stopping):
                    if (early_stopping.check(mean_loss)):
                        early_stopping_flag = True

            # save checkpoint
            if split == 'train':
                checkpoint_path = os.path.join(save_model_path,
                                               "E%i.pytorch" % (epoch))
                torch.save(model.state_dict(), checkpoint_path)
                logger.info("Model saved at %s" % checkpoint_path)

        if (early_stopping_flag):
            print("Early stopping trigerred. Training stopped...")
            break
Example #6
0
                                           worker_init_fn = _init_fn)
n_imgs = len(train_dataset)
n_maps = 4

valid_dataset = PF_Pascal(args.valid_csv_path, args.valid_image_path, args.feature_h, args.feature_w, args.eval_type)
valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset,
                                           batch_size=1,
                                           shuffle=False, num_workers=args.num_workers)

# Instantiate model
net = PMDNet(args.feature_h, args.feature_w, beta=args.beta, kernel_sigma=args.kernel_sigma)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)

noise_model = NoiseModule(num_imgs=n_imgs, num_maps=n_maps, img_size=(args.feature_w, args.feature_h), device=device)
early_stopping = EarlyStopping(net, valid_loader, args.feature_w, args.feature_h, device=device)

if resume:
    print("Load pre-trained weights")
    best_weights = torch.load("./w1/best_checkpoint.pt")
    adap3_dict = best_weights['state_dict1']
    adap4_dict = best_weights['state_dict2']
    chn4_dict = best_weights['chn4_dict']
    net.adap_layer_feat3.load_state_dict(adap3_dict)
    net.adap_layer_feat4.load_state_dict(adap4_dict)
    net.chn4.load_state_dict(chn4_dict)

# Instantiate loss
criterion = loss_function(args).to(device)

# Instantiate optimizer