Exemplo n.º 1
0
def store_attention_plots(attentions, targets, sources, output_prefix, idx):
    """
    Saves attention plots.

    :param attentions:
    :param targets:
    :param sources:
    :param output_prefix:
    :param idx:
    :return:
    """
    for i in idx:
        plot_file = "{}.{}.pdf".format(output_prefix, i)
        src = sources[i]
        trg = targets[i]
        attention_scores = attentions[i].T
        try:
            plot_heatmap(scores=attention_scores,
                         column_labels=trg,
                         row_labels=src,
                         output_path=plot_file)
        # pylint: disable=bare-except
        except:
            print("Couldn't plot example {}: src len {}, trg len {}, "
                  "attention scores shape {}".format(i, len(src), len(trg),
                                                     attention_scores.shape))
            continue
Exemplo n.º 2
0
def store_attention_plots(attentions: np.array,
                          targets: List[List[str]],
                          sources: List[List[str]],
                          output_prefix: str,
                          indices: List[int],
                          tb_writer: Optional[SummaryWriter] = None,
                          steps: int = 0) -> None:
    """
    Saves attention plots.

    :param attentions: attention scores
    :param targets: list of tokenized targets
    :param sources: list of tokenized sources
    :param output_prefix: prefix for attention plots
    :param indices: indices selected for plotting
    :param tb_writer: Tensorboard summary writer (optional)
    :param steps: current training steps, needed for tb_writer
    :param dpi: resolution for images
    """
    for i in indices:
        sources = list(sources)
        if i >= len(sources):
            continue
        plot_file = "{}.{}.pdf".format(output_prefix, i)
        src = sources[i]
        trg = targets[i]
        attention_scores = attentions[i].T
        try:
            fig = plot_heatmap(scores=attention_scores,
                               column_labels=trg,
                               row_labels=src,
                               output_path=plot_file,
                               dpi=100)
            if tb_writer is not None:
                # lower resolution for tensorboard
                fig = plot_heatmap(scores=attention_scores,
                                   column_labels=trg,
                                   row_labels=src,
                                   output_path=None,
                                   dpi=50)
                tb_writer.add_figure("attention/{}.".format(i),
                                     fig,
                                     global_step=steps)
        # pylint: disable=bare-except
        except:
            print("Couldn't plot example {}: src len {}, trg len {}, "
                  "attention scores shape {}".format(i, len(src), len(trg),
                                                     attention_scores.shape))
            continue
Exemplo n.º 3
0
def store_attention_plots(attentions: dict,
                          targets: List[List[str]],
                          sources: List[List[str]],
                          model_dir: str,
                          indices: List[int],
                          tb_writer: Optional[SummaryWriter] = None,
                          steps: int = 0,
                          data_set_name: str = "att") -> None:
    """
    Saves attention plots.

    :param attentions: attention scores
    :param targets: list of tokenized targets
    :param sources: list of tokenized sources
    :param output_prefix: prefix for attention plots
    :param indices: indices selected for plotting
    :param tb_writer: Tensorboard summary writer (optional)
    :param steps: current training steps, needed for tb_writer
    :param dpi: resolution for images
    """
    assert all(i < len(sources) for i in indices)
    for i in indices:
        for name, attn in attentions.items():
            output_prefix = join(model_dir,
                                 "{}.{}.{}".format(data_set_name, name, steps))
            col_labels = targets if name == "trg_trg" else sources
            row_labels = sources if name == "src_src" else targets
            plot_file = "{}.{}.pdf".format(output_prefix, i)
            attn_matrix = attn[i]
            cols = col_labels[i]
            rows = row_labels[i]

            fig = plot_heatmap(scores=attn_matrix,
                               column_labels=cols,
                               row_labels=rows,
                               output_path=plot_file,
                               dpi=100)
            if tb_writer is not None:
                # lower resolution for tensorboard
                fig = plot_heatmap(scores=attn_matrix,
                                   column_labels=cols,
                                   row_labels=rows,
                                   output_path=None,
                                   dpi=50)
                # the tensorboard thing will need fixing
                tb_writer.add_figure("attention_{}/{}.".format(name, i),
                                     fig,
                                     global_step=steps)
