예제 #1
0
    def forward(self, y_pred: torch.Tensor,
                y_actual: torch.Tensor) -> torch.Tensor:
        """
        Calculate composite metric

        Args:
            y_pred: network output
            y_actual: actual values

        Returns:
            torch.Tensor: metric value on which backpropagation can be applied
        """
        y_pred_mean = y_pred.mean(0).unsqueeze(0)
        if isinstance(y_actual, rnn.PackedSequence):
            target, lengths = rnn.pad_packed_sequence(y_actual,
                                                      batch_first=True)
            # batch sizes reside on the CPU by default -> we need to bring them to GPU
            lengths = lengths.to(target.device)

            # calculate mean for all time steps
            tmask = torch.arange(
                target.size(1),
                device=target.device).unsqueeze(0) >= lengths.unsqueeze(-1)
            if target.ndim > 2:
                tmask = tmask.unsqueeze(-1)
                lengths = lengths.unsqueeze(-1)
            target = target.masked_fill(tmask, 0.0)
            y_mean = target.sum(0).unsqueeze(0) / lengths.sum()

            # calculate weight as length
            decoder_length_histogram = integer_histogram(lengths,
                                                         min=1,
                                                         max=target.size(1))
            weight = decoder_length_histogram.flip(0).cumsum(0).flip(
                0).float().unsqueeze(0)

            # modify weight
            if y_mean.ndim == 3:
                y_mean[..., 1] = y_mean[..., 1] * weight
            else:
                y_mean = torch.stack((y_mean, weight), dim=-1)

        else:
            y_mean = y_actual.mean(0).unsqueeze(0)
        out = self.metric(y_pred_mean, y_mean)
        return out
예제 #2
0
    def interpret_output(
        self,
        out: Dict[str, torch.Tensor],
        reduction: str = "none",
        attention_prediction_horizon: int = 0,
        attention_as_autocorrelation: bool = False,
    ) -> Dict[str, torch.Tensor]:
        """
        interpret output of model

        Args:
            out: output as produced by ``forward()``
            reduction: "none" for no averaging over batches, "sum" for summing attentions, "mean" for
                normalizing by encode lengths
            attention_prediction_horizon: which prediction horizon to use for attention
            attention_as_autocorrelation: if to record attention as autocorrelation - this should be set to true in
                case of ``reduction != "none"`` and differing prediction times of the samples. Defaults to False

        Returns:
            interpretations that can be plotted with ``plot_interpretation()``
        """

        # histogram of decode and encode lengths
        encoder_length_histogram = integer_histogram(
            out["encoder_lengths"], min=0, max=self.hparams.max_encoder_length)
        decoder_length_histogram = integer_histogram(
            out["decoder_lengths"],
            min=1,
            max=out["decoder_variables"].size(1))

        # mask where decoder and encoder where not applied when averaging variable selection weights
        encoder_variables = out["encoder_variables"].squeeze(-2)
        encode_mask = create_mask(encoder_variables.size(1),
                                  out["encoder_lengths"])
        encoder_variables = encoder_variables.masked_fill(
            encode_mask.unsqueeze(-1), 0.0).sum(dim=1)
        encoder_variables /= (out["encoder_lengths"].where(
            out["encoder_lengths"] > 0,
            torch.ones_like(out["encoder_lengths"])).unsqueeze(-1))

        decoder_variables = out["decoder_variables"].squeeze(-2)
        decode_mask = create_mask(decoder_variables.size(1),
                                  out["decoder_lengths"])
        decoder_variables = decoder_variables.masked_fill(
            decode_mask.unsqueeze(-1), 0.0).sum(dim=1)
        decoder_variables /= out["decoder_lengths"].unsqueeze(-1)

        # static variables need no masking
        static_variables = out["static_variables"].squeeze(1)
        # attention is batch x time x heads x time_to_attend
        # average over heads + only keep prediction attention and attention on observed timesteps
        attention = out["attention"][:, attention_prediction_horizon, :, :
                                     out["encoder_lengths"].max() +
                                     attention_prediction_horizon].mean(1)

        if reduction != "none":  # if to average over batches
            static_variables = static_variables.sum(dim=0)
            encoder_variables = encoder_variables.sum(dim=0)
            decoder_variables = decoder_variables.sum(dim=0)

            # reorder attention or averaging
            for i in range(
                    len(attention)):  # very inefficient but does the trick
                if 0 < out["encoder_lengths"][i] < attention.size(
                        1) - attention_prediction_horizon - 1:
                    relevant_attention = attention[
                        i, :out["encoder_lengths"][i] +
                        attention_prediction_horizon].clone()
                    if attention_as_autocorrelation:
                        relevant_attention = autocorrelation(
                            relevant_attention)
                    attention[
                        i, -out["encoder_lengths"][i] -
                        attention_prediction_horizon:] = relevant_attention
                    attention[i, :attention.size(1) -
                              out["encoder_lengths"][i] -
                              attention_prediction_horizon] = 0.0
                elif attention_as_autocorrelation:
                    attention[i] = autocorrelation(attention[i])

            attention = attention.sum(dim=0)
            if reduction == "mean":
                attention = attention / encoder_length_histogram[1:].flip(
                    0).cumsum(0).clamp(1)
                attention = attention / attention.sum(-1).unsqueeze(
                    -1)  # renormalize
            elif reduction == "sum":
                pass
            else:
                raise ValueError(f"Unknown reduction {reduction}")

            attention = torch.zeros(
                self.hparams.max_encoder_length + attention_prediction_horizon,
                device=self.device).scatter(
                    dim=0,
                    index=torch.arange(
                        self.hparams.max_encoder_length +
                        attention_prediction_horizon - attention.size(-1),
                        self.hparams.max_encoder_length +
                        attention_prediction_horizon,
                        device=self.device,
                    ),
                    src=attention,
                )
        else:
            attention = attention / attention.sum(-1).unsqueeze(
                -1)  # renormalize

        interpretation = dict(
            attention=attention,
            static_variables=static_variables,
            encoder_variables=encoder_variables,
            decoder_variables=decoder_variables,
            encoder_length_histogram=encoder_length_histogram,
            decoder_length_histogram=decoder_length_histogram,
        )
        return interpretation
