def evaluate_model(config: Config, model: NNCRF, batch_insts_ids, name: str, insts: List[Instance]): ## evaluation metrics = np.asarray([0, 0, 0], dtype=int) batch_id = 0 batch_size = config.batch_size for batch in batch_insts_ids: one_batch_insts = insts[batch_id * batch_size:(batch_id + 1) * batch_size] sorted_batch_insts = sorted(one_batch_insts, key=lambda inst: len(inst.input.words), reverse=True) batch_max_scores, batch_max_ids = model.decode(batch) metrics += eval.evaluate_num(sorted_batch_insts, batch_max_ids, batch[-1], batch[1], config.idx2labels) batch_id += 1 p, total_predict, total_entity = metrics[0], metrics[1], metrics[2] precision = p * 1.0 / total_predict * 100 if total_predict != 0 else 0 recall = p * 1.0 / total_entity * 100 if total_entity != 0 else 0 fscore = 2.0 * precision * recall / ( precision + recall) if precision != 0 or recall != 0 else 0 print("[%s set] Precision: %.2f, Recall: %.2f, F1: %.2f" % (name, precision, recall, fscore), flush=True) return [precision, recall, fscore]
class MT_LSTMCRF(nn.Module): def __init__(self, config_base, config_conll, config_ontonotes): super(MT_LSTMCRF, self).__init__() self.config_base = config_base self.config_conll = config_conll self.config_ontonotes = config_ontonotes self.lstmcrf_base = NNCRF(config_base) self.lstmcrf_conll = NNCRF(config_conll) self.lstmcrf_ontonotes = NNCRF(config_ontonotes) def neg_log_obj_total(self, words, word_seq_lens, batch_context_emb, chars, char_seq_lens, prefix_label, conll_label, notes_label, mask_base, mask_conll, mask_ontonotes): loss_base, hiddens_base = self.lstmcrf_base.neg_log_obj( words, word_seq_lens, batch_context_emb, chars, char_seq_lens, prefix_label, mask_base) # hidden_base = w1 * h1 loss_conll, _ = self.lstmcrf_conll.neg_log_obj(words, word_seq_lens, batch_context_emb, chars, char_seq_lens, conll_label, mask_conll, hiddens_base) loss_ontonotes, _ = self.lstmcrf_ontonotes.neg_log_obj( words, word_seq_lens, batch_context_emb, chars, char_seq_lens, notes_label, mask_ontonotes, hiddens_base) loss_total = loss_base + loss_conll + loss_ontonotes # loss_total = loss_ontonotes return loss_total def decode(self, batchinput): words, word_seq_lens, batch_context_emb, chars, char_seq_lens, prefix_label, conll_label, notes_label, mask_base, mask_conll, mask_ontonotes = batchinput _, hiddens_base = self.lstmcrf_base.neg_log_obj( words, word_seq_lens, batch_context_emb, chars, char_seq_lens, prefix_label, mask_base) bestScores_conll, decodeIdx_conll = self.lstmcrf_conll.decode( batchinput, hiddens_base) bestScores_notes, decodeIdx_notes = self.lstmcrf_ontonotes.decode( batchinput, hiddens_base) return bestScores_conll, decodeIdx_conll, bestScores_notes, decodeIdx_notes, mask_conll, mask_ontonotes