Exemple #1
0
def doc_texts_to_token_ids():
    max_seq_len = 128
    output_file = os.path.join(config.DATA_DIR, 'realm_data/blocks_tok_id_seqs.pkl')
    tfr_text_docs_file = os.path.join(config.DATA_DIR, 'realm_data/blocks.tfr')
    reader_module_path = '/data/hldai/data/realm_data/cc_news_pretrained/bert'
    vocab_file = os.path.join(reader_module_path, 'assets/vocab.txt')
    tokenizer = tokenization.FullTokenizer(vocab_file, do_lower_case=True)

    blocks_dataset = tf.data.TFRecordDataset(tfr_text_docs_file, buffer_size=512 * 1024 * 1024)

    tok_id_seqs = list()
    for i, v in enumerate(blocks_dataset):
        # print(v)
        v = v.numpy()
        v = v.decode('utf-8')
        tokens = tokenizer.tokenize(v)
        # print(tokens)
        # print(len(tokens))
        token_ids = np.array(tokenizer.convert_tokens_to_ids(tokens), dtype=np.int32)
        # print(type(token_ids))
        if len(token_ids) > max_seq_len:
            token_ids = token_ids[:max_seq_len]
        tok_id_seqs.append(token_ids)
        # if i > 3:
        #     break
        if i % 10000 == 0:
            print(i)

    datautils.save_pickle_data(tok_id_seqs, output_file)
Exemple #2
0
def gen_word_freq_vec_pkl(word_freq_file, w2v_file, dim, filter_stop_words,
                          output_file):
    stop_words = None
    if filter_stop_words:
        from utils.stopwords import deeped_stop_words
        stop_words = deeped_stop_words

    with open(word_freq_file, encoding='utf-8') as f:
        df = pd.read_csv(f)
    word_freq_dict = {w: cnt for w, cnt in df.itertuples(False, None)}
    vocab, vecs = list(), list()
    f = open(w2v_file, encoding='utf-8')
    for line in f:
        parts = line.strip().split(' ')
        if len(parts) != dim + 1:
            print(line)
            continue
        # assert(len(parts) == dim + 1)
        if parts[0] not in word_freq_dict or len(parts[0]) < 2:
            continue
        if stop_words is not None and parts[0].lower() in stop_words:
            continue
        vocab.append(parts[0])
        vec = np.asarray([float(v) for v in parts[1:]], np.float32)
        vecs.append(vec)
    f.close()

    freqs = np.asarray([word_freq_dict[w] for w in vocab], np.int32)
    vecs = np.asarray(vecs, np.float32)
    datautils.save_pickle_data((vocab, freqs, vecs), output_file)
Exemple #3
0
def predict_results(estimator, input_fn):
    results_file = os.path.join(config.DATA_DIR, 'tmp/uf_wia_results_200.txt')
    qemb_file = os.path.join(config.DATA_DIR, 'realm_output/uf_test_qembs.pkl')
    fout = open(results_file, 'w', encoding='utf-8')
    qembs = list()
    for i, pred in enumerate(estimator.predict(input_fn)):
        x = {
            'text_id': int(pred['text_ids']),
            'block_ids': [int(v) for v in pred['block_ids']]
        }
        qembs.append(pred['qemb'])
        fout.write('{}\n'.format(json.dumps(x)))
        # print(pred)
        if i > 2:
            break
        if i % 100 == 0:
            print(i)
    fout.close()

    qembs = np.array(qembs)
    datautils.save_pickle_data(qembs, qemb_file)
Exemple #4
0
def sample_from_realm_blocks():
    output_block_records_path = os.path.join(config.DATA_DIR, 'realm_data/blocks_2m.tfr')
    output_block_emb_path = os.path.join(config.DATA_DIR, 'realm_data/block_emb_2m.pkl')
    block_records_path = os.path.join(config.DATA_DIR, 'realm_data/blocks.tfr')
    retriever_module_path = os.path.join(config.DATA_DIR, 'realm_data/cc_news_pretrained/embedder')
    var_name = "block_emb"
    checkpoint_path = os.path.join(retriever_module_path, "encoded", "encoded.ckpt")
    np_db = tf.train.load_checkpoint(checkpoint_path).get_tensor(var_name)
    n_docs = np_db.shape[0]
    n_keep = 2000000
    idxs = np.random.permutation(np.arange(n_docs))[:n_keep]

    idxs = np.sort(idxs)
    datautils.save_pickle_data(np_db[idxs], output_block_emb_path)
    print(idxs[:10])
    # exit()

    idxs = set(idxs)
    # print(idxs[:10])
    # exit()

    blocks_dataset = tf.data.TFRecordDataset(
        block_records_path, buffer_size=512 * 1024 * 1024)
    with tf.io.TFRecordWriter(output_block_records_path) as file_writer:
        for i, v in enumerate(blocks_dataset):
            if i not in idxs:
                continue
            # # features = tf.train.Features(feature=by)
            # feature = {'': bytes_feature([v.numpy()])}
            # # print(feature)
            # # features = {'text': tf.train.Features(feature=[v.numpy()])}
            # features = tf.train.Features(feature=feature)
            # example = tf.train.Example(features=features)
            # print('TTTTTTTTTTTTTT', type(example.SerializeToString()))
            # exit()
            file_writer.write(v.numpy())
            # print(v)
            if i % 100000 == 0:
                print(i)
