def train(opts): torch.manual_seed(opts.seed) french_word2id = a2_dataloader.read_word2id_from_file(opts.french_vocab) english_word2id = a2_dataloader.read_word2id_from_file(opts.english_vocab) train_prefixes = opts.train_prefixes.read().strip().split('\n') train_dataloader = a2_dataloader.HansardDataLoader( opts.training_dir, french_word2id, english_word2id, opts.source_lang, train_prefixes, batch_size=opts.batch_size, shuffle=True, pin_memory=(opts.device.type == 'cuda'), num_workers=1, ) del train_prefixes dev_prefixes = opts.dev_prefixes.read().strip().split('\n') dev_dataloader = a2_dataloader.HansardDataLoader( opts.training_dir, french_word2id, english_word2id, opts.source_lang, dev_prefixes, batch_size=opts.batch_size, pin_memory=(opts.device.type == 'cuda'), num_workers=1, ) del dev_prefixes, french_word2id, english_word2id model = init(opts, train_dataloader) model.to(opts.device) optimizer = torch.optim.Adam(model.parameters()) best_bleu = 0. num_poor = 0 epoch = 1 if opts.patience is None: max_epochs = opts.epochs patience = float('inf') else: max_epochs = float('inf') patience = opts.patience while epoch <= max_epochs and num_poor < patience: model.train() loss = a2_training_and_testing.train_for_epoch( model, train_dataloader, optimizer, opts.device) model.eval() bleu = a2_training_and_testing.compute_average_bleu_over_dataset( model, dev_dataloader, dev_dataloader.dataset.target_sos, dev_dataloader.dataset.target_eos, opts.device, ) print(f'Epoch {epoch}: loss={loss}, BLEU={bleu}') if bleu < best_bleu: num_poor += 1 else: num_poor = 0 best_bleu = bleu epoch += 1 if epoch > max_epochs: print(f'Finished {max_epochs} epochs') else: print(f'BLEU did not improve after {patience} epochs. Done.') model.cpu() torch.save(model.state_dict(), opts.model_path)
def test(opts): french_word2id = a2_dataloader.read_word2id_from_file(opts.french_vocab) english_word2id = a2_dataloader.read_word2id_from_file(opts.english_vocab) dataloader = a2_dataloader.HansardDataLoader( opts.testing_dir, french_word2id, english_word2id, opts.source_lang, batch_size=opts.batch_size, pin_memory=(opts.device.type == 'cuda')) del french_word2id, english_word2id model = init(opts, dataloader) state_dict = torch.load(opts.model_path) model.load_state_dict(state_dict) del state_dict model.to(opts.device) model.eval() bleu = a2_training_and_testing.compute_average_bleu_over_dataset( model, dataloader, dataloader.dataset.target_sos, dataloader.dataset.target_eos, opts.device, ) print(f'The average BLEU score over the test set was {bleu}')