def load_data(src_lang, tgt_lang, base_folder, bpe=False): if bpe: train_prefix = os.path.join( base_folder, f"{src_lang}{tgt_lang}_parallel.bpe.train") dev_prefix = os.path.join(base_folder, f"{src_lang}{tgt_lang}_parallel.bpe.dev") print("loading", train_prefix, dev_prefix) vocab = Vocab.from_data_files( f"{train_prefix}.{src_lang}", f"{train_prefix}.{tgt_lang}", ) else: train_prefix = os.path.join(base_folder, f"{src_lang}{tgt_lang}_parallel.train") dev_prefix = os.path.join(base_folder, f"{src_lang}{tgt_lang}_parallel.dev") vocab = Vocab.from_data_files(f"{train_prefix}.{src_lang}", f"{train_prefix}.{tgt_lang}", min_freq=2) print("loading", train_prefix, dev_prefix) train = MTNoisyDataset(vocab, train_prefix, src_lang=src_lang, tgt_lang=tgt_lang) valid = MTNoisyDataset(vocab, dev_prefix, src_lang=src_lang, tgt_lang=tgt_lang) return vocab, train, valid
def load_data(src_lang, tgt_lang, cached_folder="assignment2/data", overwrite=False): """Load data (and cache to file)""" cached_file = os.path.join(cached_folder, f"{src_lang}-{tgt_lang}.pt") if not os.path.isfile(cached_file) or overwrite: base_folder = os.path.join("assignment2", "data", f"{src_lang}_{tgt_lang}") train_prefix = os.path.join( base_folder, f"{src_lang}{tgt_lang}_parallel.bpe.train") dev_prefix = os.path.join(base_folder, f"{src_lang}{tgt_lang}_parallel.bpe.dev") vocab = Vocab.from_data_files( f"{train_prefix}.{src_lang}", f"{train_prefix}.{tgt_lang}", ) train = MTDataset(vocab, train_prefix, src_lang=src_lang, tgt_lang=tgt_lang) valid = MTDataset(vocab, dev_prefix, src_lang=src_lang, tgt_lang=tgt_lang) th.save([vocab, train, valid], cached_file) # Load cached dataset return th.load(cached_file)
def load_data(cached_data="data/cached.pt", overwrite=False): if not os.path.isfile(cached_data) or overwrite: vocab = Vocab.from_data_files("data/train.bpe.fr", "data/train.bpe.en") train = MTDataset(vocab, "data/train.bpe", src_lang="fr", tgt_lang="en") valid = MTDataset(vocab, "data/valid.bpe", src_lang="fr", tgt_lang="en") th.save([vocab, train, valid], cached_data) # Load cached dataset return th.load(cached_data)
def train(args: Dict[str, str]): train_data = read_corpus(args['--train-src'], source='src') dev_data = read_corpus(args['--dev-src'], source='src') train_batch_size = int(args['--batch-size']) train_batch_size = 64 clip_grad = float(args['--clip-grad']) valid_niter = int(args['--valid-niter']) # valid_niter = 100 log_every = int(args['--log-every']) # log_every = 1 dropout_rate = float(args['--dropout']) model_save_path = args['--save-to'] optim_save_path = "work_dir/optim.bin" # vocab = pickle.load(open(args['--vocab'], 'rb')) # initialize vocabe vocab = Vocab.from_data_files(args['--vocab']) print("vocab size = %d" % len(vocab)) model = NMT(embed_size=int(args['--embed-size']), hidden_size=int(args['--hidden-size']), dropout_rate=dropout_rate, vocab=vocab).to(device) #model = torch.load(model_save_path) lr = float(args['--lr']) optimizer = torch.optim.Adam(model.parameters(), lr=lr) num_trial = 0 train_iter = patience = cum_loss = report_loss = cumulative_tgt_words = report_tgt_words = 0 cumulative_examples = report_examples = epoch = valid_num = 0 hist_valid_scores = [] train_time = begin_time = time.time() print('begin Maximum Likelihood training') while True: epoch += 1 for sents in batch_iter(train_data, batch_size=train_batch_size, shuffle=True): train_iter += 1 batch_size = len(sents) # (batch_size) optimizer.zero_grad() loss = model(sents) report_loss += loss.item() cum_loss += loss.item() loss.backward() if clip_grad > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad) optimizer.step() tgt_words_num_to_predict = sum( len(s[1:]) for s in sents) # omitting leading `<s>` report_tgt_words += tgt_words_num_to_predict cumulative_tgt_words += tgt_words_num_to_predict report_examples += batch_size cumulative_examples += batch_size if train_iter % log_every == 0: print('epoch %d, iter %d, avg. loss %.2f, avg. ppl %.2f ' \ 'cum. examples %d, speed %.2f words/sec, time elapsed %.2f sec' % (epoch, train_iter, report_loss / report_examples, math.exp(report_loss / report_tgt_words), cumulative_examples, report_tgt_words / (time.time() - train_time), time.time() - begin_time), file=sys.stderr) train_time = time.time() report_loss = report_tgt_words = report_examples = 0. # the following code performs validation on dev set, and controls the learning schedule # if the dev score is better than the last check point, then the current model is saved. # otherwise, we allow for that performance degeneration for up to `--patience` times; # if the dev score does not increase after `--patience` iterations, we reload the previously # saved best model (and the state of the optimizer), halve the learning rate and continue # training. This repeats for up to `--max-num-trial` times. if train_iter % valid_niter == 0: print( 'epoch %d, iter %d, cum. loss %.2f, cum. ppl %.2f cum. examples %d' % (epoch, train_iter, cum_loss / cumulative_examples, math.exp(cum_loss / cumulative_tgt_words), cumulative_examples), file=sys.stderr) cum_loss = cumulative_examples = cumulative_tgt_words = 0. valid_num += 1 print('begin validation ...', file=sys.stderr) # compute dev. ppl and bleu dev_ppl = model.evaluate_ppl( dev_data, batch_size=128) # dev batch size can be a bit larger valid_metric = -dev_ppl print('validation: iter %d, dev. ppl %f' % (train_iter, dev_ppl), file=sys.stderr) is_better = len(hist_valid_scores ) == 0 or valid_metric > max(hist_valid_scores) hist_valid_scores.append(valid_metric) if is_better: patience = 0 print('save currently the best model to [%s]' % model_save_path, file=sys.stderr) model.save(model_save_path) # You may also save the optimizer's state torch.save(optimizer.state_dict(), optim_save_path) elif patience < int(args['--patience']): patience += 1 print('hit patience %d' % patience, file=sys.stderr) if patience == int(args['--patience']): num_trial += 1 print('hit #%d trial' % num_trial, file=sys.stderr) if num_trial == int(args['--max-num-trial']): print('early stop!', file=sys.stderr) exit(0) # decay learning rate, and restore from previously best checkpoint lr = lr * float(args['--lr-decay']) print( 'load previously best model and decay learning rate to %f' % lr, file=sys.stderr) # load model model = NMT.load(args['MODEL_PATH']) print('restore parameters of the optimizers', file=sys.stderr) # You may also need to load the state of the optimizer saved before optimizer.load_state_dict(torch.load(optim_save_path)) # reset patience patience = 0 if epoch == int(args['--max-epoch']): print('reached maximum number of epochs!', file=sys.stderr) exit(0)