예제 #1
0
파일: loss.py 프로젝트: rqhappy/FGEC
def customized_bce_loss(model, input_fx, target, opt, hierarchical_types):

    gold = target
    pred = input_fx.sigmoid().ge(config.PRED_THRESHOLD).long()
    print(f"input_fx[0] = {input_fx[0]}")
    device = torch.device(config.CUDA) if torch.cuda.is_available() and opt.cuda else "cpu"
    top_level_type_indicator = torch.zeros(target.shape, dtype=torch.float, device=device)
    top_level_type_indicator[:, 0:46] = 1.0
    panalty_mat = (input_fx * 1)
    idx0 = torch.tensor(list(range(input_fx.shape[0])), dtype=torch.long, device=device)
    idx1 = (input_fx * top_level_type_indicator).argmax(dim=1)

    max_val = (-input_fx).clamp(min=0)
    los = input_fx - input_fx * target.float() + max_val + ((-max_val).exp() + (-input_fx - max_val).exp()).log()

    weighted_bce_loss = ((1 - (top_level_type_indicator * target.float()).sum(1)).exp() * los.sum(1)).sum()

    panalty_mat[idx0, idx1] = torch.tensor(float('-inf'), device=device)
    top_type = (input_fx.sigmoid() * top_level_type_indicator).ge(config.PRED_THRESHOLD).float() * input_fx
    top_type_count_panalty = -F.softmax(top_type, dim=1).max(dim=1)[0].log().sum()

    loss = weighted_bce_loss + model.get_struct_loss() + top_type_count_panalty

    pma, rema = e.loose_macro_PR(gold, pred, opt)
    pmi, remi = e.loose_micro_PR(gold, pred, opt)
    pstr, restr = e.strict_PR(gold, pred, opt)
    print(f"\nloss_val = {los.sum()}\n"
          f"\nmacro-F1 = {e.f1_score(pma, rema)} precision = {pma}, recall = {rema}"
          f"\nmicro-F1 = {e.f1_score(pmi, remi)} precision = {pmi}, recall = {remi}"
          f"\nstrict-F1 = {e.f1_score(pstr, restr)} precision = {pstr}, recall = {restr}")

    return los.sum(), gold, pred
예제 #2
0
파일: test.py 프로젝트: sxrczh/FETHI
def test(opt,
         model,
         test_dataloader,
         threshold=config.PRED_THRESHOLD,
         record_result=False,
         analysis_result=False,
         mode=config.TEST):
    device = torch.device(config.CUDA) if torch.cuda.is_available() else "cpu"
    model.eval()
    bgn = 0
    test_iter = iter(test_dataloader)

    gold_all, pred_all = [], []
    hierarchical_types = pickle.load(
        open(config.DATA_ROOT + opt.corpus_dir + "hierarchical_types.pkl",
             'rb'))

    for batch in test_iter:
        # mention, mention_len, mention_neighbor, lcontext, rcontext, y = batch
        mention, mention_len, lcontext, rcontext, mention_char, y = batch

        mention = mention.to(device)
        mention_len = mention_len.to(device)
        lcontext = lcontext.to(device)
        rcontext = rcontext.to(device)
        mention_char = mention_char.to(device)
        y = y.to(device)

        model_output = model(
            [mention, mention_len, lcontext, rcontext, mention_char])

        loss, gold, pred, prob = bce_loss(model, model_output, y, opt,
                                          hierarchical_types, threshold)

        if record_result:
            util.record_result(gold, pred, prob, opt, bgn, mode)

        gold_all.append(gold)
        pred_all.append(pred)

        bgn += opt.batch_size

    gold_all = torch.cat(gold_all)
    pred_all = torch.cat(pred_all)
    if analysis_result:
        util.analysis_result(gold_all, pred_all)

    pmacro, remacro = e.loose_macro_PR(gold_all, pred_all, opt)
    pmicro, remicro = e.loose_micro_PR(gold_all, pred_all, opt)
    pstrict, restrict = e.strict_PR(gold_all, pred_all, opt)
    macro_F1 = e.f1_score(pmacro, remacro)
    micro_F1 = e.f1_score(pmicro, remicro)
    strict_F1 = e.f1_score(pstrict, restrict)

    return (macro_F1, pmacro, remacro), \
           (micro_F1, pmicro, remicro), \
           (strict_F1, pstrict, restrict)
