def batch_backward_end(self, t, batch, epoch_losses, batch_losses,
                           batch_outputs):

        # skip
        if t.step % self.interval != 0:
            return

        enc, dec = batch_outputs['model_outputs']
        src_vocab = t.valid_loader.dataset.src.vocab
        trg_vocab = t.valid_loader.dataset.trg.vocab
        row = lambda x, z, y: f"<p>INP: {self.norm(x)} <br/>" \
                              f"HYP: {self.norm(z)} <br/>" \
                              f"TRG: {self.norm(y)}</p>"

        devec = lambda t, v: devectorize(t.tolist(), v.id2tok, v.tok2id[v.EOS],
                                         True)
        src = devec(batch[1], src_vocab)
        hyp = devec(dec["logits"].max(dim=2)[1], trg_vocab)
        trg = devec(batch[4], trg_vocab)

        src = [src_vocab.detokenize(x) for x in src]
        hyp = [trg_vocab.detokenize(x) for x in hyp]
        trg = [trg_vocab.detokenize(x) for x in trg]

        samples = [row(x, z, y) for x, z, y in zip(src, hyp, trg)]

        html_samples = f'<style> body, p {{ font-family: "Dejavu Sans Mono", ' \
                       f'serif; font-size: 12px; }}</style>{"".join(samples)}'

        t.exp.text("samples", html_samples, "Samples")

        with open(os.path.join(t.exp.output_dir, "samples.html"), "w") as f:
            f.write(html_samples)
    def _batch_backward_end(self, t, batch, epoch_losses, batch_losses,
                            batch_outputs):

        # skip
        if t.step % self.interval != 0:
            return

        enc, dec = batch_outputs['model_outputs']
        src_vocab = t.valid_loader.dataset.src.vocab
        trg_vocab = t.valid_loader.dataset.trg.vocab

        devec = lambda t, v: devectorize(t.tolist(), v.id2tok, v.tok2id[v.EOS],
                                         True)
        src = devec(batch[1], src_vocab)
        hyp = devec(dec["logits"].max(dim=2)[1], trg_vocab)
        trg = devec(batch[4], trg_vocab)

        if t.config["model"]["decoder"]["learn_tau"]:
            tau = numpy.around(numpy.array(dec["tau"].tolist()), 2)
            t0 = t.config["model"]["decoder"]["tau_0"]
            tau_bound = round(1 / (math.log(1 + math.exp(-100)) + t0), 2)
            tau_opacity = (tau / tau_bound).tolist()
            tau = tau.tolist()
        else:
            tau = None

        samples = []
        for i in range(len(src)):
            sample = []

            _src = {"tag": "SRC", "tokens": src[i], "color": "0, 0, 0"}
            sample.append(_src)

            _hyp = {"tag": "HYP", "tokens": hyp[i], "color": "0, 0, 0"}
            sample.append(_hyp)

            if t.config["model"]["decoder"]["learn_tau"]:
                _tau = {
                    "tag": "TAU",
                    "tokens": list(map(str, tau[i])),
                    "opacity": tau_opacity[i],
                    "color": "255, 100, 100",
                    "normalize": False
                }
                sample.insert(2, _tau)

            _trg = {"tag": "TRG", "tokens": trg[i], "color": "0, 0, 0"}
            sample.append(_trg)

            samples.append(sample)

        html_samples = samples2html(samples)
        t.exp.text("samples", html_samples, "Samples", pre=False)
Esempio n. 3
0
def seq2seq_output_ids_to_file(output_ids, trg_vocab, out_file):
    """
    Devectorize and Detokenize the translated token ids and write the
    translations to a text file
    """
    output_tokens = devectorize(output_ids.tolist(),
                                trg_vocab.id2tok,
                                trg_vocab.EOS_id,
                                strip_eos=True,
                                pp=True)

    with open(out_file, "w") as fo:
        for sent in output_tokens:
            text = trg_vocab.detokenize(sent)
            fo.write(text + "\n")
    def batch_backward_end(self, t, batch, epoch_losses, batch_losses,
                           batch_outputs):
        # skip
        if t.step % self.interval != 0:
            return

        enc, dec = batch_outputs['model_outputs']
        src_vocab = t.valid_loader.dataset.src.vocab
        trg_vocab = t.valid_loader.dataset.trg.vocab
        devec = lambda t, v: devectorize(t.tolist(), v.id2tok, v.tok2id[v.EOS])
        src = devec(batch[1], src_vocab)
        hyp = devec(dec["logits"].max(dim=2)[1], trg_vocab)

        file = os.path.join(t.exp.output_dir, "attentions.pdf")
        seq2seq_attentions(src[:5], hyp[:5], dec["attention"][:5].tolist(),
                           file)
 def tok2text(self, inputs, preds, vocab):
     _t2i = lambda x: devectorize(
         x, vocab.id2tok, vocab.tok2id[vocab.EOS], strip_eos=True)
     bpe = vocab.subword is not None
     if inputs is None:
         row = lambda x: f"<p>{self.norm(x)}</p>"
         samples = [
             row(html.unescape(detokenize(y, bpe))) for y in _t2i(preds)
         ]
     else:
         row = lambda x, y: f"<p>INP: {self.norm(x)} <br/>" \
                            f"REC: {self.norm(y)}</p>"
         src = [
             html.unescape(detokenize(x[1:], bpe))
             for x in _t2i(inputs.tolist())
         ]
         y_toks = [html.unescape(detokenize(x, bpe)) for x in _t2i(preds)]
         samples = [row(x, y) for x, y in zip(src, y_toks)]
     return ''.join(samples)