Exemple #5
0
def replace_with_provided_ent_embed(provided_ent_vocab_file,
                                    provided_ent_embed_file,
                                    mrel_title_to_wid_file,
                                    self_trained_ent_embed_pkl, output_file):
    df = datautils.load_csv(mrel_title_to_wid_file)
    mrel_title_wid_dict = {
        title: wid
        for title, wid in df.itertuples(False, None)
    }
    mrel_entity_names = __read_mrel_ent_vocab_file(provided_ent_vocab_file)
    mrel_entity_vecs = np.load(provided_ent_embed_file)
    assert mrel_entity_vecs.shape[0] == len(mrel_entity_names)

    mrel_wids = [
        mrel_title_wid_dict.get(name, -1) for name in mrel_entity_names
    ]
    wid_vocab, entity_vecs = datautils.load_pickle_data(
        self_trained_ent_embed_pkl)
    wid_eid_dict = {wid: i for i, wid in enumerate(wid_vocab)}
    extra_entity_vecs = list()
    for wid, mrel_entity_vec in zip(mrel_wids, mrel_entity_vecs):
        if wid > -1:
            eid = wid_eid_dict.get(wid, None)
            if eid is not None:
                entity_vecs[eid] = mrel_entity_vec
            else:
                extra_entity_vecs.append(mrel_entity_vec)
                wid_vocab.append(wid)
    new_entity_vecs = np.zeros(
        (entity_vecs.shape[0] + len(extra_entity_vecs), entity_vecs.shape[1]),
        np.float32)
    for i, vec in enumerate(entity_vecs):
        new_entity_vecs[i] = vec
    for i, vec in enumerate(extra_entity_vecs):
        new_entity_vecs[i + entity_vecs.shape[0]] = vec
    datautils.save_pickle_data((wid_vocab, new_entity_vecs), output_file)
Exemple #6
0
def gen_training_data_from_wiki(typed_mentions_file, sents_file, word_vecs_pkl, sample_rate,
                                n_dev_samples, output_files_name_prefix, core_title_wid_file=None):
    np.random.seed(config.RANDOM_SEED)

    core_wids = None
    if core_title_wid_file is not None:
        df = datautils.load_csv(core_title_wid_file)
        core_wids = {wid for _, wid in df.itertuples(False, None)}

    token_vocab, token_vecs = datautils.load_pickle_data(word_vecs_pkl)
    token_id_dict = {t: i for i, t in enumerate(token_vocab)}
    unknown_token_id = token_id_dict[config.TOKEN_UNK]

    f_mention = open(typed_mentions_file, encoding='utf-8')
    f_sent = open(sents_file, encoding='utf-8')
    all_samples = list()
    cur_sent = json.loads(next(f_sent))
    mention_id = 0
    for i, line in enumerate(f_mention):
        if (i + 1) % 1000000 == 0:
            print(i + 1)
        # if i > 400000:
        #     break

        v = np.random.uniform()
        if v > sample_rate:
            continue

        (wid, mention_str, sent_id, pos_beg, pos_end, target_wid, type_ids
         ) = datautils.parse_typed_mention_file_line(line)
        if core_wids is not None and target_wid not in core_wids:
            continue

        mention_str = mention_str.replace('-LRB-', '(').replace('-RRB-', ')')
        while not (cur_sent['wid'] == wid and cur_sent['sent_id'] == sent_id):
            cur_sent = json.loads(next(f_sent))
        sent_tokens = cur_sent['tokens'].split(' ')
        sent_token_ids = [token_id_dict.get(token, unknown_token_id) for token in sent_tokens]

        sample = (mention_id, mention_str, pos_beg, pos_end, target_wid, type_ids, sent_token_ids)
        mention_id += 1
        all_samples.append(sample)
        # print(i, mention_str)
        # print(sent_token_ids)
        # print()
    f_mention.close()
    f_sent.close()

    dev_samples = all_samples[:n_dev_samples]
    train_samples = all_samples[n_dev_samples:]

    print('shuffling ...', end=' ', flush=True)
    rand_perm = np.random.permutation(len(train_samples))
    train_samples_shuffled = list()
    for idx in rand_perm:
        train_samples_shuffled.append(train_samples[idx])
    train_samples = train_samples_shuffled
    print('done')

    dev_mentions, dev_sents = list(), list()
    for i, sample in enumerate(dev_samples):
        mention_id, mention_str, pos_beg, pos_end, target_wid, type_ids, sent_token_ids = sample
        mention = {'mention_id': mention_id, 'span': [pos_beg, pos_end], 'str': mention_str, 'sent_id': i}
        sent = {'sent_id': i, 'text': ' '.join([token_vocab[token_id] for token_id in sent_token_ids]),
                'afet-senid': 0, 'file_id': 'null'}
        dev_mentions.append(mention)
        dev_sents.append(sent)
    datautils.save_json_objs(dev_mentions, output_files_name_prefix + '-dev-mentions.txt')
    datautils.save_json_objs(dev_sents, output_files_name_prefix + '-dev-sents.txt')

    datautils.save_pickle_data(dev_samples, output_files_name_prefix + '-dev.pkl')
    datautils.save_pickle_data(train_samples, output_files_name_prefix + '-train.pkl')
