Ejemplo n.º 1
0
    def update(self, y_pred: torch.Tensor,
               y_actual: torch.Tensor) -> torch.Tensor:
        """
        Calculate composite metric

        Args:
            y_pred: network output
            y_actual: actual values

        Returns:
            torch.Tensor: metric value on which backpropagation can be applied
        """
        # extract target and weight
        if isinstance(
                y_actual,
            (tuple, list)) and not isinstance(y_actual, rnn.PackedSequence):
            target, weight = y_actual
        else:
            target = y_actual
            weight = None

        # handle rnn sequence as target
        if isinstance(target, rnn.PackedSequence):
            target, lengths = rnn.pad_packed_sequence(target, batch_first=True)
            # batch sizes reside on the CPU by default -> we need to bring them to GPU
            lengths = lengths.to(target.device)

            # calculate mask for time steps
            length_mask = create_mask(target.size(1), lengths, inverse=True)

            # modify weight
            if weight is None:
                weight = length_mask
            else:
                weight = weight * length_mask

        if weight is None:
            y_mean = target.mean(0)
            y_pred_mean = y_pred.mean(0)
        else:

            # calculate weighted sums
            y_mean = (target *
                      unsqueeze_like(weight, y_pred)).sum(0) / weight.sum(0)

            y_pred_sum = (y_pred * unsqueeze_like(weight, y_pred)).sum(0)
            y_pred_mean = y_pred_sum / unsqueeze_like(weight.sum(0),
                                                      y_pred_sum)

        # update metric. unsqueeze first batch dimension (as batches are collapsed)
        self.metric.update(y_pred_mean.unsqueeze(0), y_mean.unsqueeze(0))
Ejemplo n.º 2
0
    def update(self, y_pred, target):
        """
        Update method of metric that handles masking of values.

        Do not override this method but :py:meth:`~loss` instead

        Args:
            y_pred (Dict[str, torch.Tensor]): network output
            target (Union[torch.Tensor, rnn.PackedSequence]): actual values

        Returns:
            torch.Tensor: loss as a single number for backpropagation
        """
        # unpack weight
        if isinstance(
                target,
            (list, tuple)) and not isinstance(target, rnn.PackedSequence):
            target, weight = target
        else:
            weight = None

        # unpack target
        if isinstance(target, rnn.PackedSequence):
            target, lengths = unpack_sequence(target)
        else:
            lengths = torch.full((target.size(0), ),
                                 fill_value=target.size(1),
                                 dtype=torch.long,
                                 device=target.device)

        losses = self.loss(y_pred, target)
        # weight samples
        if weight is not None:
            losses = losses * unsqueeze_like(weight, losses)
        self._update_losses_and_lengths(losses, lengths)
Ejemplo n.º 3
0
    def _calculate_mean(
            y_pred: torch.Tensor,
            y_actual: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # extract target and weight
        if isinstance(
                y_actual,
            (tuple, list)) and not isinstance(y_actual, rnn.PackedSequence):
            target, weight = y_actual
        else:
            target = y_actual
            weight = None

        # handle rnn sequence as target
        if isinstance(target, rnn.PackedSequence):
            target, lengths = rnn.pad_packed_sequence(target, batch_first=True)
            # batch sizes reside on the CPU by default -> we need to bring them to GPU
            lengths = lengths.to(target.device)

            # calculate mask for time steps
            length_mask = create_mask(target.size(1), lengths, inverse=True)

            # modify weight
            if weight is None:
                weight = length_mask
            else:
                weight = weight * length_mask

        if weight is None:
            y_mean = target.mean(0)
            y_pred_mean = y_pred.mean(0)
        else:

            # calculate weighted sums
            y_mean = (target *
                      unsqueeze_like(weight, y_pred)).sum(0) / weight.sum(0)

            y_pred_sum = (y_pred * unsqueeze_like(weight, y_pred)).sum(0)
            y_pred_mean = y_pred_sum / unsqueeze_like(weight.sum(0),
                                                      y_pred_sum)
        return y_pred_mean.unsqueeze(0), y_mean.unsqueeze(0)