Пример #1
0
    def get_attention_mask(self, encoder_lengths: torch.LongTensor,
                           decoder_length: int):
        """
        Returns causal mask to apply for self-attention layer.

        Args:
            self_attn_inputs: Inputs to self attention layer to determine mask shape
        """
        # indices to which is attended
        attend_step = torch.arange(decoder_length, device=self.device)
        # indices for which is predicted
        predict_step = torch.arange(0, decoder_length,
                                    device=self.device)[:, None]
        # do not attend to steps to self or after prediction
        # todo: there is potential value in attending to future forecasts if they are made with knowledge currently
        #   available
        #   one possibility is here to use a second attention layer for future attention (assuming different effects
        #   matter in the future than the past)
        #   or alternatively using the same layer but allowing forward attention - i.e. only masking out non-available
        #   data and self
        decoder_mask = attend_step >= predict_step
        # do not attend to steps where data is padded
        encoder_mask = create_mask(encoder_lengths.max(), encoder_lengths)
        # combine masks along attended time - first encoder and then decoder
        mask = torch.cat(
            (
                encoder_mask.unsqueeze(1).expand(-1, decoder_length, -1),
                decoder_mask.unsqueeze(0).expand(encoder_lengths.size(0), -1,
                                                 -1),
            ),
            dim=2,
        )
        return mask
