Exemplo n.º 1
0
def real_manual_label_file(filename, type_id_dict):
    from utils import fetutils

    f = open(filename, encoding='utf-8')
    samples = list()
    for i, line in enumerate(f):
        mention_id = int(line.strip())
        sent_str = next(f).strip()
        # next(f)
        labels_str = next(f).strip()
        if ': ' in labels_str:
            labels_str = next(f).strip()
        next(f)
        if not labels_str or labels_str == '////':
            continue

        labels = labels_str.split(',')
        for t in labels:
            if not t.startswith('/'):
                print(i, mention_id, t, line)
            assert t.startswith('/')
        labels = fetutils.get_full_types(labels)

        label_ids = None
        if type_id_dict is not None:
            try:
                label_ids = [type_id_dict[t] for t in labels]
            except KeyError:
                print(i, mention_id, labels, line)
                exit()

        samples.append((mention_id, sent_str, labels, label_ids))
    f.close()
    return samples
Exemplo n.º 2
0
def load_labeled_samples(type_id_dict, child_type_vecs, labeled_samples_file,
                         pred_file_tup, use_vr, use_hr):
    base_preds_file, srl_preds_file, hyp_file, verif_hyp_file, hypext_logits_file = pred_file_tup
    prc = PredResultCollect(base_preds_file, srl_preds_file, hyp_file,
                            verif_hyp_file, hypext_logits_file)
    n_types = len(type_id_dict)
    samples = list()
    cnt = 0
    f = open(labeled_samples_file, encoding='utf-8')
    for i, line in enumerate(f):
        mention_id = int(line.strip())
        sent_str = next(f).strip()
        # next(f)
        labels_str = next(f).strip()
        if ': ' in labels_str:
            labels_str = next(f).strip()
        next(f)
        if not labels_str or labels_str == '////':
            continue

        srl_pred_obj, hyp_pred_obj = prc.srl_preds_dict.get(
            mention_id), prc.hyp_preds_dict.get(mention_id)
        if srl_pred_obj is None and hyp_pred_obj is None:
            # print(mention_id)
            cnt += 1
            continue
        if not use_vr and hyp_pred_obj is None:
            continue
        if not use_hr and srl_pred_obj is None:
            continue

        base_logits, srl_logits, hyp_logits, hyp_verif_logit = get_pred_results(
            prc, n_types, type_id_dict, child_type_vecs, mention_id)

        labels = labels_str.split(',')
        for t in labels:
            if not t.startswith('/'):
                print(i, mention_id, t, line)
            assert t.startswith('/')
        labels = fetutils.get_full_types(labels)
        try:
            label_ids = [type_id_dict[t] for t in labels]
        except KeyError:
            print(i, mention_id, labels, line)
            exit()

        sample = (mention_id, base_logits, srl_logits, hyp_logits,
                  hyp_verif_logit, label_ids)
        samples.append(sample)
    f.close()
    print(cnt)
    return samples
Exemplo n.º 3
0
def __eval(gres: expdata.ResData,
           models: List[SRLFET],
           samples_list,
           true_labels_dict,
           batch_size=16,
           single_type_path=False):
    pred_labels_dict = dict()
    result_objs = list()
    for mention_arg_idx, samples in enumerate(samples_list):
        model = models[mention_arg_idx]
        model.eval()
        device = model.device
        n_batches = (len(samples) + batch_size - 1) // batch_size
        # print('{} batches'.format(n_batches))
        for i in range(n_batches):
            batch_beg, batch_end = i * batch_size, min((i + 1) * batch_size,
                                                       len(samples))
            samples_batch = samples[batch_beg:batch_end]

            mstr_vec_seqs, verb_vec_seqs, arg1_vec_seqs, arg2_vec_seqs = __get_sample_batch_srl_inputs(
                device, model.n_types, gres.token_vecs, samples_batch,
                mention_arg_idx)
            with torch.no_grad():
                logits = model(mstr_vec_seqs, verb_vec_seqs, arg1_vec_seqs,
                               arg2_vec_seqs)
                # loss = model.get_loss(true_type_vecs, logits)
            # losses.append(loss)

            if single_type_path:
                preds = model.inference(logits)
            else:
                preds = model.inference_full(logits, extra_label_thres=0.0)
            for j, (sample, type_ids_pred, sample_logits) in enumerate(
                    zip(samples_batch, preds,
                        logits.data.cpu().numpy())):
                labels = fetutils.get_full_types(
                    [gres.type_vocab[tid] for tid in type_ids_pred])
                pred_labels_dict[sample[0]] = labels
                result_objs.append({
                    'mention_id': sample[0],
                    'labels': labels,
                    'logits': [float(v) for v in sample_logits]
                })

    strict_acc = fetutils.strict_acc(true_labels_dict, pred_labels_dict)
    # partial_acc = utils.partial_acc(true_labels_dict, pred_labels_dict)
    maf1 = fetutils.macrof1(true_labels_dict, pred_labels_dict)
    mif1 = fetutils.microf1(true_labels_dict, pred_labels_dict)
    return strict_acc, maf1, mif1, result_objs
