def test_calculate_all_attentions_MultiHeadedAttention():
    model = Dummy()
    bs = 2
    batch = {
        "x": torch.randn(bs, 3, 10),
        "x_lengths": torch.tensor([3, 2], dtype=torch.long),
        "y": torch.randn(bs, 2, 10),
        "y_lengths": torch.tensor([4, 4], dtype=torch.long),
    }
    t = calculate_all_attentions(model, batch)
    print(t)
    for k in model.desired:
        for i in range(bs):
            np.testing.assert_array_equal(t[k][i].numpy(), model.desired[k][i].numpy())
def test_calculate_all_attentions(atype):
    model = Dummy2(atype)
    bs = 2
    batch = {
        "x": torch.randn(bs, 20, 128),
        "x_lengths": torch.tensor([20, 17], dtype=torch.long),
        "y": torch.randint(0, 50, [bs, 7]),
        "y_lengths": torch.tensor([7, 5], dtype=torch.long),
    }
    t = calculate_all_attentions(model, batch)
    for k, o in t.items():
        for i, att in enumerate(o):
            print(att.shape)
            if att.dim() == 2:
                att = att[None]
            for a in att:
                assert a.shape == (batch["y_lengths"][i], batch["x_lengths"][i])
Example #3
0
    def plot_attention(
        cls,
        model: torch.nn.Module,
        output_dir: Optional[Path],
        summary_writer,
        iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
        reporter: SubReporter,
        options: TrainerOptions,
    ) -> None:
        assert check_argument_types()
        import matplotlib

        ngpu = options.ngpu
        no_forward_run = options.no_forward_run

        matplotlib.use("Agg")
        import matplotlib.pyplot as plt
        from matplotlib.ticker import MaxNLocator

        model.eval()
        for ids, batch in iterator:
            assert isinstance(batch, dict), type(batch)
            assert len(next(iter(batch.values()))) == len(ids), (
                len(next(iter(batch.values()))),
                len(ids),
            )

            batch["utt_id"] = ids

            batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
            if no_forward_run:
                continue

            # 1. Forwarding model and gathering all attentions
            #    calculate_all_attentions() uses single gpu only.
            att_dict = calculate_all_attentions(model, batch)

            # 2. Plot attentions: This part is slow due to matplotlib
            for k, att_list in att_dict.items():
                assert len(att_list) == len(ids), (len(att_list), len(ids))
                for id_, att_w in zip(ids, att_list):

                    if isinstance(att_w, torch.Tensor):
                        att_w = att_w.detach().cpu().numpy()

                    if att_w.ndim == 2:
                        att_w = att_w[None]
                    elif att_w.ndim > 3 or att_w.ndim == 1:
                        raise RuntimeError(
                            f"Must be 2 or 3 dimension: {att_w.ndim}")

                    w, h = plt.figaspect(1.0 / len(att_w))
                    fig = plt.Figure(figsize=(w * 1.3, h * 1.3))
                    axes = fig.subplots(1, len(att_w))
                    if len(att_w) == 1:
                        axes = [axes]

                    for ax, aw in zip(axes, att_w):
                        ax.imshow(aw.astype(np.float32), aspect="auto")
                        ax.set_title(f"{k}_{id_}")
                        ax.set_xlabel("Input")
                        ax.set_ylabel("Output")
                        ax.xaxis.set_major_locator(MaxNLocator(integer=True))
                        ax.yaxis.set_major_locator(MaxNLocator(integer=True))

                    if output_dir is not None:
                        p = output_dir / id_ / f"{k}.{reporter.get_epoch()}ep.png"
                        p.parent.mkdir(parents=True, exist_ok=True)
                        fig.savefig(p)

                    if summary_writer is not None:
                        summary_writer.add_figure(f"{k}_{id_}", fig,
                                                  reporter.get_epoch())

                    if options.use_wandb:
                        import wandb

                        wandb.log(
                            {f"attention plot/{k}_{id_}": wandb.Image(fig)})
            reporter.next()
Example #4
0
    def recognize(self, audiofile: Union[Path, str, bytes]) -> Result:

        result = Result()

        if isinstance(audiofile, str):
            audio_samples, rate = librosa.load(audiofile, sr=16000)
        elif isinstance(audiofile, bytes):
            audio_samples, rate = librosa.core.load(io.BytesIO(audiofile),
                                                    sr=16000)
        else:
            raise ValueError("Failed to load audio file")

        result.audio_samples = copy.deepcopy(audio_samples)

        #a entrada do modelo é torch.tensor
        if isinstance(audio_samples, np.ndarray):
            audio_samples = torch.tensor(audio_samples)
        audio_samples = audio_samples.unsqueeze(0).to(getattr(
            torch, 'float32'))

        lengths = audio_samples.new_full([1],
                                         dtype=torch.long,
                                         fill_value=audio_samples.size(1))
        batch = {"speech": audio_samples, "speech_lengths": lengths}
        batch = to_device(batch, device=self.device)

        #model encoder
        enc, _ = self.model.encode(**batch)

        #model decoder
        nbest_hyps = self.beam_search(x=enc[0])

        #Apenas a melhor hipótese
        best_hyps = nbest_hyps[0]

        #Conversão de tokenids do treinamento para texto
        token_int = best_hyps.yseq[1:-1].tolist()
        token_int = list(filter(lambda x: x != 0, token_int))
        token = self.converter.ids2tokens(token_int)
        text = self.tokenizer.tokens2text(token)

        #Preenche o objeto result
        result.text = text
        result.encoded_vector = enc[0]  #[0] remove dimensão de batch

        #calcula todas as matrizes de atenção
        #
        text_tensor = torch.Tensor(token_int).unsqueeze(0).to(
            getattr(torch, 'long'))
        batch["text"] = text_tensor
        batch["text_lengths"] = text_tensor.new_full(
            [1], dtype=torch.long, fill_value=text_tensor.size(1))

        result.attention_weights = calculate_all_attentions(self.model, batch)
        result.tokens_txt = token

        #CTC posteriors
        logp = self.model.ctc.log_softmax(enc.unsqueeze(0))[0]
        result.ctc_posteriors = logp.exp_().numpy()
        result.tokens_int = best_hyps.yseq
        result.mel_features, _ = self.frontend(audiofile, normalize=False)
        return result