def evaluate_predictions(self, logits, label_ids): """ Evaluate given logits on truth labels Args: logits: logits of model label_ids: truth label ids Returns: dict: dictionary containing P/R/F1 metrics """ active_positions = label_ids.view(-1) != 0.0 active_labels = label_ids.view(-1)[active_positions] if self.use_crf: logits_shape = logits.size() decode_ap = active_positions.view(logits_shape[0], logits_shape[1]) != 0.0 if self.n_gpus > 1: decode_fn = self.crf.module.decode else: decode_fn = self.crf.decode logits = decode_fn(logits.to(self.device), mask=decode_ap.to(self.device)) logits = [l for ll in logits for l in ll] else: active_logits = logits.view(-1, len(self.label_id_str) + 1)[active_positions] logits = torch.argmax(F.log_softmax(active_logits, dim=1), dim=1) logits = logits.detach().cpu().numpy() out_label_ids = active_labels.detach().cpu().numpy() y_true, y_pred = self.extract_labels(out_label_ids, logits) p, r, f1 = tagging(y_pred, y_true) return {"p": p, "r": r, "f1": f1}
def extract_labels(label_ids, label_map, logits): y_true = [] y_pred = [] for p, y in zip(logits, label_ids): y_pred.append(label_map.get(p, "O")) y_true.append(label_map.get(y, "O")) assert len(y_true) == len(y_pred) return tagging(y_pred, y_true)