def update(
        self,
        y_pred,
        target,
        encoder_target,
        encoder_lengths=None,
    ) -> torch.Tensor:
        """
        Update metric that handles masking of values.

        Args:
            y_pred (Dict[str, torch.Tensor]): network output
            target (Tuple[Union[torch.Tensor, rnn.PackedSequence], torch.Tensor]): tuple of actual values and weights
            encoder_target (Union[torch.Tensor, rnn.PackedSequence]): historic actual values
            encoder_lengths (torch.Tensor): optional encoder lengths, not necessary if encoder_target
                is rnn.PackedSequence. Assumed encoder_target is torch.Tensor

        Returns:
            torch.Tensor: loss as a single number for backpropagation
        """
        # unpack weight
        if isinstance(target, (list, tuple)):
            weight = target[1]
            target = target[0]
        else:
            weight = None

        # unpack target
        if isinstance(target, rnn.PackedSequence):
            target, lengths = unpack_sequence(target)
        else:
            lengths = torch.full((target.size(0), ),
                                 fill_value=target.size(1),
                                 dtype=torch.long,
                                 device=target.device)

        # determine lengths for encoder
        if encoder_lengths is None:
            encoder_target, encoder_lengths = unpack_sequence(target)
        else:
            assert isinstance(encoder_target, torch.Tensor)
        assert not target.requires_grad

        # calculate loss with "none" reduction
        scaling = self.calculate_scaling(target, lengths, encoder_target,
                                         encoder_lengths)
        losses = self.loss(y_pred, target, scaling)

        # weight samples
        if weight is not None:
            losses = losses * weight.unsqueeze(-1)

        self._update_losses_and_lengths(losses, lengths)
示例#2
0
    def update(self, y_pred: Dict[str, torch.Tensor],
               target: Union[torch.Tensor, rnn.PackedSequence]):
        """
        Update method of metric that handles masking of values.

        Do not override this method but :py:meth:`~loss` instead

        Args:
            y_pred (Dict[str, torch.Tensor]): network output
            target (Union[torch.Tensor, rnn.PackedSequence]): actual values

        Returns:
            torch.Tensor: loss as a single number for backpropagation
        """
        target, lengths = unpack_sequence(target)
        assert not target.requires_grad

        # calculate loss with "none" reduction
        if target.ndim == 3:
            weight = target[..., 1]
            target = target[..., 0]
        else:
            weight = None

        losses = self.loss(y_pred, target)
        # weight samples
        if weight is not None:
            losses = losses * weight.unsqueeze(-1)
        self._update_losses_and_lengths(losses, lengths)
示例#3
0
    def update(self, y_pred, target):
        """
        Update method of metric that handles masking of values.

        Do not override this method but :py:meth:`~loss` instead

        Args:
            y_pred (Dict[str, torch.Tensor]): network output
            target (Union[torch.Tensor, rnn.PackedSequence]): actual values

        Returns:
            torch.Tensor: loss as a single number for backpropagation
        """
        # unpack weight
        if isinstance(
                target,
            (list, tuple)) and not isinstance(target, rnn.PackedSequence):
            target, weight = target
        else:
            weight = None

        # unpack target
        if isinstance(target, rnn.PackedSequence):
            target, lengths = unpack_sequence(target)
        else:
            lengths = torch.full((target.size(0), ),
                                 fill_value=target.size(1),
                                 dtype=torch.long,
                                 device=target.device)

        losses = self.loss(y_pred, target)
        # weight samples
        if weight is not None:
            losses = losses * unsqueeze_like(weight, losses)
        self._update_losses_and_lengths(losses, lengths)
示例#4
0
    def _convert(self, y_pred: torch.Tensor,
                 target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # unpack target into target and weights
        if isinstance(
                target,
            (list, tuple)) and not isinstance(target, rnn.PackedSequence):
            target, weight = target
            if weight is not None:
                raise NotImplementedError(
                    "Weighting is not supported for pure torchmetrics - "
                    "implement a custom version or use pytorch-forecasting metrics"
                )

        # convert to point prediction - limits applications of class
        y_pred = self.to_prediction(y_pred)

        # unpack target if it is PackedSequence
        if isinstance(target, rnn.PackedSequence):
            target, lengths = unpack_sequence(target)
            # create mask for different lengths
            length_mask = create_mask(target.size(1), lengths, inverse=True)
            target = target.masked_select(length_mask)
            y_pred = y_pred.masked_select(length_mask)

        y_pred = y_pred.flatten()
        target = target.flatten()
        return y_pred, target
示例#5
0
    def update(
        self,
        y_pred: Dict[str, torch.Tensor],
        target: Union[torch.Tensor, rnn.PackedSequence],
        encoder_target: Union[torch.Tensor, rnn.PackedSequence],
        encoder_lengths: torch.Tensor = None,
    ) -> torch.Tensor:
        """
        Update metric that handles masking of values.

        Args:
            y_pred (Dict[str, torch.Tensor]): network output
            target (Union[torch.Tensor, rnn.PackedSequence]): actual values
            encoder_target (Union[torch.Tensor, rnn.PackedSequence]): historic actual values
            encoder_lengths (torch.Tensor): optional encoder lengths, not necessary if encoder_target
                is rnn.PackedSequence. Assumed encoder_target is torch.Tensor

        Returns:
            torch.Tensor: loss as a single number for backpropagation
        """
        target, lengths = unpack_sequence(target)
        if encoder_lengths is None:
            encoder_target, encoder_lengths = unpack_sequence(target)
        else:
            assert isinstance(encoder_target, torch.Tensor)
        assert not target.requires_grad

        # calculate loss with "none" reduction
        if target.ndim == 3:
            weight = target[..., 1]
            target = target[..., 0]
        else:
            weight = None

        scaling = self.calculate_scaling(target, lengths, encoder_target,
                                         encoder_lengths)
        losses = self.loss(y_pred, target, scaling)
        # weight samples
        if weight is not None:
            losses = losses * weight.unsqueeze(-1)

        self._update_losses_and_lengths(losses, lengths)