def run_main(args, rank=None): # Set the random seed manually for reproducibility. torch.manual_seed(args.seed) if args.parallel == 'DDP': n = torch.cuda.device_count() // args.world_size device = list(range(rank * n, (rank + 1) * n)) else: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ################################################################### # Load data ################################################################### vocab = torch.load(args.save_vocab) pad_id = vocab.stoi['<pad>'] sep_id = vocab.stoi['<sep>'] if args.dataset == 'WikiText103': from data import WikiText103 train_dataset, valid_dataset, test_dataset = WikiText103( vocab=vocab, single_line=False) elif args.dataset == 'BookCorpus': from data import BookCorpus train_dataset, valid_dataset, test_dataset = BookCorpus( vocab=vocab, min_sentence_len=args.min_sentence_len) if rank is not None: chunk_len = len(train_dataset.data) // args.world_size train_dataset.data = train_dataset.data[(rank * chunk_len):((rank + 1) * chunk_len)] train_dataset.data = generate_next_sentence_data(train_dataset.data, args) valid_dataset.data = generate_next_sentence_data(valid_dataset.data, args) test_dataset.data = generate_next_sentence_data(test_dataset.data, args) ################################################################### # Build the model ################################################################### pretrained_bert = torch.load(args.bert_model) model = NextSentenceTask(pretrained_bert) if args.checkpoint != 'None': model = torch.load(args.checkpoint) if args.parallel == 'DDP': model = model.to(device[0]) from torch.nn.parallel import DistributedDataParallel as DDP model = DDP(model, device_ids=device) else: model = model.to(device) criterion = nn.CrossEntropyLoss() ################################################################### # Loop over epochs. ################################################################### lr = args.lr optimizer = torch.optim.SGD(model.parameters(), lr=lr) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1) best_val_loss = None train_loss_log, val_loss_log = [], [] for epoch in range(1, args.epochs + 1): epoch_start_time = time.time() train(train_dataset, model, train_loss_log, device, optimizer, criterion, epoch, scheduler, sep_id, pad_id, args, rank) val_loss = evaluate(valid_dataset, model, device, criterion, sep_id, pad_id, args) val_loss_log.append(val_loss) if (rank is None) or (rank == 0): print('-' * 89) print('| end of epoch {:3d} | time: {:5.2f}s ' '| valid loss {:8.5f} | '.format( epoch, (time.time() - epoch_start_time), val_loss)) print('-' * 89) # Save the model if the validation loss is the best we've seen so far. if not best_val_loss or val_loss < best_val_loss: if rank is None: with open(args.save, 'wb') as f: torch.save(model, f) elif rank == 0: with open(os.environ['SLURM_JOB_ID'] + '_' + args.save, 'wb') as f: torch.save(model.state_dict(), f) best_val_loss = val_loss else: scheduler.step() ################################################################### # Load the best saved model and run on test data ################################################################### if args.parallel == 'DDP': # [TODO] put dist.barrier() back # dist.barrier() # configure map_location properly rank0_devices = [x - rank * len(device) for x in device] device_pairs = zip(rank0_devices, device) map_location = {'cuda:%d' % x: 'cuda:%d' % y for x, y in device_pairs} model.load_state_dict( torch.load(os.environ['SLURM_JOB_ID'] + '_' + args.save, map_location=map_location)) test_loss = evaluate(test_dataset, model, device, criterion, sep_id, pad_id, args) if rank == 0: print('=' * 89) print('| End of training | test loss {:8.5f} | test ppl {:8.5f}'. format(test_loss, math.exp(test_loss))) print('=' * 89) print_loss_log(os.environ['SLURM_JOB_ID'] + '_ns_loss.txt', train_loss_log, val_loss_log, test_loss, args) ############################################################################### # Save the bert model layer ############################################################################### with open(os.environ['SLURM_JOB_ID'] + '_' + args.save, 'wb') as f: torch.save(model.module.bert_model, f) with open(os.environ['SLURM_JOB_ID'] + '_' + 'full_ns_model.pt', 'wb') as f: torch.save(model.module, f) else: with open(args.save, 'rb') as f: model = torch.load(f) test_loss = evaluate(test_dataset, model, device, criterion, sep_id, pad_id) print('=' * 89) print( '| End of training | test loss {:8.5f} | test ppl {:8.5f}'.format( test_loss, math.exp(test_loss))) print('=' * 89) print_loss_log('ns_loss.txt', train_loss_log, val_loss_log, test_loss) with open(args.save, 'wb') as f: torch.save(model.bert_model, f)
for epoch in range(1, args.epochs + 1): epoch_start_time = time.time() train() val_loss, val_exact, val_f1 = evaluate(dev_dataset, vocab) val_loss_log.append(val_loss) print('-' * 89) print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | ' 'exact {:8.3f}% | ' 'f1 {:8.3f}%'.format(epoch, (time.time() - epoch_start_time), val_loss, val_exact, val_f1)) print('-' * 89) if best_f1 is None or val_f1 > best_f1: with open(args.save, 'wb') as f: torch.save(model, f) best_f1 = val_f1 else: scheduler.step() with open(args.save, 'rb') as f: model = torch.load(f) test_loss, test_exact, test_f1 = evaluate(dev_dataset, vocab) print('=' * 89) print( '| End of training | test loss {:5.2f} | exact {:8.3f}% | f1 {:8.3f}%'. format(test_loss, test_exact, test_f1)) print('=' * 89) print_loss_log('qa_loss.txt', train_loss_log, val_loss_log, test_loss) with open(args.save, 'wb') as f: torch.save(model, f)
def run_main(args, rank=None): # Set the random seed manually for reproducibility. torch.manual_seed(args.seed) if args.parallel == 'DDP': n = torch.cuda.device_count() // args.world_size device = list(range(rank * n, (rank + 1) * n)) else: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ############################################################################### # Import dataset ############################################################################### import torchtext if args.dataset == 'WikiText103': from torchtext.experimental.datasets import WikiText103 as WLMDataset elif args.dataset == 'WikiText2': from torchtext.experimental.datasets import WikiText2 as WLMDataset elif args.dataset == 'WMTNewsCrawl': from data import WMTNewsCrawl as WLMDataset elif args.dataset == 'EnWik9': from torchtext.datasets import EnWik9 elif args.dataset == 'BookCorpus': from data import BookCorpus else: print("dataset for MLM task is not supported") try: vocab = torch.load(args.save_vocab) except: train_dataset, test_dataset, valid_dataset = WLMDataset() old_vocab = train_dataset.vocab vocab = torchtext.vocab.Vocab(counter=old_vocab.freqs, specials=['<unk>', '<pad>', '<MASK>']) with open(args.save_vocab, 'wb') as f: torch.save(vocab, f) if args.dataset == 'WikiText103' or args.dataset == 'WikiText2': train_dataset, test_dataset, valid_dataset = WLMDataset(vocab=vocab) elif args.dataset == 'WMTNewsCrawl': test_dataset, valid_dataset = torchtext.experimental.datasets.WikiText2( vocab=vocab, data_select=('test', 'valid')) train_dataset, = WLMDataset(vocab=vocab, data_select='train') elif args.dataset == 'EnWik9': enwik9 = EnWik9() idx1, idx2 = int(len(enwik9) * 0.8), int(len(enwik9) * 0.9) train_data = torch.tensor([vocab.stoi[_id] for _id in enwik9[0:idx1]]).long() val_data = torch.tensor([vocab.stoi[_id] for _id in enwik9[idx1:idx2]]).long() test_data = torch.tensor([vocab.stoi[_id] for _id in enwik9[idx2:]]).long() from torchtext.experimental.datasets import LanguageModelingDataset train_dataset = LanguageModelingDataset(train_data, vocab) valid_dataset = LanguageModelingDataset(val_data, vocab) test_dataset = LanguageModelingDataset(test_data, vocab) elif args.dataset == 'BookCorpus': train_dataset, test_dataset, valid_dataset = BookCorpus(vocab) train_data = batchify(train_dataset.data, args.batch_size, args) if rank is not None: # Chunk training data by rank for different gpus chunk_len = len(train_data) // args.world_size train_data = train_data[(rank * chunk_len):((rank + 1) * chunk_len)] val_data = batchify(valid_dataset.data, args.batch_size, args) test_data = batchify(test_dataset.data, args.batch_size, args) ############################################################################### # Build the model ############################################################################### ntokens = len(train_dataset.get_vocab()) model = MLMTask(ntokens, args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout) if args.checkpoint != 'None': model.bert_model = torch.load(args.checkpoint) if args.parallel == 'DDP': model = model.to(device[0]) # model = nn.DataParallel(model) # Wrap up by nn.DataParallel model = DDP(model, device_ids=device) else: model = model.to(device) criterion = nn.CrossEntropyLoss() ############################################################################### # Loop over epochs. ############################################################################### lr = args.lr optimizer = torch.optim.SGD(model.parameters(), lr=lr) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1) best_val_loss = None train_loss_log, val_loss_log = [], [] for epoch in range(1, args.epochs + 1): epoch_start_time = time.time() train(model, train_dataset.vocab, train_loss_log, train_data, optimizer, criterion, ntokens, epoch, scheduler, args, device, rank) # train() val_loss = evaluate(val_data, model, train_dataset.vocab, ntokens, criterion, args, device) if (rank is None) or (rank == 0): val_loss_log.append(val_loss) print('-' * 89) print( '| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | ' 'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time), val_loss, math.exp(val_loss))) print('-' * 89) # Save the model if the validation loss is the best we've seen so far. if not best_val_loss or val_loss < best_val_loss: if rank is None: with open(args.save, 'wb') as f: torch.save(model, f) elif rank == 0: with open(os.environ['SLURM_JOB_ID'] + '_' + args.save, 'wb') as f: torch.save(model.state_dict(), f) best_val_loss = val_loss else: scheduler.step() ############################################################################### # Load the best saved model. ############################################################################### if args.parallel == 'DDP': dist.barrier() # configure map_location properly rank0_devices = [x - rank * len(device) for x in device] device_pairs = zip(rank0_devices, device) map_location = {'cuda:%d' % x: 'cuda:%d' % y for x, y in device_pairs} model.load_state_dict( torch.load(os.environ['SLURM_JOB_ID'] + '_' + args.save, map_location=map_location)) ############################################################################### # Run on test data. ############################################################################### test_loss = evaluate(test_data, model, train_dataset.vocab, ntokens, criterion, args, device) if rank == 0: print('=' * 89) print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'. format(test_loss, math.exp(test_loss))) print('=' * 89) print_loss_log(os.environ['SLURM_JOB_ID'] + '_mlm_loss.txt', train_loss_log, val_loss_log, test_loss, args) ############################################################################### # Save the bert model layer ############################################################################### with open(os.environ['SLURM_JOB_ID'] + '_' + args.save, 'wb') as f: torch.save(model.module.bert_model, f) with open(os.environ['SLURM_JOB_ID'] + '_mlm_model.pt', 'wb') as f: torch.save(model.module, f) else: with open(args.save, 'rb') as f: model = torch.load(f) test_loss = evaluate(test_data, model, train_dataset.vocab, ntokens, criterion, args, device) print('=' * 89) print( '| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format( test_loss, math.exp(test_loss))) print('=' * 89) print_loss_log('mlm_loss.txt', train_loss_log, val_loss_log, test_loss, args) ############################################################################### # Save the bert model layer ############################################################################### with open(args.save, 'wb') as f: torch.save(model.module.bert_model, f) with open('mlm_model.pt', 'wb') as f: torch.save(model.module, f)