Пример #1
0
def train_mnli_meta(**kwargs):
    train, dev_matched_train, test, dev_matched_test, dev_mismatched_test, vocab = prepare_mnli_split(root='datasets/data',
                                                                                                      urls=['https://www.nyu.edu/projects/bowman/multinli/multinli_1.0.zip'],
                                                                                                      dir='MultiNLI',
                                                                                                      name='MultiNLI',
                                                                                                      data_path='datasets/data/MultiNLI/multinli_1.0',
                                                                                                      train_genres=[['government'], ['telephone'], ['slate'], ['travel']],
                                                                                                      test_genres=[['fiction']],
                                                                                                      max_len=60)

    weight_matrix = prepare_glove(glove_path="datasets/GloVe/glove.840B.300d.txt",
                                  vocab=vocab)

    train_loaders = [DataLoader(
        MultiNLIDataset(dataset=t),
        batch_size=kwargs['batch_size'],
        shuffle=True,
        num_workers=1,
        pin_memory=torch.cuda.is_available()) for t in train]

    val_matched_loaders = [DataLoader(
                           MultiNLIDataset(dataset=t),
                           batch_size=2000,
                           shuffle=True,
                           num_workers=1,
                           pin_memory=torch.cuda.is_available()) for t in dev_matched_train]

    model = construct_model(model_type=kwargs['type'],
                            weight_matrix=weight_matrix)

    num_parameters = sum([p.data.nelement() for p in model.parameters()])
    print(f'Number of model parameters: {num_parameters}')

    cudnn.benchmark = True

    if torch.cuda.is_available():
        torch.cuda.set_device(kwargs['device'])

    if torch.cuda.is_available():
        model = model.cuda()
        loss_function = torch.nn.CrossEntropyLoss().cuda()
    else:
        loss_function = torch.nn.CrossEntropyLoss()

    optimizer = construct_optimizer(optimizer=kwargs['optim'],
                                    model=model,
                                    lr=kwargs['lr_outer_meta'])

    meta_model = MetaTrainWrapper(module=model,
                                  inner_lr=kwargs['lr_inner_meta'],
                                  use_maml=kwargs['use_maml'],
                                  optim=optimizer,
                                  second_order=True,
                                  sample_task=True)

    train_batcher = Batcher(loaders=train_loaders,
                            batch_size=kwargs['num_inner_iterations'])
    meta_model.train()
    for epoch in range(kwargs['epochs']):
        for train_batch in tqdm(train_batcher):
            meta_model(tasks=[ClassifierTask() for _ in range(len(train_loaders))],
                       train_batch=train_batch,
                       val_loaders=train_loaders)

        print(f'Epoch {epoch + 1} Validation')
        prec = []
        for loader in val_matched_loaders:
            prec.append(validate(val_loader=loader,
                                 model=model,
                                 criterion=loss_function,
                                 epoch=epoch,
                                 print_freq=kwargs['print_freq'],
                                 writer=None))
        print(f'Average Matched Precision is {np.mean(prec)}')

    train_loader = DataLoader(
        MultiNLIDataset(dataset=test[0]),
        batch_size=kwargs['batch_size'],
        shuffle=True,
        num_workers=1,
        pin_memory=torch.cuda.is_available())

    val_matched_loader = DataLoader(
        MultiNLIDataset(dataset=dev_matched_test[0]),
        batch_size=kwargs['batch_size'],
        shuffle=True,
        num_workers=1,
        pin_memory=torch.cuda.is_available())

    val_mismatched_loader = DataLoader(
        MultiNLIDataset(dataset=dev_mismatched_test[0]),
        batch_size=kwargs['batch_size'],
        shuffle=True,
        num_workers=1,
        pin_memory=torch.cuda.is_available())

    print('Zero Shot Performance')

    validate(val_loader=train_loader,
             model=model,
             criterion=loss_function,
             epoch=0,
             print_freq=kwargs['print_freq'],
             writer=None)

    validate(val_loader=val_matched_loader,
             model=model,
             criterion=loss_function,
             epoch=0,
             print_freq=kwargs['print_freq'],
             writer=None)

    validate(val_loader=val_mismatched_loader,
             model=model,
             criterion=loss_function,
             epoch=0,
             print_freq=kwargs['print_freq'],
             writer=None)

    if kwargs['k'] > 0:
        print(f"{kwargs['k']}-Shot Performance")
        optimizer = construct_optimizer(optimizer=kwargs['optim'],
                                        model=model,
                                        lr=kwargs['lr_kshot'])

        train_batcher = Batcher(loaders=[train_loader],
                                batch_size=1)

        for i, train_batch in enumerate(train_batcher):
            train_single_epoch(train_loader=train_batch[0],
                               model=model,
                               criterion=loss_function,
                               optimizer=optimizer,
                               epoch=i,
                               total_steps=0,
                               print_freq=kwargs['print_freq'],
                               num_batches=1,
                               writer=None)

            validate(val_loader=val_matched_loader,
                     model=model,
                     criterion=loss_function,
                     epoch=0,
                     print_freq=kwargs['print_freq'],
                     writer=None)

            validate(val_loader=val_mismatched_loader,
                     model=model,
                     criterion=loss_function,
                     epoch=0,
                     print_freq=kwargs['print_freq'],
                     writer=None)

            if i >= kwargs['k']:
                break
