Example #1
0
def gen_full_embed_title_wid_file(mrel_title_wid_file, core_title_wids_file,
                                  all_title_wid_file, output_file):
    wid_title_dict = datautils.load_title_wid_file(all_title_wid_file,
                                                   to_wid_title_dict=True)
    df = datautils.load_csv(mrel_title_wid_file)
    embed_wid_title_dict = {
        wid: wid_title_dict[wid]
        for _, wid in df.itertuples(False, None)
    }
    df = datautils.load_csv(core_title_wids_file)
    for title, wid in df.itertuples(False, None):
        embed_wid_title_dict[wid] = title
    embed_title_wid_tups = [(title, wid)
                            for wid, title in embed_wid_title_dict.items()]
    datautils.save_csv(embed_title_wid_tups, ['title', 'wid'], output_file)
Example #2
0
def filter_wiki_for_training(required_title_wid_file, word_freq_vec_pkl,
                             articles_file, anchor_cxt_file,
                             output_articles_file, output_anchor_cxt_file):
    df = datautils.load_csv(required_title_wid_file)
    wids = {wid for _, wid in df.itertuples(False, None)}

    word_vocab, freqs, word_vecs = datautils.load_pickle_data(
        word_freq_vec_pkl)
    word_to_id_dict = {w: i for i, w in enumerate(word_vocab)}

    filter_wiki_articles(wids, word_to_id_dict, articles_file,
                         output_articles_file)
    filter_anchor_cxt(wids, word_to_id_dict, anchor_cxt_file,
                      output_anchor_cxt_file)
Example #3
0
def gen_candidates(wcwy_pem_file, title_wid_file):
    print(f'loading {title_wid_file} ...')
    wid_title_dict = {
        wid: title
        for title, wid in datautils.load_csv(title_wid_file).itertuples(
            False, None)
    }
    print(f'loading {wcwy_pem_file} ...')
    mstr_pem_dict = datautils.load_pickle_data(wcwy_pem_file)

    wid_probs = mstr_pem_dict['Japan']

    # wid_prob_tups = [(wid, prob) for wid, prob in wid_probs.items()]
    # wid_prob_tups.sort(key=lambda x: -x[1])
    print(len(wid_probs))
    for wid, prob in wid_probs[:10]:
        print(wid, wid_title_dict.get(wid), prob)
Example #4
0
def replace_with_provided_ent_embed(provided_ent_vocab_file,
                                    provided_ent_embed_file,
                                    mrel_title_to_wid_file,
                                    self_trained_ent_embed_pkl, output_file):
    df = datautils.load_csv(mrel_title_to_wid_file)
    mrel_title_wid_dict = {
        title: wid
        for title, wid in df.itertuples(False, None)
    }
    mrel_entity_names = __read_mrel_ent_vocab_file(provided_ent_vocab_file)
    mrel_entity_vecs = np.load(provided_ent_embed_file)
    assert mrel_entity_vecs.shape[0] == len(mrel_entity_names)

    mrel_wids = [
        mrel_title_wid_dict.get(name, -1) for name in mrel_entity_names
    ]
    wid_vocab, entity_vecs = datautils.load_pickle_data(
        self_trained_ent_embed_pkl)
    wid_eid_dict = {wid: i for i, wid in enumerate(wid_vocab)}
    extra_entity_vecs = list()
    for wid, mrel_entity_vec in zip(mrel_wids, mrel_entity_vecs):
        if wid > -1:
            eid = wid_eid_dict.get(wid, None)
            if eid is not None:
                entity_vecs[eid] = mrel_entity_vec
            else:
                extra_entity_vecs.append(mrel_entity_vec)
                wid_vocab.append(wid)
    new_entity_vecs = np.zeros(
        (entity_vecs.shape[0] + len(extra_entity_vecs), entity_vecs.shape[1]),
        np.float32)
    for i, vec in enumerate(entity_vecs):
        new_entity_vecs[i] = vec
    for i, vec in enumerate(extra_entity_vecs):
        new_entity_vecs[i + entity_vecs.shape[0]] = vec
    datautils.save_pickle_data((wid_vocab, new_entity_vecs), output_file)
