Ejemplo n.º 1
0
    def interpret_output(
        self,
        out: Dict[str, torch.Tensor],
        reduction: str = "none",
        attention_prediction_horizon: int = 0,
        attention_as_autocorrelation: bool = False,
    ) -> Dict[str, torch.Tensor]:
        """
        interpret output of model

        Args:
            out: output as produced by ``forward()``
            reduction: "none" for no averaging over batches, "sum" for summing attentions, "mean" for
                normalizing by encode lengths
            attention_prediction_horizon: which prediction horizon to use for attention
            attention_as_autocorrelation: if to record attention as autocorrelation - this should be set to true in
                case of ``reduction != "none"`` and differing prediction times of the samples. Defaults to False

        Returns:
            interpretations that can be plotted with ``plot_interpretation()``
        """

        # histogram of decode and encode lengths
        encoder_length_histogram = integer_histogram(
            out["encoder_lengths"], min=0, max=self.hparams.max_encoder_length)
        decoder_length_histogram = integer_histogram(
            out["decoder_lengths"],
            min=1,
            max=out["decoder_variables"].size(1))

        # mask where decoder and encoder where not applied when averaging variable selection weights
        encoder_variables = out["encoder_variables"].squeeze(-2)
        encode_mask = create_mask(encoder_variables.size(1),
                                  out["encoder_lengths"])
        encoder_variables = encoder_variables.masked_fill(
            encode_mask.unsqueeze(-1), 0.0).sum(dim=1)
        encoder_variables /= (out["encoder_lengths"].where(
            out["encoder_lengths"] > 0,
            torch.ones_like(out["encoder_lengths"])).unsqueeze(-1))

        decoder_variables = out["decoder_variables"].squeeze(-2)
        decode_mask = create_mask(decoder_variables.size(1),
                                  out["decoder_lengths"])
        decoder_variables = decoder_variables.masked_fill(
            decode_mask.unsqueeze(-1), 0.0).sum(dim=1)
        decoder_variables /= out["decoder_lengths"].unsqueeze(-1)

        # static variables need no masking
        static_variables = out["static_variables"].squeeze(1)
        # attention is batch x time x heads x time_to_attend
        # average over heads + only keep prediction attention and attention on observed timesteps
        attention = out["attention"][:, attention_prediction_horizon, :, :
                                     out["encoder_lengths"].max() +
                                     attention_prediction_horizon].mean(1)

        if reduction != "none":  # if to average over batches
            static_variables = static_variables.sum(dim=0)
            encoder_variables = encoder_variables.sum(dim=0)
            decoder_variables = decoder_variables.sum(dim=0)

            # reorder attention or averaging
            for i in range(
                    len(attention)):  # very inefficient but does the trick
                if 0 < out["encoder_lengths"][i] < attention.size(
                        1) - attention_prediction_horizon - 1:
                    relevant_attention = attention[
                        i, :out["encoder_lengths"][i] +
                        attention_prediction_horizon].clone()
                    if attention_as_autocorrelation:
                        relevant_attention = autocorrelation(
                            relevant_attention)
                    attention[
                        i, -out["encoder_lengths"][i] -
                        attention_prediction_horizon:] = relevant_attention
                    attention[i, :attention.size(1) -
                              out["encoder_lengths"][i] -
                              attention_prediction_horizon] = 0.0
                elif attention_as_autocorrelation:
                    attention[i] = autocorrelation(attention[i])

            attention = attention.sum(dim=0)
            if reduction == "mean":
                attention = attention / encoder_length_histogram[1:].flip(
                    0).cumsum(0).clamp(1)
                attention = attention / attention.sum(-1).unsqueeze(
                    -1)  # renormalize
            elif reduction == "sum":
                pass
            else:
                raise ValueError(f"Unknown reduction {reduction}")

            attention = torch.zeros(
                self.hparams.max_encoder_length + attention_prediction_horizon,
                device=self.device).scatter(
                    dim=0,
                    index=torch.arange(
                        self.hparams.max_encoder_length +
                        attention_prediction_horizon - attention.size(-1),
                        self.hparams.max_encoder_length +
                        attention_prediction_horizon,
                        device=self.device,
                    ),
                    src=attention,
                )
        else:
            attention = attention / attention.sum(-1).unsqueeze(
                -1)  # renormalize

        interpretation = dict(
            attention=attention,
            static_variables=static_variables,
            encoder_variables=encoder_variables,
            decoder_variables=decoder_variables,
            encoder_length_histogram=encoder_length_histogram,
            decoder_length_histogram=decoder_length_histogram,
        )
        return interpretation
def test_autocorrelation():
    x = torch.sin(torch.linspace(0, 2 * 2 * math.pi, 201))
    corr = autocorrelation(x, dim=-1)
    assert corr[0] == 1, "Autocorrelation of first element should be 1."
    assert corr[101] > 0.99, "Autocorrelation should be near 1 for sin(2*pi)"
    assert corr[50] < -0.99, "Autocorrelation should be near -1 for sin(pi)"