Ejemplo n.º 1
0
def run(config):
    # -------------------------------------------------------------------------
    # Load pretrained models
    # -------------------------------------------------------------------------
    vocab_src = None
    vocab_trg = None

    # Load pretrained LM, which will be used for LM-Fusion or as LM-prior
    if config["data"]["prior_path"] is not None:
        if "gpt2" in config["data"]["prior_path"]:
            _gpt_model = os.path.split(config["data"]["prior_path"])[1]
            tokenizer = GPT2Tokenizer.from_pretrained(_gpt_model)
            vocab_trg = Vocab()
            vocab_trg.from_gpt2(tokenizer)
            _checkp_prior = GPT2LMHeadModel.from_pretrained(_gpt_model)
            config["model"]["dec_padding_idx"] = None
        else:
            _checkp_prior = load_checkpoint(config["data"]["prior_path"])
            vocab_trg = _checkp_prior["vocab"]

            if _checkp_prior["config"]["data"]["subword_path"] is not None:
                sub_path = _checkp_prior["config"]["data"]["subword_path"]
                config["data"]["trg"]["subword_path"] = sub_path

    # -------------------------------------------------------------------------
    # Data Loading and Preprocessing
    # -------------------------------------------------------------------------
    train_loader, val_loader = nmt_dataloaders(config, vocab_src, vocab_trg)

    # -------------------------------------------------------------------------
    # Initialize Model and Priors
    # -------------------------------------------------------------------------
    model_type = config["model"].get("type", "rnn")
    src_ntokens = len(val_loader.dataset.src.vocab)
    trg_ntokens = len(val_loader.dataset.trg.vocab)

    # Initialize Model
    if model_type == "rnn":
        model = Seq2SeqRNN(src_ntokens, trg_ntokens, **config["model"])
    elif model_type == "transformer":
        model = Seq2SeqTransformer(src_ntokens, trg_ntokens, **config["model"])
    else:
        raise NotImplementedError

    model_init(model, **config.get("init", {}))

    # Initialize prior LM
    _has_lm_prior = "prior" in config["losses"]
    _has_lm_fusion = config["model"]["decoding"].get("fusion") is not None
    if _has_lm_prior or _has_lm_fusion:
        if "gpt2" in config["data"]["prior_path"]:
            prior = _checkp_prior
            prior.to(config["device"])
            freeze_module(prior)
            for name, module in prior.named_modules():
                if isinstance(module, nn.Dropout):
                    module.p = 0
        else:
            prior = prior_model_from_checkpoint(_checkp_prior)
            prior.to(config["device"])
            freeze_module(prior)
    else:
        prior = None

    model.tie_weights()

    # -------------------------------------------------------------------------
    # Training Pipeline
    # -------------------------------------------------------------------------
    callbacks = [
        LossCallback(config["logging"]["log_interval"]),
        GradientCallback(config["logging"]["log_interval"]),
        ModuleGradientCallback(["encoder"], config["logging"]["log_interval"]),
        SamplesCallback(config["logging"]["log_interval"]),
        EvalCallback(config["logging"]["eval_interval"],
                     keep_best=True,
                     early_stop=config["optim"]["early_stop"])
    ]
    if model_type == "rnn":
        callbacks.append(AttentionCallback(config["logging"]["eval_interval"]))

    eval_interval = config["logging"]["eval_interval"]
    full_eval_interval = config["logging"].get("full_eval_interval",
                                               15 * eval_interval)
    callbacks.append(FunctionCallback(eval_best, full_eval_interval))

    trainer = NmtPriorTrainer(model,
                              train_loader,
                              val_loader,
                              config,
                              config["device"],
                              prior=prior,
                              callbacks=callbacks,
                              src_dirs=config["src_dirs"],
                              resume_state_id=config["resume_state_id"])

    if trainer.exp.has_finished():
        return trainer

    # -------------------------------------------------------------------------
    # Training Loop
    # -------------------------------------------------------------------------
    for epoch in range(config["epochs"]):
        train_loss = trainer.train_epoch()
        val_loss = trainer.eval_epoch()
        print("\n" * 3)

        if trainer.early_stop:
            print("Stopping early ...")
            break

    trainer.exp.finalize()
    return trainer