Exemple #7
0
def traindeschyp(device,
                 word_freq_vec_pkl,
                 title_wid_file,
                 articles_file,
                 anchor_cxt_file,
                 n_words_per_ent,
                 batch_size,
                 lr,
                 output_file,
                 hyp_ctxt_len=10,
                 n_desc_iter=200,
                 n_neg_words=5,
                 unig_power=0.6):
    word_vocab, freqs, word_vecs = datautils.load_pickle_data(
        word_freq_vec_pkl)
    word_to_id_dict = {w: i for i, w in enumerate(word_vocab)}
    n_words = word_vecs.shape[0]
    title_wid_df = datautils.load_csv(title_wid_file, False)
    wid_vocab = [wid for _, wid in title_wid_df.itertuples(False, None)]
    entity_titles = [
        title for title, _ in title_wid_df.itertuples(False, None)
    ]
    entity_title_word_ids = get_entity_title_word_ids(entity_titles,
                                                      word_to_id_dict)
    wid_idx_dict = {wid: i for i, wid in enumerate(wid_vocab)}
    n_entities = len(wid_vocab)

    logging.info('{} entities, {} words, word_ves: {}'.format(
        n_entities, n_words, word_vecs.shape))
    neg_word_gen = RandNegWordGen(word_vocab, freqs, unig_power)

    init_entity_vecs = get_init_entity_vecs(entity_title_word_ids, word_vecs)
    model = deschypembed.DescHypEmbed(word_vecs, n_entities, init_entity_vecs)
    if device.type == 'cuda':
        model = model.cuda(device.index)
    optimizer = torch.optim.Adagrad(model.parameters(), lr=lr)

    desc_data_producer = DescDataProducer(neg_word_gen, word_to_id_dict,
                                          wid_idx_dict, articles_file,
                                          entity_title_word_ids,
                                          n_words_per_ent, batch_size,
                                          n_desc_iter)
    anchor_data_producer = AnchorDataProducer(neg_word_gen, anchor_cxt_file,
                                              wid_idx_dict,
                                              entity_title_word_ids,
                                              hyp_ctxt_len, n_words_per_ent,
                                              batch_size, 15)
    use_desc_data = True
    num_batches_per_epoch = 8000
    step = 0
    losses = list()
    while True:
        model.train()
        if use_desc_data:
            entity_ids, pos_word_ids_batch = desc_data_producer.get_pos_words_batch(
            )
            if len(entity_ids) == 0:
                use_desc_data = False
                logging.info('start using anchor data ...')
                continue
        else:
            entity_ids, pos_word_ids_batch = anchor_data_producer.get_pos_words_batch(
            )
            # print(pos_word_ids_batch)
            if len(entity_ids) == 0:
                break

        word_ids_batch, pos_word_idxs_batch, ttn = fill_neg_words(
            neg_word_gen, pos_word_ids_batch, n_neg_words)

        cur_batch_size = len(entity_ids)
        word_ids = torch.tensor(word_ids_batch,
                                dtype=torch.long,
                                device=device)
        entity_ids_tt = torch.tensor(entity_ids,
                                     dtype=torch.long,
                                     device=device)
        target_idxs = torch.tensor(pos_word_idxs_batch,
                                   dtype=torch.long,
                                   device=device)
        scores = model(cur_batch_size, word_ids, entity_ids_tt)
        scores = scores.view(-1, n_neg_words)
        target_scores = scores[list(range(cur_batch_size * n_words_per_ent)),
                               target_idxs.view(-1)]
        # target_scores = get_target_scores_with_try()
        loss = torch.mean(F.relu(0.1 - target_scores.view(-1, 1) + scores))
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 10.0, float('inf'))
        optimizer.step()
        losses.append(loss.data.cpu().numpy())
        step += 1
        if step % num_batches_per_epoch == 0:
            logging.info('i={}, loss={}'.format(step, sum(losses)))
            # logging.info('i={}, loss={}, t0={:.4f}, t1={:.4f}, t2={:.4f}, t3={:.4f}, tn={:.4f}'.format(
            #     step, sum(losses), t0, t1, t2, t3, tn))
            losses = list()

            if output_file:
                entity_vecs = F.normalize(model.entity_embed_layer.weight)
                datautils.save_pickle_data(
                    (wid_vocab, entity_vecs.data.cpu().numpy()), output_file)
                logging.info('saved model to {}'.format(output_file))