Exemplo n.º 4
0
def __eval_ens(device, loss_obj, type_vocab, model, type_infer: fetutils.TypeInfer, use_vr, use_hr,
               samples, true_labels_dict):
    batch_size = 16
    n_types = len(type_vocab)
    model.eval()
    pred_labels_dict = dict()
    result_objs = list()
    n_steps = (len(samples) + batch_size - 1) // batch_size
    n_weights = 3
    if not use_vr:
        n_weights -= 1
    if not use_hr:
        n_weights -= 1
    weight_sums = np.zeros(n_weights, np.float32)
    losses = list()
    for step in range(n_steps):
        bbeg, bend = step * batch_size, min((step + 1) * batch_size, len(samples))
        samples_batch = samples[bbeg:bend]
        cur_batch_size = bend - bbeg
        (pred_logits_tensor, max_logits_tensor, true_label_vecs
         ) = __get_batch_input(device, n_types, samples_batch, use_vr, use_hr)
        feats = max_logits_tensor

        with torch.no_grad():
            ens_logits = model(feats)

        weights = torch.nn.functional.softmax(ens_logits, dim=1)
        weights = weights.data.cpu().numpy()
        weight_sums += np.sum(weights, axis=0)
        # ens_logits = torch.tensor([[1, 0.01, 0.01] for _ in range(cur_batch_size)], device=model.device,
        #                           dtype=torch.float32)

        final_logits = ens_labeler_logits(ens_logits, pred_logits_tensor, cur_batch_size)
        if loss_obj is not None:
            loss = loss_obj.loss(true_label_vecs, final_logits)
            losses.append(loss.data.cpu().numpy())

        preds = type_infer.inference(final_logits)
        for j, (sample, type_ids_pred, sample_logits) in enumerate(
                zip(samples_batch, preds, final_logits.data.cpu().numpy())):
            labels = fetutils.get_full_types([type_vocab[tid] for tid in type_ids_pred])
            pred_labels_dict[sample[0]] = labels
            result_objs.append({'mention_id': sample[0], 'labels': labels,
                                'logits': [float(v) for v in sample_logits]})

    strict_acc, maf1, mif1 = fetutils.eval_fet_performance(true_labels_dict, pred_labels_dict)
    return strict_acc, maf1, mif1, result_objs, sum(losses), weight_sums / len(samples)
Exemplo n.º 5
0
def train_stacking(device, gres: expdata.ResData, use_vr, use_hr, type_infer_train, all_train_samples,
                   gres_test: expdata.ResData, type_infer_test, test_samples, test_true_labels_dict):
    n_runs = 5
    n_iter = 200
    n_dev = 100
    batch_size = 32
    n_mlp_layers = 2
    mlp_hdim = 5
    learning_rate = 0.01
    lr_gamma = 0.9
    margin = 0.5
    dropout = 0.5
    n_labelers = 3
    if not use_vr:
        n_labelers -= 1
    if not use_hr:
        n_labelers -= 1

    print('{} train samples'.format(len(all_train_samples)))

    test_acc_list, test_maf1_list, test_mif1_list = list(), list(), list()
    for i in range(n_runs):
        print('run', i)
        random.shuffle(all_train_samples)
        train_samples = all_train_samples
        dev_samples, dev_true_labels_dict = None, None
        if n_dev > 0:
            dev_samples = all_train_samples[:n_dev]
            dev_true_labels_dict = {s[0]: fetutils.get_full_types(
                gres.type_vocab[tid] for tid in s[-1]) for s in dev_samples}
            train_samples = all_train_samples[n_dev:]

        test_acc, test_maf1, test_mif1 = __do_train(
            device, train_samples, dev_samples, dev_true_labels_dict, test_samples, test_true_labels_dict,
            gres.type_vocab, gres_test.type_vocab, type_infer_train, type_infer_test, use_vr, use_hr, n_labelers,
            n_mlp_layers, mlp_hdim, dropout, margin, learning_rate, batch_size, n_iter, lr_gamma)
        test_acc_list.append(test_acc)
        test_maf1_list.append(test_maf1)
        test_mif1_list.append(test_mif1)
    (avg_acc, avg_maf1, avg_mif1
     ) = sum(test_acc_list) / n_runs, sum(test_maf1_list) / n_runs, sum(test_mif1_list) / n_runs
    print('acc', ' '.join(['{:.4f}'.format(v) for v in test_acc_list]), '*', '{:.4f}'.format(avg_acc))
    print('maf1', ' '.join(['{:.4f}'.format(v) for v in test_maf1_list]), '*', '{:.4f}'.format(avg_maf1))
    print('mif1', ' '.join(['{:.4f}'.format(v) for v in test_mif1_list]), '*', '{:.4f}'.format(avg_mif1))
Exemplo n.º 6
0
def get_pred_results(preds_collect: PredResultCollect, n_types, type_id_dict,
                     child_type_vecs, mention_id):
    base_result = preds_collect.base_preds_dict.get(mention_id)
    base_logits = base_result['logits'] if base_result is not None else [
        0.0
    ] * n_types

    srl_result = preds_collect.srl_preds_dict.get(mention_id)
    srl_logits = srl_result['logits'] if srl_result is not None else [
        0.0
    ] * n_types

    hyp_result = preds_collect.hyp_preds_dict.get(mention_id)
    hyp_verif_logit = 0
    if hyp_result is None:
        hyp_logits = [0.0] * n_types
    else:
        hyp_logits = get_hyp_logits(
            fetutils.get_full_types(hyp_result[0]['dtype']), type_id_dict,
            n_types, child_type_vecs)
        hyp_verif_logit = hyp_result[1]

    return base_logits, srl_logits, hyp_logits, hyp_verif_logit