Esempio n. 1
0
 def __init__(self, **kwargs: Dict[str, Any]):
     self.training_metrics = {}
     self.validation_metrics = {}
     for name, metric_kwargs in kwargs.items():
         # We need a special logic for the vocabulary, we do not want to deep copy it,
         # and it cannot be used in Params
         vocab = metric_kwargs.pop("vocabulary", None)
         self.training_metrics[name] = Metric.from_params(
             Params(copy.deepcopy(metric_kwargs)),
             **{} if vocab is None else {"vocabulary": vocab})
         self.validation_metrics[name] = Metric.from_params(
             Params(metric_kwargs),
             **{} if vocab is None else {"vocabulary": vocab})
 def test_span_f1_can_build_from_params(self):
     params = Params(
         {"type": "non_bio_span_f1", "tag_namespace": "tags", "ignore_classes": ["V"]})
     metric = Metric.from_params(params, self.vocab)
     assert metric._ignore_classes == ["V"]
     assert metric._label_vocabulary == self.vocab.get_index_to_token_vocabulary(
         "tags")
Esempio n. 3
0
 def test_span_f1_can_build_from_params(self, device: str):
     params = Params({
         "type": "span_f1",
         "tag_namespace": "tags",
         "ignore_classes": ["V"]
     })
     metric = Metric.from_params(params=params, vocabulary=self.vocab)
     assert metric._ignore_classes == ["V"]  # type: ignore
     assert metric._label_vocabulary == self.vocab.get_index_to_token_vocabulary(  # type: ignore
         "tags")
Esempio n. 4
0
def get_metric_name_value_pairs(metric: Metric, default_name: str, reset: bool = False) -> Iterable[Tuple[str, float]]:
    """
    Return the metric as in `Metric.get_metric` but as an iterable of string-float pairs.
    """
    value = metric.get_metric(reset)
    if isinstance(value, collections.abc.Mapping):
        for sub_name, sub_value in value.items():
            if isinstance(sub_value, collections.abc.Iterable):
                for i, sub_value_i in enumerate(sub_value):
                    yield f"{sub_name}_{i}", sub_value_i
            else:
                yield sub_name, sub_value
    elif isinstance(value, collections.abc.Iterable):
        for i, sub_value in enumerate(value):  # type: ignore
            yield f"{default_name}_{i}", sub_value  # type: ignore
    else:
        yield default_name, value
Esempio n. 5
0
 def from_params(cls, vocab, params):
     text_field_embedder = TextFieldEmbedder.from_params(
         vocab, params.pop('text_field_embedder'))
     hidden_size = params.pop('hidden_size', 128)
     num_layers = params.pop('num_layers', 2)
     dropout = params.pop('dropout', 0.5)
     tag_namespace = params.pop('tag_namespace', 'tags')
     initializer = None
     initializer_params = params.pop('initializer', None)
     if initializer_params is not None:
         initializer = Initializer.from_params(initializer_params)
     metric = None
     metric_params = params.pop('metric', None)
     if metric_params is not None:
         metric = Metric.from_params(metric_params)
     params.assert_empty(cls.__name__)
     return cls(vocab, text_field_embedder, hidden_size=hidden_size, num_layers=num_layers,
                dropout=dropout, tag_namespace=tag_namespace, initializer=initializer,
                metric=metric)
def global_distributed_metric(
    global_rank: int,
    world_size: int,
    gpu_id: Union[int, torch.device],
    metric: Metric,
    metric_kwargs: Dict[str, List[Any]],
    desired_values: Dict[str, Any],
    exact: Union[bool, Tuple[float, float]] = True,
    number_of_runs: int = 1,
):
    kwargs = {}

    # Use the arguments meant for the process with rank `global_rank`.
    for argname in metric_kwargs:
        kwargs[argname] = metric_kwargs[argname][global_rank]

    for _ in range(number_of_runs):
        metric(**kwargs)

    metrics = metric.get_metric(False)
    if not isinstance(metrics, Dict) and not isinstance(desired_values, Dict):
        metrics = {"metric_value": metrics}
        desired_values = {"metric_value": desired_values}

    # Call `assertion_metrics_values` to check if the metrics have the desired values.
    if isinstance(exact, bool):
        if exact:
            rtol = 0.0
            atol = 0.0
        else:
            rtol = 0.0001
            atol = 1e-05
    else:
        rtol = exact[0]
        atol = exact[1]

    assert_metrics_values(metrics, desired_values, rtol, atol)  # type: ignore