Exemple #8
0
def gen_training_data_from_wiki(typed_mentions_file,
                                sents_file,
                                word_vecs_pkl,
                                sample_rate,
                                n_dev_samples,
                                output_files_name_prefix,
                                core_title_wid_file=None,
                                do_bert=False):
    np.random.seed(config.RANDOM_SEED)
    print('output file destination: {}'.format(output_files_name_prefix))

    if do_bert:
        tokenizer = BertTokenizer.from_pretrained('bert-base-cased',
                                                  do_lower_case=False)
        # tokenizer.add_special_tokens({'mention': '[MASK]'})

    core_wids = None
    if core_title_wid_file is not None:
        df = datautils.load_csv(core_title_wid_file)
        core_wids = {wid for _, wid in df.itertuples(False, None)}

    print('loading word vec...')
    token_vocab, token_vecs = datautils.load_pickle_data(word_vecs_pkl)
    token_id_dict = {t: i for i, t in enumerate(token_vocab)}
    unknown_token_id = token_id_dict[config.TOKEN_UNK]

    f_mention = open(typed_mentions_file, encoding='utf-8')
    f_sent = open(sents_file, encoding='utf-8')
    all_samples = list()
    cur_sent = json.loads(next(f_sent))
    mention_id = 0
    for i, line in enumerate(f_mention):
        if (i + 1) % 100000 == 0:
            print(i + 1)
        # if i > 40000:
        #     break

        v = np.random.uniform()
        if v > sample_rate:
            continue

        (wid, mention_str, sent_id, pos_beg, pos_end, target_wid,
         type_ids) = datautils.parse_typed_mention_file_line(line)
        if core_wids is not None and target_wid not in core_wids:
            continue

        mention_str = mention_str.replace('-LRB-', '(').replace('-RRB-', ')')
        while not (cur_sent['wid'] == wid and cur_sent['sent_id'] == sent_id):
            cur_sent = json.loads(next(f_sent))
        sent_tokens = cur_sent['tokens'].split(' ')
        sent_token_ids = [
            token_id_dict.get(token, unknown_token_id) for token in sent_tokens
        ]

        if not do_bert:
            sample = (mention_id, mention_str, pos_beg, pos_end, target_wid,
                      type_ids, sent_token_ids)
        else:
            sent_tokens = sent_tokens[:pos_beg] + ['[MASK]'
                                                   ] + sent_tokens[pos_end:]
            full_sent = ' '.join(sent_tokens)
            tokens = ["[CLS]"]
            t = tokenizer.tokenize(full_sent)
            tokens.extend(t)
            mention_token_idx_bert = 0
            for i, x in enumerate(tokens):
                if x == '[MASK]':
                    mention_token_idx_bert = i
                    break
            tokens.append("[SEP]")
            sent_token_bert_ids = tokenizer.convert_tokens_to_ids(tokens)

            sample = (mention_id, mention_str, pos_beg, pos_end, target_wid,
                      type_ids, sent_token_ids, sent_token_bert_ids,
                      mention_token_idx_bert)

        mention_id += 1
        all_samples.append(sample)
        # print(i, mention_str)
        # print(sent_token_ids)
        # print()
        if (i + 1) % 100000 == 0:
            print(i + 1, mention_str)
            print(sent_token_ids)
            print()
            print(sent_token_bert_ids)

    f_mention.close()
    f_sent.close()

    dev_samples = all_samples[:n_dev_samples]
    train_samples = all_samples[n_dev_samples:]

    print('shuffling ...', end=' ', flush=True)
    rand_perm = np.random.permutation(len(train_samples))
    train_samples_shuffled = list()
    for idx in rand_perm:
        train_samples_shuffled.append(train_samples[idx])
    train_samples = train_samples_shuffled
    print('done')

    dev_mentions, dev_sents = list(), list()
    for i, sample in enumerate(dev_samples):
        if do_bert:
            mention_id, mention_str, pos_beg, pos_end, target_wid, type_ids, sent_token_ids, \
            sent_token_bert_ids, mention_token_idx_bert = sample
        else:
            mention_id, mention_str, pos_beg, pos_end, target_wid, type_ids, sent_token_ids = sample
        mention = {
            'mention_id': mention_id,
            'span': [pos_beg, pos_end],
            'str': mention_str,
            'sent_id': i
        }
        sent = {
            'sent_id':
            i,
            'text':
            ' '.join([token_vocab[token_id] for token_id in sent_token_ids]),
            'afet-senid':
            0,
            'file_id':
            'null'
        }
        dev_mentions.append(mention)
        dev_sents.append(sent)
    datautils.save_json_objs(dev_mentions,
                             output_files_name_prefix + '-dev-mentions.txt')
    datautils.save_json_objs(dev_sents,
                             output_files_name_prefix + '-dev-sents.txt')
    print('saving pickle data...')
    datautils.save_pickle_data(dev_samples,
                               output_files_name_prefix + '-dev.pkl')
    datautils.save_pickle_data(train_samples,
                               output_files_name_prefix + '-train.pkl')
