def forward(self, output, target): """for N examples and C classes - output: N x C these are raw outputs (without softmax/sigmoid) - target: N x C or N corresponding targets Target elements set to ignore_index contribute 0 loss. Samples where all entries are ignore_index do not contribute to the loss reduction. """ # check if targets are inputted as class integers if target.ndim == 1: assert ( output.shape[0] == target.shape[0] ), "SoftTargetCrossEntropyLoss requires output and target to have same batch size" target = convert_to_one_hot(target.view(-1, 1), output.shape[1]) assert output.shape == target.shape, ( "SoftTargetCrossEntropyLoss requires output and target to be same " f"shape: {output.shape} != {target.shape}") valid_mask = target != self._ignore_index valid_targets = target.float() * valid_mask.float() if self._normalize_targets: valid_targets /= self._eps + valid_targets.sum(dim=1, keepdim=True) per_sample_per_target_loss = -valid_targets * F.log_softmax(output, -1) per_sample_loss = torch.sum(per_sample_per_target_loss, -1) # perform reduction if self._reduction == "mean": # normalize based on the number of samples with > 0 non-ignored targets loss = per_sample_loss.sum() / torch.sum( (torch.sum(valid_mask, -1) > 0)).clamp(min=1) elif self._reduction == "none": loss = per_sample_loss return loss
def update(self, model_output, target, **kwargs): """ args: model_output: tensor of shape (B, C) where each value is either logit or class probability. target: tensor of shape (B, C), one-hot encoded or integer encoded or tensor of shape (B), integer encoded. Note: For binary classification, C=2. For integer encoded target, C=1. """ target_shape_list = list(target.size()) if self._target_is_one_hot is False: assert len(target_shape_list) == 1 or ( len(target_shape_list) == 2 and target_shape_list[1] == 1 ), "Integer encoded target must be single labeled" target = convert_to_one_hot(target.view(-1, 1), self._num_classes) assert ( torch.min(target.eq(0) + target.eq(1)) == 1 ), "Target must be one-hot encoded vector" # Due to dummy samples, in some corner cases, the whole batch could # be dummy samples, in that case we want to not update meters on that # process if model_output.shape[0] == 0: return _, pred_classes = model_output.topk( max(self._topk), dim=1, largest=True, sorted=True ) pred_mask_tensor = torch.zeros(target.size()) for i, k in enumerate(self._topk): pred_mask_tensor.zero_() self._curr_correct_predictions_k[i] += torch.sum( # torch.min is used to simulate AND between binary # tensors. If tensors are not binary, this will fail. torch.min( pred_mask_tensor.scatter_(1, pred_classes[:, :k], 1.0), target.float(), ) ).item() self._curr_correct_targets += target.sum().item()
def compute_valid_targets(self, target, classes): """ This function takes one-hot or index target vectors and computes valid one-hot target vectors, based on ignore index value """ target_shape_list = list(target.size()) valid_mask = target != self._ignore_index valid_targets = target.float() * valid_mask.float() # check if targets are inputted as class integers if len(target_shape_list) == 1 or (len(target_shape_list) == 2 and target_shape_list[1] == 1): valid_targets = convert_to_one_hot(valid_targets.view(-1, 1), classes) valid_targets = valid_targets.float() * valid_mask.view(-1, 1).float() return valid_targets
def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: """ Args: sample: the batch data. """ if sample["target"].ndim == 1: assert self.num_classes is not None, "num_classes is expected for 1D target" sample["target"] = convert_to_one_hot(sample["target"].view(-1, 1), self.num_classes) else: assert sample[ "target"].ndim == 2, "target tensor shape must be 1D or 2D" c = Beta(self.alpha, self.alpha).sample().to(device=sample["target"].device) permuted_indices = torch.randperm(sample["target"].shape[0]) for key in ["input", "target"]: sample[key] = c * sample[key] + ( 1.0 - c) * sample[key][permuted_indices, :] return sample
def test_two(self): targets = torch.tensor([[0], [1]]) one_hot_target = convert_to_one_hot(targets, 3) self.assertTrue( torch.allclose(one_hot_target, torch.tensor([[1, 0, 0], [0, 1, 0]])))
def test_single(self): targets = torch.tensor([[4]]) one_hot_target = convert_to_one_hot(targets, 5) self.assertTrue( torch.allclose(one_hot_target, torch.tensor([[0, 0, 0, 0, 1]])))