def get_ccm_labels(
        self,
        output_dict: Dict[str, torch.Tensor],
        partial_labels: Optional[List[List[Tuple[int, int]]]] = None
    ) -> List[List[int]]:
        _start_transitions = self.crf.start_transitions \
            if hasattr(self.crf, "start_transitions") else None
        _end_transitions = self.crf.end_transitions \
            if hasattr(self.crf, "end_transitions") else None
        logits, mask, transitions, start_transitions, end_transitions = [
            (x.numpy() if isinstance(x, torch.Tensor) else x)
            for x in Metric.unwrap_to_tensors(
                output_dict["logits"], output_dict["mask"],
                self.crf.transitions, _start_transitions, _end_transitions)
        ]

        return self._ccm_decoder.ccm_tags(
            logits=logits,
            mask=mask,
            transitions=transitions,
            start_transitions=start_transitions,
            end_transitions=end_transitions,
            partial_labels=partial_labels,
            sentence_boundaries=output_dict["sentence_markers"])
Exemple #2
0
    def update_confusion_matrices(self, predictions, gold_labels):
        mask = gold_labels > 0
        predictions, gold_labels, mask = Metric.unwrap_to_tensors(
            predictions, gold_labels, mask)
        num_classes = predictions.size(-1)
        if gold_labels.dim() != predictions.dim() - 1:
            raise ConfigurationError(
                "gold_labels must have dimension == predictions.size() - 1 but "
                "found tensor of shape: {}".format(predictions.size()))
        if (gold_labels >= num_classes).any():
            raise ConfigurationError(
                "A gold label passed to Categorical Accuracy contains an id >= {}, "
                "the number of classes.".format(num_classes))
        predictions = predictions.view((-1, num_classes))
        gold_labels = gold_labels.view(-1).long()

        top_k = predictions.max(-1)[1].unsqueeze(-1)
        gold_labels = gold_labels.unsqueeze(-1)

        vocab = self.model.vocab

        for i, gold_label in enumerate(gold_labels):
            if gold_label == 0:
                continue
            pred = top_k[i]
            gold_label, pred = gold_label.item(), pred.item()
            self.chord_cm[gold_label][pred] += 1

            gold_token = vocab.get_token_from_index(gold_label)
            pred_token = vocab.get_token_from_index(pred)

            gold_key, gold_form, gold_figbass = parse_chord_name_core(
                gold_token)
            pred_key, pred_form, pred_figbass = parse_chord_name_core(
                pred_token)

            if gold_key is None and gold_token == "@end@":
                gold_key = "@end@"
            if pred_key is None and pred_token == "@end@":
                pred_key = "@end@"

            if gold_key in self.key_list and pred_key in self.key_list:
                gold_key_idx = self.key_list.index(gold_key)
                pred_key_idx = self.key_list.index(pred_key)
                self.key_cm[gold_key_idx][pred_key_idx] += 1
            else:
                print((gold_token, gold_key), (pred_token, pred_key))

            if gold_key != "@end@":
                form = gold_form if gold_form is not None else ""
                figbass = gold_figbass if gold_figbass is not None else ""
                gold_type = form + figbass
            else:
                gold_type = "@end@"

            if pred_key != "@end@":
                form = pred_form if pred_form is not None else ""
                figbass = pred_figbass if pred_figbass is not None else ""
                pred_type = form + figbass
            else:
                pred_type = "@end@"

            if gold_type in self.type_list and pred_type in self.type_list:
                gold_type_idx = self.type_list.index(gold_type)
                pred_type_idx = self.type_list.index(pred_type)
                self.type_cm[gold_type_idx][pred_type_idx] += 1
            else:
                print((gold_token, gold_type), (pred_token, pred_type))