Exemplo n.º 4
0
def store_attention_plots(attentions: np.array,
                          targets: List[List[str]],
                          sources: List[List[str]],
                          output_prefix: str,
                          indices: List[int],
                          tb_writer: Optional[SummaryWriter] = None,
                          steps: int = 0,
                          kb_info: Tuple[List[int]] = None,
                          on_the_fly_info: Tuple = None) -> str:
    """
    Saves attention plots.

    :param attentions: attention scores
    :param targets: list of tokenized targets
    :param sources: list of tokenized sources
    :param output_prefix: prefix for attention plots
    :param indices: indices selected for plotting
    :param tb_writer: Tensorboard summary writer (optional)
    :param steps: current training steps, needed for tb_writer
    :param dpi: resolution for images
    :param kbinfo: tuple of the valid set's kb_lkp, kb_lens, kb_truvals
    :param on_the_fly_info: tuple containing valid_data.src field, valid_kb, canonization function, model.trg_vocab
    """
    success, failure = 0, 0
    for i in indices:
        i -= 1
        if i < 0:
            i = len(indices) + i
        if i >= len(sources):
            continue
        plot_file = "{}.{}.pdf".format(output_prefix, i)

        attention_scores = attentions[i].T  # => KB x UNROLL
        print(
            f"PLOTTING: shape of {i}th attention matrix from print_valid_sents: {attention_scores.shape}"
        )
        trg = targets[i]
        if kb_info is None:
            src = sources[i]
        else:
            kbkey = sources
            kb_lkp, kb_lens, kbtrv = kb_info
            kbtrv_fields = kbtrv.fields  # needed for on the fly creation below
            kbtrv = list(kbtrv)
            print(f"KB PLOTTING: kb_lens: {kb_lens}")

            # index calculation (find batch in valid/test files using kb lookup indices and length info)
            kb_num = kb_lkp[i]
            lower = sum(kb_lens[:kb_num])
            upper = lower + kb_lens[kb_num] + 1
            calcKbLen = upper - lower

            if calcKbLen == 1 and attention_scores.shape[0] > 1:
                # FIXME make this an option in the cfg
                # this is a scheduling KB created on the fly in data.batch_with_kb
                # TODO which fields are needed to recreate it on the fly here
                # valid_kb: has fields kbsrc, kbtrg; valid_kbtrv
                valid_src, valid_kb, canon_func, trg_vocab = on_the_fly_info
                v_src = list(valid_src)

                on_the_fly_kb, on_the_fly_kbtrv = create_KB_on_the_fly(
                    # FIXME perhaps matchup issues are due to generator to list issues?
                    v_src[i],
                    trg_vocab,
                    valid_kb.fields,
                    kbtrv_fields,
                    canon_func)

                keys = [entry.kbsrc for entry in on_the_fly_kb]
                vals = on_the_fly_kbtrv

                calcKbLen = len(keys)  # update with length of newly created KB

                print(f"KB PLOTTING: on the fly recreation:")
                print(keys, [v.kbtrv for v in vals])
                print(f"calcKbLen={calcKbLen}")
            else:

                keys = kbkey[lower:upper]
                vals = kbtrv[lower:upper]

                # in the normal case (non_empty KB),
                # the kb lengths (i) summed and (ii) looked up
                # should match up
                assert calcKbLen == kb_lens[kb_num] + 1, (calcKbLen,
                                                          kb_lens[kb_num] + 1)

            if len(kb_lens) > 3:
                assertion_str = f"plotting idx={i} with kb_num={kb_num} and kb_len={kb_lens[kb_num]+1},\n\
                    kb_before: {kb_lens[kb_num-1]+1}, kb_after: {kb_lens[kb_num+1]+1};\n\
                    att_scores.shape={attention_scores.shape};\n\
                    calcKbLen={calcKbLen};\n\
                    kb_lens[kb_num]+1={kb_lens[kb_num]+1};"

            else:
                # TODO
                assertion_str = ""
            """
            # make sure attention plots have the right shape
            if not calcKbLen == attention_scores.shape[0]:
                print(f"Couldnt plot example {i} because knowledgebase was created on the fly")
                print(f"actual shape mismatch: retrieved: {calcKbLen} vs att matrix: {attention_scores.shape[0]}")
                print(assertion_str)
                # FIXME FIXME FIXME FIXME im doing something wrong with the vocab lookup in the code above
                failure += 1
                continue
            """

            print(f"KB PLOTTING: calcKbLen: {calcKbLen}")
            print(
                f"KB PLOTTING: calcKbLen should be != 0 often!!: {assertion_str}"
            )

            # index application
            DUMMY = "@DUMMY=@DUMMY"

            src = [DUMMY] + [
                "+".join(key) + "=" + val.kbtrv[0]
                for key, val in zip(keys, vals)
            ]

        try:
            fig = plot_heatmap(scores=attention_scores,
                               column_labels=trg,
                               row_labels=src,
                               output_path=plot_file,
                               dpi=100)
            if tb_writer is not None:
                # lower resolution for tensorboard
                fig = plot_heatmap(scores=attention_scores,
                                   column_labels=trg,
                                   row_labels=src,
                                   output_path=None,
                                   dpi=50)
                tb_writer.add_figure("attention/{}.".format(i),
                                     fig,
                                     global_step=steps)
            print("plotted example {}: src len {}, trg len {}, "
                  "attention scores shape {}".format(i, len(src), len(trg),
                                                     attention_scores.shape))
            # pylint: disable=bare-except
            success += 1
        except:
            print("Couldn't plot example {}: src len {}, trg len {}, "
                  "attention scores shape {}".format(i, len(src), len(trg),
                                                     attention_scores.shape))
            failure += 1
            continue

    assert success + failure == len(
        indices
    ), f"plotting success:{success}, failure:{failure}, indices:{len(indices)}"
    return f"{success}/{len(indices)}"