def store_attention_plots(attentions, targets, sources, output_prefix, idx): """ Saves attention plots. :param attentions: :param targets: :param sources: :param output_prefix: :param idx: :return: """ for i in idx: plot_file = "{}.{}.pdf".format(output_prefix, i) src = sources[i] trg = targets[i] attention_scores = attentions[i].T try: plot_heatmap(scores=attention_scores, column_labels=trg, row_labels=src, output_path=plot_file) # pylint: disable=bare-except except: print("Couldn't plot example {}: src len {}, trg len {}, " "attention scores shape {}".format(i, len(src), len(trg), attention_scores.shape)) continue
def store_attention_plots(attentions: np.array, targets: List[List[str]], sources: List[List[str]], output_prefix: str, indices: List[int], tb_writer: Optional[SummaryWriter] = None, steps: int = 0) -> None: """ Saves attention plots. :param attentions: attention scores :param targets: list of tokenized targets :param sources: list of tokenized sources :param output_prefix: prefix for attention plots :param indices: indices selected for plotting :param tb_writer: Tensorboard summary writer (optional) :param steps: current training steps, needed for tb_writer :param dpi: resolution for images """ for i in indices: sources = list(sources) if i >= len(sources): continue plot_file = "{}.{}.pdf".format(output_prefix, i) src = sources[i] trg = targets[i] attention_scores = attentions[i].T try: fig = plot_heatmap(scores=attention_scores, column_labels=trg, row_labels=src, output_path=plot_file, dpi=100) if tb_writer is not None: # lower resolution for tensorboard fig = plot_heatmap(scores=attention_scores, column_labels=trg, row_labels=src, output_path=None, dpi=50) tb_writer.add_figure("attention/{}.".format(i), fig, global_step=steps) # pylint: disable=bare-except except: print("Couldn't plot example {}: src len {}, trg len {}, " "attention scores shape {}".format(i, len(src), len(trg), attention_scores.shape)) continue
def store_attention_plots(attentions: dict, targets: List[List[str]], sources: List[List[str]], model_dir: str, indices: List[int], tb_writer: Optional[SummaryWriter] = None, steps: int = 0, data_set_name: str = "att") -> None: """ Saves attention plots. :param attentions: attention scores :param targets: list of tokenized targets :param sources: list of tokenized sources :param output_prefix: prefix for attention plots :param indices: indices selected for plotting :param tb_writer: Tensorboard summary writer (optional) :param steps: current training steps, needed for tb_writer :param dpi: resolution for images """ assert all(i < len(sources) for i in indices) for i in indices: for name, attn in attentions.items(): output_prefix = join(model_dir, "{}.{}.{}".format(data_set_name, name, steps)) col_labels = targets if name == "trg_trg" else sources row_labels = sources if name == "src_src" else targets plot_file = "{}.{}.pdf".format(output_prefix, i) attn_matrix = attn[i] cols = col_labels[i] rows = row_labels[i] fig = plot_heatmap(scores=attn_matrix, column_labels=cols, row_labels=rows, output_path=plot_file, dpi=100) if tb_writer is not None: # lower resolution for tensorboard fig = plot_heatmap(scores=attn_matrix, column_labels=cols, row_labels=rows, output_path=None, dpi=50) # the tensorboard thing will need fixing tb_writer.add_figure("attention_{}/{}.".format(name, i), fig, global_step=steps)
def store_attention_plots(attentions: np.array, targets: List[List[str]], sources: List[List[str]], output_prefix: str, indices: List[int], tb_writer: Optional[SummaryWriter] = None, steps: int = 0, kb_info: Tuple[List[int]] = None, on_the_fly_info: Tuple = None) -> str: """ Saves attention plots. :param attentions: attention scores :param targets: list of tokenized targets :param sources: list of tokenized sources :param output_prefix: prefix for attention plots :param indices: indices selected for plotting :param tb_writer: Tensorboard summary writer (optional) :param steps: current training steps, needed for tb_writer :param dpi: resolution for images :param kbinfo: tuple of the valid set's kb_lkp, kb_lens, kb_truvals :param on_the_fly_info: tuple containing valid_data.src field, valid_kb, canonization function, model.trg_vocab """ success, failure = 0, 0 for i in indices: i -= 1 if i < 0: i = len(indices) + i if i >= len(sources): continue plot_file = "{}.{}.pdf".format(output_prefix, i) attention_scores = attentions[i].T # => KB x UNROLL print( f"PLOTTING: shape of {i}th attention matrix from print_valid_sents: {attention_scores.shape}" ) trg = targets[i] if kb_info is None: src = sources[i] else: kbkey = sources kb_lkp, kb_lens, kbtrv = kb_info kbtrv_fields = kbtrv.fields # needed for on the fly creation below kbtrv = list(kbtrv) print(f"KB PLOTTING: kb_lens: {kb_lens}") # index calculation (find batch in valid/test files using kb lookup indices and length info) kb_num = kb_lkp[i] lower = sum(kb_lens[:kb_num]) upper = lower + kb_lens[kb_num] + 1 calcKbLen = upper - lower if calcKbLen == 1 and attention_scores.shape[0] > 1: # FIXME make this an option in the cfg # this is a scheduling KB created on the fly in data.batch_with_kb # TODO which fields are needed to recreate it on the fly here # valid_kb: has fields kbsrc, kbtrg; valid_kbtrv valid_src, valid_kb, canon_func, trg_vocab = on_the_fly_info v_src = list(valid_src) on_the_fly_kb, on_the_fly_kbtrv = create_KB_on_the_fly( # FIXME perhaps matchup issues are due to generator to list issues? v_src[i], trg_vocab, valid_kb.fields, kbtrv_fields, canon_func) keys = [entry.kbsrc for entry in on_the_fly_kb] vals = on_the_fly_kbtrv calcKbLen = len(keys) # update with length of newly created KB print(f"KB PLOTTING: on the fly recreation:") print(keys, [v.kbtrv for v in vals]) print(f"calcKbLen={calcKbLen}") else: keys = kbkey[lower:upper] vals = kbtrv[lower:upper] # in the normal case (non_empty KB), # the kb lengths (i) summed and (ii) looked up # should match up assert calcKbLen == kb_lens[kb_num] + 1, (calcKbLen, kb_lens[kb_num] + 1) if len(kb_lens) > 3: assertion_str = f"plotting idx={i} with kb_num={kb_num} and kb_len={kb_lens[kb_num]+1},\n\ kb_before: {kb_lens[kb_num-1]+1}, kb_after: {kb_lens[kb_num+1]+1};\n\ att_scores.shape={attention_scores.shape};\n\ calcKbLen={calcKbLen};\n\ kb_lens[kb_num]+1={kb_lens[kb_num]+1};" else: # TODO assertion_str = "" """ # make sure attention plots have the right shape if not calcKbLen == attention_scores.shape[0]: print(f"Couldnt plot example {i} because knowledgebase was created on the fly") print(f"actual shape mismatch: retrieved: {calcKbLen} vs att matrix: {attention_scores.shape[0]}") print(assertion_str) # FIXME FIXME FIXME FIXME im doing something wrong with the vocab lookup in the code above failure += 1 continue """ print(f"KB PLOTTING: calcKbLen: {calcKbLen}") print( f"KB PLOTTING: calcKbLen should be != 0 often!!: {assertion_str}" ) # index application DUMMY = "@DUMMY=@DUMMY" src = [DUMMY] + [ "+".join(key) + "=" + val.kbtrv[0] for key, val in zip(keys, vals) ] try: fig = plot_heatmap(scores=attention_scores, column_labels=trg, row_labels=src, output_path=plot_file, dpi=100) if tb_writer is not None: # lower resolution for tensorboard fig = plot_heatmap(scores=attention_scores, column_labels=trg, row_labels=src, output_path=None, dpi=50) tb_writer.add_figure("attention/{}.".format(i), fig, global_step=steps) print("plotted example {}: src len {}, trg len {}, " "attention scores shape {}".format(i, len(src), len(trg), attention_scores.shape)) # pylint: disable=bare-except success += 1 except: print("Couldn't plot example {}: src len {}, trg len {}, " "attention scores shape {}".format(i, len(src), len(trg), attention_scores.shape)) failure += 1 continue assert success + failure == len( indices ), f"plotting success:{success}, failure:{failure}, indices:{len(indices)}" return f"{success}/{len(indices)}"