Example #1
0
def gen_tf_records_files(
        vocab_file, category_map_file, mentions_file, valid_idxs_file,
        tok_sents_file, max_seq_len, train_output_file, valid_output_file):
    tokenizer = tokenization.SpaceTokenizer(vocab_file)

    category_map_dict, category_id_dict = datautils.load_category_mapping(category_map_file)
    mentions = datautils.read_json_objs(mentions_file)
    valid_idxs = set(datautils.read_json_objs(valid_idxs_file)[0])

    writer_train = tf.python_io.TFRecordWriter(train_output_file)
    writer_valid = tf.python_io.TFRecordWriter(valid_output_file)
    f = open(tok_sents_file, encoding='utf-8')
    for i, tokens_str in enumerate(f):
        mention = mentions[i]
        token_span = mention['token_span']
        categories = mention['tcategory']
        y_categories = onehot_encode(categories, category_id_dict)

        input_ids, input_mask, segment_ids, tokens = get_tfrec_example(tokens_str.lower(), max_seq_len, tokenizer)
        features = __get_feature_dict(input_ids, input_mask, segment_ids, len(tokens), token_span, y_categories)
        tf_example = tf.train.Example(features=tf.train.Features(feature=features))
        if i in valid_idxs:
            writer_valid.write(tf_example.SerializeToString())
        else:
            writer_train.write(tf_example.SerializeToString())
        # if i > 500:
        #     break
    f.close()
    writer_train.close()
    writer_valid.close()
Example #2
0
def load_hyp_preds(hypext_file,
                   verif_hypext_file,
                   hyp_verif_logits_file,
                   use_verif_logits=True):
    from utils import datautils

    hypext_results = datautils.read_json_objs(hypext_file)

    verif_hypext_results = datautils.read_json_objs(verif_hypext_file)
    with open(hyp_verif_logits_file, encoding='utf-8') as f:
        verif_logits = [float(line.strip()) for line in f]
    assert len(verif_hypext_results) == len(verif_logits)
    if len(hypext_results) != len(verif_hypext_results):
        print('len(hypext_results) != len(verif_hypext_results)')

    verif_hypext_result_dict = {
        r['mention_id']: (i, r)
        for i, r in enumerate(verif_hypext_results)
    }

    hypext_results_dict = dict()
    for r in hypext_results:
        mention_id = r['mention_id']
        tmp = verif_hypext_result_dict.get(mention_id)
        if tmp is None:
            continue
        verif_result_idx, verif_result = tmp
        if not use_verif_logits or verif_logits[verif_result_idx] > 0:
            hypext_results_dict[mention_id] = (r,
                                               verif_logits[verif_result_idx])

    return hypext_results_dict
Example #3
0
def samples_from_man_labeled(token_id_dict, unknown_token_id, type_id_dict,
                             mentions_file, sents_file, dep_tags_file,
                             srl_results_file, man_label_file):
    labeled_samples = datautils.real_manual_label_file(man_label_file, None)
    mentions = datautils.read_json_objs(mentions_file)
    sents = datautils.read_json_objs(sents_file)
    sent_dict = {sent['sent_id']: (i, sent) for i, sent in enumerate(sents)}
    dep_tag_seq_list = None
    if dep_tags_file is not None:
        with open(dep_tags_file, encoding='utf-8') as f:
            dep_tag_seq_list = [
                datautils.next_sent_dependency(f) for _ in range(len(sents))
            ]
    srl_results_list = datautils.read_srl_results(srl_results_file)
    print(len(sents), len(srl_results_list))

    mid_manual_label_dict = {x[0]: x[2] for x in labeled_samples}
    samples = list()
    for m in mentions:
        manual_labels = mid_manual_label_dict.get(m['mention_id'], None)
        if manual_labels is None:
            continue
        # print(m)
        # print(manual_labels)
        # exit()

        mspan = m['span']
        sent_idx, sent = sent_dict[m['sent_id']]
        sent_tokens = sent['text'].split(' ')
        dep_tag_seq = dep_tag_seq_list[
            sent_idx] if dep_tag_seq_list is not None else None
        srl_results = srl_results_list[sent_idx]
        matched_tag_list, matched_tag_spans_list = utils.match_srl_to_mentions_all(
            sent_tokens, srl_results, mspan, dep_tag_seq)

        if not matched_tag_list:
            continue

        # type_labels = m.get('labels', ['/PERSON'])
        type_ids = [type_id_dict[t] for t in manual_labels]
        for matched_tag, matched_tag_spans in zip(matched_tag_list,
                                                  matched_tag_spans_list):
            matched_tag_pos = int(matched_tag[-1:])
            srl_info = (utils.get_srl_tag_span(matched_tag_spans, 'V'),
                        utils.get_srl_tag_span(matched_tag_spans, 'ARG0'),
                        utils.get_srl_tag_span(matched_tag_spans, 'ARG1'),
                        utils.get_srl_tag_span(matched_tag_spans,
                                               'ARG2'), matched_tag_pos)
            sent_token_ids = [
                token_id_dict.get(token, unknown_token_id)
                for token in sent_tokens
            ]

            sample = (m['mention_id'], m['str'], mspan[0], mspan[1], None,
                      type_ids, sent_token_ids, srl_info)
            samples.append(sample)
    return samples