예제 #3
0
    def interpret_output(
        self,
        out: Dict[str, torch.Tensor],
        reduction: str = "none",
        attention_prediction_horizon: int = 0,
    ) -> Dict[str, torch.Tensor]:
        """
        interpret output of model

        Args:
            out: output as produced by ``forward()``
            reduction: "none" for no averaging over batches, "sum" for summing attentions, "mean" for
                normalizing by encode lengths
            attention_prediction_horizon: which prediction horizon to use for attention

        Returns:
            interpretations that can be plotted with ``plot_interpretation()``
        """
        # take attention and concatenate if a list to proper attention object
        batch_size = len(out["decoder_attention"])
        if isinstance(out["decoder_attention"], (list, tuple)):
            # start with decoder attention
            # assume issue is in last dimension, we need to find max
            max_last_dimension = max(
                x.size(-1) for x in out["decoder_attention"])
            first_elm = out["decoder_attention"][0]
            # create new attention tensor into which we will scatter
            decoder_attention = torch.full(
                (batch_size, *first_elm.shape[:-1], max_last_dimension),
                float("nan"),
                dtype=first_elm.dtype,
                device=first_elm.device,
            )
            # scatter into tensor
            for idx, x in enumerate(out["decoder_attention"]):
                decoder_length = out["decoder_lengths"][idx]
                decoder_attention[idx, :, :, :decoder_length] = x[
                    ..., :decoder_length]
        else:
            decoder_attention = out["decoder_attention"]
            decoder_mask = create_mask(out["decoder_attention"].size(1),
                                       out["decoder_lengths"])
            decoder_attention[decoder_mask[..., None, None].expand_as(
                decoder_attention)] = float("nan")

        if isinstance(out["encoder_attention"], (list, tuple)):
            # same game for encoder attention
            # create new attention tensor into which we will scatter
            first_elm = out["encoder_attention"][0]
            encoder_attention = torch.full(
                (batch_size, *first_elm.shape[:-1],
                 self.hparams.max_encoder_length),
                float("nan"),
                dtype=first_elm.dtype,
                device=first_elm.device,
            )
            # scatter into tensor
            for idx, x in enumerate(out["encoder_attention"]):
                encoder_length = out["encoder_lengths"][idx]
                encoder_attention[idx, :, :, self.hparams.max_encoder_length -
                                  encoder_length:] = x[..., :encoder_length]
        else:
            # roll encoder attention (so start last encoder value is on the right)
            encoder_attention = out["encoder_attention"]
            shifts = encoder_attention.size(3) - out["encoder_lengths"]
            new_index = (
                torch.arange(encoder_attention.size(3),
                             device=encoder_attention.device)
                [None, None, None].expand_as(encoder_attention) -
                shifts[:, None, None, None]) % encoder_attention.size(3)
            encoder_attention = torch.gather(encoder_attention,
                                             dim=3,
                                             index=new_index)
            # expand encoder_attentiont to full size
            if encoder_attention.size(-1) < self.hparams.max_encoder_length:
                encoder_attention = torch.concat(
                    [
                        torch.full(
                            (
                                *encoder_attention.shape[:-1],
                                self.hparams.max_encoder_length -
                                out["encoder_lengths"].max(),
                            ),
                            float("nan"),
                            dtype=encoder_attention.dtype,
                            device=encoder_attention.device,
                        ),
                        encoder_attention,
                    ],
                    dim=-1,
                )

        # combine attention vector
        attention = torch.concat([encoder_attention, decoder_attention],
                                 dim=-1)
        attention[attention < 1e-5] = float("nan")

        # histogram of decode and encode lengths
        encoder_length_histogram = integer_histogram(
            out["encoder_lengths"], min=0, max=self.hparams.max_encoder_length)
        decoder_length_histogram = integer_histogram(
            out["decoder_lengths"],
            min=1,
            max=out["decoder_variables"].size(1))

        # mask where decoder and encoder where not applied when averaging variable selection weights
        encoder_variables = out["encoder_variables"].squeeze(-2)
        encode_mask = create_mask(encoder_variables.size(1),
                                  out["encoder_lengths"])
        encoder_variables = encoder_variables.masked_fill(
            encode_mask.unsqueeze(-1), 0.0).sum(dim=1)
        encoder_variables /= (out["encoder_lengths"].where(
            out["encoder_lengths"] > 0,
            torch.ones_like(out["encoder_lengths"])).unsqueeze(-1))

        decoder_variables = out["decoder_variables"].squeeze(-2)
        decode_mask = create_mask(decoder_variables.size(1),
                                  out["decoder_lengths"])
        decoder_variables = decoder_variables.masked_fill(
            decode_mask.unsqueeze(-1), 0.0).sum(dim=1)
        decoder_variables /= out["decoder_lengths"].unsqueeze(-1)

        # static variables need no masking
        static_variables = out["static_variables"].squeeze(1)
        # attention is batch x time x heads x time_to_attend
        # average over heads + only keep prediction attention and attention on observed timesteps
        attention = masked_op(
            attention[:, attention_prediction_horizon, :, :self.hparams.
                      max_encoder_length + attention_prediction_horizon],
            op="mean",
            dim=1,
        )

        if reduction != "none":  # if to average over batches
            static_variables = static_variables.sum(dim=0)
            encoder_variables = encoder_variables.sum(dim=0)
            decoder_variables = decoder_variables.sum(dim=0)

            attention = masked_op(attention, dim=0, op=reduction)
        else:
            attention = attention / masked_op(
                attention, dim=1, op="sum").unsqueeze(-1)  # renormalize

        interpretation = dict(
            attention=attention.masked_fill(torch.isnan(attention), 0.0),
            static_variables=static_variables,
            encoder_variables=encoder_variables,
            decoder_variables=decoder_variables,
            encoder_length_histogram=encoder_length_histogram,
            decoder_length_histogram=decoder_length_histogram,
        )
        return interpretation