def get_ccm_labels( self, output_dict: Dict[str, torch.Tensor], partial_labels: Optional[List[List[Tuple[int, int]]]] = None ) -> List[List[int]]: _start_transitions = self.crf.start_transitions \ if hasattr(self.crf, "start_transitions") else None _end_transitions = self.crf.end_transitions \ if hasattr(self.crf, "end_transitions") else None logits, mask, transitions, start_transitions, end_transitions = [ (x.numpy() if isinstance(x, torch.Tensor) else x) for x in Metric.unwrap_to_tensors( output_dict["logits"], output_dict["mask"], self.crf.transitions, _start_transitions, _end_transitions) ] return self._ccm_decoder.ccm_tags( logits=logits, mask=mask, transitions=transitions, start_transitions=start_transitions, end_transitions=end_transitions, partial_labels=partial_labels, sentence_boundaries=output_dict["sentence_markers"])
def update_confusion_matrices(self, predictions, gold_labels): mask = gold_labels > 0 predictions, gold_labels, mask = Metric.unwrap_to_tensors( predictions, gold_labels, mask) num_classes = predictions.size(-1) if gold_labels.dim() != predictions.dim() - 1: raise ConfigurationError( "gold_labels must have dimension == predictions.size() - 1 but " "found tensor of shape: {}".format(predictions.size())) if (gold_labels >= num_classes).any(): raise ConfigurationError( "A gold label passed to Categorical Accuracy contains an id >= {}, " "the number of classes.".format(num_classes)) predictions = predictions.view((-1, num_classes)) gold_labels = gold_labels.view(-1).long() top_k = predictions.max(-1)[1].unsqueeze(-1) gold_labels = gold_labels.unsqueeze(-1) vocab = self.model.vocab for i, gold_label in enumerate(gold_labels): if gold_label == 0: continue pred = top_k[i] gold_label, pred = gold_label.item(), pred.item() self.chord_cm[gold_label][pred] += 1 gold_token = vocab.get_token_from_index(gold_label) pred_token = vocab.get_token_from_index(pred) gold_key, gold_form, gold_figbass = parse_chord_name_core( gold_token) pred_key, pred_form, pred_figbass = parse_chord_name_core( pred_token) if gold_key is None and gold_token == "@end@": gold_key = "@end@" if pred_key is None and pred_token == "@end@": pred_key = "@end@" if gold_key in self.key_list and pred_key in self.key_list: gold_key_idx = self.key_list.index(gold_key) pred_key_idx = self.key_list.index(pred_key) self.key_cm[gold_key_idx][pred_key_idx] += 1 else: print((gold_token, gold_key), (pred_token, pred_key)) if gold_key != "@end@": form = gold_form if gold_form is not None else "" figbass = gold_figbass if gold_figbass is not None else "" gold_type = form + figbass else: gold_type = "@end@" if pred_key != "@end@": form = pred_form if pred_form is not None else "" figbass = pred_figbass if pred_figbass is not None else "" pred_type = form + figbass else: pred_type = "@end@" if gold_type in self.type_list and pred_type in self.type_list: gold_type_idx = self.type_list.index(gold_type) pred_type_idx = self.type_list.index(pred_type) self.type_cm[gold_type_idx][pred_type_idx] += 1 else: print((gold_token, gold_type), (pred_token, pred_type))