def _batch_forward(batch):
    batch = list(map(lambda x: x.to(device), batch))
    x_sos, x_eos, x_len, y_sos, y_eos, y_len = batch

    # prior
    _, dec_prior = model_prior(x_eos, y_sos, x_len, y_len)
    dec_prior["lm"] = lm(y_sos, y_len)["logits"]

    # shallow
    _, dec_shallow = model_base(
        x_eos, y_sos, x_len, y_len, **{
            "fusion": "shallow",
            "fusion_a": 0.1,
            "lm": lm
        })
    # postnorm
    _, dec_postnorm = model_postnorm(x_eos, y_sos, x_len, y_len, **{
        "fusion": "postnorm",
        "lm": lm
    })

    # --------------------------------------------------------------------
    _inputs = devectorize(x_eos.tolist(),
                          src_vocab.id2tok,
                          src_vocab.EOS_id,
                          strip_eos=True)
    _targets = devectorize(y_eos.tolist(),
                           trg_vocab.id2tok,
                           trg_vocab.EOS_id,
                           strip_eos=True)

    # --------------------------------------------------------------------
    # prior
    # --------------------------------------------------------------------
    _prior_ids = dec_prior["logits"].max(2)[1].tolist()
    _prior_lm_ids = dec_prior["lm"].max(2)[1].tolist()

    _prior_tokens = devectorize(_prior_ids, trg_vocab.id2tok)
    _prior_lm_tokens = devectorize(_prior_lm_ids, trg_vocab.id2tok)

    # --------------------------------------------------------------------
    # shallow
    # --------------------------------------------------------------------
    _shallow_ids = dec_shallow["logits"].max(2)[1].tolist()
    _shallow_tm_ids = dec_shallow["dec"].max(2)[1].tolist()
    _shallow_lm_ids = dec_shallow["lm"].max(2)[1].tolist()

    _shallow_fails = [
        _is_failed(y_eos[i], _shallow_ids[i], _shallow_lm_ids[i],
                   _shallow_tm_ids[i]) for i in range(x_eos.size(0))
    ]

    _shallow_tokens = devectorize(_shallow_ids, trg_vocab.id2tok)
    _shallow_tm_tokens = devectorize(_shallow_tm_ids, trg_vocab.id2tok)
    _shallow_lm_tokens = devectorize(_shallow_lm_ids, trg_vocab.id2tok)

    # --------------------------------------------------------------------
    # postnorm
    # --------------------------------------------------------------------
    _postnorm_ids = dec_postnorm["logits"].max(2)[1].tolist()
    _postnorm_tm_ids = dec_postnorm["dec"].max(2)[1].tolist()
    _postnorm_lm_ids = dec_postnorm["lm"].max(2)[1].tolist()

    _postnorm_fails = [
        _is_failed(y_eos[i], _postnorm_ids[i], _postnorm_lm_ids[i],
                   _postnorm_tm_ids[i]) for i in range(x_eos.size(0))
    ]

    _postnorm_tokens = devectorize(_postnorm_ids, trg_vocab.id2tok)
    _postnorm_tm_tokens = devectorize(_postnorm_tm_ids, trg_vocab.id2tok)
    _postnorm_lm_tokens = devectorize(_postnorm_lm_ids, trg_vocab.id2tok)
    # --------------------------------------------------------------------

    for i in range(x_eos.size(0)):

        if y_len[i].item() > 20:
            continue

        row = {
            "source":
            _inputs[i][:x_len[i]],
            "target":
            _targets[i][:y_len[i]],
            "prior_toks":
            _prior_tokens[i][:y_len[i]],
            "prior_toks_lm":
            _prior_lm_tokens[i][:y_len[i]],
            "prior_dist":
            _logits2dist(dec_prior["logits"][i], trg_vocab)[:y_len[i]],
            "prior_dist_lm":
            _logits2dist(dec_prior["lm"][i], trg_vocab)[:y_len[i]],
            "postnorm_toks":
            _postnorm_tokens[i][:y_len[i]],
            "postnorm_toks_lm":
            _postnorm_lm_tokens[i][:y_len[i]],
            "postnorm_toks_tm":
            _postnorm_tm_tokens[i][:y_len[i]],
            "postnorm_dist":
            _logits2dist(dec_postnorm["logits"][i], trg_vocab)[:y_len[i]],
            "postnorm_dist_tm":
            _logits2dist(dec_postnorm["dec"][i], trg_vocab)[:y_len[i]],
            "postnorm_dist_lm":
            _logits2dist(dec_postnorm["lm"][i], trg_vocab)[:y_len[i]],
            "postnorm_fails":
            _postnorm_fails[i],
            "shallow_toks":
            _shallow_tokens[i][:y_len[i]],
            "shallow_toks_lm":
            _shallow_lm_tokens[i][:y_len[i]],
            "shallow_toks_tm":
            _shallow_tm_tokens[i][:y_len[i]],
            "shallow_dist":
            _logits2dist(dec_shallow["logits"][i], trg_vocab)[:y_len[i]],
            "shallow_dist_tm":
            _logits2dist(dec_shallow["dec"][i], trg_vocab)[:y_len[i]],
            "shallow_dist_lm":
            _logits2dist(dec_shallow["lm"][i], trg_vocab)[:y_len[i]],
            "shallow_fails":
            _shallow_fails[i],
        }
        if any(_postnorm_fails[i]) or any(_shallow_fails[i]):
            yield row

    del batch