Example #5
0
def gen_training_data_from_wiki(typed_mentions_file, sents_file, word_vecs_pkl, sample_rate,
                                n_dev_samples, output_files_name_prefix, core_title_wid_file=None):
    np.random.seed(config.RANDOM_SEED)

    core_wids = None
    if core_title_wid_file is not None:
        df = datautils.load_csv(core_title_wid_file)
        core_wids = {wid for _, wid in df.itertuples(False, None)}

    token_vocab, token_vecs = datautils.load_pickle_data(word_vecs_pkl)
    token_id_dict = {t: i for i, t in enumerate(token_vocab)}
    unknown_token_id = token_id_dict[config.TOKEN_UNK]

    f_mention = open(typed_mentions_file, encoding='utf-8')
    f_sent = open(sents_file, encoding='utf-8')
    all_samples = list()
    cur_sent = json.loads(next(f_sent))
    mention_id = 0
    for i, line in enumerate(f_mention):
        if (i + 1) % 1000000 == 0:
            print(i + 1)
        # if i > 400000:
        #     break

        v = np.random.uniform()
        if v > sample_rate:
            continue

        (wid, mention_str, sent_id, pos_beg, pos_end, target_wid, type_ids
         ) = datautils.parse_typed_mention_file_line(line)
        if core_wids is not None and target_wid not in core_wids:
            continue

        mention_str = mention_str.replace('-LRB-', '(').replace('-RRB-', ')')
        while not (cur_sent['wid'] == wid and cur_sent['sent_id'] == sent_id):
            cur_sent = json.loads(next(f_sent))
        sent_tokens = cur_sent['tokens'].split(' ')
        sent_token_ids = [token_id_dict.get(token, unknown_token_id) for token in sent_tokens]

        sample = (mention_id, mention_str, pos_beg, pos_end, target_wid, type_ids, sent_token_ids)
        mention_id += 1
        all_samples.append(sample)
        # print(i, mention_str)
        # print(sent_token_ids)
        # print()
    f_mention.close()
    f_sent.close()

    dev_samples = all_samples[:n_dev_samples]
    train_samples = all_samples[n_dev_samples:]

    print('shuffling ...', end=' ', flush=True)
    rand_perm = np.random.permutation(len(train_samples))
    train_samples_shuffled = list()
    for idx in rand_perm:
        train_samples_shuffled.append(train_samples[idx])
    train_samples = train_samples_shuffled
    print('done')

    dev_mentions, dev_sents = list(), list()
    for i, sample in enumerate(dev_samples):
        mention_id, mention_str, pos_beg, pos_end, target_wid, type_ids, sent_token_ids = sample
        mention = {'mention_id': mention_id, 'span': [pos_beg, pos_end], 'str': mention_str, 'sent_id': i}
        sent = {'sent_id': i, 'text': ' '.join([token_vocab[token_id] for token_id in sent_token_ids]),
                'afet-senid': 0, 'file_id': 'null'}
        dev_mentions.append(mention)
        dev_sents.append(sent)
    datautils.save_json_objs(dev_mentions, output_files_name_prefix + '-dev-mentions.txt')
    datautils.save_json_objs(dev_sents, output_files_name_prefix + '-dev-sents.txt')

    datautils.save_pickle_data(dev_samples, output_files_name_prefix + '-dev.pkl')
    datautils.save_pickle_data(train_samples, output_files_name_prefix + '-train.pkl')
