Ejemplo n.º 1
0
def evaluate_model(config: Config, model: NNCRF, batch_insts_ids, name: str,
                   insts: List[Instance]):
    ## evaluation
    metrics = np.asarray([0, 0, 0], dtype=int)
    batch_id = 0
    batch_size = config.batch_size
    for batch in batch_insts_ids:
        one_batch_insts = insts[batch_id * batch_size:(batch_id + 1) *
                                batch_size]
        sorted_batch_insts = sorted(one_batch_insts,
                                    key=lambda inst: len(inst.input.words),
                                    reverse=True)
        batch_max_scores, batch_max_ids = model.decode(batch)
        metrics += eval.evaluate_num(sorted_batch_insts, batch_max_ids,
                                     batch[-1], batch[1], config.idx2labels)
        batch_id += 1
    p, total_predict, total_entity = metrics[0], metrics[1], metrics[2]
    precision = p * 1.0 / total_predict * 100 if total_predict != 0 else 0
    recall = p * 1.0 / total_entity * 100 if total_entity != 0 else 0
    fscore = 2.0 * precision * recall / (
        precision + recall) if precision != 0 or recall != 0 else 0
    print("[%s set] Precision: %.2f, Recall: %.2f, F1: %.2f" %
          (name, precision, recall, fscore),
          flush=True)
    return [precision, recall, fscore]
Ejemplo n.º 2
0
class MT_LSTMCRF(nn.Module):
    def __init__(self, config_base, config_conll, config_ontonotes):
        super(MT_LSTMCRF, self).__init__()
        self.config_base = config_base
        self.config_conll = config_conll
        self.config_ontonotes = config_ontonotes
        self.lstmcrf_base = NNCRF(config_base)
        self.lstmcrf_conll = NNCRF(config_conll)
        self.lstmcrf_ontonotes = NNCRF(config_ontonotes)

    def neg_log_obj_total(self, words, word_seq_lens, batch_context_emb, chars,
                          char_seq_lens, prefix_label, conll_label,
                          notes_label, mask_base, mask_conll, mask_ontonotes):
        loss_base, hiddens_base = self.lstmcrf_base.neg_log_obj(
            words, word_seq_lens, batch_context_emb, chars, char_seq_lens,
            prefix_label, mask_base)
        # hidden_base = w1 * h1
        loss_conll, _ = self.lstmcrf_conll.neg_log_obj(words, word_seq_lens,
                                                       batch_context_emb,
                                                       chars, char_seq_lens,
                                                       conll_label, mask_conll,
                                                       hiddens_base)
        loss_ontonotes, _ = self.lstmcrf_ontonotes.neg_log_obj(
            words, word_seq_lens, batch_context_emb, chars, char_seq_lens,
            notes_label, mask_ontonotes, hiddens_base)
        loss_total = loss_base + loss_conll + loss_ontonotes
        # loss_total = loss_ontonotes
        return loss_total

    def decode(self, batchinput):
        words, word_seq_lens, batch_context_emb, chars, char_seq_lens, prefix_label, conll_label, notes_label, mask_base, mask_conll, mask_ontonotes = batchinput
        _, hiddens_base = self.lstmcrf_base.neg_log_obj(
            words, word_seq_lens, batch_context_emb, chars, char_seq_lens,
            prefix_label, mask_base)
        bestScores_conll, decodeIdx_conll = self.lstmcrf_conll.decode(
            batchinput, hiddens_base)
        bestScores_notes, decodeIdx_notes = self.lstmcrf_ontonotes.decode(
            batchinput, hiddens_base)

        return bestScores_conll, decodeIdx_conll, bestScores_notes, decodeIdx_notes, mask_conll, mask_ontonotes