Пример #2
0
def train_mnli(**kwargs):
    dir = set_directory(name=kwargs['type'], type_net=kwargs['type'])
    writer = SummaryWriter(dir)

    train, dev_matched, dev_mismatched, vocab = prepare_mnli(root='datasets/data',
                                                             urls=['https://www.nyu.edu/projects/bowman/multinli/multinli_1.0.zip'],
                                                             dir='MultiNLI',
                                                             name='MultiNLI',
                                                             data_path='datasets/data/MultiNLI/multinli_1.0',
                                                             max_len=60)

    weight_matrix = prepare_glove(glove_path="datasets/GloVe/glove.840B.300d.txt",
                                  vocab=vocab)

    train_loader = DataLoader(
        MultiNLIDataset(dataset=train),
        batch_size=kwargs['batch_size'],
        shuffle=True,
        num_workers=1,
        pin_memory=torch.cuda.is_available())

    val_loader = [DataLoader(
        MultiNLIDataset(dataset=loader),
        batch_size=kwargs['batch_size'],
        shuffle=True,
        num_workers=1,
        pin_memory=torch.cuda.is_available()) for loader in [dev_matched, dev_mismatched]]

    model = construct_model(model_type=kwargs['type'],
                            weight_matrix=weight_matrix)

    num_parameters = sum([p.data.nelement() for p in model.parameters()])
    print(f'Number of model parameters: {num_parameters}')

    if torch.cuda.is_available():
        torch.cuda.set_device(kwargs['device'])

    if torch.cuda.is_available():
        model = model.cuda()
        loss_function = torch.nn.CrossEntropyLoss().cuda()
    else:
        loss_function = torch.nn.CrossEntropyLoss()

    optimizer = construct_optimizer(optimizer=kwargs['optim'],
                                    model=model,
                                    lr=kwargs['lr'])

    total_steps = 0

    cudnn.benchmark = True

    for epoch in tqdm(range(kwargs['epochs'])):
        total_steps = train_single_epoch(train_loader=train_loader,
                                         model=model,
                                         criterion=loss_function,
                                         optimizer=optimizer,
                                         epoch=epoch,
                                         total_steps=total_steps,
                                         print_freq=kwargs['print_freq'],
                                         writer=writer)

        for loader in val_loader:
            validate(val_loader=loader,
                     model=model,
                     criterion=loss_function,
                     epoch=epoch,
                     print_freq=kwargs['print_freq'],
                     writer=writer)
