def main(): config = setup_config() train_data, dev_data, vocab_src, vocab_tgt = load_dataset_joey(config) model, train_fn, validate_fn = create_model(vocab_src, vocab_tgt, config) model.to(torch.device(config["device"])) init_model(model, vocab_src.stoi[config["pad"]], vocab_tgt.stoi[config["pad"]], config) trainer = Trainer(model, train_fn, validate_fn, vocab_src, vocab_tgt, train_data, dev_data, config) trainer.train_model()
def train(model_xy, model_yx, bi_train_fn, mono_train_fn, validate_fn, bucketing_dl_xy, dev_data, cycle_iterate_dl_x, cycle_iterate_dl_y, vocab_src, vocab_tgt, config): print("Training...") optimizers_xy, schedulers_xy = create_optimizers( model_xy.generative_parameters(), model_xy.inference_parameters(), config) optimizers_yx, schedulers_yx = create_optimizers( model_yx.generative_parameters(), model_yx.inference_parameters(), config) saved_epoch = 0 patience_counter = 0 max_bleu = 0.0 converged_counter = 0 num_batches = sum(1 for _ in iter(bucketing_dl_xy)) checkpoints_path = "{}/{}/checkpoints".format(config["out_dir"], config["session"]) if os.path.exists(checkpoints_path): checkpoints = [ cp for cp in sorted(os.listdir(checkpoints_path)) if cp == config["session"] ] if checkpoints: state = torch.load('{}/{}'.format(checkpoints_path, checkpoints[-1])) saved_epoch = state['epoch'] patience_counter = state['patience_counter'] max_bleu = state['max_bleu'] model_xy.load_state_dict(state['state_dict_xy']) model_yx.load_state_dict(state['state_dict_yx']) optimizers_xy["gen"].load_state_dict(state['optimizer_xy_gen']) optimizers_yx["gen"].load_state_dict(state['optimizer_yx_gen']) schedulers_xy["gen"].load_state_dict(state['scheduler_xy_gen']) schedulers_yx["gen"].load_state_dict(state['scheduler_yx_gen']) if config["model_type"] == "coaevnmt": optimizers_xy["inf"].load_state_dict(state['optimizer_xy_inf']) optimizers_yx["inf"].load_state_dict(state['optimizer_yx_inf']) schedulers_xy["inf"].load_state_dict(state['scheduler_xy_inf']) schedulers_yx["inf"].load_state_dict(state['scheduler_yx_inf']) else: init_model(model_xy, vocab_src[PAD_TOKEN], vocab_tgt[PAD_TOKEN], config) init_model(model_yx, vocab_tgt[PAD_TOKEN], vocab_src[PAD_TOKEN], config) else: init_model(model_xy, vocab_src[PAD_TOKEN], vocab_tgt[PAD_TOKEN], config) init_model(model_yx, vocab_tgt[PAD_TOKEN], vocab_src[PAD_TOKEN], config) curriculum = config["curriculum"].split() cycle_iterate_dl_xy = cycle(bucketing_dl_xy) cycle_curriculum = cycle(curriculum) device = torch.device( "cpu") if config["device"] == "cpu" else torch.device("cuda:0") for epoch in range(saved_epoch, config["num_epochs"]): # Reset optimizers after bilingual warmup if epoch == config["bilingual_warmup"] and config["reset_opt"]: optimizers_xy, schedulers_xy = create_optimizers( model_xy.generative_parameters(), model_xy.inference_parameters(), config) optimizers_yx, schedulers_yx = create_optimizers( model_yx.generative_parameters(), model_yx.inference_parameters(), config) step = 0 while step < num_batches: batch_type = next(cycle_curriculum) model_xy.train() model_yx.train() loss = None if batch_type == 'y' and epoch >= config["bilingual_warmup"]: loss = monolingual_step(model_xy, model_yx, cycle_iterate_dl_y, mono_train_fn, optimizers_xy, vocab_src, vocab_tgt, config, step, device) if not "xy" in curriculum: step += 1 elif batch_type == 'x' and epoch >= config["bilingual_warmup"]: loss = monolingual_step(model_yx, model_xy, cycle_iterate_dl_x, mono_train_fn, optimizers_yx, vocab_tgt, vocab_src, config, step, device) if not "xy" in curriculum: step += 1 elif batch_type == 'xy' or batch_type == 'yx': sentences_x, sentences_y = next(cycle_iterate_dl_xy) loss = bilingual_step(model_xy, model_yx, sentences_x, sentences_y, bi_train_fn, optimizers_xy, optimizers_yx, vocab_src, vocab_tgt, config, step, device) step += 1 # Print progress and loss if loss: print( "Epoch: {:03d}/{:03d}, Batch {:05d}/{:05d}, {}-Loss: {:.2f}" .format(epoch + 1, config["num_epochs"], step + 1, num_batches, batch_type, loss)) val_bleu_xy = evaluate(model_xy, validate_fn, dev_data, vocab_src, vocab_tgt, epoch, config, direction="xy") val_bleu_yx = evaluate(model_yx, validate_fn, dev_data, vocab_tgt, vocab_src, epoch, config, direction="yx") scheduler_step(schedulers_xy, val_bleu_xy) scheduler_step(schedulers_yx, val_bleu_yx) print("Blue scores: {}-{}: {}, {}-{}: {}".format( config["src"], config["tgt"], val_bleu_xy, config["tgt"], config["src"], val_bleu_yx)) if epoch >= config["bilingual_warmup"]: if float(val_bleu_xy * val_bleu_yx) > max_bleu: max_bleu = float(val_bleu_xy * val_bleu_yx) patience_counter = 0 # Save checkpoint if not os.path.exists(checkpoints_path): os.makedirs(checkpoints_path) state = { 'epoch': epoch + 1, 'patience_counter': patience_counter, 'max_bleu': max_bleu, 'state_dict_xy': model_xy.state_dict(), 'state_dict_yx': model_yx.state_dict(), 'optimizer_xy_gen': optimizers_xy["gen"].state_dict(), 'optimizer_yx_gen': optimizers_yx["gen"].state_dict(), 'scheduler_xy_gen': schedulers_xy["gen"].state_dict(), 'scheduler_yx_gen': schedulers_yx["gen"].state_dict(), } if config["model_type"] == "coaevnmt": state['optimizer_xy_inf'] = optimizers_xy[ "inf"].state_dict() state['optimizer_yx_inf'] = optimizers_yx[ "inf"].state_dict() state['scheduler_xy_inf'] = schedulers_xy[ "inf"].state_dict() state['scheduler_yx_inf'] = schedulers_yx[ "inf"].state_dict() torch.save(state, '{}/{}'.format(checkpoints_path, config["session"])) else: patience_counter += 1 if patience_counter >= config["patience"]: max_bleu = 0 patience_counter = 0 converged_counter += 1 optimizers_xy, schedulers_xy = create_optimizers( model_xy.generative_parameters(), model_xy.inference_parameters(), config) optimizers_yx, schedulers_yx = create_optimizers( model_yx.generative_parameters(), model_yx.inference_parameters(), config) print("Times converged: {}".format(converged_counter)) if converged_counter >= 2: break
def train(model, train_fn, validate_fn, bucketing_dl_xy, dev_data, vocab_src, vocab_tgt, config, cycle_iterate_dl_back=None): print("Training...") optimizer, scheduler = create_optimizer(model.parameters(), config) saved_epoch = 0 patience_counter = 0 max_bleu = 0.0 num_batches = sum(1 for _ in iter(bucketing_dl_xy)) checkpoints_path = "{}/{}/checkpoints".format(config["out_dir"], config["session"]) if os.path.exists(checkpoints_path): checkpoints = [ cp for cp in sorted(os.listdir(checkpoints_path)) if cp == config["session"] ] if checkpoints: state = torch.load('{}/{}'.format(checkpoints_path, checkpoints[-1])) saved_epoch = state['epoch'] patience_counter = state['patience_counter'] max_bleu = state['max_bleu'] model.load_state_dict(state['state_dict']) optimizer.load_state_dict(state['optimizer']) scheduler.load_state_dict(state['scheduler']) else: init_model(model, vocab_src[PAD_TOKEN], vocab_tgt[PAD_TOKEN], config) else: init_model(model, vocab_src[PAD_TOKEN], vocab_tgt[PAD_TOKEN], config) cycle_iterate_dl_xy = cycle(bucketing_dl_xy) device = torch.device( "cpu") if config["device"] == "cpu" else torch.device("cuda:0") for epoch in range(saved_epoch, config["num_epochs"]): step = 0 while step < num_batches: model.train() # Back-translation data if not cycle_iterate_dl_back == None: sentences_x, sentences_y = next(cycle_iterate_dl_back) loss = bilingual_step(model, sentences_x, sentences_y, train_fn, optimizer, vocab_src, vocab_tgt, config, step, device) print( "Epoch: {:03d}/{:03d}, Batch {:05d}/{:05d}, Back-Loss: {:.2f}" .format(epoch + 1, config["num_epochs"], step + 1, num_batches, loss)) # step += 1 # Bilingual data sentences_x, sentences_y = next(cycle_iterate_dl_xy) loss = bilingual_step(model, sentences_x, sentences_y, train_fn, optimizer, vocab_src, vocab_tgt, config, step, device) print("Epoch: {:03d}/{:03d}, Batch {:05d}/{:05d}, xy-Loss: {:.2f}". format(epoch + 1, config["num_epochs"], step + 1, num_batches, loss)) step += 1 val_bleu = evaluate(model, validate_fn, dev_data, vocab_src, vocab_tgt, epoch, config) scheduler.step(float(val_bleu)) print("Blue score: {}".format(val_bleu)) if float(val_bleu) > max_bleu: max_bleu = float(val_bleu) patience_counter = 0 # Save checkpoint if not os.path.exists(checkpoints_path): os.makedirs(checkpoints_path) state = { 'epoch': epoch + 1, 'patience_counter': patience_counter, 'max_bleu': max_bleu, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict() } torch.save(state, '{}/{}'.format(checkpoints_path, config["session"])) else: patience_counter += 1 if patience_counter >= config["patience"]: break