Example #4
0
    def __init__(self, base_preds_file, srl_preds_file, hypext_file,
                 verif_hypext_file, hypext_logits_file):
        base_pred_objs = datautils.read_json_objs(base_preds_file)
        self.base_preds_dict = {x['mention_id']: x for x in base_pred_objs}
        srl_pred_objs = datautils.read_json_objs(srl_preds_file)
        self.srl_preds_dict = get_srl_pred_dict(srl_pred_objs)

        self.hyp_preds_dict = load_hyp_preds(hypext_file, verif_hypext_file,
                                             hypext_logits_file)
Example #5
0
def gen_mrel_title_to_wid_file(cands_files, output_file):
    title_wid_dict = dict()
    for cands_file in cands_files:
        cands_objs = datautils.read_json_objs(cands_file)
        for cands_obj in cands_objs:
            for wid, title in zip(cands_obj['wids'], cands_obj['titles']):
                title_wid_dict[title] = wid
    datautils.save_csv(list(title_wid_dict.items()), ['title', 'wid'],
                       output_file)
Example #6
0
def samples_from_test(gres: ResData, child_type_vecs, test_file_tup):
    (mentions_file, sents_file, base_preds_file, srl_preds_file, hypext_file,
     verif_hypext_file, hypext_logits_file) = test_file_tup
    prc = PredResultCollect(base_preds_file, srl_preds_file, hypext_file,
                            verif_hypext_file, hypext_logits_file)
    mentions = datautils.read_json_objs(mentions_file)
    true_labels_dict = {m['mention_id']: m['labels'] for m in mentions}
    sents = datautils.read_json_objs(sents_file)
    samples = list()
    for i, m in enumerate(mentions):
        mention_id = m['mention_id']
        base_logits, srl_logits, hyp_logits, hyp_verif_logit = get_pred_results(
            prc, gres.n_types, gres.type_id_dict, child_type_vecs, mention_id)

        labels = m['labels']
        label_ids = [gres.type_id_dict[t] for t in labels]

        sample = (mention_id, base_logits, srl_logits, hyp_logits,
                  hyp_verif_logit, label_ids)
        samples.append(sample)
    return samples, true_labels_dict
Example #7
0
def model_samples_from_json(token_id_dict, unknown_token_id, mention_token_id,
                            type_id_dict, mentions_file, sents_file):
    sent_tokens_dict = datautils.read_sents_to_token_id_seq_dict(
        sents_file, token_id_dict, unknown_token_id)

    samples = list()
    mentions = datautils.read_json_objs(mentions_file)
    for m in mentions:
        sample = get_model_sample(m['mention_id'],
                                  mention_str=m['str'],
                                  mention_span=m['span'],
                                  sent_tokens=sent_tokens_dict[m['sent_id']],
                                  mention_token_id=mention_token_id)
        samples.append(sample)
    return samples
Example #8
0
def model_samples_from_json(config, token_id_dict, unknown_token_id, type_id_dict,
                            mentions_file, sents_file):

    if config.use_bert:
        tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)
        print('bert tokenizer loaded')
    sent_tokens_id_dict = dict()
    sent_tokens_dict = dict()
    with open(sents_file, encoding='utf-8') as f:
        for line in f:
            sent = json.loads(line)
            tokens = sent['text'].split(' ')
            sent_tokens_id_dict[sent['sent_id']] = [token_id_dict.get(t, unknown_token_id) for t in tokens]
            sent_tokens_dict[sent['sent_id']] = [t for t in tokens]

    samples = list()
    mentions = datautils.read_json_objs(mentions_file)
    for m in mentions:
        if config.use_bert:
            org_tok_sents = sent_tokens_dict[m['sent_id']]
            bert_sent_tokens = org_tok_sents[:m['span'][0]] + ['[MASK]'] + org_tok_sents[m['span'][1]:]
            full_sent = ' '.join(bert_sent_tokens)
            tokens = ["[CLS]"]
            t = tokenizer.tokenize(full_sent)
            tokens.extend(t)
            mention_token_idx = 0
            for i, x in enumerate(tokens):
                if x == '[MASK]':
                    mention_token_idx = i
                    break
            tokens.append("[SEP]")
            sentence_token = tokenizer.convert_tokens_to_ids(tokens)

        else:
            sentence_token = sent_tokens_id_dict[m['sent_id']]
            mention_token_idx = m['span'][0]

        labels = m['labels']
        label_ids = [type_id_dict[t] for t in labels]
        sample = [m['mention_id'],
                  sent_tokens_id_dict[m['sent_id']][m['span'][0]:m['span'][1]],
                  sentence_token,
                  mention_token_idx,
                  label_ids
                  ]
        samples.append(sample)
    return samples
