Beispiel #1
0
def main():
    config = setup_config()

    train_data, dev_data, vocab_src, vocab_tgt = load_dataset_joey(config)

    model, train_fn, validate_fn = create_model(vocab_src, vocab_tgt, config)
    model.to(torch.device(config["device"]))

    init_model(model, vocab_src.stoi[config["pad"]],
               vocab_tgt.stoi[config["pad"]], config)

    trainer = Trainer(model, train_fn, validate_fn, vocab_src, vocab_tgt,
                      train_data, dev_data, config)
    trainer.train_model()
Beispiel #2
0
def train(model_xy, model_yx, bi_train_fn, mono_train_fn, validate_fn,
          bucketing_dl_xy, dev_data, cycle_iterate_dl_x, cycle_iterate_dl_y,
          vocab_src, vocab_tgt, config):

    print("Training...")

    optimizers_xy, schedulers_xy = create_optimizers(
        model_xy.generative_parameters(), model_xy.inference_parameters(),
        config)

    optimizers_yx, schedulers_yx = create_optimizers(
        model_yx.generative_parameters(), model_yx.inference_parameters(),
        config)

    saved_epoch = 0
    patience_counter = 0
    max_bleu = 0.0
    converged_counter = 0

    num_batches = sum(1 for _ in iter(bucketing_dl_xy))

    checkpoints_path = "{}/{}/checkpoints".format(config["out_dir"],
                                                  config["session"])
    if os.path.exists(checkpoints_path):
        checkpoints = [
            cp for cp in sorted(os.listdir(checkpoints_path))
            if cp == config["session"]
        ]
        if checkpoints:
            state = torch.load('{}/{}'.format(checkpoints_path,
                                              checkpoints[-1]))
            saved_epoch = state['epoch']
            patience_counter = state['patience_counter']
            max_bleu = state['max_bleu']
            model_xy.load_state_dict(state['state_dict_xy'])
            model_yx.load_state_dict(state['state_dict_yx'])
            optimizers_xy["gen"].load_state_dict(state['optimizer_xy_gen'])
            optimizers_yx["gen"].load_state_dict(state['optimizer_yx_gen'])
            schedulers_xy["gen"].load_state_dict(state['scheduler_xy_gen'])
            schedulers_yx["gen"].load_state_dict(state['scheduler_yx_gen'])
            if config["model_type"] == "coaevnmt":
                optimizers_xy["inf"].load_state_dict(state['optimizer_xy_inf'])
                optimizers_yx["inf"].load_state_dict(state['optimizer_yx_inf'])
                schedulers_xy["inf"].load_state_dict(state['scheduler_xy_inf'])
                schedulers_yx["inf"].load_state_dict(state['scheduler_yx_inf'])
        else:
            init_model(model_xy, vocab_src[PAD_TOKEN], vocab_tgt[PAD_TOKEN],
                       config)
            init_model(model_yx, vocab_tgt[PAD_TOKEN], vocab_src[PAD_TOKEN],
                       config)
    else:
        init_model(model_xy, vocab_src[PAD_TOKEN], vocab_tgt[PAD_TOKEN],
                   config)
        init_model(model_yx, vocab_tgt[PAD_TOKEN], vocab_src[PAD_TOKEN],
                   config)

    curriculum = config["curriculum"].split()
    cycle_iterate_dl_xy = cycle(bucketing_dl_xy)
    cycle_curriculum = cycle(curriculum)
    device = torch.device(
        "cpu") if config["device"] == "cpu" else torch.device("cuda:0")
    for epoch in range(saved_epoch, config["num_epochs"]):
        # Reset optimizers after bilingual warmup
        if epoch == config["bilingual_warmup"] and config["reset_opt"]:
            optimizers_xy, schedulers_xy = create_optimizers(
                model_xy.generative_parameters(),
                model_xy.inference_parameters(), config)

            optimizers_yx, schedulers_yx = create_optimizers(
                model_yx.generative_parameters(),
                model_yx.inference_parameters(), config)

        step = 0
        while step < num_batches:
            batch_type = next(cycle_curriculum)
            model_xy.train()
            model_yx.train()
            loss = None
            if batch_type == 'y' and epoch >= config["bilingual_warmup"]:
                loss = monolingual_step(model_xy, model_yx, cycle_iterate_dl_y,
                                        mono_train_fn, optimizers_xy,
                                        vocab_src, vocab_tgt, config, step,
                                        device)
                if not "xy" in curriculum:
                    step += 1
            elif batch_type == 'x' and epoch >= config["bilingual_warmup"]:
                loss = monolingual_step(model_yx, model_xy, cycle_iterate_dl_x,
                                        mono_train_fn, optimizers_yx,
                                        vocab_tgt, vocab_src, config, step,
                                        device)
                if not "xy" in curriculum:
                    step += 1
            elif batch_type == 'xy' or batch_type == 'yx':
                sentences_x, sentences_y = next(cycle_iterate_dl_xy)
                loss = bilingual_step(model_xy, model_yx, sentences_x,
                                      sentences_y, bi_train_fn, optimizers_xy,
                                      optimizers_yx, vocab_src, vocab_tgt,
                                      config, step, device)
                step += 1

            # Print progress and loss
            if loss:
                print(
                    "Epoch: {:03d}/{:03d}, Batch {:05d}/{:05d}, {}-Loss: {:.2f}"
                    .format(epoch + 1, config["num_epochs"], step + 1,
                            num_batches, batch_type, loss))

        val_bleu_xy = evaluate(model_xy,
                               validate_fn,
                               dev_data,
                               vocab_src,
                               vocab_tgt,
                               epoch,
                               config,
                               direction="xy")
        val_bleu_yx = evaluate(model_yx,
                               validate_fn,
                               dev_data,
                               vocab_tgt,
                               vocab_src,
                               epoch,
                               config,
                               direction="yx")

        scheduler_step(schedulers_xy, val_bleu_xy)
        scheduler_step(schedulers_yx, val_bleu_yx)

        print("Blue scores: {}-{}: {}, {}-{}: {}".format(
            config["src"], config["tgt"], val_bleu_xy, config["tgt"],
            config["src"], val_bleu_yx))

        if epoch >= config["bilingual_warmup"]:
            if float(val_bleu_xy * val_bleu_yx) > max_bleu:
                max_bleu = float(val_bleu_xy * val_bleu_yx)
                patience_counter = 0

                # Save checkpoint
                if not os.path.exists(checkpoints_path):
                    os.makedirs(checkpoints_path)
                state = {
                    'epoch': epoch + 1,
                    'patience_counter': patience_counter,
                    'max_bleu': max_bleu,
                    'state_dict_xy': model_xy.state_dict(),
                    'state_dict_yx': model_yx.state_dict(),
                    'optimizer_xy_gen': optimizers_xy["gen"].state_dict(),
                    'optimizer_yx_gen': optimizers_yx["gen"].state_dict(),
                    'scheduler_xy_gen': schedulers_xy["gen"].state_dict(),
                    'scheduler_yx_gen': schedulers_yx["gen"].state_dict(),
                }
                if config["model_type"] == "coaevnmt":
                    state['optimizer_xy_inf'] = optimizers_xy[
                        "inf"].state_dict()
                    state['optimizer_yx_inf'] = optimizers_yx[
                        "inf"].state_dict()
                    state['scheduler_xy_inf'] = schedulers_xy[
                        "inf"].state_dict()
                    state['scheduler_yx_inf'] = schedulers_yx[
                        "inf"].state_dict()
                torch.save(state, '{}/{}'.format(checkpoints_path,
                                                 config["session"]))
            else:
                patience_counter += 1
                if patience_counter >= config["patience"]:
                    max_bleu = 0
                    patience_counter = 0
                    converged_counter += 1

                    optimizers_xy, schedulers_xy = create_optimizers(
                        model_xy.generative_parameters(),
                        model_xy.inference_parameters(), config)

                    optimizers_yx, schedulers_yx = create_optimizers(
                        model_yx.generative_parameters(),
                        model_yx.inference_parameters(), config)
                    print("Times converged: {}".format(converged_counter))
            if converged_counter >= 2:
                break
