def get_attention_mask(self, encoder_lengths: torch.LongTensor, decoder_length: int): """ Returns causal mask to apply for self-attention layer. Args: self_attn_inputs: Inputs to self attention layer to determine mask shape """ # indices to which is attended attend_step = torch.arange(decoder_length, device=self.device) # indices for which is predicted predict_step = torch.arange(0, decoder_length, device=self.device)[:, None] # do not attend to steps to self or after prediction # todo: there is potential value in attending to future forecasts if they are made with knowledge currently # available # one possibility is here to use a second attention layer for future attention (assuming different effects # matter in the future than the past) # or alternatively using the same layer but allowing forward attention - i.e. only masking out non-available # data and self decoder_mask = attend_step >= predict_step # do not attend to steps where data is padded encoder_mask = create_mask(encoder_lengths.max(), encoder_lengths) # combine masks along attended time - first encoder and then decoder mask = torch.cat( ( encoder_mask.unsqueeze(1).expand(-1, decoder_length, -1), decoder_mask.unsqueeze(0).expand(encoder_lengths.size(0), -1, -1), ), dim=2, ) return mask
def _convert(self, y_pred: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # unpack target into target and weights if isinstance( target, (list, tuple)) and not isinstance(target, rnn.PackedSequence): target, weight = target if weight is not None: raise NotImplementedError( "Weighting is not supported for pure torchmetrics - " "implement a custom version or use pytorch-forecasting metrics" ) # convert to point prediction - limits applications of class y_pred = self.to_prediction(y_pred) # unpack target if it is PackedSequence if isinstance(target, rnn.PackedSequence): target, lengths = unpack_sequence(target) # create mask for different lengths length_mask = create_mask(target.size(1), lengths, inverse=True) target = target.masked_select(length_mask) y_pred = y_pred.masked_select(length_mask) y_pred = y_pred.flatten() target = target.flatten() return y_pred, target
def update(self, y_pred: torch.Tensor, y_actual: torch.Tensor) -> torch.Tensor: """ Calculate composite metric Args: y_pred: network output y_actual: actual values Returns: torch.Tensor: metric value on which backpropagation can be applied """ # extract target and weight if isinstance( y_actual, (tuple, list)) and not isinstance(y_actual, rnn.PackedSequence): target, weight = y_actual else: target = y_actual weight = None # handle rnn sequence as target if isinstance(target, rnn.PackedSequence): target, lengths = rnn.pad_packed_sequence(target, batch_first=True) # batch sizes reside on the CPU by default -> we need to bring them to GPU lengths = lengths.to(target.device) # calculate mask for time steps length_mask = create_mask(target.size(1), lengths, inverse=True) # modify weight if weight is None: weight = length_mask else: weight = weight * length_mask if weight is None: y_mean = target.mean(0) y_pred_mean = y_pred.mean(0) else: # calculate weighted sums y_mean = (target * unsqueeze_like(weight, y_pred)).sum(0) / weight.sum(0) y_pred_sum = (y_pred * unsqueeze_like(weight, y_pred)).sum(0) y_pred_mean = y_pred_sum / unsqueeze_like(weight.sum(0), y_pred_sum) # update metric. unsqueeze first batch dimension (as batches are collapsed) self.metric.update(y_pred_mean.unsqueeze(0), y_mean.unsqueeze(0))
def _calculate_mean( y_pred: torch.Tensor, y_actual: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # extract target and weight if isinstance( y_actual, (tuple, list)) and not isinstance(y_actual, rnn.PackedSequence): target, weight = y_actual else: target = y_actual weight = None # handle rnn sequence as target if isinstance(target, rnn.PackedSequence): target, lengths = rnn.pad_packed_sequence(target, batch_first=True) # batch sizes reside on the CPU by default -> we need to bring them to GPU lengths = lengths.to(target.device) # calculate mask for time steps length_mask = create_mask(target.size(1), lengths, inverse=True) # modify weight if weight is None: weight = length_mask else: weight = weight * length_mask if weight is None: y_mean = target.mean(0) y_pred_mean = y_pred.mean(0) else: # calculate weighted sums y_mean = (target * unsqueeze_like(weight, y_pred)).sum(0) / weight.sum(0) y_pred_sum = (y_pred * unsqueeze_like(weight, y_pred)).sum(0) y_pred_mean = y_pred_sum / unsqueeze_like(weight.sum(0), y_pred_sum) return y_pred_mean.unsqueeze(0), y_mean.unsqueeze(0)
def interpret_output( self, out: Dict[str, torch.Tensor], reduction: str = "none", attention_prediction_horizon: int = 0, attention_as_autocorrelation: bool = False, ) -> Dict[str, torch.Tensor]: """ interpret output of model Args: out: output as produced by ``forward()`` reduction: "none" for no averaging over batches, "sum" for summing attentions, "mean" for normalizing by encode lengths attention_prediction_horizon: which prediction horizon to use for attention attention_as_autocorrelation: if to record attention as autocorrelation - this should be set to true in case of ``reduction != "none"`` and differing prediction times of the samples. Defaults to False Returns: interpretations that can be plotted with ``plot_interpretation()`` """ # histogram of decode and encode lengths encoder_length_histogram = integer_histogram( out["encoder_lengths"], min=0, max=self.hparams.max_encoder_length) decoder_length_histogram = integer_histogram( out["decoder_lengths"], min=1, max=out["decoder_variables"].size(1)) # mask where decoder and encoder where not applied when averaging variable selection weights encoder_variables = out["encoder_variables"].squeeze(-2) encode_mask = create_mask(encoder_variables.size(1), out["encoder_lengths"]) encoder_variables = encoder_variables.masked_fill( encode_mask.unsqueeze(-1), 0.0).sum(dim=1) encoder_variables /= (out["encoder_lengths"].where( out["encoder_lengths"] > 0, torch.ones_like(out["encoder_lengths"])).unsqueeze(-1)) decoder_variables = out["decoder_variables"].squeeze(-2) decode_mask = create_mask(decoder_variables.size(1), out["decoder_lengths"]) decoder_variables = decoder_variables.masked_fill( decode_mask.unsqueeze(-1), 0.0).sum(dim=1) decoder_variables /= out["decoder_lengths"].unsqueeze(-1) # static variables need no masking static_variables = out["static_variables"].squeeze(1) # attention is batch x time x heads x time_to_attend # average over heads + only keep prediction attention and attention on observed timesteps attention = out["attention"][:, attention_prediction_horizon, :, : out["encoder_lengths"].max() + attention_prediction_horizon].mean(1) if reduction != "none": # if to average over batches static_variables = static_variables.sum(dim=0) encoder_variables = encoder_variables.sum(dim=0) decoder_variables = decoder_variables.sum(dim=0) # reorder attention or averaging for i in range( len(attention)): # very inefficient but does the trick if 0 < out["encoder_lengths"][i] < attention.size( 1) - attention_prediction_horizon - 1: relevant_attention = attention[ i, :out["encoder_lengths"][i] + attention_prediction_horizon].clone() if attention_as_autocorrelation: relevant_attention = autocorrelation( relevant_attention) attention[ i, -out["encoder_lengths"][i] - attention_prediction_horizon:] = relevant_attention attention[i, :attention.size(1) - out["encoder_lengths"][i] - attention_prediction_horizon] = 0.0 elif attention_as_autocorrelation: attention[i] = autocorrelation(attention[i]) attention = attention.sum(dim=0) if reduction == "mean": attention = attention / encoder_length_histogram[1:].flip( 0).cumsum(0).clamp(1) attention = attention / attention.sum(-1).unsqueeze( -1) # renormalize elif reduction == "sum": pass else: raise ValueError(f"Unknown reduction {reduction}") attention = torch.zeros( self.hparams.max_encoder_length + attention_prediction_horizon, device=self.device).scatter( dim=0, index=torch.arange( self.hparams.max_encoder_length + attention_prediction_horizon - attention.size(-1), self.hparams.max_encoder_length + attention_prediction_horizon, device=self.device, ), src=attention, ) else: attention = attention / attention.sum(-1).unsqueeze( -1) # renormalize interpretation = dict( attention=attention, static_variables=static_variables, encoder_variables=encoder_variables, decoder_variables=decoder_variables, encoder_length_histogram=encoder_length_histogram, decoder_length_histogram=decoder_length_histogram, ) return interpretation
def interpret_output( self, out: Dict[str, torch.Tensor], reduction: str = "none", attention_prediction_horizon: int = 0, ) -> Dict[str, torch.Tensor]: """ interpret output of model Args: out: output as produced by ``forward()`` reduction: "none" for no averaging over batches, "sum" for summing attentions, "mean" for normalizing by encode lengths attention_prediction_horizon: which prediction horizon to use for attention Returns: interpretations that can be plotted with ``plot_interpretation()`` """ # take attention and concatenate if a list to proper attention object batch_size = len(out["decoder_attention"]) if isinstance(out["decoder_attention"], (list, tuple)): # start with decoder attention # assume issue is in last dimension, we need to find max max_last_dimension = max( x.size(-1) for x in out["decoder_attention"]) first_elm = out["decoder_attention"][0] # create new attention tensor into which we will scatter decoder_attention = torch.full( (batch_size, *first_elm.shape[:-1], max_last_dimension), float("nan"), dtype=first_elm.dtype, device=first_elm.device, ) # scatter into tensor for idx, x in enumerate(out["decoder_attention"]): decoder_length = out["decoder_lengths"][idx] decoder_attention[idx, :, :, :decoder_length] = x[ ..., :decoder_length] else: decoder_attention = out["decoder_attention"] decoder_mask = create_mask(out["decoder_attention"].size(1), out["decoder_lengths"]) decoder_attention[decoder_mask[..., None, None].expand_as( decoder_attention)] = float("nan") if isinstance(out["encoder_attention"], (list, tuple)): # same game for encoder attention # create new attention tensor into which we will scatter first_elm = out["encoder_attention"][0] encoder_attention = torch.full( (batch_size, *first_elm.shape[:-1], self.hparams.max_encoder_length), float("nan"), dtype=first_elm.dtype, device=first_elm.device, ) # scatter into tensor for idx, x in enumerate(out["encoder_attention"]): encoder_length = out["encoder_lengths"][idx] encoder_attention[idx, :, :, self.hparams.max_encoder_length - encoder_length:] = x[..., :encoder_length] else: # roll encoder attention (so start last encoder value is on the right) encoder_attention = out["encoder_attention"] shifts = encoder_attention.size(3) - out["encoder_lengths"] new_index = ( torch.arange(encoder_attention.size(3), device=encoder_attention.device) [None, None, None].expand_as(encoder_attention) - shifts[:, None, None, None]) % encoder_attention.size(3) encoder_attention = torch.gather(encoder_attention, dim=3, index=new_index) # expand encoder_attentiont to full size if encoder_attention.size(-1) < self.hparams.max_encoder_length: encoder_attention = torch.concat( [ torch.full( ( *encoder_attention.shape[:-1], self.hparams.max_encoder_length - out["encoder_lengths"].max(), ), float("nan"), dtype=encoder_attention.dtype, device=encoder_attention.device, ), encoder_attention, ], dim=-1, ) # combine attention vector attention = torch.concat([encoder_attention, decoder_attention], dim=-1) attention[attention < 1e-5] = float("nan") # histogram of decode and encode lengths encoder_length_histogram = integer_histogram( out["encoder_lengths"], min=0, max=self.hparams.max_encoder_length) decoder_length_histogram = integer_histogram( out["decoder_lengths"], min=1, max=out["decoder_variables"].size(1)) # mask where decoder and encoder where not applied when averaging variable selection weights encoder_variables = out["encoder_variables"].squeeze(-2) encode_mask = create_mask(encoder_variables.size(1), out["encoder_lengths"]) encoder_variables = encoder_variables.masked_fill( encode_mask.unsqueeze(-1), 0.0).sum(dim=1) encoder_variables /= (out["encoder_lengths"].where( out["encoder_lengths"] > 0, torch.ones_like(out["encoder_lengths"])).unsqueeze(-1)) decoder_variables = out["decoder_variables"].squeeze(-2) decode_mask = create_mask(decoder_variables.size(1), out["decoder_lengths"]) decoder_variables = decoder_variables.masked_fill( decode_mask.unsqueeze(-1), 0.0).sum(dim=1) decoder_variables /= out["decoder_lengths"].unsqueeze(-1) # static variables need no masking static_variables = out["static_variables"].squeeze(1) # attention is batch x time x heads x time_to_attend # average over heads + only keep prediction attention and attention on observed timesteps attention = masked_op( attention[:, attention_prediction_horizon, :, :self.hparams. max_encoder_length + attention_prediction_horizon], op="mean", dim=1, ) if reduction != "none": # if to average over batches static_variables = static_variables.sum(dim=0) encoder_variables = encoder_variables.sum(dim=0) decoder_variables = decoder_variables.sum(dim=0) attention = masked_op(attention, dim=0, op=reduction) else: attention = attention / masked_op( attention, dim=1, op="sum").unsqueeze(-1) # renormalize interpretation = dict( attention=attention.masked_fill(torch.isnan(attention), 0.0), static_variables=static_variables, encoder_variables=encoder_variables, decoder_variables=decoder_variables, encoder_length_histogram=encoder_length_histogram, decoder_length_histogram=decoder_length_histogram, ) return interpretation