Exemple #9
0
def gen_embs():
    import numpy as np

    embedder_module_path = os.path.join(
        config.DATA_DIR, 'realm_data/cc_news_pretrained/embedder')
    reader_module_path = os.path.join(config.DATA_DIR,
                                      'realm_data/cc_news_pretrained/bert')

    output_emb_file = os.path.join(config.DATA_DIR, 'tmp/blocks10.pkl')
    block_records_path = os.path.join(config.DATA_DIR, 'tmp/blocks10.tfr')
    params = {
        'n_blocks': 12,
        'block_records_path': block_records_path,
        'reader_module_path': reader_module_path,
        'embedder_module_path': embedder_module_path
    }

    # output_emb_file = os.path.join(config.DATA_DIR, 'ultrafine/rlm_fet/enwiki-20151002-type-sents-2m-emb.pkl')
    # block_records_path = os.path.join(config.DATA_DIR, 'ultrafine/rlm_fet/enwiki-20151002-type-sents-2m.tfr')
    # params = {'n_blocks': 2000007, 'block_records_path': block_records_path, 'reader_module_path': reader_module_path,
    #           'embedder_module_path': embedder_module_path}

    # output_emb_file = os.path.join(config.DATA_DIR, 'ultrafine/zoutput/webisa_full_uffilter_emb.pkl')
    # block_records_path = os.path.join(config.DATA_DIR, 'ultrafine/zoutput/webisa_full_uffilter.tfr')
    # params = {'n_blocks': 1671143, 'block_records_path': block_records_path, 'reader_module_path': reader_module_path,
    #           'embedder_module_path': embedder_module_path}

    logger = tf.get_logger()
    logger.setLevel('INFO')
    logger.propagate = False

    def input_fn():
        batch_size = 14

        def data_gen():
            block_ids = list()
            for block_id in range(params['n_blocks']):
                block_ids.append(block_id)
                if len(block_ids) >= batch_size:
                    yield {'block_ids': block_ids}
                    block_ids = list()
            if len(block_ids) > 0:
                yield {'block_ids': block_ids}

        dataset = tf.data.Dataset.from_generator(data_gen,
                                                 output_signature=({
                                                     'block_ids':
                                                     tf.TensorSpec(
                                                         shape=None,
                                                         dtype=tf.int32),
                                                 }))
        return dataset

    run_config = tf.estimator.RunConfig(model_dir=None,
                                        log_step_count_steps=100000,
                                        save_checkpoints_steps=None,
                                        save_checkpoints_secs=None,
                                        tf_random_seed=1973)
    estimator = tf.estimator.Estimator(config=run_config,
                                       model_fn=gen_emb_model_fn,
                                       params=params,
                                       model_dir=None)

    embs_list = list()
    for i, v in enumerate(estimator.predict(input_fn)):
        if i == 0:
            print(v)
            break
        embs_list.append(v)
        # print(i, v)
        # if i > 10:
        #     break
        if i % 1000 == 0:
            print(i)

    datautils.save_pickle_data(np.array(embs_list, dtype=np.float32),
                               output_emb_file)