Exemple #1
0
def run_main(args, rank=None):

    # Set the random seed manually for reproducibility.
    torch.manual_seed(args.seed)
    if args.parallel == 'DDP':
        n = torch.cuda.device_count() // args.world_size
        device = list(range(rank * n, (rank + 1) * n))
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    ###################################################################
    # Load data
    ###################################################################
    vocab = torch.load(args.save_vocab)
    pad_id = vocab.stoi['<pad>']
    sep_id = vocab.stoi['<sep>']

    if args.dataset == 'WikiText103':
        from data import WikiText103
        train_dataset, valid_dataset, test_dataset = WikiText103(
            vocab=vocab, single_line=False)
    elif args.dataset == 'BookCorpus':
        from data import BookCorpus
        train_dataset, valid_dataset, test_dataset = BookCorpus(
            vocab=vocab, min_sentence_len=args.min_sentence_len)

    if rank is not None:
        chunk_len = len(train_dataset.data) // args.world_size
        train_dataset.data = train_dataset.data[(rank *
                                                 chunk_len):((rank + 1) *
                                                             chunk_len)]
    train_dataset.data = generate_next_sentence_data(train_dataset.data, args)
    valid_dataset.data = generate_next_sentence_data(valid_dataset.data, args)
    test_dataset.data = generate_next_sentence_data(test_dataset.data, args)

    ###################################################################
    # Build the model
    ###################################################################
    pretrained_bert = torch.load(args.bert_model)
    model = NextSentenceTask(pretrained_bert)
    if args.checkpoint != 'None':
        model = torch.load(args.checkpoint)

    if args.parallel == 'DDP':
        model = model.to(device[0])
        from torch.nn.parallel import DistributedDataParallel as DDP
        model = DDP(model, device_ids=device)
    else:
        model = model.to(device)

    criterion = nn.CrossEntropyLoss()

    ###################################################################
    # Loop over epochs.
    ###################################################################

    lr = args.lr
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
    best_val_loss = None
    train_loss_log, val_loss_log = [], []

    for epoch in range(1, args.epochs + 1):
        epoch_start_time = time.time()
        train(train_dataset, model, train_loss_log, device, optimizer,
              criterion, epoch, scheduler, sep_id, pad_id, args, rank)
        val_loss = evaluate(valid_dataset, model, device, criterion, sep_id,
                            pad_id, args)
        val_loss_log.append(val_loss)

        if (rank is None) or (rank == 0):
            print('-' * 89)
            print('| end of epoch {:3d} | time: {:5.2f}s '
                  '| valid loss {:8.5f} | '.format(
                      epoch, (time.time() - epoch_start_time), val_loss))
            print('-' * 89)
        # Save the model if the validation loss is the best we've seen so far.
        if not best_val_loss or val_loss < best_val_loss:
            if rank is None:
                with open(args.save, 'wb') as f:
                    torch.save(model, f)
            elif rank == 0:
                with open(os.environ['SLURM_JOB_ID'] + '_' + args.save,
                          'wb') as f:
                    torch.save(model.state_dict(), f)
            best_val_loss = val_loss
        else:
            scheduler.step()
    ###################################################################
    # Load the best saved model and run on test data
    ###################################################################
    if args.parallel == 'DDP':
        # [TODO] put dist.barrier() back
        # dist.barrier()
        # configure map_location properly
        rank0_devices = [x - rank * len(device) for x in device]
        device_pairs = zip(rank0_devices, device)
        map_location = {'cuda:%d' % x: 'cuda:%d' % y for x, y in device_pairs}
        model.load_state_dict(
            torch.load(os.environ['SLURM_JOB_ID'] + '_' + args.save,
                       map_location=map_location))
        test_loss = evaluate(test_dataset, model, device, criterion, sep_id,
                             pad_id, args)
        if rank == 0:
            print('=' * 89)
            print('| End of training | test loss {:8.5f} | test ppl {:8.5f}'.
                  format(test_loss, math.exp(test_loss)))
            print('=' * 89)
            print_loss_log(os.environ['SLURM_JOB_ID'] + '_ns_loss.txt',
                           train_loss_log, val_loss_log, test_loss, args)
            ###############################################################################
            # Save the bert model layer
            ###############################################################################
            with open(os.environ['SLURM_JOB_ID'] + '_' + args.save, 'wb') as f:
                torch.save(model.module.bert_model, f)
            with open(os.environ['SLURM_JOB_ID'] + '_' + 'full_ns_model.pt',
                      'wb') as f:
                torch.save(model.module, f)
    else:
        with open(args.save, 'rb') as f:
            model = torch.load(f)

        test_loss = evaluate(test_dataset, model, device, criterion, sep_id,
                             pad_id)
        print('=' * 89)
        print(
            '| End of training | test loss {:8.5f} | test ppl {:8.5f}'.format(
                test_loss, math.exp(test_loss)))
        print('=' * 89)
        print_loss_log('ns_loss.txt', train_loss_log, val_loss_log, test_loss)

        with open(args.save, 'wb') as f:
            torch.save(model.bert_model, f)
    for epoch in range(1, args.epochs + 1):
        epoch_start_time = time.time()
        train()
        val_loss, val_exact, val_f1 = evaluate(dev_dataset, vocab)
        val_loss_log.append(val_loss)
        print('-' * 89)
        print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
              'exact {:8.3f}% | '
              'f1 {:8.3f}%'.format(epoch, (time.time() - epoch_start_time),
                                   val_loss, val_exact, val_f1))
        print('-' * 89)
        if best_f1 is None or val_f1 > best_f1:
            with open(args.save, 'wb') as f:
                torch.save(model, f)
            best_f1 = val_f1
        else:
            scheduler.step()

    with open(args.save, 'rb') as f:
        model = torch.load(f)
    test_loss, test_exact, test_f1 = evaluate(dev_dataset, vocab)
    print('=' * 89)
    print(
        '| End of training | test loss {:5.2f} | exact {:8.3f}% | f1 {:8.3f}%'.
        format(test_loss, test_exact, test_f1))
    print('=' * 89)
    print_loss_log('qa_loss.txt', train_loss_log, val_loss_log, test_loss)
    with open(args.save, 'wb') as f:
        torch.save(model, f)