Esempio n. 7
0
    def get_ccm_labels(
        self,
        output_dict: Dict[str, torch.Tensor],
        partial_labels: Optional[List[List[Tuple[int, int]]]] = None
    ) -> List[List[int]]:
        _start_transitions = self.crf.start_transitions \
            if hasattr(self.crf, "start_transitions") else None
        _end_transitions = self.crf.end_transitions \
            if hasattr(self.crf, "end_transitions") else None
        logits, mask, transitions, start_transitions, end_transitions = [
            (x.numpy() if isinstance(x, torch.Tensor) else x)
            for x in Metric.unwrap_to_tensors(
                output_dict["logits"], output_dict["mask"],
                self.crf.transitions, _start_transitions, _end_transitions)
        ]

        return self._ccm_decoder.ccm_tags(
            logits=logits,
            mask=mask,
            transitions=transitions,
            start_transitions=start_transitions,
            end_transitions=end_transitions,
            partial_labels=partial_labels,
            sentence_boundaries=output_dict["sentence_markers"])
Esempio n. 8
0
    def update_confusion_matrices(self, predictions, gold_labels):
        mask = gold_labels > 0
        predictions, gold_labels, mask = Metric.unwrap_to_tensors(
            predictions, gold_labels, mask)
        num_classes = predictions.size(-1)
        if gold_labels.dim() != predictions.dim() - 1:
            raise ConfigurationError(
                "gold_labels must have dimension == predictions.size() - 1 but "
                "found tensor of shape: {}".format(predictions.size()))
        if (gold_labels >= num_classes).any():
            raise ConfigurationError(
                "A gold label passed to Categorical Accuracy contains an id >= {}, "
                "the number of classes.".format(num_classes))
        predictions = predictions.view((-1, num_classes))
        gold_labels = gold_labels.view(-1).long()

        top_k = predictions.max(-1)[1].unsqueeze(-1)
        gold_labels = gold_labels.unsqueeze(-1)

        vocab = self.model.vocab

        for i, gold_label in enumerate(gold_labels):
            if gold_label == 0:
                continue
            pred = top_k[i]
            gold_label, pred = gold_label.item(), pred.item()
            self.chord_cm[gold_label][pred] += 1

            gold_token = vocab.get_token_from_index(gold_label)
            pred_token = vocab.get_token_from_index(pred)

            gold_key, gold_form, gold_figbass = parse_chord_name_core(
                gold_token)
            pred_key, pred_form, pred_figbass = parse_chord_name_core(
                pred_token)

            if gold_key is None and gold_token == "@end@":
                gold_key = "@end@"
            if pred_key is None and pred_token == "@end@":
                pred_key = "@end@"

            if gold_key in self.key_list and pred_key in self.key_list:
                gold_key_idx = self.key_list.index(gold_key)
                pred_key_idx = self.key_list.index(pred_key)
                self.key_cm[gold_key_idx][pred_key_idx] += 1
            else:
                print((gold_token, gold_key), (pred_token, pred_key))

            if gold_key != "@end@":
                form = gold_form if gold_form is not None else ""
                figbass = gold_figbass if gold_figbass is not None else ""
                gold_type = form + figbass
            else:
                gold_type = "@end@"

            if pred_key != "@end@":
                form = pred_form if pred_form is not None else ""
                figbass = pred_figbass if pred_figbass is not None else ""
                pred_type = form + figbass
            else:
                pred_type = "@end@"

            if gold_type in self.type_list and pred_type in self.type_list:
                gold_type_idx = self.type_list.index(gold_type)
                pred_type_idx = self.type_list.index(pred_type)
                self.type_cm[gold_type_idx][pred_type_idx] += 1
            else:
                print((gold_token, gold_type), (pred_token, pred_type))
 def test_span_f1_can_build_from_params(self):
     params = Params({"type": "span_f1", "tag_namespace": "tags", "ignore_classes": ["V"]})
     metric = Metric.from_params(params=params, vocabulary=self.vocab)
     assert metric._ignore_classes == ["V"]
     assert metric._label_vocabulary == self.vocab.get_index_to_token_vocabulary("tags")
 def test_span_f1_can_build_from_params(self):
     params = Params({u"type": u"span_f1", u"tag_namespace": u"tags", u"ignore_classes": [u"V"]})
     metric = Metric.from_params(params=params, vocabulary=self.vocab)
     assert metric._ignore_classes == [u"V"]
     assert metric._label_vocabulary == self.vocab.get_index_to_token_vocabulary(u"tags")