Example #9
0
def check_retrieved_sents():
    import tensorflow as tf

    output_file = os.path.join(config.DATA_DIR, 'tmp/uf_wia_results_200_ins.txt')
    # results_file = os.path.join(config.DATA_DIR, 'realm_output/uf_wia_results_nm.txt')
    results_file = os.path.join(config.DATA_DIR, 'realm_output/uf_wia_results_200.txt')
    samples_file = os.path.join(config.DATA_DIR, 'ultrafine/uf_data/crowd/test.json')
    wia_block_records_path = os.path.join(config.DATA_DIR, 'ultrafine/zoutput/webisa_full_uffilter.tfr')
    wia_block_labels_file = os.path.join(config.DATA_DIR, 'ultrafine/zoutput/webisa_full_uffilter_labels.txt')

    samples = datautils.read_json_objs(samples_file)

    blocks_dataset = tf.data.TFRecordDataset(
        wia_block_records_path, buffer_size=512 * 1024 * 1024)

    sents = list()
    for i, sent in enumerate(blocks_dataset):
        sents.append(sent.numpy().decode('utf-8'))
        if i % 500000 == 0:
            print(i)
    with open(wia_block_labels_file, encoding='utf-8') as f:
        wia_labels = [line.strip() for line in f]

    fout = open(output_file, 'w', encoding='utf-8')
    f = open(results_file, encoding='utf-8')
    for i, line in enumerate(f):
        x = json.loads(line)
        # print(x)
        bids = x['block_ids']
        # print(samples[i])
        uf_sample_str = get_uf_sample_str(samples[i])
        fout.write('{}\n{}\n'.format(uf_sample_str, samples[i]['y_str']))
        # print(uf_sample_str)
        # print(x['y_str'])
        for bid in bids:
            fout.write('{}\n'.format(sents[bid]))
            fout.write('{}\n'.format(wia_labels[bid]))
            # print(sents[bid])
        fout.write('\n')
        # print()
        # if i > 3:
        #     break
    f.close()
    fout.close()
