def _get_rouge_l_score( self, predicted_tokens: torch.LongTensor, reference_tokens: torch.LongTensor ) -> float: """ Compute sum of F1 scores given batch of predictions and references. """ total_f1 = 0.0 for predicted_seq, reference_seq in zip(predicted_tokens, reference_tokens): from allennlp.training.util import get_valid_tokens_mask m = get_valid_tokens_mask(reference_seq, self._exclude_indices).sum().item() n = get_valid_tokens_mask(predicted_seq, self._exclude_indices).sum().item() lcs = self._longest_common_subsequence(reference_seq, predicted_seq) # This also rules out the case that m or n are 0, so we don't worry about it later if lcs == 0: continue recall_lcs = lcs / m precision_lcs = lcs / n f1 = 2 * recall_lcs * precision_lcs / (recall_lcs + precision_lcs) total_f1 += f1 return total_f1
def __call__( self, # type: ignore predictions: torch.LongTensor, gold_targets: torch.LongTensor, ) -> None: """ Update precision counts. # Parameters predictions : `torch.LongTensor`, required Batched predicted tokens of shape `(batch_size, max_sequence_length)`. references : `torch.LongTensor`, required Batched reference (gold) translations with shape `(batch_size, max_gold_sequence_length)`. # Returns None """ predictions, gold_targets = self.detach_tensors(predictions, gold_targets) device = gold_targets.device if is_distributed(): world_size = dist.get_world_size() for ngram_size, _ in enumerate(self._ngram_weights, start=1): precision_matches, precision_totals = self._get_modified_precision_counts( predictions, gold_targets, ngram_size ) if is_distributed(): _precision_matches = torch.tensor(precision_matches).to(device) _precision_totals = torch.tensor(precision_totals).to(device) dist.all_reduce(_precision_matches, op=dist.ReduceOp.SUM) dist.all_reduce(_precision_totals, op=dist.ReduceOp.SUM) precision_matches = _precision_matches.item() / world_size precision_totals = _precision_totals.item() / world_size self._precision_matches[ngram_size] += precision_matches self._precision_totals[ngram_size] += precision_totals if not self._exclude_indices: self._prediction_lengths += predictions.size(0) * predictions.size(1) self._reference_lengths += gold_targets.size(0) * gold_targets.size(1) else: from allennlp.training.util import get_valid_tokens_mask valid_predictions_mask = get_valid_tokens_mask(predictions, self._exclude_indices) self._prediction_lengths += valid_predictions_mask.sum().item() valid_gold_targets_mask = get_valid_tokens_mask(gold_targets, self._exclude_indices) self._reference_lengths += valid_gold_targets_mask.sum().item() if is_distributed(): _prediction_lengths = torch.tensor(self._prediction_lengths).to(device) _reference_lengths = torch.tensor(self._reference_lengths).to(device) dist.all_reduce(_prediction_lengths, op=dist.ReduceOp.SUM) dist.all_reduce(_reference_lengths, op=dist.ReduceOp.SUM) self._prediction_lengths = _prediction_lengths.item() self._reference_lengths = _reference_lengths.item()
def __call__(self, predictions, gold_targets): predictions = mask_after_stop(predictions, stop_token=2) for ngram_size, _ in enumerate(self._ngram_weights, start=1): precision_matches, precision_totals = self._get_modified_precision_counts( predictions, gold_targets, ngram_size) self._precision_matches[ngram_size] += precision_matches self._precision_totals[ngram_size] += precision_totals if not self._exclude_indices: _prediction_lengths = predictions.size(0) * predictions.size(1) _reference_lengths = gold_targets.size(0) * gold_targets.size(1) else: valid_predictions_mask = get_valid_tokens_mask(predictions, self._exclude_indices) valid_gold_targets_mask = get_valid_tokens_mask(gold_targets, self._exclude_indices) _prediction_lengths = valid_predictions_mask.sum().item() _reference_lengths = valid_gold_targets_mask.sum().item() self._prediction_lengths += _prediction_lengths self._reference_lengths += _reference_lengths
def test_get_valid_tokens_mask(self, device: str): tensor = torch.tensor([[1, 2, 3, 0], [0, 1, 1, 0]], device=device) result = get_valid_tokens_mask(tensor, self.metric._exclude_indices).long() check = torch.tensor([[1, 1, 1, 0], [0, 1, 1, 0]], device=device) assert_allclose(result, check)