Example #6
0
def traindeschyp(device,
                 word_freq_vec_pkl,
                 title_wid_file,
                 articles_file,
                 anchor_cxt_file,
                 n_words_per_ent,
                 batch_size,
                 lr,
                 output_file,
                 hyp_ctxt_len=10,
                 n_desc_iter=200,
                 n_neg_words=5,
                 unig_power=0.6):
    word_vocab, freqs, word_vecs = datautils.load_pickle_data(
        word_freq_vec_pkl)
    word_to_id_dict = {w: i for i, w in enumerate(word_vocab)}
    n_words = word_vecs.shape[0]
    title_wid_df = datautils.load_csv(title_wid_file, False)
    wid_vocab = [wid for _, wid in title_wid_df.itertuples(False, None)]
    entity_titles = [
        title for title, _ in title_wid_df.itertuples(False, None)
    ]
    entity_title_word_ids = get_entity_title_word_ids(entity_titles,
                                                      word_to_id_dict)
    wid_idx_dict = {wid: i for i, wid in enumerate(wid_vocab)}
    n_entities = len(wid_vocab)

    logging.info('{} entities, {} words, word_ves: {}'.format(
        n_entities, n_words, word_vecs.shape))
    neg_word_gen = RandNegWordGen(word_vocab, freqs, unig_power)

    init_entity_vecs = get_init_entity_vecs(entity_title_word_ids, word_vecs)
    model = deschypembed.DescHypEmbed(word_vecs, n_entities, init_entity_vecs)
    if device.type == 'cuda':
        model = model.cuda(device.index)
    optimizer = torch.optim.Adagrad(model.parameters(), lr=lr)

    desc_data_producer = DescDataProducer(neg_word_gen, word_to_id_dict,
                                          wid_idx_dict, articles_file,
                                          entity_title_word_ids,
                                          n_words_per_ent, batch_size,
                                          n_desc_iter)
    anchor_data_producer = AnchorDataProducer(neg_word_gen, anchor_cxt_file,
                                              wid_idx_dict,
                                              entity_title_word_ids,
                                              hyp_ctxt_len, n_words_per_ent,
                                              batch_size, 15)
    use_desc_data = True
    num_batches_per_epoch = 8000
    step = 0
    losses = list()
    while True:
        model.train()
        if use_desc_data:
            entity_ids, pos_word_ids_batch = desc_data_producer.get_pos_words_batch(
            )
            if len(entity_ids) == 0:
                use_desc_data = False
                logging.info('start using anchor data ...')
                continue
        else:
            entity_ids, pos_word_ids_batch = anchor_data_producer.get_pos_words_batch(
            )
            # print(pos_word_ids_batch)
            if len(entity_ids) == 0:
                break

        word_ids_batch, pos_word_idxs_batch, ttn = fill_neg_words(
            neg_word_gen, pos_word_ids_batch, n_neg_words)

        cur_batch_size = len(entity_ids)
        word_ids = torch.tensor(word_ids_batch,
                                dtype=torch.long,
                                device=device)
        entity_ids_tt = torch.tensor(entity_ids,
                                     dtype=torch.long,
                                     device=device)
        target_idxs = torch.tensor(pos_word_idxs_batch,
                                   dtype=torch.long,
                                   device=device)
        scores = model(cur_batch_size, word_ids, entity_ids_tt)
        scores = scores.view(-1, n_neg_words)
        target_scores = scores[list(range(cur_batch_size * n_words_per_ent)),
                               target_idxs.view(-1)]
        # target_scores = get_target_scores_with_try()
        loss = torch.mean(F.relu(0.1 - target_scores.view(-1, 1) + scores))
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 10.0, float('inf'))
        optimizer.step()
        losses.append(loss.data.cpu().numpy())
        step += 1
        if step % num_batches_per_epoch == 0:
            logging.info('i={}, loss={}'.format(step, sum(losses)))
            # logging.info('i={}, loss={}, t0={:.4f}, t1={:.4f}, t2={:.4f}, t3={:.4f}, tn={:.4f}'.format(
            #     step, sum(losses), t0, t1, t2, t3, tn))
            losses = list()

            if output_file:
                entity_vecs = F.normalize(model.entity_embed_layer.weight)
                datautils.save_pickle_data(
                    (wid_vocab, entity_vecs.data.cpu().numpy()), output_file)
                logging.info('saved model to {}'.format(output_file))
Example #7
0
def wid_vocab_from_title_wid_file(title_wid_file):
    title_wid_df = datautils.load_csv(title_wid_file)
    return [wid for _, wid in title_wid_df.itertuples(False, None)]