Beispiel #3
0
def train(model,
          train_fn,
          validate_fn,
          bucketing_dl_xy,
          dev_data,
          vocab_src,
          vocab_tgt,
          config,
          cycle_iterate_dl_back=None):

    print("Training...")
    optimizer, scheduler = create_optimizer(model.parameters(), config)

    saved_epoch = 0
    patience_counter = 0
    max_bleu = 0.0

    num_batches = sum(1 for _ in iter(bucketing_dl_xy))

    checkpoints_path = "{}/{}/checkpoints".format(config["out_dir"],
                                                  config["session"])
    if os.path.exists(checkpoints_path):
        checkpoints = [
            cp for cp in sorted(os.listdir(checkpoints_path))
            if cp == config["session"]
        ]
        if checkpoints:
            state = torch.load('{}/{}'.format(checkpoints_path,
                                              checkpoints[-1]))
            saved_epoch = state['epoch']
            patience_counter = state['patience_counter']
            max_bleu = state['max_bleu']
            model.load_state_dict(state['state_dict'])
            optimizer.load_state_dict(state['optimizer'])
            scheduler.load_state_dict(state['scheduler'])
        else:
            init_model(model, vocab_src[PAD_TOKEN], vocab_tgt[PAD_TOKEN],
                       config)
    else:
        init_model(model, vocab_src[PAD_TOKEN], vocab_tgt[PAD_TOKEN], config)

    cycle_iterate_dl_xy = cycle(bucketing_dl_xy)
    device = torch.device(
        "cpu") if config["device"] == "cpu" else torch.device("cuda:0")
    for epoch in range(saved_epoch, config["num_epochs"]):

        step = 0
        while step < num_batches:
            model.train()

            # Back-translation data
            if not cycle_iterate_dl_back == None:
                sentences_x, sentences_y = next(cycle_iterate_dl_back)
                loss = bilingual_step(model, sentences_x, sentences_y,
                                      train_fn, optimizer, vocab_src,
                                      vocab_tgt, config, step, device)

                print(
                    "Epoch: {:03d}/{:03d}, Batch {:05d}/{:05d}, Back-Loss: {:.2f}"
                    .format(epoch + 1, config["num_epochs"], step + 1,
                            num_batches, loss))
                # step += 1

            # Bilingual data
            sentences_x, sentences_y = next(cycle_iterate_dl_xy)
            loss = bilingual_step(model, sentences_x, sentences_y, train_fn,
                                  optimizer, vocab_src, vocab_tgt, config,
                                  step, device)

            print("Epoch: {:03d}/{:03d}, Batch {:05d}/{:05d}, xy-Loss: {:.2f}".
                  format(epoch + 1, config["num_epochs"], step + 1,
                         num_batches, loss))
            step += 1

            val_bleu = evaluate(model, validate_fn, dev_data, vocab_src,
                                vocab_tgt, epoch, config)
        scheduler.step(float(val_bleu))

        print("Blue score: {}".format(val_bleu))
        if float(val_bleu) > max_bleu:
            max_bleu = float(val_bleu)
            patience_counter = 0

            # Save checkpoint
            if not os.path.exists(checkpoints_path):
                os.makedirs(checkpoints_path)
            state = {
                'epoch': epoch + 1,
                'patience_counter': patience_counter,
                'max_bleu': max_bleu,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict()
            }
            torch.save(state, '{}/{}'.format(checkpoints_path,
                                             config["session"]))
        else:
            patience_counter += 1
            if patience_counter >= config["patience"]:
                break