예제 #3
0
파일: loss.py 프로젝트: sxrczh/FETHI
def hier_loss(model, output, target, opt, tune, prior, mask):
    device = torch.device(
        config.CUDA) if torch.cuda.is_available() and opt.cuda else "cpu"

    proba = output.softmax(dim=1)
    adjust_proba = torch.matmul(proba, tune.t())

    # print(f"proba = {proba[0]}")
    # print(f"adjust_proba = {adjust_proba[0]}")
    p_caret = torch.argmax(adjust_proba, dim=1)
    # print(f"p_caret = {p_caret[0]}")

    # loss

    gold = target * 1
    # for i, t in enumerate(target):
    #     target[i] = mask[t.nonzero().squeeze()].prod(0) * target[i]

    # print(f"target = {target}")
    tgt = torch.argmax(adjust_proba * target.float(), dim=1)

    tgt_idx = torch.zeros(target.shape, dtype=torch.float,
                          device=device).scatter_(1, tgt.unsqueeze(1), 1)
    loss = -(adjust_proba.log() * tgt_idx).sum(dim=1).mean()

    pred = F.embedding(p_caret, prior)

    pma, rema = e.loose_macro_PR(gold, pred, opt)
    pmi, remi = e.loose_micro_PR(gold, pred, opt)
    pstr, restr = e.strict_PR(gold, pred, opt)
    print(
        f"\nloss_val = {loss}\n"
        f"\nmacro-F1 = {e.f1_score(pma, rema)} precision = {pma}, recall = {rema}"
        f"\nmicro-F1 = {e.f1_score(pmi, remi)} precision = {pmi}, recall = {remi}"
        f"\nstrict-F1 = {e.f1_score(pstr, restr)} precision = {pstr}, recall = {restr}"
    )

    return loss, gold, pred
예제 #4
0
파일: train.py 프로젝트: rqhappy/FGEC
def test(opt, model, test_dataloader, record_result=False):
    device = torch.device(config.CUDA) if torch.cuda.is_available() else "cpu"
    model.eval()

    macro_F1, micro_F1, strict_F1 = 0, 0, 0
    pmacro, remacro = 0, 0
    pmicro, remicro = 0, 0
    pstrict, restrict = 0, 0

    bgn = 0

    total = len(test_dataloader)
    test_iter = iter(test_dataloader)

    p = config.DATA_ROOT + opt.corpus_dir + "hierarchical_types.pkl"
    prior = torch.tensor(util.create_prior(p),
                         requires_grad=False,
                         dtype=torch.long).to(device)
    tune = torch.tensor(util.create_prior(p, config.BETA),
                        requires_grad=False,
                        dtype=torch.float).to(device)
    mask = torch.tensor(util.create_mask(p),
                        requires_grad=False,
                        dtype=torch.long).to(device)

    for batch in test_iter:
        mention, mention_len, mention_neighbor, lcontext, rcontext, y = batch

        mention = mention.to(device)
        mention_len = mention_len.to(device)
        mention_neighbor = mention_neighbor.to(device)
        lcontext = lcontext.to(device)
        rcontext = rcontext.to(device)
        y = y.to(device)

        model_output = model(
            [mention, mention_len, mention_neighbor, lcontext, rcontext])

        # loss, gold, pred = customized_bce_loss(model, model_output, y, opt, hierarchical_types)
        loss, gold, pred = bce_loss(model, model_output, y, opt, "test")
        # loss, gold, pred = hier_loss(model, model_output, y, opt, tune, prior, mask)

        if record_result:
            util.record_result(gold, pred, opt, bgn)

        bgn += opt.batch_size
        pma, rema = e.loose_macro_PR(gold, pred, opt)
        macro_F1 += e.f1_score(pma, rema)
        pmacro += pma
        remacro += rema

        pmi, remi = e.loose_micro_PR(gold, pred, opt)
        micro_F1 += e.f1_score(pmi, remi)
        pmicro += pmi
        remicro += remi

        pstr, restr = e.strict_PR(gold, pred, opt)
        strict_F1 += e.f1_score(pstr, restr)
        pstrict += pstr
        restrict += restr

    return (macro_F1/total, pmacro/total, remacro/total), \
           (micro_F1/total, pmicro/total, remicro/total), \
           (strict_F1/total, pstrict/total, restrict/total)