Пример #3
0
def train_mnli_gradient_reversal(**kwargs):
    dir = set_directory(name=kwargs['type'], type_net=kwargs['type'])
    writer = SummaryWriter(dir)

    train, dev_matched_train, test, dev_matched_test, dev_mismatched_test, vocab = prepare_mnli_split(root='datasets/data',
                                                                                                      urls=['https://www.nyu.edu/projects/bowman/multinli/multinli_1.0.zip'],
                                                                                                      dir='MultiNLI',
                                                                                                      name='MultiNLI',
                                                                                                      data_path='datasets/data/MultiNLI/multinli_1.0',
                                                                                                      train_genres=[['government', 'telephone', 'slate', 'travel']],
                                                                                                      test_genres=[['fiction']],
                                                                                                      max_len=60)

    weight_matrix = prepare_glove(glove_path="datasets/GloVe/glove.840B.300d.txt",
                                  vocab=vocab)

    train_loader = DataLoader(
        MultiNLIDataset(dataset=train[0]),
        batch_size=kwargs['batch_size'],
        shuffle=True,
        num_workers=1,
        pin_memory=torch.cuda.is_available())

    test_loader = DataLoader(
        MultiNLIDataset(dataset=test[0]),
        batch_size=kwargs['batch_size'],
        shuffle=True,
        num_workers=1,
        pin_memory=torch.cuda.is_available())

    val_loader = DataLoader(
        MultiNLIDataset(dataset=dev_matched_train[0]),
        batch_size=kwargs['batch_size'],
        shuffle=True,
        num_workers=1,
        pin_memory=torch.cuda.is_available())

    model = construct_model_r(model_type=kwargs['type'],
                              weight_matrix=weight_matrix)

    num_parameters = sum([p.data.nelement() for p in model.parameters()])
    print(f'Number of model parameters: {num_parameters}')

    if torch.cuda.is_available():
        torch.cuda.set_device(kwargs['device'])

    if torch.cuda.is_available():
        model = model.cuda()
        loss_function = torch.nn.CrossEntropyLoss().cuda()
    else:
        loss_function = torch.nn.CrossEntropyLoss()

    optimizer = construct_optimizer(optimizer=kwargs['optim'],
                                    model=model,
                                    lr=kwargs['lr'])

    cudnn.benchmark = True

    total_steps = 0

    for epoch in tqdm(range(kwargs['epochs'])):
        total_steps = train_single_epoch_with_gradient_reversal(train_loader=train_loader,
                                                                val_loader=test_loader,
                                                                model=model,
                                                                criterion=loss_function,
                                                                optimizer=optimizer,
                                                                epoch=epoch,
                                                                alpha=1e-2,
                                                                total_steps=total_steps,
                                                                print_freq=kwargs['print_freq'],
                                                                writer=writer)

        validate(val_loader=val_loader,
                 model=model,
                 criterion=loss_function,
                 epoch=epoch,
                 print_freq=kwargs['print_freq'],
                 writer=writer)

    print('Zero Shot Performance')

    train_loader = DataLoader(
        MultiNLIDataset(dataset=test[0]),
        batch_size=kwargs['batch_size'],
        shuffle=True,
        num_workers=1,
        pin_memory=torch.cuda.is_available())

    val_loader = [DataLoader(
        MultiNLIDataset(dataset=dataset[0]),
        batch_size=kwargs['batch_size'],
        shuffle=True,
        num_workers=1,
        pin_memory=torch.cuda.is_available()) for dataset in [dev_matched_test, dev_mismatched_test]]

    validate(val_loader=train_loader,
             model=model,
             criterion=loss_function,
             epoch=epoch,
             print_freq=kwargs['print_freq'],
             writer=writer)

    for loader in val_loader:
        validate(val_loader=loader,
                 model=model,
                 criterion=loss_function,
                 epoch=epoch,
                 print_freq=kwargs['print_freq'],
                 writer=writer)