Exemple #3
0
def run_main(args, rank=None):
    # Set the random seed manually for reproducibility.
    torch.manual_seed(args.seed)
    if args.parallel == 'DDP':
        n = torch.cuda.device_count() // args.world_size
        device = list(range(rank * n, (rank + 1) * n))
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    ###############################################################################
    # Import dataset
    ###############################################################################
    import torchtext
    if args.dataset == 'WikiText103':
        from torchtext.experimental.datasets import WikiText103 as WLMDataset
    elif args.dataset == 'WikiText2':
        from torchtext.experimental.datasets import WikiText2 as WLMDataset
    elif args.dataset == 'WMTNewsCrawl':
        from data import WMTNewsCrawl as WLMDataset
    elif args.dataset == 'EnWik9':
        from torchtext.datasets import EnWik9
    elif args.dataset == 'BookCorpus':
        from data import BookCorpus
    else:
        print("dataset for MLM task is not supported")

    try:
        vocab = torch.load(args.save_vocab)
    except:
        train_dataset, test_dataset, valid_dataset = WLMDataset()
        old_vocab = train_dataset.vocab
        vocab = torchtext.vocab.Vocab(counter=old_vocab.freqs,
                                      specials=['<unk>', '<pad>', '<MASK>'])
        with open(args.save_vocab, 'wb') as f:
            torch.save(vocab, f)

    if args.dataset == 'WikiText103' or args.dataset == 'WikiText2':
        train_dataset, test_dataset, valid_dataset = WLMDataset(vocab=vocab)
    elif args.dataset == 'WMTNewsCrawl':
        test_dataset, valid_dataset = torchtext.experimental.datasets.WikiText2(
            vocab=vocab, data_select=('test', 'valid'))
        train_dataset, = WLMDataset(vocab=vocab, data_select='train')
    elif args.dataset == 'EnWik9':
        enwik9 = EnWik9()
        idx1, idx2 = int(len(enwik9) * 0.8), int(len(enwik9) * 0.9)
        train_data = torch.tensor([vocab.stoi[_id]
                                   for _id in enwik9[0:idx1]]).long()
        val_data = torch.tensor([vocab.stoi[_id]
                                 for _id in enwik9[idx1:idx2]]).long()
        test_data = torch.tensor([vocab.stoi[_id]
                                  for _id in enwik9[idx2:]]).long()
        from torchtext.experimental.datasets import LanguageModelingDataset
        train_dataset = LanguageModelingDataset(train_data, vocab)
        valid_dataset = LanguageModelingDataset(val_data, vocab)
        test_dataset = LanguageModelingDataset(test_data, vocab)
    elif args.dataset == 'BookCorpus':
        train_dataset, test_dataset, valid_dataset = BookCorpus(vocab)

    train_data = batchify(train_dataset.data, args.batch_size, args)

    if rank is not None:
        # Chunk training data by rank for different gpus
        chunk_len = len(train_data) // args.world_size
        train_data = train_data[(rank * chunk_len):((rank + 1) * chunk_len)]

    val_data = batchify(valid_dataset.data, args.batch_size, args)
    test_data = batchify(test_dataset.data, args.batch_size, args)

    ###############################################################################
    # Build the model
    ###############################################################################
    ntokens = len(train_dataset.get_vocab())
    model = MLMTask(ntokens, args.emsize, args.nhead, args.nhid, args.nlayers,
                    args.dropout)
    if args.checkpoint != 'None':
        model.bert_model = torch.load(args.checkpoint)

    if args.parallel == 'DDP':
        model = model.to(device[0])
        # model = nn.DataParallel(model)  # Wrap up by nn.DataParallel
        model = DDP(model, device_ids=device)
    else:
        model = model.to(device)
    criterion = nn.CrossEntropyLoss()

    ###############################################################################
    # Loop over epochs.
    ###############################################################################
    lr = args.lr
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
    best_val_loss = None
    train_loss_log, val_loss_log = [], []

    for epoch in range(1, args.epochs + 1):
        epoch_start_time = time.time()
        train(model, train_dataset.vocab, train_loss_log, train_data,
              optimizer, criterion, ntokens, epoch, scheduler, args, device,
              rank)
        # train()
        val_loss = evaluate(val_data, model, train_dataset.vocab, ntokens,
                            criterion, args, device)

        if (rank is None) or (rank == 0):
            val_loss_log.append(val_loss)
            print('-' * 89)
            print(
                '| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
                'valid ppl {:8.2f}'.format(epoch,
                                           (time.time() - epoch_start_time),
                                           val_loss, math.exp(val_loss)))
            print('-' * 89)

        # Save the model if the validation loss is the best we've seen so far.
        if not best_val_loss or val_loss < best_val_loss:
            if rank is None:
                with open(args.save, 'wb') as f:
                    torch.save(model, f)
            elif rank == 0:
                with open(os.environ['SLURM_JOB_ID'] + '_' + args.save,
                          'wb') as f:
                    torch.save(model.state_dict(), f)
            best_val_loss = val_loss
        else:
            scheduler.step()

    ###############################################################################
    # Load the best saved model.
    ###############################################################################
    if args.parallel == 'DDP':
        dist.barrier()
        # configure map_location properly
        rank0_devices = [x - rank * len(device) for x in device]
        device_pairs = zip(rank0_devices, device)
        map_location = {'cuda:%d' % x: 'cuda:%d' % y for x, y in device_pairs}
        model.load_state_dict(
            torch.load(os.environ['SLURM_JOB_ID'] + '_' + args.save,
                       map_location=map_location))

        ###############################################################################
        # Run on test data.
        ###############################################################################
        test_loss = evaluate(test_data, model, train_dataset.vocab, ntokens,
                             criterion, args, device)
        if rank == 0:
            print('=' * 89)
            print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.
                  format(test_loss, math.exp(test_loss)))
            print('=' * 89)
            print_loss_log(os.environ['SLURM_JOB_ID'] + '_mlm_loss.txt',
                           train_loss_log, val_loss_log, test_loss, args)

            ###############################################################################
            # Save the bert model layer
            ###############################################################################
            with open(os.environ['SLURM_JOB_ID'] + '_' + args.save, 'wb') as f:
                torch.save(model.module.bert_model, f)
            with open(os.environ['SLURM_JOB_ID'] + '_mlm_model.pt', 'wb') as f:
                torch.save(model.module, f)
    else:
        with open(args.save, 'rb') as f:
            model = torch.load(f)
        test_loss = evaluate(test_data, model, train_dataset.vocab, ntokens,
                             criterion, args, device)
        print('=' * 89)
        print(
            '| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
                test_loss, math.exp(test_loss)))
        print('=' * 89)
        print_loss_log('mlm_loss.txt', train_loss_log, val_loss_log, test_loss,
                       args)

        ###############################################################################
        # Save the bert model layer
        ###############################################################################
        with open(args.save, 'wb') as f:
            torch.save(model.module.bert_model, f)
        with open('mlm_model.pt', 'wb') as f:
            torch.save(model.module, f)