コード例 #1
0
ファイル: tagging.py プロジェクト: yehuangcn/nlp-architect
    def evaluate_predictions(self, logits, label_ids):
        """
        Evaluate given logits on truth labels

        Args:
            logits: logits of model
            label_ids: truth label ids

        Returns:
            dict: dictionary containing P/R/F1 metrics
        """
        active_positions = label_ids.view(-1) != 0.0
        active_labels = label_ids.view(-1)[active_positions]
        if self.use_crf:
            logits_shape = logits.size()
            decode_ap = active_positions.view(logits_shape[0],
                                              logits_shape[1]) != 0.0
            if self.n_gpus > 1:
                decode_fn = self.crf.module.decode
            else:
                decode_fn = self.crf.decode
            logits = decode_fn(logits.to(self.device),
                               mask=decode_ap.to(self.device))
            logits = [l for ll in logits for l in ll]
        else:
            active_logits = logits.view(-1,
                                        len(self.label_id_str) +
                                        1)[active_positions]
            logits = torch.argmax(F.log_softmax(active_logits, dim=1), dim=1)
            logits = logits.detach().cpu().numpy()
        out_label_ids = active_labels.detach().cpu().numpy()
        y_true, y_pred = self.extract_labels(out_label_ids, logits)
        p, r, f1 = tagging(y_pred, y_true)
        return {"p": p, "r": r, "f1": f1}
コード例 #2
0
 def extract_labels(label_ids, label_map, logits):
     y_true = []
     y_pred = []
     for p, y in zip(logits, label_ids):
         y_pred.append(label_map.get(p, "O"))
         y_true.append(label_map.get(y, "O"))
     assert len(y_true) == len(y_pred)
     return tagging(y_pred, y_true)