Пример #2
0
    def _convert(self, y_pred: torch.Tensor,
                 target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # unpack target into target and weights
        if isinstance(
                target,
            (list, tuple)) and not isinstance(target, rnn.PackedSequence):
            target, weight = target
            if weight is not None:
                raise NotImplementedError(
                    "Weighting is not supported for pure torchmetrics - "
                    "implement a custom version or use pytorch-forecasting metrics"
                )

        # convert to point prediction - limits applications of class
        y_pred = self.to_prediction(y_pred)

        # unpack target if it is PackedSequence
        if isinstance(target, rnn.PackedSequence):
            target, lengths = unpack_sequence(target)
            # create mask for different lengths
            length_mask = create_mask(target.size(1), lengths, inverse=True)
            target = target.masked_select(length_mask)
            y_pred = y_pred.masked_select(length_mask)

        y_pred = y_pred.flatten()
        target = target.flatten()
        return y_pred, target
Пример #3
0
    def update(self, y_pred: torch.Tensor,
               y_actual: torch.Tensor) -> torch.Tensor:
        """
        Calculate composite metric

        Args:
            y_pred: network output
            y_actual: actual values

        Returns:
            torch.Tensor: metric value on which backpropagation can be applied
        """
        # extract target and weight
        if isinstance(
                y_actual,
            (tuple, list)) and not isinstance(y_actual, rnn.PackedSequence):
            target, weight = y_actual
        else:
            target = y_actual
            weight = None

        # handle rnn sequence as target
        if isinstance(target, rnn.PackedSequence):
            target, lengths = rnn.pad_packed_sequence(target, batch_first=True)
            # batch sizes reside on the CPU by default -> we need to bring them to GPU
            lengths = lengths.to(target.device)

            # calculate mask for time steps
            length_mask = create_mask(target.size(1), lengths, inverse=True)

            # modify weight
            if weight is None:
                weight = length_mask
            else:
                weight = weight * length_mask

        if weight is None:
            y_mean = target.mean(0)
            y_pred_mean = y_pred.mean(0)
        else:

            # calculate weighted sums
            y_mean = (target *
                      unsqueeze_like(weight, y_pred)).sum(0) / weight.sum(0)

            y_pred_sum = (y_pred * unsqueeze_like(weight, y_pred)).sum(0)
            y_pred_mean = y_pred_sum / unsqueeze_like(weight.sum(0),
                                                      y_pred_sum)

        # update metric. unsqueeze first batch dimension (as batches are collapsed)
        self.metric.update(y_pred_mean.unsqueeze(0), y_mean.unsqueeze(0))
Пример #4
0
    def _calculate_mean(
            y_pred: torch.Tensor,
            y_actual: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # extract target and weight
        if isinstance(
                y_actual,
            (tuple, list)) and not isinstance(y_actual, rnn.PackedSequence):
            target, weight = y_actual
        else:
            target = y_actual
            weight = None

        # handle rnn sequence as target
        if isinstance(target, rnn.PackedSequence):
            target, lengths = rnn.pad_packed_sequence(target, batch_first=True)
            # batch sizes reside on the CPU by default -> we need to bring them to GPU
            lengths = lengths.to(target.device)

            # calculate mask for time steps
            length_mask = create_mask(target.size(1), lengths, inverse=True)

            # modify weight
            if weight is None:
                weight = length_mask
            else:
                weight = weight * length_mask

        if weight is None:
            y_mean = target.mean(0)
            y_pred_mean = y_pred.mean(0)
        else:

            # calculate weighted sums
            y_mean = (target *
                      unsqueeze_like(weight, y_pred)).sum(0) / weight.sum(0)

            y_pred_sum = (y_pred * unsqueeze_like(weight, y_pred)).sum(0)
            y_pred_mean = y_pred_sum / unsqueeze_like(weight.sum(0),
                                                      y_pred_sum)
        return y_pred_mean.unsqueeze(0), y_mean.unsqueeze(0)
Пример #5
0
    def interpret_output(
        self,
        out: Dict[str, torch.Tensor],
        reduction: str = "none",
        attention_prediction_horizon: int = 0,
        attention_as_autocorrelation: bool = False,
    ) -> Dict[str, torch.Tensor]:
        """
        interpret output of model

        Args:
            out: output as produced by ``forward()``
            reduction: "none" for no averaging over batches, "sum" for summing attentions, "mean" for
                normalizing by encode lengths
            attention_prediction_horizon: which prediction horizon to use for attention
            attention_as_autocorrelation: if to record attention as autocorrelation - this should be set to true in
                case of ``reduction != "none"`` and differing prediction times of the samples. Defaults to False

        Returns:
            interpretations that can be plotted with ``plot_interpretation()``
        """

        # histogram of decode and encode lengths
        encoder_length_histogram = integer_histogram(
            out["encoder_lengths"], min=0, max=self.hparams.max_encoder_length)
        decoder_length_histogram = integer_histogram(
            out["decoder_lengths"],
            min=1,
            max=out["decoder_variables"].size(1))

        # mask where decoder and encoder where not applied when averaging variable selection weights
        encoder_variables = out["encoder_variables"].squeeze(-2)
        encode_mask = create_mask(encoder_variables.size(1),
                                  out["encoder_lengths"])
        encoder_variables = encoder_variables.masked_fill(
            encode_mask.unsqueeze(-1), 0.0).sum(dim=1)
        encoder_variables /= (out["encoder_lengths"].where(
            out["encoder_lengths"] > 0,
            torch.ones_like(out["encoder_lengths"])).unsqueeze(-1))

        decoder_variables = out["decoder_variables"].squeeze(-2)
        decode_mask = create_mask(decoder_variables.size(1),
                                  out["decoder_lengths"])
        decoder_variables = decoder_variables.masked_fill(
            decode_mask.unsqueeze(-1), 0.0).sum(dim=1)
        decoder_variables /= out["decoder_lengths"].unsqueeze(-1)

        # static variables need no masking
        static_variables = out["static_variables"].squeeze(1)
        # attention is batch x time x heads x time_to_attend
        # average over heads + only keep prediction attention and attention on observed timesteps
        attention = out["attention"][:, attention_prediction_horizon, :, :
                                     out["encoder_lengths"].max() +
                                     attention_prediction_horizon].mean(1)

        if reduction != "none":  # if to average over batches
            static_variables = static_variables.sum(dim=0)
            encoder_variables = encoder_variables.sum(dim=0)
            decoder_variables = decoder_variables.sum(dim=0)

            # reorder attention or averaging
            for i in range(
                    len(attention)):  # very inefficient but does the trick
                if 0 < out["encoder_lengths"][i] < attention.size(
                        1) - attention_prediction_horizon - 1:
                    relevant_attention = attention[
                        i, :out["encoder_lengths"][i] +
                        attention_prediction_horizon].clone()
                    if attention_as_autocorrelation:
                        relevant_attention = autocorrelation(
                            relevant_attention)
                    attention[
                        i, -out["encoder_lengths"][i] -
                        attention_prediction_horizon:] = relevant_attention
                    attention[i, :attention.size(1) -
                              out["encoder_lengths"][i] -
                              attention_prediction_horizon] = 0.0
                elif attention_as_autocorrelation:
                    attention[i] = autocorrelation(attention[i])

            attention = attention.sum(dim=0)
            if reduction == "mean":
                attention = attention / encoder_length_histogram[1:].flip(
                    0).cumsum(0).clamp(1)
                attention = attention / attention.sum(-1).unsqueeze(
                    -1)  # renormalize
            elif reduction == "sum":
                pass
            else:
                raise ValueError(f"Unknown reduction {reduction}")

            attention = torch.zeros(
                self.hparams.max_encoder_length + attention_prediction_horizon,
                device=self.device).scatter(
                    dim=0,
                    index=torch.arange(
                        self.hparams.max_encoder_length +
                        attention_prediction_horizon - attention.size(-1),
                        self.hparams.max_encoder_length +
                        attention_prediction_horizon,
                        device=self.device,
                    ),
                    src=attention,
                )
        else:
            attention = attention / attention.sum(-1).unsqueeze(
                -1)  # renormalize

        interpretation = dict(
            attention=attention,
            static_variables=static_variables,
            encoder_variables=encoder_variables,
            decoder_variables=decoder_variables,
            encoder_length_histogram=encoder_length_histogram,
            decoder_length_histogram=decoder_length_histogram,
        )
        return interpretation
Пример #6
0
    def interpret_output(
        self,
        out: Dict[str, torch.Tensor],
        reduction: str = "none",
        attention_prediction_horizon: int = 0,
    ) -> Dict[str, torch.Tensor]:
        """
        interpret output of model

        Args:
            out: output as produced by ``forward()``
            reduction: "none" for no averaging over batches, "sum" for summing attentions, "mean" for
                normalizing by encode lengths
            attention_prediction_horizon: which prediction horizon to use for attention

        Returns:
            interpretations that can be plotted with ``plot_interpretation()``
        """
        # take attention and concatenate if a list to proper attention object
        batch_size = len(out["decoder_attention"])
        if isinstance(out["decoder_attention"], (list, tuple)):
            # start with decoder attention
            # assume issue is in last dimension, we need to find max
            max_last_dimension = max(
                x.size(-1) for x in out["decoder_attention"])
            first_elm = out["decoder_attention"][0]
            # create new attention tensor into which we will scatter
            decoder_attention = torch.full(
                (batch_size, *first_elm.shape[:-1], max_last_dimension),
                float("nan"),
                dtype=first_elm.dtype,
                device=first_elm.device,
            )
            # scatter into tensor
            for idx, x in enumerate(out["decoder_attention"]):
                decoder_length = out["decoder_lengths"][idx]
                decoder_attention[idx, :, :, :decoder_length] = x[
                    ..., :decoder_length]
        else:
            decoder_attention = out["decoder_attention"]
            decoder_mask = create_mask(out["decoder_attention"].size(1),
                                       out["decoder_lengths"])
            decoder_attention[decoder_mask[..., None, None].expand_as(
                decoder_attention)] = float("nan")

        if isinstance(out["encoder_attention"], (list, tuple)):
            # same game for encoder attention
            # create new attention tensor into which we will scatter
            first_elm = out["encoder_attention"][0]
            encoder_attention = torch.full(
                (batch_size, *first_elm.shape[:-1],
                 self.hparams.max_encoder_length),
                float("nan"),
                dtype=first_elm.dtype,
                device=first_elm.device,
            )
            # scatter into tensor
            for idx, x in enumerate(out["encoder_attention"]):
                encoder_length = out["encoder_lengths"][idx]
                encoder_attention[idx, :, :, self.hparams.max_encoder_length -
                                  encoder_length:] = x[..., :encoder_length]
        else:
            # roll encoder attention (so start last encoder value is on the right)
            encoder_attention = out["encoder_attention"]
            shifts = encoder_attention.size(3) - out["encoder_lengths"]
            new_index = (
                torch.arange(encoder_attention.size(3),
                             device=encoder_attention.device)
                [None, None, None].expand_as(encoder_attention) -
                shifts[:, None, None, None]) % encoder_attention.size(3)
            encoder_attention = torch.gather(encoder_attention,
                                             dim=3,
                                             index=new_index)
            # expand encoder_attentiont to full size
            if encoder_attention.size(-1) < self.hparams.max_encoder_length:
                encoder_attention = torch.concat(
                    [
                        torch.full(
                            (
                                *encoder_attention.shape[:-1],
                                self.hparams.max_encoder_length -
                                out["encoder_lengths"].max(),
                            ),
                            float("nan"),
                            dtype=encoder_attention.dtype,
                            device=encoder_attention.device,
                        ),
                        encoder_attention,
                    ],
                    dim=-1,
                )

        # combine attention vector
        attention = torch.concat([encoder_attention, decoder_attention],
                                 dim=-1)
        attention[attention < 1e-5] = float("nan")

        # histogram of decode and encode lengths
        encoder_length_histogram = integer_histogram(
            out["encoder_lengths"], min=0, max=self.hparams.max_encoder_length)
        decoder_length_histogram = integer_histogram(
            out["decoder_lengths"],
            min=1,
            max=out["decoder_variables"].size(1))

        # mask where decoder and encoder where not applied when averaging variable selection weights
        encoder_variables = out["encoder_variables"].squeeze(-2)
        encode_mask = create_mask(encoder_variables.size(1),
                                  out["encoder_lengths"])
        encoder_variables = encoder_variables.masked_fill(
            encode_mask.unsqueeze(-1), 0.0).sum(dim=1)
        encoder_variables /= (out["encoder_lengths"].where(
            out["encoder_lengths"] > 0,
            torch.ones_like(out["encoder_lengths"])).unsqueeze(-1))

        decoder_variables = out["decoder_variables"].squeeze(-2)
        decode_mask = create_mask(decoder_variables.size(1),
                                  out["decoder_lengths"])
        decoder_variables = decoder_variables.masked_fill(
            decode_mask.unsqueeze(-1), 0.0).sum(dim=1)
        decoder_variables /= out["decoder_lengths"].unsqueeze(-1)

        # static variables need no masking
        static_variables = out["static_variables"].squeeze(1)
        # attention is batch x time x heads x time_to_attend
        # average over heads + only keep prediction attention and attention on observed timesteps
        attention = masked_op(
            attention[:, attention_prediction_horizon, :, :self.hparams.
                      max_encoder_length + attention_prediction_horizon],
            op="mean",
            dim=1,
        )

        if reduction != "none":  # if to average over batches
            static_variables = static_variables.sum(dim=0)
            encoder_variables = encoder_variables.sum(dim=0)
            decoder_variables = decoder_variables.sum(dim=0)

            attention = masked_op(attention, dim=0, op=reduction)
        else:
            attention = attention / masked_op(
                attention, dim=1, op="sum").unsqueeze(-1)  # renormalize

        interpretation = dict(
            attention=attention.masked_fill(torch.isnan(attention), 0.0),
            static_variables=static_variables,
            encoder_variables=encoder_variables,
            decoder_variables=decoder_variables,
            encoder_length_histogram=encoder_length_histogram,
            decoder_length_histogram=decoder_length_histogram,
        )
        return interpretation