Example #10
0
def train_fetel(device, gres: exputils.GlobalRes, el_entityvec: ELDirectEntityVec, train_samples_pkl,
                dev_samples_pkl, test_mentions_file, test_sents_file, test_noel_preds_file, type_embed_dim,
                context_lstm_hidden_dim, learning_rate, batch_size, n_iter, dropout, rand_per, per_penalty,
                use_mlp=False, pred_mlp_hdim=None, save_model_file=None, nil_rate=0.5,
                single_type_path=False, stack_lstm=False, concat_lstm=False, results_file=None):
    logging.info('result_file={}'.format(results_file))
    logging.info(
        'type_embed_dim={} cxt_lstm_hidden_dim={} pmlp_hdim={} nil_rate={} single_type_path={}'.format(
            type_embed_dim, context_lstm_hidden_dim, pred_mlp_hdim, nil_rate, single_type_path))
    logging.info('rand_per={} per_pen={}'.format(rand_per, per_penalty))
    logging.info('stack_lstm={} cat_lstm={}'.format(stack_lstm, concat_lstm))

    if stack_lstm:
        model = FETELStack(
            device, gres.type_vocab, gres.type_id_dict, gres.embedding_layer, context_lstm_hidden_dim,
            type_embed_dim=type_embed_dim, dropout=dropout, use_mlp=use_mlp, mlp_hidden_dim=pred_mlp_hdim,
            concat_lstm=concat_lstm)
    else:
        model = None
    if device.type == 'cuda':
        model = model.cuda(device.index)

    train_samples = datautils.load_pickle_data(train_samples_pkl)

    dev_samples = datautils.load_pickle_data(dev_samples_pkl)
    dev_samples = anchor_samples_to_model_samples(dev_samples, gres.mention_token_id, gres.parent_type_ids_dict)

    lr_gamma = 0.7
    eval_batch_size = 32
    logging.info('{}'.format(model.__class__.__name__))
    dev_true_labels_dict = {s.mention_id: [gres.type_vocab[l] for l in s.labels] for s in dev_samples}
    dev_entity_vecs, dev_el_sgns, dev_el_probs = __get_entity_vecs_for_samples(el_entityvec, dev_samples, None)

    test_samples = model_samples_from_json(gres.token_id_dict, gres.unknown_token_id, gres.mention_token_id,
                                           gres.type_id_dict, test_mentions_file, test_sents_file)
    test_noel_pred_results = datautils.read_pred_results_file(test_noel_preds_file)

    test_mentions = datautils.read_json_objs(test_mentions_file)
    test_entity_vecs, test_el_sgns, test_el_probs = __get_entity_vecs_for_mentions(
        el_entityvec, test_mentions, test_noel_pred_results, gres.n_types)

    test_true_labels_dict = {m['mention_id']: m['labels'] for m in test_mentions} if (
            'labels' in next(iter(test_mentions))) else None

    person_type_id = gres.type_id_dict.get('/person')
    l2_person_type_ids, person_loss_vec = None, None
    if person_type_id is not None:
        l2_person_type_ids = __get_l2_person_type_ids(gres.type_vocab)
        person_loss_vec = exputils.get_person_type_loss_vec(
            l2_person_type_ids, gres.n_types, per_penalty, model.device)

    dev_results_file = None
    n_batches = (len(train_samples) + batch_size - 1) // batch_size
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=n_batches, gamma=lr_gamma)
    losses = list()
    best_dev_acc = -1
    logging.info('{} steps, {} steps per iter, lr_decay={}, start training ...'.format(
        n_iter * n_batches, n_batches, lr_gamma))
    step = 0
    n_steps = n_iter * n_batches
    while step < n_steps:
        batch_idx = step % n_batches
        batch_beg, batch_end = batch_idx * batch_size, min((batch_idx + 1) * batch_size, len(train_samples))
        batch_samples = anchor_samples_to_model_samples(
            train_samples[batch_beg:batch_end], gres.mention_token_id, gres.parent_type_ids_dict)
        if rand_per:
            entity_vecs, el_sgns, el_probs = __get_entity_vecs_for_samples(
                el_entityvec, batch_samples, None, True, person_type_id, l2_person_type_ids, gres.type_vocab)
        else:
            entity_vecs, el_sgns, el_probs = __get_entity_vecs_for_samples(el_entityvec, batch_samples, None, True)

        use_entity_vecs = True
        model.train()

        (context_token_seqs, mention_token_idxs, mstrs, mstr_token_seqs, y_true
         ) = exputils.get_mstr_cxt_label_batch_input(model.device, gres.n_types, batch_samples)

        if use_entity_vecs:
            for i in range(entity_vecs.shape[0]):
                if np.random.uniform() < nil_rate:
                    entity_vecs[i] = np.zeros(entity_vecs.shape[1], np.float32)
            el_probs = torch.tensor(el_probs, dtype=torch.float32, device=model.device)
            entity_vecs = torch.tensor(entity_vecs, dtype=torch.float32, device=model.device)
        else:
            entity_vecs = None
        logits = model(context_token_seqs, mention_token_idxs, mstr_token_seqs, entity_vecs, el_probs)
        loss = model.get_loss(y_true, logits, person_loss_vec=person_loss_vec)
        scheduler.step()
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 10.0, float('inf'))
        optimizer.step()
        losses.append(loss.data.cpu().numpy())

        step += 1
        if step % 1000 == 0:
            # logging.info('i={} l={:.4f}'.format(step + 1, sum(losses)))
            acc_v, pacc_v, _, _, dev_results = eval_fetel(
                gres, model, dev_samples, dev_entity_vecs, dev_el_probs, eval_batch_size,
                use_entity_vecs=use_entity_vecs, single_type_path=single_type_path,
                true_labels_dict=dev_true_labels_dict)
            acc_t, _, maf1, mif1, test_results = eval_fetel(
                gres, model, test_samples, test_entity_vecs, test_el_probs, eval_batch_size,
                use_entity_vecs=use_entity_vecs, single_type_path=single_type_path,
                true_labels_dict=test_true_labels_dict)

            best_tag = '*' if acc_v > best_dev_acc else ''
            logging.info(
                'i={} l={:.4f} accv={:.4f} paccv={:.4f} acct={:.4f} maf1={:.4f} mif1={:.4f}{}'.format(
                    step, sum(losses), acc_v, pacc_v, acc_t, maf1, mif1, best_tag))
            if acc_v > best_dev_acc and save_model_file:
                torch.save(model.state_dict(), save_model_file)
                logging.info('model saved to {}'.format(save_model_file))

            if dev_results_file is not None and acc_v > best_dev_acc:
                datautils.save_json_objs(dev_results, dev_results_file)
                logging.info('dev reuslts saved {}'.format(dev_results_file))
            if results_file is not None and acc_v > best_dev_acc:
                datautils.save_json_objs(test_results, results_file)
                logging.info('test reuslts saved {}'.format(results_file))

            if acc_v > best_dev_acc:
                best_dev_acc = acc_v
            losses = list()