Пример #4
0
def train_basecnn(**kwargs):
    if kwargs['tensorboard']:
        name, directory = set_directory(name=kwargs['name'],
                                        type_net=kwargs['type_net'],
                                        dof=kwargs['dof'])
        writer = SummaryWriter(directory)
    else:
        writer = None

    train_loader, val_loader, iter_per_epoch = load_cifar10(
        batch_size=kwargs['batch_size'])

    model = BaseCNN(num_classes=10,
                    model_size=1,
                    type_net=kwargs['type_net'],
                    N=50000,
                    beta_ema=kwargs['beta_ema'])

    num_parameters = sum([p.data.nelement() for p in model.parameters()])
    print(f'Number of model parameters: {num_parameters}')

    if torch.cuda.is_available():
        torch.cuda.set_device(kwargs['device'])

    # for training on multiple GPUs.
    # Use CUDA_VISIBLE_DEVICES=0,1 to specify which GPUs to use
    if kwargs['multi_gpu']:
        model = torch.nn.DataParallel(model).cuda()
    else:
        if torch.cuda.is_available():
            model = model.cuda()

    optimizer = construct_optimizer(optimizer=kwargs['optim'],
                                    model=model,
                                    lr=kwargs['lr'])

    if kwargs['resume'] != '':
        kwargs[
            'start_epoch'], best_prec1, total_steps, model, optimizer = resume_from_checkpoint(
                resume_path=kwargs['resume'], model=model, optimizer=optimizer)
    else:
        total_steps = 0
        best_prec1 = 0.

    cudnn.benchmark = True

    loss_function = CrossEntropyLossWithAnnealing(
        iter_per_epoch=iter_per_epoch,
        total_steps=total_steps,
        anneal_type=kwargs['anneal_type'],
        anneal_kl=kwargs['anneal_kl'],
        epzero=kwargs['epzero'],
        epmax=kwargs['epmax'],
        anneal_maxval=kwargs['anneal_maxval'],
        writer=writer)

    for epoch in range(kwargs['start_epoch'], kwargs['epochs']):
        total_steps = train_single_epoch(train_loader=train_loader,
                                         model=model,
                                         criterion=loss_function,
                                         optimizer=optimizer,
                                         epoch=epoch,
                                         clip_var=kwargs['clip_var'],
                                         total_steps=total_steps,
                                         print_freq=kwargs['print_freq'],
                                         writer=writer,
                                         thres_stds=kwargs['thres_std'])

        prec1 = validate(val_loader=val_loader,
                         model=model,
                         criterion=loss_function,
                         epoch=epoch,
                         print_freq=kwargs['print_freq'],
                         writer=writer)

        if kwargs['restart'] and epoch % kwargs['restart_interval'] == 0:
            print('Restarting optimizer...')
            optimizer = construct_optimizer(optimizer=kwargs['restart_optim'],
                                            model=model,
                                            lr=kwargs['restart_lr'])

        is_best = prec1 > best_prec1
        if is_best:
            best_prec1 = prec1
        if isinstance(model, torch.nn.DataParallel):
            state = {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': max(prec1, best_prec1),
                'beta_ema': model.module.beta_ema,
                'optimizer': optimizer.state_dict(),
                'total_steps': total_steps
            }
            if model.module.beta_ema > 0:
                state['avg_params'] = model.module.avg_param
                state['steps_ema'] = model.module.steps_ema
        else:
            state = {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': max(prec1, best_prec1),
                'beta_ema': model.beta_ema,
                'optimizer': optimizer.state_dict(),
                'total_steps': total_steps
            }
            if model.beta_ema > 0:
                state['avg_params'] = model.avg_param
                state['steps_ema'] = model.steps_ema

        if epoch in kwargs['save_at']:
            name = f'checkpoint_{epoch}.pth.tar'
        else:
            name = 'checkpoint.pth.tar'

        save_checkpoint(state=state, is_best=is_best, name=name)
    print('Best accuracy: ', best_prec1)

    if writer is not None:
        writer.close()