def forward(self, output, target):
        """for N examples and C classes
        - output: N x C these are raw outputs (without softmax/sigmoid)
        - target: N x C or N corresponding targets

        Target elements set to ignore_index contribute 0 loss.

        Samples where all entries are ignore_index do not contribute to the loss
        reduction.
        """
        # check if targets are inputted as class integers
        if target.ndim == 1:
            assert (
                output.shape[0] == target.shape[0]
            ), "SoftTargetCrossEntropyLoss requires output and target to have same batch size"
            target = convert_to_one_hot(target.view(-1, 1), output.shape[1])
        assert output.shape == target.shape, (
            "SoftTargetCrossEntropyLoss requires output and target to be same "
            f"shape: {output.shape} != {target.shape}")
        valid_mask = target != self._ignore_index
        valid_targets = target.float() * valid_mask.float()
        if self._normalize_targets:
            valid_targets /= self._eps + valid_targets.sum(dim=1, keepdim=True)
        per_sample_per_target_loss = -valid_targets * F.log_softmax(output, -1)
        per_sample_loss = torch.sum(per_sample_per_target_loss, -1)
        # perform reduction
        if self._reduction == "mean":
            # normalize based on the number of samples with > 0 non-ignored targets
            loss = per_sample_loss.sum() / torch.sum(
                (torch.sum(valid_mask, -1) > 0)).clamp(min=1)
        elif self._reduction == "none":
            loss = per_sample_loss

        return loss
Exemplo n.º 2
0
    def update(self, model_output, target, **kwargs):
        """
        args:
            model_output: tensor of shape (B, C) where each value is
                          either logit or class probability.
            target:       tensor of shape (B, C), one-hot encoded
                          or integer encoded or tensor of shape (B),
                          integer encoded.

        Note:

            For binary classification, C=2. For integer encoded target, C=1.
        """

        target_shape_list = list(target.size())

        if self._target_is_one_hot is False:
            assert len(target_shape_list) == 1 or (
                len(target_shape_list) == 2 and target_shape_list[1] == 1
            ), "Integer encoded target must be single labeled"
            target = convert_to_one_hot(target.view(-1, 1), self._num_classes)

        assert (
            torch.min(target.eq(0) + target.eq(1)) == 1
        ), "Target must be one-hot encoded vector"
        # Due to dummy samples, in some corner cases, the whole batch could
        # be dummy samples, in that case we want to not update meters on that
        # process
        if model_output.shape[0] == 0:
            return
        _, pred_classes = model_output.topk(
            max(self._topk), dim=1, largest=True, sorted=True
        )
        pred_mask_tensor = torch.zeros(target.size())
        for i, k in enumerate(self._topk):
            pred_mask_tensor.zero_()
            self._curr_correct_predictions_k[i] += torch.sum(
                # torch.min is used to simulate AND between binary
                # tensors. If tensors are not binary, this will fail.
                torch.min(
                    pred_mask_tensor.scatter_(1, pred_classes[:, :k], 1.0),
                    target.float(),
                )
            ).item()
        self._curr_correct_targets += target.sum().item()
    def compute_valid_targets(self, target, classes):
        """
        This function takes one-hot or index target vectors and computes valid one-hot
        target vectors, based on ignore index value
        """
        target_shape_list = list(target.size())

        valid_mask = target != self._ignore_index
        valid_targets = target.float() * valid_mask.float()

        # check if targets are inputted as class integers
        if len(target_shape_list) == 1 or (len(target_shape_list) == 2
                                           and target_shape_list[1] == 1):

            valid_targets = convert_to_one_hot(valid_targets.view(-1, 1),
                                               classes)
            valid_targets = valid_targets.float() * valid_mask.view(-1,
                                                                    1).float()

        return valid_targets
Exemplo n.º 4
0
    def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]:
        """
        Args:
            sample: the batch data.
        """
        if sample["target"].ndim == 1:
            assert self.num_classes is not None, "num_classes is expected for 1D target"
            sample["target"] = convert_to_one_hot(sample["target"].view(-1, 1),
                                                  self.num_classes)
        else:
            assert sample[
                "target"].ndim == 2, "target tensor shape must be 1D or 2D"

        c = Beta(self.alpha,
                 self.alpha).sample().to(device=sample["target"].device)
        permuted_indices = torch.randperm(sample["target"].shape[0])
        for key in ["input", "target"]:
            sample[key] = c * sample[key] + (
                1.0 - c) * sample[key][permuted_indices, :]

        return sample
 def test_two(self):
     targets = torch.tensor([[0], [1]])
     one_hot_target = convert_to_one_hot(targets, 3)
     self.assertTrue(
         torch.allclose(one_hot_target, torch.tensor([[1, 0, 0], [0, 1,
                                                                  0]])))
 def test_single(self):
     targets = torch.tensor([[4]])
     one_hot_target = convert_to_one_hot(targets, 5)
     self.assertTrue(
         torch.allclose(one_hot_target, torch.tensor([[0, 0, 0, 0, 1]])))