def batch_backward_end(self, t, batch, epoch_losses, batch_losses, batch_outputs): # skip if t.step % self.interval != 0: return enc, dec = batch_outputs['model_outputs'] src_vocab = t.valid_loader.dataset.src.vocab trg_vocab = t.valid_loader.dataset.trg.vocab row = lambda x, z, y: f"<p>INP: {self.norm(x)} <br/>" \ f"HYP: {self.norm(z)} <br/>" \ f"TRG: {self.norm(y)}</p>" devec = lambda t, v: devectorize(t.tolist(), v.id2tok, v.tok2id[v.EOS], True) src = devec(batch[1], src_vocab) hyp = devec(dec["logits"].max(dim=2)[1], trg_vocab) trg = devec(batch[4], trg_vocab) src = [src_vocab.detokenize(x) for x in src] hyp = [trg_vocab.detokenize(x) for x in hyp] trg = [trg_vocab.detokenize(x) for x in trg] samples = [row(x, z, y) for x, z, y in zip(src, hyp, trg)] html_samples = f'<style> body, p {{ font-family: "Dejavu Sans Mono", ' \ f'serif; font-size: 12px; }}</style>{"".join(samples)}' t.exp.text("samples", html_samples, "Samples") with open(os.path.join(t.exp.output_dir, "samples.html"), "w") as f: f.write(html_samples)
def _batch_backward_end(self, t, batch, epoch_losses, batch_losses, batch_outputs): # skip if t.step % self.interval != 0: return enc, dec = batch_outputs['model_outputs'] src_vocab = t.valid_loader.dataset.src.vocab trg_vocab = t.valid_loader.dataset.trg.vocab devec = lambda t, v: devectorize(t.tolist(), v.id2tok, v.tok2id[v.EOS], True) src = devec(batch[1], src_vocab) hyp = devec(dec["logits"].max(dim=2)[1], trg_vocab) trg = devec(batch[4], trg_vocab) if t.config["model"]["decoder"]["learn_tau"]: tau = numpy.around(numpy.array(dec["tau"].tolist()), 2) t0 = t.config["model"]["decoder"]["tau_0"] tau_bound = round(1 / (math.log(1 + math.exp(-100)) + t0), 2) tau_opacity = (tau / tau_bound).tolist() tau = tau.tolist() else: tau = None samples = [] for i in range(len(src)): sample = [] _src = {"tag": "SRC", "tokens": src[i], "color": "0, 0, 0"} sample.append(_src) _hyp = {"tag": "HYP", "tokens": hyp[i], "color": "0, 0, 0"} sample.append(_hyp) if t.config["model"]["decoder"]["learn_tau"]: _tau = { "tag": "TAU", "tokens": list(map(str, tau[i])), "opacity": tau_opacity[i], "color": "255, 100, 100", "normalize": False } sample.insert(2, _tau) _trg = {"tag": "TRG", "tokens": trg[i], "color": "0, 0, 0"} sample.append(_trg) samples.append(sample) html_samples = samples2html(samples) t.exp.text("samples", html_samples, "Samples", pre=False)
def seq2seq_output_ids_to_file(output_ids, trg_vocab, out_file): """ Devectorize and Detokenize the translated token ids and write the translations to a text file """ output_tokens = devectorize(output_ids.tolist(), trg_vocab.id2tok, trg_vocab.EOS_id, strip_eos=True, pp=True) with open(out_file, "w") as fo: for sent in output_tokens: text = trg_vocab.detokenize(sent) fo.write(text + "\n")
def batch_backward_end(self, t, batch, epoch_losses, batch_losses, batch_outputs): # skip if t.step % self.interval != 0: return enc, dec = batch_outputs['model_outputs'] src_vocab = t.valid_loader.dataset.src.vocab trg_vocab = t.valid_loader.dataset.trg.vocab devec = lambda t, v: devectorize(t.tolist(), v.id2tok, v.tok2id[v.EOS]) src = devec(batch[1], src_vocab) hyp = devec(dec["logits"].max(dim=2)[1], trg_vocab) file = os.path.join(t.exp.output_dir, "attentions.pdf") seq2seq_attentions(src[:5], hyp[:5], dec["attention"][:5].tolist(), file)
def tok2text(self, inputs, preds, vocab): _t2i = lambda x: devectorize( x, vocab.id2tok, vocab.tok2id[vocab.EOS], strip_eos=True) bpe = vocab.subword is not None if inputs is None: row = lambda x: f"<p>{self.norm(x)}</p>" samples = [ row(html.unescape(detokenize(y, bpe))) for y in _t2i(preds) ] else: row = lambda x, y: f"<p>INP: {self.norm(x)} <br/>" \ f"REC: {self.norm(y)}</p>" src = [ html.unescape(detokenize(x[1:], bpe)) for x in _t2i(inputs.tolist()) ] y_toks = [html.unescape(detokenize(x, bpe)) for x in _t2i(preds)] samples = [row(x, y) for x, y in zip(src, y_toks)] return ''.join(samples)
def _batch_forward(batch): batch = list(map(lambda x: x.to(device), batch)) x_sos, x_eos, x_len, y_sos, y_eos, y_len = batch # prior _, dec_prior = model_prior(x_eos, y_sos, x_len, y_len) dec_prior["lm"] = lm(y_sos, y_len)["logits"] # shallow _, dec_shallow = model_base( x_eos, y_sos, x_len, y_len, **{ "fusion": "shallow", "fusion_a": 0.1, "lm": lm }) # postnorm _, dec_postnorm = model_postnorm(x_eos, y_sos, x_len, y_len, **{ "fusion": "postnorm", "lm": lm }) # -------------------------------------------------------------------- _inputs = devectorize(x_eos.tolist(), src_vocab.id2tok, src_vocab.EOS_id, strip_eos=True) _targets = devectorize(y_eos.tolist(), trg_vocab.id2tok, trg_vocab.EOS_id, strip_eos=True) # -------------------------------------------------------------------- # prior # -------------------------------------------------------------------- _prior_ids = dec_prior["logits"].max(2)[1].tolist() _prior_lm_ids = dec_prior["lm"].max(2)[1].tolist() _prior_tokens = devectorize(_prior_ids, trg_vocab.id2tok) _prior_lm_tokens = devectorize(_prior_lm_ids, trg_vocab.id2tok) # -------------------------------------------------------------------- # shallow # -------------------------------------------------------------------- _shallow_ids = dec_shallow["logits"].max(2)[1].tolist() _shallow_tm_ids = dec_shallow["dec"].max(2)[1].tolist() _shallow_lm_ids = dec_shallow["lm"].max(2)[1].tolist() _shallow_fails = [ _is_failed(y_eos[i], _shallow_ids[i], _shallow_lm_ids[i], _shallow_tm_ids[i]) for i in range(x_eos.size(0)) ] _shallow_tokens = devectorize(_shallow_ids, trg_vocab.id2tok) _shallow_tm_tokens = devectorize(_shallow_tm_ids, trg_vocab.id2tok) _shallow_lm_tokens = devectorize(_shallow_lm_ids, trg_vocab.id2tok) # -------------------------------------------------------------------- # postnorm # -------------------------------------------------------------------- _postnorm_ids = dec_postnorm["logits"].max(2)[1].tolist() _postnorm_tm_ids = dec_postnorm["dec"].max(2)[1].tolist() _postnorm_lm_ids = dec_postnorm["lm"].max(2)[1].tolist() _postnorm_fails = [ _is_failed(y_eos[i], _postnorm_ids[i], _postnorm_lm_ids[i], _postnorm_tm_ids[i]) for i in range(x_eos.size(0)) ] _postnorm_tokens = devectorize(_postnorm_ids, trg_vocab.id2tok) _postnorm_tm_tokens = devectorize(_postnorm_tm_ids, trg_vocab.id2tok) _postnorm_lm_tokens = devectorize(_postnorm_lm_ids, trg_vocab.id2tok) # -------------------------------------------------------------------- for i in range(x_eos.size(0)): if y_len[i].item() > 20: continue row = { "source": _inputs[i][:x_len[i]], "target": _targets[i][:y_len[i]], "prior_toks": _prior_tokens[i][:y_len[i]], "prior_toks_lm": _prior_lm_tokens[i][:y_len[i]], "prior_dist": _logits2dist(dec_prior["logits"][i], trg_vocab)[:y_len[i]], "prior_dist_lm": _logits2dist(dec_prior["lm"][i], trg_vocab)[:y_len[i]], "postnorm_toks": _postnorm_tokens[i][:y_len[i]], "postnorm_toks_lm": _postnorm_lm_tokens[i][:y_len[i]], "postnorm_toks_tm": _postnorm_tm_tokens[i][:y_len[i]], "postnorm_dist": _logits2dist(dec_postnorm["logits"][i], trg_vocab)[:y_len[i]], "postnorm_dist_tm": _logits2dist(dec_postnorm["dec"][i], trg_vocab)[:y_len[i]], "postnorm_dist_lm": _logits2dist(dec_postnorm["lm"][i], trg_vocab)[:y_len[i]], "postnorm_fails": _postnorm_fails[i], "shallow_toks": _shallow_tokens[i][:y_len[i]], "shallow_toks_lm": _shallow_lm_tokens[i][:y_len[i]], "shallow_toks_tm": _shallow_tm_tokens[i][:y_len[i]], "shallow_dist": _logits2dist(dec_shallow["logits"][i], trg_vocab)[:y_len[i]], "shallow_dist_tm": _logits2dist(dec_shallow["dec"][i], trg_vocab)[:y_len[i]], "shallow_dist_lm": _logits2dist(dec_shallow["lm"][i], trg_vocab)[:y_len[i]], "shallow_fails": _shallow_fails[i], } if any(_postnorm_fails[i]) or any(_shallow_fails[i]): yield row del batch