def __call__( self, best_span_strings: Union[str, List[str]], answer_strings: Union[List[str], List[List[str]]], ): if not isinstance(best_span_strings, list): best_span_strings = [best_span_strings] answer_strings = [answer_strings] # type: ignore cast(List[str], best_span_strings) cast(List[List[str]], answer_strings) assert len(best_span_strings) == len(answer_strings) count = len(best_span_strings) exact_match = 0 f1_score = 0.0 for prediction, gold_answers in zip(best_span_strings, answer_strings): exact_match += squad.metric_max_over_ground_truths( squad.compute_exact, prediction, gold_answers) f1_score += squad.metric_max_over_ground_truths( squad.compute_f1, prediction, gold_answers) # Converting to int here, since we want to count the number of exact matches. self._total_em += dist_reduce_sum(int(exact_match)) self._total_f1 += dist_reduce_sum(f1_score) self._count += dist_reduce_sum(count)
def __call__( self, predictions: torch.Tensor, gold_labels: torch.Tensor, mask: Optional[torch.BoolTensor] = None, ): """ # Parameters predictions : `torch.Tensor`, required. A tensor of predictions of shape (batch_size, ...). gold_labels : `torch.Tensor`, required. A tensor of the same shape as `predictions`. mask : `torch.BoolTensor`, optional (default = `None`). A tensor of the same shape as `predictions`. """ predictions, gold_labels, mask = self.detach_tensors(predictions, gold_labels, mask) # Some sanity checks. if gold_labels.size() != predictions.size(): raise ValueError( f"gold_labels must have shape == predictions.size() but " f"found tensor of shape: {gold_labels.size()}" ) if mask is not None and mask.size() != predictions.size(): raise ValueError( f"mask must have shape == predictions.size() but " f"found tensor of shape: {mask.size()}" ) batch_size = predictions.size(0) if mask is not None: # We can multiply by the mask up front, because we're just checking equality below, and # this way everything that's masked will be equal. predictions = predictions * mask gold_labels = gold_labels * mask # We want to skip predictions that are completely masked; # so we'll keep predictions that aren't. keep = mask.view(batch_size, -1).max(dim=1)[0] else: keep = torch.ones(batch_size, device=predictions.device).bool() predictions = predictions.view(batch_size, -1) gold_labels = gold_labels.view(batch_size, -1) # At this point, predictions is (batch_size, rest_of_dims_combined), # so .eq -> .prod will be 1 if every element of the instance prediction is correct # and 0 if at least one element of the instance prediction is wrong. # Because of how we're handling masking, masked positions are automatically "correct". correct = predictions.eq(gold_labels).prod(dim=1).float() # Since masked positions are correct, we need to explicitly exclude instance predictions # where the entire prediction is masked (because they look "correct"). _correct_count = (correct * keep).sum() _total_count = keep.sum() self._correct_count += dist_reduce_sum(_correct_count).item() self._total_count += dist_reduce_sum(_total_count).item()
def __call__( self, # type: ignore logits: torch.Tensor, mask: Optional[torch.BoolTensor] = None, ): """ # Parameters logits : `torch.Tensor`, required. A tensor of unnormalized log probabilities of shape (batch_size, ..., num_classes). mask : `torch.BoolTensor`, optional (default = `None`). A masking tensor of shape (batch_size, ...). """ logits, mask = self.detach_tensors(logits, mask) if mask is None: mask = torch.ones(logits.size()[:-1], device=logits.device).bool() log_probs = torch.nn.functional.log_softmax(logits, dim=-1) probabilities = torch.exp(log_probs) * mask.unsqueeze(-1) weighted_negative_likelihood = -log_probs * probabilities entropy = weighted_negative_likelihood.sum(-1) _entropy = entropy.sum() / mask.sum() self._entropy += dist_reduce_sum(_entropy).item() self._count += dist_reduce_sum(1)
def __call__( self, predictions: torch.Tensor, gold_labels: torch.Tensor, mask: Optional[torch.BoolTensor] = None, ) -> None: """ # Parameters predictions : `torch.Tensor`, required. A tensor of predictions of shape (batch_size, ...). gold_labels : `torch.Tensor`, required. A tensor of the same shape as `predictions`. mask : `torch.BoolTensor`, optional (default = `None`). A tensor of the same shape as `predictions`. """ predictions, gold_labels, mask = self.detach_tensors( predictions, gold_labels, mask) absolute_errors = torch.abs(predictions - gold_labels) if mask is not None: absolute_errors *= mask _total_count = torch.sum(mask) else: _total_count = gold_labels.numel() _absolute_error = torch.sum(absolute_errors) self._absolute_error += float(dist_reduce_sum(_absolute_error)) self._total_count += int(dist_reduce_sum(_total_count))
def __call__( self, predictions: torch.Tensor, gold_labels: torch.Tensor, mask: Optional[torch.BoolTensor] = None, end_index: int = sys.maxsize, ): """ # Parameters predictions : `torch.Tensor`, required. A tensor of predictions of shape (batch_size, k, sequence_length). gold_labels : `torch.Tensor`, required. A tensor of integer class label of shape (batch_size, sequence_length). mask : `torch.BoolTensor`, optional (default = `None`). A masking tensor the same size as `gold_labels`. """ predictions, gold_labels, mask = self.detach_tensors( predictions, gold_labels, mask) # Some sanity checks. if gold_labels.dim() != predictions.dim() - 1: raise ConfigurationError( "gold_labels must have dimension == predictions.dim() - 1 but " "found tensor of shape: {}".format(gold_labels.size())) if mask is not None and mask.size() != gold_labels.size(): raise ConfigurationError( "mask must have the same size as predictions but " "found tensor of shape: {}".format(mask.size())) batch_size = predictions.size()[0] correct = 0.0 for i in range(batch_size): beams = predictions[i] cur_gold = gold_labels[i] if mask is not None: masked_gold = cur_gold * mask[i] else: masked_gold = cur_gold cleaned_gold = [x for x in masked_gold if x not in (0, end_index)] retval = 0.0 for word in cleaned_gold: stillsearch = True for beam in beams: # word is from cleaned gold which doesn't have 0 or # end_index, so we don't need to explicitly remove those # from beam. if stillsearch and word in beam: retval += 1 / len(cleaned_gold) stillsearch = False correct += retval _correct_count = correct _total_count = predictions.size()[0] self.correct_count += dist_reduce_sum(_correct_count) self.total_count += dist_reduce_sum(_total_count)
def __call__(self, value): """ # Parameters value : `float` The value to average. """ self._count += dist_reduce_sum(1) self._total_value += dist_reduce_sum( float(list(self.detach_tensors(value))[0]))
def update(self, predicted, gold, mention_to_predicted, mention_to_gold): if self.metric == self.ceafe: p_num, p_den, r_num, r_den = self.metric(predicted, gold) else: p_num, p_den = self.metric(predicted, mention_to_gold) r_num, r_den = self.metric(gold, mention_to_predicted) self.precision_numerator += dist_reduce_sum(p_num) self.precision_denominator += dist_reduce_sum(p_den) self.recall_numerator += dist_reduce_sum(r_num) self.recall_denominator += dist_reduce_sum(r_den)
def __call__( self, predictions: torch.Tensor, gold_labels: torch.Tensor, mask: Optional[torch.BoolTensor] = None, ): """ # Parameters predictions : `torch.Tensor`, required. A tensor of predictions of shape (batch_size, k, sequence_length). gold_labels : `torch.Tensor`, required. A tensor of integer class label of shape (batch_size, sequence_length). mask : `torch.BoolTensor`, optional (default = `None`). A masking tensor the same size as `gold_labels`. """ predictions, gold_labels, mask = self.detach_tensors( predictions, gold_labels, mask) # Some sanity checks. if gold_labels.dim() != predictions.dim() - 1: raise ConfigurationError( "gold_labels must have dimension == predictions.dim() - 1 but " "found tensor of shape: {}".format(gold_labels.size())) if mask is not None and mask.size() != gold_labels.size(): raise ConfigurationError( "mask must have the same size as predictions but " "found tensor of shape: {}".format(mask.size())) k = predictions.size()[1] expanded_size = list(gold_labels.size()) expanded_size.insert(1, k) expanded_gold = gold_labels.unsqueeze(1).expand(expanded_size) if mask is not None: expanded_mask = mask.unsqueeze(1).expand(expanded_size) masked_gold = expanded_mask * expanded_gold masked_predictions = expanded_mask * predictions else: masked_gold = expanded_gold masked_predictions = predictions eqs = masked_gold.eq(masked_predictions) matches_per_question = eqs.min(dim=2)[0] some_match = matches_per_question.max(dim=1)[0] correct = some_match.sum().item() _total_count = predictions.size()[0] _correct_count = correct self.correct_count += dist_reduce_sum(_correct_count) self.total_count += dist_reduce_sum(_total_count)
def _get_rouge_l_score( self, predicted_tokens: torch.LongTensor, reference_tokens: torch.LongTensor ) -> float: """ Compute sum of F1 scores given batch of predictions and references. """ total_f1 = 0.0 for predicted_seq, reference_seq in zip(predicted_tokens, reference_tokens): from allennlp.training.util import get_valid_tokens_mask m = get_valid_tokens_mask(reference_seq, self._exclude_indices).sum().item() n = get_valid_tokens_mask(predicted_seq, self._exclude_indices).sum().item() lcs = self._longest_common_subsequence(reference_seq, predicted_seq) # This also rules out the case that m or n are 0, so we don't worry about it later if lcs == 0: continue recall_lcs = lcs / m precision_lcs = lcs / n f1 = 2 * recall_lcs * precision_lcs / (recall_lcs + precision_lcs) total_f1 += f1 return dist_reduce_sum(total_f1)
def __call__( self, # type: ignore predictions: torch.LongTensor, gold_targets: torch.LongTensor, ) -> None: """ Update recall counts. # Parameters predictions : `torch.LongTensor` Batched predicted tokens of shape `(batch_size, max_sequence_length)`. references : `torch.LongTensor` Batched reference (gold) sequences with shape `(batch_size, max_gold_sequence_length)`. # Returns None """ # ROUGE-N predictions, gold_targets = self.detach_tensors(predictions, gold_targets) for n in range(1, self._ngram_size + 1): recall, precision, f1 = self._get_rouge_n_stats(predictions, gold_targets, n) self._total_rouge_n_recalls[n] += recall self._total_rouge_n_precisions[n] += precision self._total_rouge_n_f1s[n] += f1 # ROUGE-L self._total_rouge_l_f1 += self._get_rouge_l_score(predictions, gold_targets) sequence_count = len(predictions) self._total_sequence_count += dist_reduce_sum(sequence_count)
def __call__( self, # type: ignore predictions: torch.LongTensor, gold_targets: torch.LongTensor, ) -> None: """ Update precision counts. # Parameters predictions : `torch.LongTensor`, required Batched predicted tokens of shape `(batch_size, max_sequence_length)`. references : `torch.LongTensor`, required Batched reference (gold) translations with shape `(batch_size, max_gold_sequence_length)`. # Returns None """ predictions, gold_targets = self.detach_tensors(predictions, gold_targets) if is_distributed(): world_size = dist.get_world_size() else: world_size = 1 for ngram_size, _ in enumerate(self._ngram_weights, start=1): precision_matches, precision_totals = self._get_modified_precision_counts( predictions, gold_targets, ngram_size ) self._precision_matches[ngram_size] += dist_reduce_sum(precision_matches) / world_size self._precision_totals[ngram_size] += dist_reduce_sum(precision_totals) / world_size if not self._exclude_indices: _prediction_lengths = predictions.size(0) * predictions.size(1) _reference_lengths = gold_targets.size(0) * gold_targets.size(1) else: from allennlp.training.util import get_valid_tokens_mask valid_predictions_mask = get_valid_tokens_mask(predictions, self._exclude_indices) valid_gold_targets_mask = get_valid_tokens_mask(gold_targets, self._exclude_indices) _prediction_lengths = valid_predictions_mask.sum().item() _reference_lengths = valid_gold_targets_mask.sum().item() self._prediction_lengths += dist_reduce_sum(_prediction_lengths) self._reference_lengths += dist_reduce_sum(_reference_lengths)
def _get_rouge_n_stats( self, predicted_tokens: torch.LongTensor, reference_tokens: torch.LongTensor, ngram_size: int, ) -> Tuple[float, float, float]: """ Compare the predicted tokens to the reference (gold) tokens at the desired ngram size and compute recall, precision and f1 sums """ total_recall = 0.0 total_precision = 0.0 total_f1 = 0.0 for predicted_seq, reference_seq in zip(predicted_tokens, reference_tokens): from allennlp.training.util import ngrams predicted_ngram_counts = ngrams(predicted_seq, ngram_size, self._exclude_indices) reference_ngram_counts = ngrams(reference_seq, ngram_size, self._exclude_indices) matches = 0 total_reference_ngrams = 0 for ngram, count in reference_ngram_counts.items(): matches += min(predicted_ngram_counts[ngram], count) total_reference_ngrams += count total_predicted_ngrams = sum(predicted_ngram_counts.values()) if total_reference_ngrams == 0 or total_predicted_ngrams == 0 or matches == 0: continue recall = matches / total_reference_ngrams precision = matches / total_predicted_ngrams f1 = 2.0 * recall * precision / (recall + precision) # Accumulate stats total_recall += recall total_precision += precision total_f1 += f1 total_recall = dist_reduce_sum(total_recall) total_precision = dist_reduce_sum(total_precision) total_f1 = dist_reduce_sum(total_f1) return total_recall, total_precision, total_f1
def __call__(self, nli_probabilities: torch.Tensor) -> None: """ # Parameters !!! Note In the examples below, we treat gender identity as binary, which does not accurately characterize gender in real life. nli_probabilities : `torch.Tensor`, required. A tensor of size (batch_size, ..., 3) containing natural language inference (i.e. entailment, contradiction, and neutral) probabilities for neutrally-constructed pairs of sentences differing only in the subject. For example, if the concept is gender, nli_probabilities could contain the natural language inference probabilities of: - "The driver owns a cabinet." -> "The man owns a cabinet." - "The driver owns a cabinet." -> "The woman owns a cabinet." - "The doctor eats an apple." -> "The man eats an apple." - "The doctor eats an apple." -> "The woman eats an apple." """ nli_probabilities = nli_probabilities.detach() # Some sanity checks if nli_probabilities.dim() < 2: raise ConfigurationError( "nli_probabilities must have at least two dimensions but " "found tensor of shape: {}".format(nli_probabilities.size())) if nli_probabilities.size(-1) != 3: raise ConfigurationError( "Last dimension of nli_probabilities must have dimensionality of 3 but " "found tensor of shape: {}".format(nli_probabilities.size())) _nli_neutral_probs = nli_probabilities[..., self.neutral_label] self._nli_probs_sum += dist_reduce_sum(_nli_neutral_probs.sum().item()) self._num_neutral_predictions += dist_reduce_sum( (nli_probabilities.argmax( dim=-1) == self.neutral_label).float().sum().item()) for tau in self.taus: self._num_neutral_above_taus[tau] += dist_reduce_sum( (_nli_neutral_probs > tau).float().sum().item()) self._total_predictions += dist_reduce_sum(_nli_neutral_probs.numel())
def __call__( self, # type: ignore batched_top_spans: torch.Tensor, batched_metadata: List[Dict[str, Any]], ): num_gold_mentions = 0 num_recalled_mentions = 0 for top_spans, metadata in zip(batched_top_spans.tolist(), batched_metadata): gold_mentions: Set[Tuple[int, int]] = { mention for cluster in metadata["clusters"] for mention in cluster } predicted_spans: Set[Tuple[int, int]] = {(span[0], span[1]) for span in top_spans} num_gold_mentions += len(gold_mentions) num_recalled_mentions += len(gold_mentions & predicted_spans) self._num_gold_mentions += dist_reduce_sum(num_gold_mentions) self._num_recalled_mentions += dist_reduce_sum(num_recalled_mentions)
def __call__( # type: ignore self, predicted_indices: torch.Tensor, predicted_labels: torch.Tensor, gold_indices: torch.Tensor, gold_labels: torch.Tensor, mask: Optional[torch.BoolTensor] = None, ): """ # Parameters predicted_indices : `torch.Tensor`, required. A tensor of head index predictions of shape (batch_size, timesteps). predicted_labels : `torch.Tensor`, required. A tensor of arc label predictions of shape (batch_size, timesteps). gold_indices : `torch.Tensor`, required. A tensor of the same shape as `predicted_indices`. gold_labels : `torch.Tensor`, required. A tensor of the same shape as `predicted_labels`. mask : `torch.BoolTensor`, optional (default = `None`). A tensor of the same shape as `predicted_indices`. """ detached = self.detach_tensors(predicted_indices, predicted_labels, gold_indices, gold_labels, mask) predicted_indices, predicted_labels, gold_indices, gold_labels, mask = detached if mask is None: mask = torch.ones_like(predicted_indices).bool() predicted_indices = predicted_indices.long() predicted_labels = predicted_labels.long() gold_indices = gold_indices.long() gold_labels = gold_labels.long() # Multiply by a mask denoting locations of # gold labels which we should ignore. for label in self._ignore_classes: label_mask = gold_labels.eq(label) mask = mask & ~label_mask correct_indices = predicted_indices.eq(gold_indices).long() * mask unlabeled_exact_match = (correct_indices + ~mask).prod(dim=-1) correct_labels = predicted_labels.eq(gold_labels).long() * mask correct_labels_and_indices = correct_indices * correct_labels labeled_exact_match = (correct_labels_and_indices + ~mask).prod(dim=-1) total_sentences = correct_indices.size(0) total_words = correct_indices.numel() - (~mask).sum() self._unlabeled_correct += dist_reduce_sum(correct_indices).sum() self._exact_unlabeled_correct += dist_reduce_sum( unlabeled_exact_match).sum() self._labeled_correct += dist_reduce_sum( correct_labels_and_indices).sum() self._exact_labeled_correct += dist_reduce_sum( labeled_exact_match).sum() self._total_sentences += dist_reduce_sum(total_sentences) self._total_words += dist_reduce_sum(total_words)
def __call__(self, prediction: Union[str, List], ground_truths: List): # type: ignore """ Parameters ---------- prediction: ``Union[str, List]`` The predicted answer from the model evaluated. This could be a string, or a list of string when multiple spans are predicted as answer. ground_truths: ``List`` All the ground truth answer annotations. """ # If you wanted to split this out by answer type, you could look at [1] here and group by # that, instead of only keeping [0]. ground_truth_answer_strings = [ answer_json_to_strings(annotation)[0] for annotation in ground_truths ] exact_match, f1_score = metric_max_over_ground_truths( drop_em_and_f1, prediction, ground_truth_answer_strings) # Converting to int here, since we want to count the number of exact matches. self._total_em += dist_reduce_sum(int(exact_match)) self._total_f1 += dist_reduce_sum(f1_score) self._count += dist_reduce_sum(1)
def __call__( self, predictions: torch.Tensor, gold_labels: torch.Tensor, mask: Optional[torch.BoolTensor] = None, ): """ # Parameters predictions : `torch.Tensor`, required. A tensor of predictions of shape (batch_size, ..., num_classes). gold_labels : `torch.Tensor`, required. A tensor of integer class label of shape (batch_size, ...). It must be the same shape as the `predictions` tensor without the `num_classes` dimension. mask : `torch.BoolTensor`, optional (default = `None`). A masking tensor the same size as `gold_labels`. """ predictions, gold_labels, mask = self.detach_tensors( predictions, gold_labels, mask) # Some sanity checks. num_classes = predictions.size(-1) if gold_labels.dim() != predictions.dim() - 1: raise ConfigurationError( "gold_labels must have dimension == predictions.size() - 1 but " "found tensor of shape: {}".format(predictions.size())) if (gold_labels >= num_classes).any(): raise ConfigurationError( "A gold label passed to Categorical Accuracy contains an id >= {}, " "the number of classes.".format(num_classes)) predictions = predictions.view((-1, num_classes)) gold_labels = gold_labels.view(-1).long() if not self._tie_break: # Top K indexes of the predictions (or fewer, if there aren't K of them). # Special case topk == 1, because it's common and .max() is much faster than .topk(). if self._top_k == 1: top_k = predictions.max(-1)[1].unsqueeze(-1) else: _, sorted_indices = predictions.sort(dim=-1, descending=True) top_k = sorted_indices[ ..., :min(self._top_k, predictions.shape[-1])] # This is of shape (batch_size, ..., top_k). correct = top_k.eq(gold_labels.unsqueeze(-1)).float() else: # prediction is correct if gold label falls on any of the max scores. distribute score by tie_counts max_predictions = predictions.max(-1)[0] max_predictions_mask = predictions.eq( max_predictions.unsqueeze(-1)) # max_predictions_mask is (rows X num_classes) and gold_labels is (batch_size) # ith entry in gold_labels points to index (0-num_classes) for ith row in max_predictions # For each row check if index pointed by gold_label is was 1 or not (among max scored classes) correct = max_predictions_mask[torch.arange( gold_labels.numel(), device=gold_labels.device).long(), gold_labels].float() tie_counts = max_predictions_mask.sum(-1) correct /= tie_counts.float() correct.unsqueeze_(-1) if mask is not None: correct *= mask.view(-1, 1) _total_count = mask.sum() else: _total_count = torch.tensor(gold_labels.numel()) _correct_count = correct.sum() self.correct_count += dist_reduce_sum(_correct_count).item() self.total_count += dist_reduce_sum(_total_count).item()
def __call__(self, predicted_trees: List[Tree], gold_trees: List[Tree]) -> None: # type: ignore """ # Parameters predicted_trees : `List[Tree]` A list of predicted NLTK Trees to compute score for. gold_trees : `List[Tree]` A list of gold NLTK Trees to use as a reference. """ if not os.path.exists(self._evalb_program_path): logger.warning( f"EVALB not found at {self._evalb_program_path}. Attempting to compile it." ) EvalbBracketingScorer.compile_evalb(self._evalb_directory_path) # If EVALB executable still doesn't exist, raise an error. if not os.path.exists(self._evalb_program_path): compile_command = ( f"python -c 'from allennlp.training.metrics import EvalbBracketingScorer; " f'EvalbBracketingScorer.compile_evalb("{self._evalb_directory_path}")\'' ) raise ConfigurationError( f"EVALB still not found at {self._evalb_program_path}. " "You must compile the EVALB scorer before using it." " Run 'make' in the '{}' directory or run: {}".format( self._evalb_program_path, compile_command)) tempdir = tempfile.mkdtemp() gold_path = os.path.join(tempdir, "gold.txt") predicted_path = os.path.join(tempdir, "predicted.txt") with open(gold_path, "w") as gold_file: for tree in gold_trees: gold_file.write(f"{tree.pformat(margin=1000000)}\n") with open(predicted_path, "w") as predicted_file: for tree in predicted_trees: predicted_file.write(f"{tree.pformat(margin=1000000)}\n") command = [ self._evalb_program_path, "-p", self._evalb_param_path, "-e", str(self._evalb_num_errors_to_kill), gold_path, predicted_path, ] completed_process = subprocess.run(command, stdout=subprocess.PIPE, universal_newlines=True, check=True) _correct_predicted_brackets = 0.0 _gold_brackets = 0.0 _predicted_brackets = 0.0 for line in completed_process.stdout.split("\n"): stripped = line.strip().split() if len(stripped) == 12 and stripped != self._header_line: # This line contains results for a single tree. numeric_line = [float(x) for x in stripped] _correct_predicted_brackets += numeric_line[5] _gold_brackets += numeric_line[6] _predicted_brackets += numeric_line[7] shutil.rmtree(tempdir) self._correct_predicted_brackets += dist_reduce_sum( _correct_predicted_brackets) self._gold_brackets += dist_reduce_sum(_gold_brackets) self._predicted_brackets += dist_reduce_sum(_predicted_brackets)
def __call__( self, predictions: torch.Tensor, gold_labels: torch.Tensor, mask: Optional[torch.BoolTensor] = None, ): """ # Parameters predictions : `torch.Tensor`, required. A tensor of predictions of shape (batch_size, ..., num_classes). gold_labels : `torch.Tensor`, required. A tensor of boolean labels of shape (batch_size, ..., num_classes). It must be the same shape as the `predictions`. mask : `torch.BoolTensor`, optional (default = `None`). A masking tensor the same size as `gold_labels`. """ predictions, gold_labels, mask = self.detach_tensors(predictions, gold_labels, mask) # Calculate true_positive_sum, true_negative_sum, pred_sum, true_sum num_classes = predictions.size(-1) # It means we call this metric at the first time # when `self._true_positive_sum` is None. if self._true_positive_sum is None: self._true_positive_sum = torch.zeros(num_classes, device=predictions.device) self._true_sum = torch.zeros(num_classes, device=predictions.device) self._pred_sum = torch.zeros(num_classes, device=predictions.device) self._total_sum = torch.zeros(num_classes, device=predictions.device) if mask is None: mask = torch.ones_like(gold_labels, dtype=torch.bool) gold_labels = gold_labels.float() # If the prediction tensor is all zeros, the record is not classified to any of the labels. pred_mask = (predictions.sum(dim=-1) != 0).unsqueeze(-1) threshold_predictions = (predictions >= self._threshold).float() class_indices = ( torch.arange(num_classes, device=predictions.device) .unsqueeze(0) .repeat(gold_labels.size(0), 1) ) true_positives = (gold_labels * threshold_predictions).bool() & mask & pred_mask true_positives_bins = class_indices[true_positives] # Watch it: # The total numbers of true positives under all _predicted_ classes are zeros. if true_positives_bins.shape[0] == 0: true_positive_sum = torch.zeros(num_classes, device=predictions.device) else: true_positive_sum = torch.bincount( true_positives_bins.long(), minlength=num_classes ).float() pred_bins = class_indices[threshold_predictions.bool() & mask & pred_mask] # Watch it: # When the `mask` is all 0, we will get an _empty_ tensor. if pred_bins.shape[0] != 0: pred_sum = torch.bincount(pred_bins, minlength=num_classes).float() else: pred_sum = torch.zeros(num_classes, device=predictions.device) gold_labels_bins = class_indices[gold_labels.bool() & mask] if gold_labels_bins.shape[0] != 0: true_sum = torch.bincount(gold_labels_bins, minlength=num_classes).float() else: true_sum = torch.zeros(num_classes, device=predictions.device) self._total_sum += mask.expand_as(gold_labels).sum().to(torch.float) self._true_positive_sum += dist_reduce_sum(true_positive_sum) self._pred_sum += dist_reduce_sum(pred_sum) self._true_sum += dist_reduce_sum(true_sum)
def __call__( self, predictions: torch.Tensor, gold_labels: torch.Tensor, mask: Optional[torch.BoolTensor] = None, ): """ # Parameters predictions : `torch.Tensor`, required. A tensor of predictions of shape (batch_size, ..., num_classes). gold_labels : `torch.Tensor`, required. A tensor of integer class label of shape (batch_size, ...). It must be the same shape as the `predictions` tensor without the `num_classes` dimension. mask : `torch.BoolTensor`, optional (default = `None`). A masking tensor the same size as `gold_labels`. """ predictions, gold_labels, mask = self.detach_tensors(predictions, gold_labels, mask) # Calculate true_positive_sum, true_negative_sum, pred_sum, true_sum num_classes = predictions.size(-1) if (gold_labels >= num_classes).any(): raise ConfigurationError( "A gold label passed to FBetaMeasure contains " f"an id >= {num_classes}, the number of classes." ) # It means we call this metric at the first time # when `self._true_positive_sum` is None. if self._true_positive_sum is None: self._true_positive_sum = torch.zeros(num_classes, device=predictions.device) self._true_sum = torch.zeros(num_classes, device=predictions.device) self._pred_sum = torch.zeros(num_classes, device=predictions.device) self._total_sum = torch.zeros(num_classes, device=predictions.device) if mask is None: mask = torch.ones_like(gold_labels).bool() gold_labels = gold_labels.float() # If the prediction tensor is all zeros, the record is not classified to any of the labels. pred_mask = predictions.sum(dim=-1) != 0 argmax_predictions = predictions.max(dim=-1)[1].float() true_positives = (gold_labels == argmax_predictions) & mask & pred_mask true_positives_bins = gold_labels[true_positives] # Watch it: # The total numbers of true positives under all _predicted_ classes are zeros. if true_positives_bins.shape[0] == 0: true_positive_sum = torch.zeros(num_classes, device=predictions.device) else: true_positive_sum = torch.bincount( true_positives_bins.long(), minlength=num_classes ).float() pred_bins = argmax_predictions[mask & pred_mask].long() # Watch it: # When the `mask` is all 0, we will get an _empty_ tensor. if pred_bins.shape[0] != 0: pred_sum = torch.bincount(pred_bins, minlength=num_classes).float() else: pred_sum = torch.zeros(num_classes, device=predictions.device) gold_labels_bins = gold_labels[mask].long() if gold_labels.shape[0] != 0: true_sum = torch.bincount(gold_labels_bins, minlength=num_classes).float() else: true_sum = torch.zeros(num_classes, device=predictions.device) self._total_sum += mask.sum().to(torch.float) self._true_positive_sum += dist_reduce_sum(true_positive_sum) self._pred_sum += dist_reduce_sum(pred_sum) self._true_sum += dist_reduce_sum(true_sum)