def interpret_output( self, out: Dict[str, torch.Tensor], reduction: str = "none", attention_prediction_horizon: int = 0, attention_as_autocorrelation: bool = False, ) -> Dict[str, torch.Tensor]: """ interpret output of model Args: out: output as produced by ``forward()`` reduction: "none" for no averaging over batches, "sum" for summing attentions, "mean" for normalizing by encode lengths attention_prediction_horizon: which prediction horizon to use for attention attention_as_autocorrelation: if to record attention as autocorrelation - this should be set to true in case of ``reduction != "none"`` and differing prediction times of the samples. Defaults to False Returns: interpretations that can be plotted with ``plot_interpretation()`` """ # histogram of decode and encode lengths encoder_length_histogram = integer_histogram( out["encoder_lengths"], min=0, max=self.hparams.max_encoder_length) decoder_length_histogram = integer_histogram( out["decoder_lengths"], min=1, max=out["decoder_variables"].size(1)) # mask where decoder and encoder where not applied when averaging variable selection weights encoder_variables = out["encoder_variables"].squeeze(-2) encode_mask = create_mask(encoder_variables.size(1), out["encoder_lengths"]) encoder_variables = encoder_variables.masked_fill( encode_mask.unsqueeze(-1), 0.0).sum(dim=1) encoder_variables /= (out["encoder_lengths"].where( out["encoder_lengths"] > 0, torch.ones_like(out["encoder_lengths"])).unsqueeze(-1)) decoder_variables = out["decoder_variables"].squeeze(-2) decode_mask = create_mask(decoder_variables.size(1), out["decoder_lengths"]) decoder_variables = decoder_variables.masked_fill( decode_mask.unsqueeze(-1), 0.0).sum(dim=1) decoder_variables /= out["decoder_lengths"].unsqueeze(-1) # static variables need no masking static_variables = out["static_variables"].squeeze(1) # attention is batch x time x heads x time_to_attend # average over heads + only keep prediction attention and attention on observed timesteps attention = out["attention"][:, attention_prediction_horizon, :, : out["encoder_lengths"].max() + attention_prediction_horizon].mean(1) if reduction != "none": # if to average over batches static_variables = static_variables.sum(dim=0) encoder_variables = encoder_variables.sum(dim=0) decoder_variables = decoder_variables.sum(dim=0) # reorder attention or averaging for i in range( len(attention)): # very inefficient but does the trick if 0 < out["encoder_lengths"][i] < attention.size( 1) - attention_prediction_horizon - 1: relevant_attention = attention[ i, :out["encoder_lengths"][i] + attention_prediction_horizon].clone() if attention_as_autocorrelation: relevant_attention = autocorrelation( relevant_attention) attention[ i, -out["encoder_lengths"][i] - attention_prediction_horizon:] = relevant_attention attention[i, :attention.size(1) - out["encoder_lengths"][i] - attention_prediction_horizon] = 0.0 elif attention_as_autocorrelation: attention[i] = autocorrelation(attention[i]) attention = attention.sum(dim=0) if reduction == "mean": attention = attention / encoder_length_histogram[1:].flip( 0).cumsum(0).clamp(1) attention = attention / attention.sum(-1).unsqueeze( -1) # renormalize elif reduction == "sum": pass else: raise ValueError(f"Unknown reduction {reduction}") attention = torch.zeros( self.hparams.max_encoder_length + attention_prediction_horizon, device=self.device).scatter( dim=0, index=torch.arange( self.hparams.max_encoder_length + attention_prediction_horizon - attention.size(-1), self.hparams.max_encoder_length + attention_prediction_horizon, device=self.device, ), src=attention, ) else: attention = attention / attention.sum(-1).unsqueeze( -1) # renormalize interpretation = dict( attention=attention, static_variables=static_variables, encoder_variables=encoder_variables, decoder_variables=decoder_variables, encoder_length_histogram=encoder_length_histogram, decoder_length_histogram=decoder_length_histogram, ) return interpretation
def test_autocorrelation(): x = torch.sin(torch.linspace(0, 2 * 2 * math.pi, 201)) corr = autocorrelation(x, dim=-1) assert corr[0] == 1, "Autocorrelation of first element should be 1." assert corr[101] > 0.99, "Autocorrelation should be near 1 for sin(2*pi)" assert corr[50] < -0.99, "Autocorrelation should be near -1 for sin(pi)"