예제 #5
0
파일: loss.py 프로젝트: rqhappy/FGEC
def one_path_loss(model, input_fx, target, opt, hierarchical_types):
    bceloss = model.get_bceloss()
    device = torch.device(config.CUDA) if torch.cuda.is_available() and opt.cuda else "cpu"
    p_caret = F.softmax(input_fx, dim=1)
    y_caret = torch.argmax(p_caret, dim=1, keepdim=True)

    def get_pred(tgt, y_c, hier_types):
        ls = torch.zeros_like(tgt, device=device, dtype=torch.long)
        for i, row in enumerate(y_c):
            ls[i, row.item()] = 1
            if hier_types.get(row.item()) is not None:
                ls[i, hier_types[row.item()]] = 1
        return ls

    def get_yt(tgt, hier_t):
        # yt = torch.zeros(tgt.shape, dtype=torch.long, requires_grad=False, device=device)
        # yt.copy_(tgt)
        yt = tgt * 1
        for i, row in enumerate(yt):
            for j, ele in enumerate(row):
                if yt[i][j] == 1 and hier_t.get(j) is not None:
                    yt[i][hier_t[j]] = 0
        return yt
        # torch.where(hier_t.get(tgt) is None, tgt, zero)

    gold = target
    # yt = get_yt(target, hierarchical_types)
    #
    # # gold, yt = get_gold(target, classes, hierarchical_types)
    # y_star_caret = torch.argmax((p_caret * yt.float()), dim=1, keepdim=True)
    pred = get_pred(target, y_caret, hierarchical_types)

    yt = get_yt(target, hierarchical_types)
    y_star_caret = torch.argmax((p_caret * yt.float()), dim=1, keepdim=True)
    loss = -torch.gather(p_caret, 1, y_star_caret).log().mean()
    # def get_ysc_w(ysc, hier_types):
    #     ys = torch.zeros([input_fx.shape[0], config.NUM_OF_CLS],
    #                      dtype=torch.float, requires_grad=False, device=device).scatter_(1, ysc, 1)
    #     yscw = torch.zeros([input_fx.shape[0], config.NUM_OF_CLS],
    #                        dtype=torch.float, requires_grad=False, device=device).scatter_(1, ysc, 1)
    #     for i, ele in enumerate(ysc):
    #         if hier_types.get(ele.item()) is not None:
    #             yscw[i][hier_types[ele.item()][0]] = config.BETA
    #     return ys, yscw

    # ys1, ysc_w = get_ysc_w(y_star_caret, hierarchical_types)

    # print((ysc_w*p_caret).sum(-1))
    # loss = -(torch.tanh(ysc_w)*p_caret).sum(-1).log().mean()  # hierarchical one-path loss
    # print(torch.gather(p_caret, 1, y_star_caret))
    # loss = -torch.gather(p_caret, 1, y_star_caret).log().mean()
    # loss = -(ys1*p_caret).sum(-1).log().mean()

    pma, rema = e.loose_macro_PR(gold, pred, opt)
    pmi, remi = e.loose_micro_PR(gold, pred, opt)
    pstr, restr = e.strict_PR(gold, pred, opt)
    # print(f"\nloss_val = {loss}\n"
    #       f"\nmacro-F1 = {e.f1_score(pma, rema)} precision = {pma}, recall = {rema}"
    #       f"\nmicro-F1 = {e.f1_score(pmi, remi)} precision = {pmi}, recall = {remi}"
    #       f"\nstrict-F1 = {e.f1_score(pstr, restr)} precision = {pstr}, recall = {restr}")
    return loss, gold, pred
예제 #6
0
def test(opt,
         model,
         test_dataloader,
         record_result=False,
         analysis_result=False,
         mode=config.TEST):
    device = torch.device(config.CUDA) if torch.cuda.is_available() else "cpu"
    model.eval()

    macro_F1, micro_F1, strict_F1 = 0, 0, 0
    pmacro, remacro = 0, 0
    pmicro, remicro = 0, 0
    pstrict, restrict = 0, 0

    bgn = 0

    total = len(test_dataloader)
    test_iter = iter(test_dataloader)

    gold_all, pred_all = [], []

    p = config.DATA_ROOT + opt.corpus_dir + "hierarchical_types.pkl"

    hierarchical_types = pickle.load(open(p, 'rb'))

    for batch in test_iter:
        # mention, mention_len, mention_neighbor, lcontext, rcontext, y = batch
        mention, mention_len, lcontext, rcontext, mention_char, y = batch

        mention = mention.to(device)
        mention_len = mention_len.to(device)
        # mention_neighbor = mention_neighbor.to(device)
        lcontext = lcontext.to(device)
        rcontext = rcontext.to(device)
        mention_char = mention_char.to(device)
        # feature = feature.to(device)
        y = y.to(device)

        model_output = model(
            [mention, mention_len, lcontext, rcontext, mention_char])

        loss, gold, pred, prob = bce_loss(model, model_output, y, opt,
                                          hierarchical_types)

        if record_result:
            util.record_result(gold, pred, prob, opt, bgn)

        if analysis_result:
            gold_all.append(gold)
            pred_all.append(pred)

        bgn += opt.batch_size
        pma, rema = e.loose_macro_PR(gold, pred, opt)
        macro_F1 += e.f1_score(pma, rema)
        pmacro += pma
        remacro += rema

        pmi, remi = e.loose_micro_PR(gold, pred, opt)
        micro_F1 += e.f1_score(pmi, remi)
        pmicro += pmi
        remicro += remi

        pstr, restr = e.strict_PR(gold, pred, opt)
        strict_F1 += e.f1_score(pstr, restr)
        pstrict += pstr
        restrict += restr

    if analysis_result:
        util.analysis_result(torch.cat(gold_all), torch.cat(pred_all))

    return (macro_F1/total, pmacro/total, remacro/total), \
           (micro_F1/total, pmicro/total, remicro/total), \
           (strict_F1/total, pstrict/total, restrict/total)