def forward(self, y_pred: torch.Tensor, y_actual: torch.Tensor) -> torch.Tensor: """ Calculate composite metric Args: y_pred: network output y_actual: actual values Returns: torch.Tensor: metric value on which backpropagation can be applied """ y_pred_mean = y_pred.mean(0).unsqueeze(0) if isinstance(y_actual, rnn.PackedSequence): target, lengths = rnn.pad_packed_sequence(y_actual, batch_first=True) # batch sizes reside on the CPU by default -> we need to bring them to GPU lengths = lengths.to(target.device) # calculate mean for all time steps tmask = torch.arange( target.size(1), device=target.device).unsqueeze(0) >= lengths.unsqueeze(-1) if target.ndim > 2: tmask = tmask.unsqueeze(-1) lengths = lengths.unsqueeze(-1) target = target.masked_fill(tmask, 0.0) y_mean = target.sum(0).unsqueeze(0) / lengths.sum() # calculate weight as length decoder_length_histogram = integer_histogram(lengths, min=1, max=target.size(1)) weight = decoder_length_histogram.flip(0).cumsum(0).flip( 0).float().unsqueeze(0) # modify weight if y_mean.ndim == 3: y_mean[..., 1] = y_mean[..., 1] * weight else: y_mean = torch.stack((y_mean, weight), dim=-1) else: y_mean = y_actual.mean(0).unsqueeze(0) out = self.metric(y_pred_mean, y_mean) return out
def interpret_output( self, out: Dict[str, torch.Tensor], reduction: str = "none", attention_prediction_horizon: int = 0, attention_as_autocorrelation: bool = False, ) -> Dict[str, torch.Tensor]: """ interpret output of model Args: out: output as produced by ``forward()`` reduction: "none" for no averaging over batches, "sum" for summing attentions, "mean" for normalizing by encode lengths attention_prediction_horizon: which prediction horizon to use for attention attention_as_autocorrelation: if to record attention as autocorrelation - this should be set to true in case of ``reduction != "none"`` and differing prediction times of the samples. Defaults to False Returns: interpretations that can be plotted with ``plot_interpretation()`` """ # histogram of decode and encode lengths encoder_length_histogram = integer_histogram( out["encoder_lengths"], min=0, max=self.hparams.max_encoder_length) decoder_length_histogram = integer_histogram( out["decoder_lengths"], min=1, max=out["decoder_variables"].size(1)) # mask where decoder and encoder where not applied when averaging variable selection weights encoder_variables = out["encoder_variables"].squeeze(-2) encode_mask = create_mask(encoder_variables.size(1), out["encoder_lengths"]) encoder_variables = encoder_variables.masked_fill( encode_mask.unsqueeze(-1), 0.0).sum(dim=1) encoder_variables /= (out["encoder_lengths"].where( out["encoder_lengths"] > 0, torch.ones_like(out["encoder_lengths"])).unsqueeze(-1)) decoder_variables = out["decoder_variables"].squeeze(-2) decode_mask = create_mask(decoder_variables.size(1), out["decoder_lengths"]) decoder_variables = decoder_variables.masked_fill( decode_mask.unsqueeze(-1), 0.0).sum(dim=1) decoder_variables /= out["decoder_lengths"].unsqueeze(-1) # static variables need no masking static_variables = out["static_variables"].squeeze(1) # attention is batch x time x heads x time_to_attend # average over heads + only keep prediction attention and attention on observed timesteps attention = out["attention"][:, attention_prediction_horizon, :, : out["encoder_lengths"].max() + attention_prediction_horizon].mean(1) if reduction != "none": # if to average over batches static_variables = static_variables.sum(dim=0) encoder_variables = encoder_variables.sum(dim=0) decoder_variables = decoder_variables.sum(dim=0) # reorder attention or averaging for i in range( len(attention)): # very inefficient but does the trick if 0 < out["encoder_lengths"][i] < attention.size( 1) - attention_prediction_horizon - 1: relevant_attention = attention[ i, :out["encoder_lengths"][i] + attention_prediction_horizon].clone() if attention_as_autocorrelation: relevant_attention = autocorrelation( relevant_attention) attention[ i, -out["encoder_lengths"][i] - attention_prediction_horizon:] = relevant_attention attention[i, :attention.size(1) - out["encoder_lengths"][i] - attention_prediction_horizon] = 0.0 elif attention_as_autocorrelation: attention[i] = autocorrelation(attention[i]) attention = attention.sum(dim=0) if reduction == "mean": attention = attention / encoder_length_histogram[1:].flip( 0).cumsum(0).clamp(1) attention = attention / attention.sum(-1).unsqueeze( -1) # renormalize elif reduction == "sum": pass else: raise ValueError(f"Unknown reduction {reduction}") attention = torch.zeros( self.hparams.max_encoder_length + attention_prediction_horizon, device=self.device).scatter( dim=0, index=torch.arange( self.hparams.max_encoder_length + attention_prediction_horizon - attention.size(-1), self.hparams.max_encoder_length + attention_prediction_horizon, device=self.device, ), src=attention, ) else: attention = attention / attention.sum(-1).unsqueeze( -1) # renormalize interpretation = dict( attention=attention, static_variables=static_variables, encoder_variables=encoder_variables, decoder_variables=decoder_variables, encoder_length_histogram=encoder_length_histogram, decoder_length_histogram=decoder_length_histogram, ) return interpretation
def interpret_output( self, out: Dict[str, torch.Tensor], reduction: str = "none", attention_prediction_horizon: int = 0, ) -> Dict[str, torch.Tensor]: """ interpret output of model Args: out: output as produced by ``forward()`` reduction: "none" for no averaging over batches, "sum" for summing attentions, "mean" for normalizing by encode lengths attention_prediction_horizon: which prediction horizon to use for attention Returns: interpretations that can be plotted with ``plot_interpretation()`` """ # take attention and concatenate if a list to proper attention object batch_size = len(out["decoder_attention"]) if isinstance(out["decoder_attention"], (list, tuple)): # start with decoder attention # assume issue is in last dimension, we need to find max max_last_dimension = max( x.size(-1) for x in out["decoder_attention"]) first_elm = out["decoder_attention"][0] # create new attention tensor into which we will scatter decoder_attention = torch.full( (batch_size, *first_elm.shape[:-1], max_last_dimension), float("nan"), dtype=first_elm.dtype, device=first_elm.device, ) # scatter into tensor for idx, x in enumerate(out["decoder_attention"]): decoder_length = out["decoder_lengths"][idx] decoder_attention[idx, :, :, :decoder_length] = x[ ..., :decoder_length] else: decoder_attention = out["decoder_attention"] decoder_mask = create_mask(out["decoder_attention"].size(1), out["decoder_lengths"]) decoder_attention[decoder_mask[..., None, None].expand_as( decoder_attention)] = float("nan") if isinstance(out["encoder_attention"], (list, tuple)): # same game for encoder attention # create new attention tensor into which we will scatter first_elm = out["encoder_attention"][0] encoder_attention = torch.full( (batch_size, *first_elm.shape[:-1], self.hparams.max_encoder_length), float("nan"), dtype=first_elm.dtype, device=first_elm.device, ) # scatter into tensor for idx, x in enumerate(out["encoder_attention"]): encoder_length = out["encoder_lengths"][idx] encoder_attention[idx, :, :, self.hparams.max_encoder_length - encoder_length:] = x[..., :encoder_length] else: # roll encoder attention (so start last encoder value is on the right) encoder_attention = out["encoder_attention"] shifts = encoder_attention.size(3) - out["encoder_lengths"] new_index = ( torch.arange(encoder_attention.size(3), device=encoder_attention.device) [None, None, None].expand_as(encoder_attention) - shifts[:, None, None, None]) % encoder_attention.size(3) encoder_attention = torch.gather(encoder_attention, dim=3, index=new_index) # expand encoder_attentiont to full size if encoder_attention.size(-1) < self.hparams.max_encoder_length: encoder_attention = torch.concat( [ torch.full( ( *encoder_attention.shape[:-1], self.hparams.max_encoder_length - out["encoder_lengths"].max(), ), float("nan"), dtype=encoder_attention.dtype, device=encoder_attention.device, ), encoder_attention, ], dim=-1, ) # combine attention vector attention = torch.concat([encoder_attention, decoder_attention], dim=-1) attention[attention < 1e-5] = float("nan") # histogram of decode and encode lengths encoder_length_histogram = integer_histogram( out["encoder_lengths"], min=0, max=self.hparams.max_encoder_length) decoder_length_histogram = integer_histogram( out["decoder_lengths"], min=1, max=out["decoder_variables"].size(1)) # mask where decoder and encoder where not applied when averaging variable selection weights encoder_variables = out["encoder_variables"].squeeze(-2) encode_mask = create_mask(encoder_variables.size(1), out["encoder_lengths"]) encoder_variables = encoder_variables.masked_fill( encode_mask.unsqueeze(-1), 0.0).sum(dim=1) encoder_variables /= (out["encoder_lengths"].where( out["encoder_lengths"] > 0, torch.ones_like(out["encoder_lengths"])).unsqueeze(-1)) decoder_variables = out["decoder_variables"].squeeze(-2) decode_mask = create_mask(decoder_variables.size(1), out["decoder_lengths"]) decoder_variables = decoder_variables.masked_fill( decode_mask.unsqueeze(-1), 0.0).sum(dim=1) decoder_variables /= out["decoder_lengths"].unsqueeze(-1) # static variables need no masking static_variables = out["static_variables"].squeeze(1) # attention is batch x time x heads x time_to_attend # average over heads + only keep prediction attention and attention on observed timesteps attention = masked_op( attention[:, attention_prediction_horizon, :, :self.hparams. max_encoder_length + attention_prediction_horizon], op="mean", dim=1, ) if reduction != "none": # if to average over batches static_variables = static_variables.sum(dim=0) encoder_variables = encoder_variables.sum(dim=0) decoder_variables = decoder_variables.sum(dim=0) attention = masked_op(attention, dim=0, op=reduction) else: attention = attention / masked_op( attention, dim=1, op="sum").unsqueeze(-1) # renormalize interpretation = dict( attention=attention.masked_fill(torch.isnan(attention), 0.0), static_variables=static_variables, encoder_variables=encoder_variables, decoder_variables=decoder_variables, encoder_length_histogram=encoder_length_histogram, decoder_length_histogram=decoder_length_histogram, ) return interpretation