Ejemplo n.º 2
0
def seq2seq_translate(checkpoint, src_file, out_file, beam_size,
                      length_penalty, lm, fusion, fusion_a, batch_tokens,
                      device):
    # --------------------------------------
    # load checkpoint
    # --------------------------------------
    if isinstance(checkpoint, str):
        cp = load_checkpoint(checkpoint)
    else:
        cp = checkpoint
    src_vocab, trg_vocab = cp["vocab"]

    # --------------------------------------
    # load model
    # --------------------------------------
    model_type = cp["config"]["model"].get("type", "rnn")
    src_ntokens = len(src_vocab)
    trg_ntokens = len(trg_vocab)

    if model_type == "rnn":
        model = Seq2SeqRNN(src_ntokens, trg_ntokens, **cp["config"]["model"])
    elif model_type == "transformer":
        model = Seq2SeqTransformer(src_ntokens, trg_ntokens,
                                   **cp["config"]["model"])
    else:
        raise NotImplementedError

    model.load_state_dict(cp["model"])
    model.to(device)
    model.eval()

    # --------------------------------------
    # load prior
    # --------------------------------------
    if lm is not None:
        lm_cp = load_checkpoint(lm)
    elif fusion:
        lm_cp = load_checkpoint(
            fix_paths(cp["config"]["data"]["prior_path"], "checkpoints"))
    else:
        lm_cp = None

    if lm_cp is not None:
        lm = prior_model_from_checkpoint(lm_cp)
        lm.to(device)
        lm.eval()
    else:
        lm = None

    test_set = SequenceDataset(src_file,
                               vocab=src_vocab,
                               **{
                                   **cp["config"]["data"],
                                   **{
                                       "subsample": 0
                                   },
                                   **cp["config"]["data"]["src"]
                               })
    print(test_set)

    if batch_tokens is None:
        batch_tokens = cp["config"]["batch_tokens"]

    sampler = BucketTokensSampler(test_set.lengths * 2, batch_tokens)
    data_loader = DataLoader(
        test_set,
        # num_workers=cp["config"].get("cores",
        #                              min(4, multiprocessing.cpu_count())),
        # pin_memory=cp["config"].get("pin_memory", True),
        num_workers=cp["config"].get("cores", 4),
        pin_memory=True,
        batch_sampler=sampler,
        collate_fn=LMCollate())

    # translate the data
    output_ids = seq2seq_translate_ids(model,
                                       data_loader,
                                       trg_vocab,
                                       beam_size=beam_size,
                                       length_penalty=length_penalty,
                                       lm=lm,
                                       fusion=fusion,
                                       fusion_a=fusion_a)

    output_ids = output_ids[data_loader.batch_sampler.reverse_ids]
    seq2seq_output_ids_to_file(output_ids, trg_vocab, out_file)
Ejemplo n.º 3
0
    return t


device = "cuda"
# --------------------------------------
# load model
# --------------------------------------
model_prior, _, _, _ = _load_model("final.trans.deen_prior_3M_kl_best.pt")
model_postnorm, _, _, _ = _load_model("final.trans.deen_postnorm_best.pt")
model_base, src_vocab, trg_vocab, cnf = _load_model(
    "final.trans.deen_base_best.pt")
# --------------------------------------
# load lm
# --------------------------------------
lm_cp = "../checkpoints/prior.lm_news_en_30M_trans_best.pt"
lm_cp = load_checkpoint(lm_cp)
lm = prior_model_from_checkpoint(lm_cp)
lm.to(device)
lm.eval()

# --------------------------------------
# dataset
# --------------------------------------
src_path = "de.txt"
trg_path = "en.txt"
val_src = SequenceDataset(src_path,
                          vocab=src_vocab,
                          **{
                              **cnf["data"],
                              **cnf["data"]["src"]
                          })
Ejemplo n.º 4
0
def eval_nmt_checkpoint(checkpoint,
                        device,
                        beams=None,
                        lm=None,
                        fusion_a=None,
                        results=None,
                        results_low=None):
    if beams is None:
        beams = [1, 5, 10]

    _base, _file = os.path.split(checkpoint)
    cp = load_checkpoint(checkpoint)

    def score(dataset, beam_size) -> (float, float):
        hyp_file = os.path.join(_base, f"hyps_{dataset}_beam-{beam_size}.txt")
        src_file = cp["config"]["data"]["src"][f"{dataset}_path"]
        ref_file = cp["config"]["data"]["trg"][f"{dataset}_path"]

        src_file = fix_paths(src_file, "datasets")
        ref_file = fix_paths(ref_file, "datasets")

        fusion = cp["config"]["model"]["decoding"].get("fusion")
        batch_tokens = max(10000 // beam_size, 1000)

        if fusion is None and lm is not None and fusion_a is not None:
            fusion = "shallow"

        seq2seq_translate(checkpoint=cp,
                          src_file=src_file,
                          out_file=hyp_file,
                          beam_size=beam_size,
                          length_penalty=1,
                          lm=lm,
                          fusion=fusion,
                          fusion_a=fusion_a,
                          batch_tokens=batch_tokens,
                          device=device)
        _mixed = compute_bleu_score(hyp_file, ref_file)
        _lower = compute_bleu_score(hyp_file, ref_file, True)
        return _mixed, _lower

    if results is None:
        results = {d: {k: None for k in beams} for d in ["val", "test"]}
    if results_low is None:
        results_low = {d: {k: None for k in beams} for d in ["val", "test"]}

    for d in ["val", "test"]:
        for k in beams:
            try:
                mixed, lower = score(d, k)
                results[d][k] = mixed
                results_low[d][k] = lower
            except Exception as e:
                print(e)
                results[d][k] = None
                results_low[d][k] = None

    text = pandas.DataFrame.from_dict(results).to_string()
    name = "BLEU"
    if fusion_a is not None:
        name += f"_shallow_{fusion_a}_{lm.split('.')[-2]}"
    with open(os.path.join(_base, f"{name}.txt"), "w") as f:
        f.write(text)
    with open(os.path.join(_base, f"{name}.json"), "w") as f:
        json.dump(results, f, indent=4)