Example #8
0
def gen_training_data_from_wiki(typed_mentions_file,
                                sents_file,
                                word_vecs_pkl,
                                sample_rate,
                                n_dev_samples,
                                output_files_name_prefix,
                                core_title_wid_file=None,
                                do_bert=False):
    np.random.seed(config.RANDOM_SEED)
    print('output file destination: {}'.format(output_files_name_prefix))

    if do_bert:
        tokenizer = BertTokenizer.from_pretrained('bert-base-cased',
                                                  do_lower_case=False)
        # tokenizer.add_special_tokens({'mention': '[MASK]'})

    core_wids = None
    if core_title_wid_file is not None:
        df = datautils.load_csv(core_title_wid_file)
        core_wids = {wid for _, wid in df.itertuples(False, None)}

    print('loading word vec...')
    token_vocab, token_vecs = datautils.load_pickle_data(word_vecs_pkl)
    token_id_dict = {t: i for i, t in enumerate(token_vocab)}
    unknown_token_id = token_id_dict[config.TOKEN_UNK]

    f_mention = open(typed_mentions_file, encoding='utf-8')
    f_sent = open(sents_file, encoding='utf-8')
    all_samples = list()
    cur_sent = json.loads(next(f_sent))
    mention_id = 0
    for i, line in enumerate(f_mention):
        if (i + 1) % 100000 == 0:
            print(i + 1)
        # if i > 40000:
        #     break

        v = np.random.uniform()
        if v > sample_rate:
            continue

        (wid, mention_str, sent_id, pos_beg, pos_end, target_wid,
         type_ids) = datautils.parse_typed_mention_file_line(line)
        if core_wids is not None and target_wid not in core_wids:
            continue

        mention_str = mention_str.replace('-LRB-', '(').replace('-RRB-', ')')
        while not (cur_sent['wid'] == wid and cur_sent['sent_id'] == sent_id):
            cur_sent = json.loads(next(f_sent))
        sent_tokens = cur_sent['tokens'].split(' ')
        sent_token_ids = [
            token_id_dict.get(token, unknown_token_id) for token in sent_tokens
        ]

        if not do_bert:
            sample = (mention_id, mention_str, pos_beg, pos_end, target_wid,
                      type_ids, sent_token_ids)
        else:
            sent_tokens = sent_tokens[:pos_beg] + ['[MASK]'
                                                   ] + sent_tokens[pos_end:]
            full_sent = ' '.join(sent_tokens)
            tokens = ["[CLS]"]
            t = tokenizer.tokenize(full_sent)
            tokens.extend(t)
            mention_token_idx_bert = 0
            for i, x in enumerate(tokens):
                if x == '[MASK]':
                    mention_token_idx_bert = i
                    break
            tokens.append("[SEP]")
            sent_token_bert_ids = tokenizer.convert_tokens_to_ids(tokens)

            sample = (mention_id, mention_str, pos_beg, pos_end, target_wid,
                      type_ids, sent_token_ids, sent_token_bert_ids,
                      mention_token_idx_bert)

        mention_id += 1
        all_samples.append(sample)
        # print(i, mention_str)
        # print(sent_token_ids)
        # print()
        if (i + 1) % 100000 == 0:
            print(i + 1, mention_str)
            print(sent_token_ids)
            print()
            print(sent_token_bert_ids)

    f_mention.close()
    f_sent.close()

    dev_samples = all_samples[:n_dev_samples]
    train_samples = all_samples[n_dev_samples:]

    print('shuffling ...', end=' ', flush=True)
    rand_perm = np.random.permutation(len(train_samples))
    train_samples_shuffled = list()
    for idx in rand_perm:
        train_samples_shuffled.append(train_samples[idx])
    train_samples = train_samples_shuffled
    print('done')

    dev_mentions, dev_sents = list(), list()
    for i, sample in enumerate(dev_samples):
        if do_bert:
            mention_id, mention_str, pos_beg, pos_end, target_wid, type_ids, sent_token_ids, \
            sent_token_bert_ids, mention_token_idx_bert = sample
        else:
            mention_id, mention_str, pos_beg, pos_end, target_wid, type_ids, sent_token_ids = sample
        mention = {
            'mention_id': mention_id,
            'span': [pos_beg, pos_end],
            'str': mention_str,
            'sent_id': i
        }
        sent = {
            'sent_id':
            i,
            'text':
            ' '.join([token_vocab[token_id] for token_id in sent_token_ids]),
            'afet-senid':
            0,
            'file_id':
            'null'
        }
        dev_mentions.append(mention)
        dev_sents.append(sent)
    datautils.save_json_objs(dev_mentions,
                             output_files_name_prefix + '-dev-mentions.txt')
    datautils.save_json_objs(dev_sents,
                             output_files_name_prefix + '-dev-sents.txt')
    print('saving pickle data...')
    datautils.save_pickle_data(dev_samples,
                               output_files_name_prefix + '-dev.pkl')
    datautils.save_pickle_data(train_samples,
                               output_files_name_prefix + '-train.pkl')