def question_1i_sanity_check(): """ Sanity check for nmt_model.py basic shape check """ print("-" * 80) print("Running Sanity Check for Question 1i: NMT") print("-" * 80) src_vocab_entry = VocabEntry() tgt_vocab_entry = VocabEntry() dummy_vocab = Vocab(src_vocab_entry, tgt_vocab_entry) word_embed_size = 5 hidden_size = 10 nmt = NMT(word_embed_size, hidden_size, dummy_vocab) source = [["Hello my friend"], ["How are you"]] target = [["Bonjour mon ami"], ["Comment vas tu"]] output = nmt.forward(source, target) print(output) #output_expected_size = [sentence_length, BATCH_SIZE, EMBED_SIZE] #assert(list(output.size()) == output_expected_size), "output shape is incorrect: it should be:\n {} but is:\n{}".format(output_expected_size, list(output.size())) print("Sanity Check Passed for Question 1i: NMT!") print("-" * 80)
def train(args: Dict): """ Train the NMT Model. @param args (Dict): args from cmd line """ train_data_src = read_corpus(args['--train-src'], source='src') train_data_tgt = read_corpus(args['--train-tgt'], source='tgt') dev_data_src = read_corpus(args['--dev-src'], source='src') dev_data_tgt = read_corpus(args['--dev-tgt'], source='tgt') train_data = list(zip(train_data_src, train_data_tgt)) dev_data = list(zip(dev_data_src, dev_data_tgt)) train_batch_size = int(args['--batch-size']) clip_grad = float(args['--clip-grad']) valid_niter = int(args['--valid-niter']) log_every = int(args['--log-every']) model_save_path = args['--save-to'] vocab = Vocab.load(args['--vocab']) model = NMT(embed_size=int(args['--embed-size']), hidden_size=int(args['--hidden-size']), dropout_rate=float(args['--dropout']), vocab=vocab) model.train() uniform_init = float(args['--uniform-init']) if np.abs(uniform_init) > 0.: print('uniformly initialize parameters [-%f, +%f]' % (uniform_init, uniform_init), file=sys.stderr) for p in model.parameters(): p.data.uniform_(-uniform_init, uniform_init) vocab_mask = torch.ones(len(vocab.tgt)) vocab_mask[vocab.tgt['<pad>']] = 0 device = torch.device("cuda:0" if args['--cuda'] else "cpu") print('use device: %s' % device, file=sys.stderr) model = model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=float(args['--lr'])) # Set counters num_trial = 0 train_iter = patience = cum_loss = report_loss = cum_tgt_words = report_tgt_words = 0 cum_examples = report_examples = epoch = valid_num = 0 hist_valid_scores = [] fwd_time = train_time = begin_time = time.time() # Begin training print('begin Maximum Likelihood training') while True: epoch += 1 # Loop over all data in selection batches for src_sents, tgt_sents in batch_iter(train_data, batch_size=train_batch_size, shuffle=True): # Sentences must be sorted in length (that is number of words) src_sents = sorted(src_sents, key=lambda e: len(e), reverse=True) tgt_sents = sorted(tgt_sents, key=lambda e: len(e), reverse=True) train_iter += 1 # Zero out gradients, pytorch accumulates them optimizer.zero_grad() # Get loss train_batch_losses = (-model.forward(src_sents, tgt_sents)) batch_loss = train_batch_losses.sum() loss = batch_loss / train_batch_size # Get gradients loss.backward() # clip gradient grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad) # step optimizer.step() # Report progress batch_losses_val = batch_loss.item() report_loss += batch_losses_val cum_loss += batch_losses_val # Get some report metric tgt_words_num_to_predict = sum( len(s[1:]) for s in tgt_sents) # omitting leading `<s>` report_tgt_words += tgt_words_num_to_predict cum_tgt_words += tgt_words_num_to_predict report_examples += train_batch_size cum_examples += train_batch_size if train_iter % log_every == 0: print('epoch %d, iter %d, avg. loss %.2f, avg. ppl %.2f ' \ 'cum. examples %d, speed %.2f words/sec, time elapsed %.2f sec' % (epoch, train_iter, report_loss / report_examples, math.exp(report_loss / report_tgt_words), cum_examples, report_tgt_words / (time.time() - train_time), time.time() - begin_time), file=sys.stderr) train_time = time.time() report_loss = report_tgt_words = report_examples = 0. # Test saving and loading the model # test_save_load_model(model=model,optimizer=optimizer) # perform validation if train_iter % valid_niter == 0: print( 'epoch %d, iter %d, cum. loss %.2f, cum. ppl %.2f, cum. examples %d' % (epoch, train_iter, cum_loss / cum_examples, np.exp(cum_loss / cum_tgt_words), cum_examples), file=sys.stderr) cum_loss = cum_examples = cum_tgt_words = 0. valid_num += 1 print('begin validation ...', file=sys.stderr) # compute dev. ppl and bleu #dev_ppl = evaluate_ppl(model, dev_data, batch_size=128) # dev batch size can be a bit larger dev_ppl = evaluate_ppl(model, dev_data, batch_size=train_batch_size * 2) # dev batch size can be a bit larger valid_metric = -dev_ppl print('validation: iter %d, dev. ppl %f' % (train_iter, dev_ppl), file=sys.stderr) is_better = len(hist_valid_scores ) == 0 or valid_metric > max(hist_valid_scores) hist_valid_scores.append(valid_metric) if is_better: patience = 0 print('save currently the best model to [%s]' % model_save_path, file=sys.stderr) model.save(model_save_path) # also save the optimizers' state torch.save(optimizer.state_dict(), model_save_path + '.optim') elif patience < int(args['--patience']): patience += 1 print('hit patience %d' % patience, file=sys.stderr) if patience == int(args['--patience']): num_trial += 1 print('hit #%d trial' % num_trial, file=sys.stderr) if num_trial == int(args['--max-num-trial']): print('early stop!', file=sys.stderr) exit(0) # decay lr, and restore from previously best checkpoint lr = optimizer.param_groups[0]['lr'] * float( args['--lr-decay']) print( 'load previously best model and decay learning rate to %f' % lr, file=sys.stderr) # load model #params = torch.load(model_save_path, map_location=lambda storage, loc: storage) # See https://github.com/pytorch/pytorch/issues/7415 and # https://discuss.pytorch.org/t/on-a-cpu-device-how-to-load-checkpoint-saved-on-gpu-device/349 and # https://github.com/pytorch/pytorch/issues/9139 params = torch.load(model_save_path, map_location='cpu') model.load_state_dict(params['state_dict']) model = model.to(device) print('restore parameters of the optimizers', file=sys.stderr) # optimizer.load_state_dict(torch.load(model_save_path + '.optim') optimizer.load_state_dict( torch.load(model_save_path + '.optim', map_location='cpu')) optimizer_to(optimizer, device) # set new lr for param_group in optimizer.param_groups: param_group['lr'] = lr # reset patience patience = 0 if epoch == int(args['--max-epoch']): print('reached maximum number of epochs!', file=sys.stderr) exit(0)