Exemple #1
0
def check_qz_scores():
    import numpy as np

    qemb_file = os.path.join(config.DATA_DIR, 'realm_output/uf_test_qembs.pkl')
    results_file = os.path.join(config.DATA_DIR, 'realm_output/uf_wia_results_200.txt')
    wia_block_emb_file = os.path.join(config.DATA_DIR, 'ultrafine/zoutput/webisa_full_uffilter_emb.pkl')

    bids_list = list()
    f = open(results_file, encoding='utf-8')
    for i, line in enumerate(f):
        x = json.loads(line)
        # print(x)
        bids = x['block_ids']
        bids_list.append(bids)
    f.close()

    qembs = datautils.load_pickle_data(qemb_file)
    block_embs = datautils.load_pickle_data(wia_block_emb_file)
    print(qembs.shape)
    print(block_embs.shape)
    scores = list()
    qemb = qembs[0]
    for i, block_emb in enumerate(block_embs):
        scores.append(np.sum(qemb * block_emb))
        if i % 100000 == 0:
            print(i)
    bids = bids_list[0]
    scores = np.array(scores)
    print(scores[:100])
    print(bids)
    print([scores[bid] for bid in bids])
    idxs = np.argsort(-scores)[:10]
    print(idxs)
    print(scores[idxs])
Exemple #2
0
def slim_pickles():
    """
    format of "samples":
    0: mention_id
    1: mstr_token_seqs, list of int : token of mention, using hldai's tokenization method
    2: context_token, list of int : token of context, used [MASK] to substitute the mention token.
        tokenized with bert's tokenization method
    3: mention_token_idx, list of int : index of the position of the mention token, used in fet_bert model
        to retrieve context representation
    4: labels, list of int : not full label, needs to call utils.get_full_types() to obtain full label
    """
    job = 'train'
    # job = 'dev'
    data_pkl = f"/data/cleeag/fetel/Wiki/enwiki20151002anchor-fetwiki-0_1-bert-{job}.pkl"
    # data_pkl = f"/data/cleeag/fetel/Wiki/enwiki20151002anchor-fetwiki-0_1-{job}.pkl"
    print('loading training data {} ...'.format(data_pkl), end=' ', flush=True)
    samples = datautils.load_pickle_data(data_pkl)
    print('done', flush=True)
    out_list = []
    for sample in tqdm(samples):
        new_s = [
            sample[0], sample[6][sample[2]:sample[3]], sample[7], sample[8],
            sample[5]
        ]
        # new_s = (sample[0], sample[6][sample[2]:sample[3]], sample[6], sample[2], sample[5])
        out_list.append(new_s)

    output_file = f"/data/cleeag/fetel/Wiki/enwiki20151002anchor-fetwiki-0_1-bert-{job}-slim.pkl"
    # output_file = f"/data/cleeag/fetel/Wiki/enwiki20151002anchor-fetwiki-0_1-{job}-slim.pkl"
    pkl.dump(out_list, file=open(output_file, 'wb'))
Exemple #3
0
def save_ragged_tensors_dataset_tmp():
    blocks_list = datautils.load_pickle_data(
        os.path.join(config.DATA_DIR, 'realm_data/blocks_tok_id_seqs.pkl'))
    print('blocks list loaded', len(blocks_list))

    # blocks_list = [[3, 4], [5, 3, 8], [8]]
    # for i, v in enumerate(blocks_list):
    #     print(v)
    #     if i > 3:
    #         break
    # blocks_list = blocks_list[:10000]

    def data_gen():
        yield tf.ragged.constant(blocks_list, dtype=tf.int32)
        # for i, vals in enumerate(blocks_list):
        #     if i % 10000 == 0:
        #         print(i)
        #     yield tf.ragged.constant([vals], dtype=tf.int32)

    dataset = tf.data.Dataset.from_generator(
        data_gen,
        output_signature=tf.RaggedTensorSpec(ragged_rank=1, dtype=tf.int32))

    # for v in dataset:
    #     print(v)
    print('saving')
    tf.data.experimental.save(
        dataset,
        '/data/hldai/data/tmp/blocks_tok_id_seqs_l128.tfdata',
        shard_func=lambda x: np.int64(0))
    print('saved')
Exemple #4
0
def load_blocks_from_pkl():
    blocks_list = datautils.load_pickle_data(
        os.path.join(config.DATA_DIR, 'realm_data/blocks_tok_id_seqs.pkl'))
    blocks_list = blocks_list[:10000000]
    print('blocks list loaded', len(blocks_list))
    blocks = tf.ragged.constant(blocks_list, dtype=tf.int32)
    print('blocks list to ragged')
    return blocks
Exemple #5
0
def save_doc_tok_id_seqs_singleall():
    n_parts = 5

    output_path = os.path.join(config.DATA_DIR, 'tmp/blocks_tok_id_seqs_l128_sa4m.tfd')
    blocks_list = datautils.load_pickle_data(os.path.join(config.DATA_DIR, 'realm_data/blocks_tok_id_seqs.pkl'))
    n_blocks = len(blocks_list)
    print('blocks list loaded', n_blocks)
    blocks_list = blocks_list[:4000000]
    save_ragged_vals_to_dataset(blocks_list, output_path, concat_all=False)
Exemple #6
0
def save_ragged_tensors_dataset():
    blocks_list = datautils.load_pickle_data(
        os.path.join(config.DATA_DIR, 'realm_data/blocks_tok_id_seqs.pkl'))
    print('blocks list loaded', len(blocks_list))
    # # blocks_list = [[3, 4], [5, 3, 8], [8]]
    # for i, v in enumerate(blocks_list):
    #     print(v)
    #     if i > 3:
    #         break
    blocks_list = blocks_list[:10]
    save_ragged_vals_to_dataset(blocks_list, '/data/hldai/data/tmp/tmp.tfdata')
Exemple #7
0
def filter_wiki_for_training(required_title_wid_file, word_freq_vec_pkl,
                             articles_file, anchor_cxt_file,
                             output_articles_file, output_anchor_cxt_file):
    df = datautils.load_csv(required_title_wid_file)
    wids = {wid for _, wid in df.itertuples(False, None)}

    word_vocab, freqs, word_vecs = datautils.load_pickle_data(
        word_freq_vec_pkl)
    word_to_id_dict = {w: i for i, w in enumerate(word_vocab)}

    filter_wiki_articles(wids, word_to_id_dict, articles_file,
                         output_articles_file)
    filter_anchor_cxt(wids, word_to_id_dict, anchor_cxt_file,
                      output_anchor_cxt_file)
Exemple #8
0
    def __init__(self, type_vocab_file, word_vecs_file):
        self.type_vocab, self.type_id_dict = datautils.load_type_vocab(type_vocab_file)
        self.parent_type_ids_dict = utils.get_parent_type_ids_dict(self.type_id_dict)
        self.n_types = len(self.type_vocab)

        print('loading {} ...'.format(word_vecs_file), end=' ', flush=True)
        self.token_vocab, self.token_vecs = datautils.load_pickle_data(word_vecs_file)
        self.token_id_dict = {t: i for i, t in enumerate(self.token_vocab)}
        print('done', flush=True)
        self.zero_pad_token_id = self.token_id_dict[config.TOKEN_ZERO_PAD]
        self.mention_token_id = self.token_id_dict[config.TOKEN_MENTION]
        self.unknown_token_id = self.token_id_dict[config.TOKEN_UNK]
        self.embedding_layer = nn.Embedding.from_pretrained(torch.from_numpy(self.token_vecs))
        self.embedding_layer.padding_idx = self.zero_pad_token_id
        self.embedding_layer.weight.requires_grad = False
        self.embedding_layer.share_memory()
Exemple #9
0
def gen_candidates(wcwy_pem_file, title_wid_file):
    print(f'loading {title_wid_file} ...')
    wid_title_dict = {
        wid: title
        for title, wid in datautils.load_csv(title_wid_file).itertuples(
            False, None)
    }
    print(f'loading {wcwy_pem_file} ...')
    mstr_pem_dict = datautils.load_pickle_data(wcwy_pem_file)

    wid_probs = mstr_pem_dict['Japan']

    # wid_prob_tups = [(wid, prob) for wid, prob in wid_probs.items()]
    # wid_prob_tups.sort(key=lambda x: -x[1])
    print(len(wid_probs))
    for wid, prob in wid_probs[:10]:
        print(wid, wid_title_dict.get(wid), prob)
Exemple #10
0
    def __init__(self, type_vocab_file, word_vecs_file):
        self.type_vocab, self.type_id_dict = datautils.load_vocab_file(
            type_vocab_file)
        self.parent_type_ids_dict = fetutils.get_parent_type_ids_dict(
            self.type_id_dict)
        self.n_types = len(self.type_vocab)

        if word_vecs_file is not None:
            import config

            print('loading {} ...'.format(word_vecs_file), end=' ', flush=True)
            self.token_vocab, self.token_vecs = datautils.load_pickle_data(
                word_vecs_file)
            self.token_id_dict = {t: i for i, t in enumerate(self.token_vocab)}
            print('done', flush=True)
            self.zero_pad_token_id = self.token_id_dict[config.TOKEN_ZERO_PAD]
            self.mention_token_id = self.token_id_dict[config.TOKEN_MENTION]
            self.unknown_token_id = self.token_id_dict[config.TOKEN_UNK]
Exemple #11
0
def init_pre_load_data(block_emb_file, block_labels_file, type_id_dict):
    # num_block_records = 13353718
    num_block_records = 2000000
    n_block_rec_parts = [2670743, 5341486, 8012229, 10682972, 13353718]
    var_name = "block_emb"
    # checkpoint_path = os.path.join(retriever_module_path, "encoded", "encoded.ckpt")
    # np_db = tf.train.load_checkpoint(checkpoint_path).get_tensor(var_name)[:4000000]
    # block_emb_file = os.path.join(config.DATA_DIR, 'realm_data/realm_blocks/block_emb_2m.pkl')
    # block_emb_file = os.path.join(config.DATA_DIR, 'ultrafine/rlm_fet/enwiki-20151002-type-sents-2m-emb.pkl')
    # block_records_path = os.path.join(data_dir, 'realm_data/blocks.tfr')
    pre_load_data['np_db'] = datautils.load_pickle_data(block_emb_file)
    pre_load_data['labels'] = None
    if block_labels_file is not None:
        z_labels = list()
        with open(block_labels_file, encoding='utf-8') as f:
            for i, line in enumerate(f):
                label = line.strip()
                tid = type_id_dict.get(label, 0)
                z_labels.append(tid)
        pre_load_data['labels'] = np.array(z_labels, np.int32)
Exemple #12
0
def save_doc_tok_id_seqs_to_parts():
    n_parts = 5

    output_path_prefix = os.path.join(config.DATA_DIR, 'realm_data/blocks_tok_id_seqs_l128/blocks_tok_id_seqs_l128_p')
    blocks_list = datautils.load_pickle_data(os.path.join(config.DATA_DIR, 'realm_data/blocks_tok_id_seqs.pkl'))
    n_blocks = len(blocks_list)
    print('blocks list loaded', n_blocks)

    # # blocks_list = [[3, 4], [5, 3, 8], [8]]
    # for i, v in enumerate(blocks_list):
    #     print(v)
    #     if i > 3:
    #         break
    # blocks_list = blocks_list[:10000]
    n_blocks_per_part = n_blocks // n_parts
    for i in range(n_parts):
        pbeg = i * n_blocks_per_part
        pend = (i + 1) * n_blocks_per_part if i < n_parts - 1 else n_blocks
        print(i, pbeg, pend)
        output_path = '{}{}.tfd'.format(output_path_prefix, i)
        save_ragged_vals_to_dataset(blocks_list[pbeg:pend], output_path, concat_all=True)
Exemple #13
0
def replace_with_provided_ent_embed(provided_ent_vocab_file,
                                    provided_ent_embed_file,
                                    mrel_title_to_wid_file,
                                    self_trained_ent_embed_pkl, output_file):
    df = datautils.load_csv(mrel_title_to_wid_file)
    mrel_title_wid_dict = {
        title: wid
        for title, wid in df.itertuples(False, None)
    }
    mrel_entity_names = __read_mrel_ent_vocab_file(provided_ent_vocab_file)
    mrel_entity_vecs = np.load(provided_ent_embed_file)
    assert mrel_entity_vecs.shape[0] == len(mrel_entity_names)

    mrel_wids = [
        mrel_title_wid_dict.get(name, -1) for name in mrel_entity_names
    ]
    wid_vocab, entity_vecs = datautils.load_pickle_data(
        self_trained_ent_embed_pkl)
    wid_eid_dict = {wid: i for i, wid in enumerate(wid_vocab)}
    extra_entity_vecs = list()
    for wid, mrel_entity_vec in zip(mrel_wids, mrel_entity_vecs):
        if wid > -1:
            eid = wid_eid_dict.get(wid, None)
            if eid is not None:
                entity_vecs[eid] = mrel_entity_vec
            else:
                extra_entity_vecs.append(mrel_entity_vec)
                wid_vocab.append(wid)
    new_entity_vecs = np.zeros(
        (entity_vecs.shape[0] + len(extra_entity_vecs), entity_vecs.shape[1]),
        np.float32)
    for i, vec in enumerate(entity_vecs):
        new_entity_vecs[i] = vec
    for i, vec in enumerate(extra_entity_vecs):
        new_entity_vecs[i + entity_vecs.shape[0]] = vec
    datautils.save_pickle_data((wid_vocab, new_entity_vecs), output_file)
Exemple #14
0
def gen_training_data_from_wiki(typed_mentions_file, sents_file, word_vecs_pkl, sample_rate,
                                n_dev_samples, output_files_name_prefix, core_title_wid_file=None):
    np.random.seed(config.RANDOM_SEED)

    core_wids = None
    if core_title_wid_file is not None:
        df = datautils.load_csv(core_title_wid_file)
        core_wids = {wid for _, wid in df.itertuples(False, None)}

    token_vocab, token_vecs = datautils.load_pickle_data(word_vecs_pkl)
    token_id_dict = {t: i for i, t in enumerate(token_vocab)}
    unknown_token_id = token_id_dict[config.TOKEN_UNK]

    f_mention = open(typed_mentions_file, encoding='utf-8')
    f_sent = open(sents_file, encoding='utf-8')
    all_samples = list()
    cur_sent = json.loads(next(f_sent))
    mention_id = 0
    for i, line in enumerate(f_mention):
        if (i + 1) % 1000000 == 0:
            print(i + 1)
        # if i > 400000:
        #     break

        v = np.random.uniform()
        if v > sample_rate:
            continue

        (wid, mention_str, sent_id, pos_beg, pos_end, target_wid, type_ids
         ) = datautils.parse_typed_mention_file_line(line)
        if core_wids is not None and target_wid not in core_wids:
            continue

        mention_str = mention_str.replace('-LRB-', '(').replace('-RRB-', ')')
        while not (cur_sent['wid'] == wid and cur_sent['sent_id'] == sent_id):
            cur_sent = json.loads(next(f_sent))
        sent_tokens = cur_sent['tokens'].split(' ')
        sent_token_ids = [token_id_dict.get(token, unknown_token_id) for token in sent_tokens]

        sample = (mention_id, mention_str, pos_beg, pos_end, target_wid, type_ids, sent_token_ids)
        mention_id += 1
        all_samples.append(sample)
        # print(i, mention_str)
        # print(sent_token_ids)
        # print()
    f_mention.close()
    f_sent.close()

    dev_samples = all_samples[:n_dev_samples]
    train_samples = all_samples[n_dev_samples:]

    print('shuffling ...', end=' ', flush=True)
    rand_perm = np.random.permutation(len(train_samples))
    train_samples_shuffled = list()
    for idx in rand_perm:
        train_samples_shuffled.append(train_samples[idx])
    train_samples = train_samples_shuffled
    print('done')

    dev_mentions, dev_sents = list(), list()
    for i, sample in enumerate(dev_samples):
        mention_id, mention_str, pos_beg, pos_end, target_wid, type_ids, sent_token_ids = sample
        mention = {'mention_id': mention_id, 'span': [pos_beg, pos_end], 'str': mention_str, 'sent_id': i}
        sent = {'sent_id': i, 'text': ' '.join([token_vocab[token_id] for token_id in sent_token_ids]),
                'afet-senid': 0, 'file_id': 'null'}
        dev_mentions.append(mention)
        dev_sents.append(sent)
    datautils.save_json_objs(dev_mentions, output_files_name_prefix + '-dev-mentions.txt')
    datautils.save_json_objs(dev_sents, output_files_name_prefix + '-dev-sents.txt')

    datautils.save_pickle_data(dev_samples, output_files_name_prefix + '-dev.pkl')
    datautils.save_pickle_data(train_samples, output_files_name_prefix + '-train.pkl')
Exemple #15
0
from utils import exp_utils, datautils, utils
import pickle
import config
from tqdm import tqdm

if __name__ == "__main__":
    task = input('input task: ')
    if task == 'prep':
        # job = 'train'
        job = 'dev'
        # data_pkl = f"/data/cleeag/fetel/Wiki/enwiki20151002anchor-fetwiki-0_1-bert-{job}.pkl"
        data_pkl = f"/data/cleeag/fetel/Wiki/enwiki20151002anchor-fetwiki-0_1-{job}.pkl"
        print('loading training data {} ...'.format(data_pkl), end=' ', flush=True)
        samples = datautils.load_pickle_data(data_pkl)
        print('done', flush=True)
        out_list = []
        for sample in tqdm(samples):
            # new_s = [sample[0], sample[6][sample[2]:sample[3]], sample[7], sample[8], sample[5]]
            new_s = (sample[0], sample[6][sample[2]:sample[3]], sample[6], sample[2], sample[5])
            out_list.append(new_s)

        # output_file = f"/data/cleeag/fetel/Wiki/enwiki20151002anchor-fetwiki-0_1-bert-{job}-slim.pkl"
        output_file = f"/data/cleeag/fetel/Wiki/enwiki20151002anchor-fetwiki-0_1-{job}-slim.pkl"
        pickle.dump(out_list, file=open(output_file, 'wb'))

    elif task == 'r':
        train_data_pkl = "/data/cleeag/fetel/Wiki/enwiki20151002anchor-fetwiki-0_1-bert-dev.pkl"
        samples = datautils.load_pickle_data(train_data_pkl)
        print(samples[0])

    elif task == 'ib':
Exemple #16
0
def train_model(test=False):
    if config.use_gpu:
        device = torch.device(
            'cuda:0') if torch.cuda.device_count() > 0 else torch.device('cpu')
        device_name = torch.cuda.get_device_name(device)
    else:
        device = torch.device('cpu')
        device_name = 'cpu'

    logging.info(f'running on device: {device_name}')
    dataset = 'figer'
    datafiles = config.FIGER_FILES
    word_vecs_file = config.WIKI_FETEL_WORDVEC_FILE
    save_model_file = config.DATA_DIR + 'models' + 'test'

    if config.use_bert:
        data_prefix = datafiles['anchor-train-data-prefix-bert']
    else:
        data_prefix = datafiles['anchor-train-data-prefix']
    # dev_data_pkl = data_prefix + '-dev.pkl'
    # train_data_pkl = data_prefix + '-train.pkl'
    dev_data_pkl = data_prefix + '-dev-slim.pkl'
    if test:
        train_data_pkl = data_prefix + '-dev-slim.pkl'
    else:
        train_data_pkl = data_prefix + '-train-slim.pkl'
    test_results_file = os.path.join(
        config.DATA_DIR, 'Wiki/fetel-deep-results-{}.txt'.format(dataset))

    gres = exp_utils.GlobalRes(datafiles['type-vocab'], word_vecs_file)
    logging.info('dataset={}'.format(dataset))

    logging.info('use_bert = {}, use_lstm = {}, use_mlp={}'.format(
        config.use_bert, config.use_lstm, config.use_mlp))
    logging.info(
        'type_embed_dim={} cxt_lstm_hidden_dim={} pmlp_hdim={}'.format(
            config.type_embed_dim, config.lstm_hidden_dim,
            config.pred_mlp_hdim))
    logging.info('rand_per={} per_pen={}'.format(config.rand_per,
                                                 config.per_penalty))

    print('loading training data {} ...'.format(train_data_pkl),
          end=' ',
          flush=True)
    training_samples = datautils.load_pickle_data(train_data_pkl)
    print('done', flush=True)
    # training_samples = exp_utils.anchor_samples_to_model_samples_bert(config, samples, gres.parent_type_ids_dict)

    print('loading dev data {} ...'.format(dev_data_pkl), end=' ', flush=True)
    dev_samples = datautils.load_pickle_data(dev_data_pkl)
    print('done', flush=True)
    dev_true_labels_dict = {
        s[0]: [
            gres.type_vocab[l]
            for l in utils.get_full_type_ids(s[4], gres.parent_type_ids_dict)
        ]
        for s in dev_samples
    }

    test_samples = exp_utils.model_samples_from_json(
        config, gres.token_id_dict, gres.unknown_token_id, gres.type_id_dict,
        datafiles['fetel-test-mentions'], datafiles['fetel-test-sents'])
    test_true_labels_dict = {
        s[0]: [gres.type_vocab[l] for l in s[4]]
        for s in test_samples
    }

    logging.info('building model...')
    model = fet_model(config, device, gres.type_vocab, gres.type_id_dict,
                      gres.embedding_layer)
    model.to(device)

    logging.info('{}'.format(model.__class__.__name__))
    logging.info('training batch size: {}'.format(config.batch_size))

    # get person penalty vector
    person_type_id = gres.type_id_dict.get('/person')
    l2_person_type_ids = None
    person_loss_vec = None
    if person_type_id is not None:
        l2_person_type_ids = exp_utils.get_l2_person_type_ids(gres.type_vocab)
        person_loss_vec = np.ones(gres.n_types, np.float32)
        for tid in l2_person_type_ids:
            person_loss_vec[tid] = config.per_penalty
        person_loss_vec = torch.tensor(person_loss_vec,
                                       dtype=torch.float32,
                                       device=device)

    n_batches = (len(training_samples) + config.batch_size -
                 1) // config.batch_size
    n_steps = config.n_iter * n_batches
    if config.use_lstm:
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=config.learning_rate)
    elif config.use_bert:
        from pytorch_pretrained_bert.optimization import BertAdam
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=config.learning_rate,
                             warmup=config.bert_adam_warmup,
                             t_total=n_steps)
        # optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=n_batches,
                                                gamma=config.lr_gamma)
    losses = list()
    best_dev_acc = -1

    # start training
    logging.info(
        '{} steps, {} steps per iter, learning rate={}, lr_decay={}, start training ...'
        .format(config.n_iter * n_batches, n_batches, config.learning_rate,
                config.lr_gamma))
    step = 0
    while True:
        if step == n_steps:
            break

        batch_idx = step % n_batches
        batch_beg, batch_end = batch_idx * config.batch_size, min(
            (batch_idx + 1) * config.batch_size, len(training_samples))
        context_token_list, mention_token_idxs, mstr_token_seqs, type_vecs \
            = exp_utils.samples_to_tensor(
            config, device, gres, training_samples[batch_beg:batch_end],
            person_type_id, l2_person_type_ids)
        model.train()
        logits = model(context_token_list, mention_token_idxs, mstr_token_seqs)
        loss = model.get_loss(type_vecs,
                              logits,
                              person_loss_vec=person_loss_vec)
        scheduler.step()
        optimizer.zero_grad()

        loss.backward()
        if config.use_lstm:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 10.0,
                                           float('inf'))
        optimizer.step()
        losses.append(loss.data.cpu().numpy())
        # logging.info('step={}/{} accumulated loss = {:.4f}, loss = {:.4f}'.format(step, n_steps, sum(losses), loss))

        step += 1

        eval_cycle = 1 if config.test else 100
        if step % eval_cycle == 0:
            l_v, acc_v, pacc_v, _, _, dev_results = \
                model_utils.eval_fetel(config, gres, model, dev_samples, dev_true_labels_dict)

            _, acc_t, pacc_t, maf1, mif1, test_results = \
                model_utils.eval_fetel(config, gres, model, test_samples, test_true_labels_dict)

            best_tag = '*' if acc_v > best_dev_acc else ''
            # logging.info(
            #     'step={}/{} l={:.4f} l_v={:.4f} acc_v={:.4f} paccv={:.4f}{}\n'.format(
            #         step, n_steps, loss, l_v, acc_v, pacc_v, best_tag))
            logging.info('step={}/{}, learning rate={}, losses={:.4f}'.format(
                step, n_steps, optimizer.param_groups[0]['lr'], losses))
            logging.info(
                'evaluation result: '
                'l_v={:.4f} acc_v={:.4f} paccv={:.4f} acc_t={:.4f} macro_f1={:.4f} micro_f1={:.4f}{}\n'
                .format(l_v, acc_v, pacc_v, acc_t, maf1, mif1, best_tag))
            if acc_v > best_dev_acc and save_model_file:
                # torch.save(model.state_dict(), save_model_file)
                logging.info('model saved to {}'.format(save_model_file))

            if test_results_file is not None and acc_v > best_dev_acc:
                datautils.save_json_objs(dev_results, test_results_file)
                logging.info('dev reuslts saved {}'.format(test_results_file))

            if acc_v > best_dev_acc:
                best_dev_acc = acc_v
            losses = list()
            # if config.test:
            #     input('proceed? ')

        pass
Exemple #17
0
def gen_training_data_from_wiki(typed_mentions_file,
                                sents_file,
                                word_vecs_pkl,
                                sample_rate,
                                n_dev_samples,
                                output_files_name_prefix,
                                core_title_wid_file=None,
                                do_bert=False):
    np.random.seed(config.RANDOM_SEED)
    print('output file destination: {}'.format(output_files_name_prefix))

    if do_bert:
        tokenizer = BertTokenizer.from_pretrained('bert-base-cased',
                                                  do_lower_case=False)
        # tokenizer.add_special_tokens({'mention': '[MASK]'})

    core_wids = None
    if core_title_wid_file is not None:
        df = datautils.load_csv(core_title_wid_file)
        core_wids = {wid for _, wid in df.itertuples(False, None)}

    print('loading word vec...')
    token_vocab, token_vecs = datautils.load_pickle_data(word_vecs_pkl)
    token_id_dict = {t: i for i, t in enumerate(token_vocab)}
    unknown_token_id = token_id_dict[config.TOKEN_UNK]

    f_mention = open(typed_mentions_file, encoding='utf-8')
    f_sent = open(sents_file, encoding='utf-8')
    all_samples = list()
    cur_sent = json.loads(next(f_sent))
    mention_id = 0
    for i, line in enumerate(f_mention):
        if (i + 1) % 100000 == 0:
            print(i + 1)
        # if i > 40000:
        #     break

        v = np.random.uniform()
        if v > sample_rate:
            continue

        (wid, mention_str, sent_id, pos_beg, pos_end, target_wid,
         type_ids) = datautils.parse_typed_mention_file_line(line)
        if core_wids is not None and target_wid not in core_wids:
            continue

        mention_str = mention_str.replace('-LRB-', '(').replace('-RRB-', ')')
        while not (cur_sent['wid'] == wid and cur_sent['sent_id'] == sent_id):
            cur_sent = json.loads(next(f_sent))
        sent_tokens = cur_sent['tokens'].split(' ')
        sent_token_ids = [
            token_id_dict.get(token, unknown_token_id) for token in sent_tokens
        ]

        if not do_bert:
            sample = (mention_id, mention_str, pos_beg, pos_end, target_wid,
                      type_ids, sent_token_ids)
        else:
            sent_tokens = sent_tokens[:pos_beg] + ['[MASK]'
                                                   ] + sent_tokens[pos_end:]
            full_sent = ' '.join(sent_tokens)
            tokens = ["[CLS]"]
            t = tokenizer.tokenize(full_sent)
            tokens.extend(t)
            mention_token_idx_bert = 0
            for i, x in enumerate(tokens):
                if x == '[MASK]':
                    mention_token_idx_bert = i
                    break
            tokens.append("[SEP]")
            sent_token_bert_ids = tokenizer.convert_tokens_to_ids(tokens)

            sample = (mention_id, mention_str, pos_beg, pos_end, target_wid,
                      type_ids, sent_token_ids, sent_token_bert_ids,
                      mention_token_idx_bert)

        mention_id += 1
        all_samples.append(sample)
        # print(i, mention_str)
        # print(sent_token_ids)
        # print()
        if (i + 1) % 100000 == 0:
            print(i + 1, mention_str)
            print(sent_token_ids)
            print()
            print(sent_token_bert_ids)

    f_mention.close()
    f_sent.close()

    dev_samples = all_samples[:n_dev_samples]
    train_samples = all_samples[n_dev_samples:]

    print('shuffling ...', end=' ', flush=True)
    rand_perm = np.random.permutation(len(train_samples))
    train_samples_shuffled = list()
    for idx in rand_perm:
        train_samples_shuffled.append(train_samples[idx])
    train_samples = train_samples_shuffled
    print('done')

    dev_mentions, dev_sents = list(), list()
    for i, sample in enumerate(dev_samples):
        if do_bert:
            mention_id, mention_str, pos_beg, pos_end, target_wid, type_ids, sent_token_ids, \
            sent_token_bert_ids, mention_token_idx_bert = sample
        else:
            mention_id, mention_str, pos_beg, pos_end, target_wid, type_ids, sent_token_ids = sample
        mention = {
            'mention_id': mention_id,
            'span': [pos_beg, pos_end],
            'str': mention_str,
            'sent_id': i
        }
        sent = {
            'sent_id':
            i,
            'text':
            ' '.join([token_vocab[token_id] for token_id in sent_token_ids]),
            'afet-senid':
            0,
            'file_id':
            'null'
        }
        dev_mentions.append(mention)
        dev_sents.append(sent)
    datautils.save_json_objs(dev_mentions,
                             output_files_name_prefix + '-dev-mentions.txt')
    datautils.save_json_objs(dev_sents,
                             output_files_name_prefix + '-dev-sents.txt')
    print('saving pickle data...')
    datautils.save_pickle_data(dev_samples,
                               output_files_name_prefix + '-dev.pkl')
    datautils.save_pickle_data(train_samples,
                               output_files_name_prefix + '-train.pkl')
Exemple #18
0
def traindeschyp(device,
                 word_freq_vec_pkl,
                 title_wid_file,
                 articles_file,
                 anchor_cxt_file,
                 n_words_per_ent,
                 batch_size,
                 lr,
                 output_file,
                 hyp_ctxt_len=10,
                 n_desc_iter=200,
                 n_neg_words=5,
                 unig_power=0.6):
    word_vocab, freqs, word_vecs = datautils.load_pickle_data(
        word_freq_vec_pkl)
    word_to_id_dict = {w: i for i, w in enumerate(word_vocab)}
    n_words = word_vecs.shape[0]
    title_wid_df = datautils.load_csv(title_wid_file, False)
    wid_vocab = [wid for _, wid in title_wid_df.itertuples(False, None)]
    entity_titles = [
        title for title, _ in title_wid_df.itertuples(False, None)
    ]
    entity_title_word_ids = get_entity_title_word_ids(entity_titles,
                                                      word_to_id_dict)
    wid_idx_dict = {wid: i for i, wid in enumerate(wid_vocab)}
    n_entities = len(wid_vocab)

    logging.info('{} entities, {} words, word_ves: {}'.format(
        n_entities, n_words, word_vecs.shape))
    neg_word_gen = RandNegWordGen(word_vocab, freqs, unig_power)

    init_entity_vecs = get_init_entity_vecs(entity_title_word_ids, word_vecs)
    model = deschypembed.DescHypEmbed(word_vecs, n_entities, init_entity_vecs)
    if device.type == 'cuda':
        model = model.cuda(device.index)
    optimizer = torch.optim.Adagrad(model.parameters(), lr=lr)

    desc_data_producer = DescDataProducer(neg_word_gen, word_to_id_dict,
                                          wid_idx_dict, articles_file,
                                          entity_title_word_ids,
                                          n_words_per_ent, batch_size,
                                          n_desc_iter)
    anchor_data_producer = AnchorDataProducer(neg_word_gen, anchor_cxt_file,
                                              wid_idx_dict,
                                              entity_title_word_ids,
                                              hyp_ctxt_len, n_words_per_ent,
                                              batch_size, 15)
    use_desc_data = True
    num_batches_per_epoch = 8000
    step = 0
    losses = list()
    while True:
        model.train()
        if use_desc_data:
            entity_ids, pos_word_ids_batch = desc_data_producer.get_pos_words_batch(
            )
            if len(entity_ids) == 0:
                use_desc_data = False
                logging.info('start using anchor data ...')
                continue
        else:
            entity_ids, pos_word_ids_batch = anchor_data_producer.get_pos_words_batch(
            )
            # print(pos_word_ids_batch)
            if len(entity_ids) == 0:
                break

        word_ids_batch, pos_word_idxs_batch, ttn = fill_neg_words(
            neg_word_gen, pos_word_ids_batch, n_neg_words)

        cur_batch_size = len(entity_ids)
        word_ids = torch.tensor(word_ids_batch,
                                dtype=torch.long,
                                device=device)
        entity_ids_tt = torch.tensor(entity_ids,
                                     dtype=torch.long,
                                     device=device)
        target_idxs = torch.tensor(pos_word_idxs_batch,
                                   dtype=torch.long,
                                   device=device)
        scores = model(cur_batch_size, word_ids, entity_ids_tt)
        scores = scores.view(-1, n_neg_words)
        target_scores = scores[list(range(cur_batch_size * n_words_per_ent)),
                               target_idxs.view(-1)]
        # target_scores = get_target_scores_with_try()
        loss = torch.mean(F.relu(0.1 - target_scores.view(-1, 1) + scores))
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 10.0, float('inf'))
        optimizer.step()
        losses.append(loss.data.cpu().numpy())
        step += 1
        if step % num_batches_per_epoch == 0:
            logging.info('i={}, loss={}'.format(step, sum(losses)))
            # logging.info('i={}, loss={}, t0={:.4f}, t1={:.4f}, t2={:.4f}, t3={:.4f}, tn={:.4f}'.format(
            #     step, sum(losses), t0, t1, t2, t3, tn))
            losses = list()

            if output_file:
                entity_vecs = F.normalize(model.entity_embed_layer.weight)
                datautils.save_pickle_data(
                    (wid_vocab, entity_vecs.data.cpu().numpy()), output_file)
                logging.info('saved model to {}'.format(output_file))
Exemple #19
0
def save_doc_tok_id_seqs_whole():
    n_docs = 40000
    output_path = os.path.join(config.DATA_DIR, 'realm_data/blocks_tok_id_seqs_l128_4k.tfd')
    blocks_list = datautils.load_pickle_data(os.path.join(config.DATA_DIR, 'realm_data/blocks_tok_id_seqs.pkl'))
    blocks_list = blocks_list[:n_docs]
    save_ragged_vals_to_dataset(blocks_list, output_path)
Exemple #20
0
    def __init__(self, config):

        self.only_general_types = config.only_general_types
        self.without_general_types = config.without_general_types
        if config.dataset == 'ufet':
            self.ANSWER_NUM_DICT = {
                "open": 10331,
                "onto": 89,
                "wiki": 4600,
                "kb": 130,
                "gen": 9
            }
            self.n_types = self.ANSWER_NUM_DICT[config.dataset_type]
            with open(config.UFET_FILES['ufet_training_type_set'], 'r') as r:
                self.type_id2type_dict = {
                    i: x.strip()
                    for i, x in enumerate(r.readlines()) if i < self.n_types
                }
            self.type2type_id_dict = {
                tp: id
                for id, tp in self.type_id2type_dict.items()
            }
            self.general_type_set = set([
                v for k, v in self.type_id2type_dict.items()
                if k < self.ANSWER_NUM_DICT['gen']
            ])

        elif config.GENERAL_TYPES_MAPPING:
            self.types2general_types_mapping = datautils.load_pickle_data(
                config.GENERAL_TYPES_MAPPING)
            self.general_type_set = set([
                v for k, vs in self.types2general_types_mapping.items()
                for v in vs
            ])
            if self.only_general_types:
                self.n_types = len(self.general_type_set)
                self.type_id2type_dict = {
                    i: x
                    for i, x in enumerate(self.general_type_set)
                }

            else:
                if self.without_general_types:
                    self.types2general_types_mapping = {
                        k: v
                        for k, v in self.types2general_types_mapping.items()
                    }
                    self.type_id2type_dict = {
                        i: x
                        for i, x in enumerate(self.types2general_types_mapping)
                    }
                else:
                    self.type_id2type_dict = {
                        i: x
                        for i, x in enumerate(self.types2general_types_mapping)
                    }
                self.n_types = len(self.type_id2type_dict)

            self.type2type_id_dict = {
                tp: id
                for id, tp in self.type_id2type_dict.items()
            }
            # print(self.type2type_id_dict['其他'])
            self.gen_idxs = [
                k for k, v in self.type_id2type_dict.items()
                if v in self.general_type_set
            ]
            self.fine_idxs = [
                k for k, v in self.type_id2type_dict.items()
                if v not in self.general_type_set
            ]

        tic = time.time()
        WORD_VECS_FILE = f'/data/cleeag/word_embeddings/{config.mention_tokenizer_name}/' \
                         f'{config.mention_tokenizer_name}_tokenizer&vecs.pkl'
        print('loading {} ...'.format(WORD_VECS_FILE), end=' ', flush=True)
        self.token2token_id_dict, self.token_vecs = datautils.load_pickle_data(
            WORD_VECS_FILE)
        print(f'done, {time.time() - tic :.2f} secs taken.', flush=True)
        self.zero_pad_token_id = self.token2token_id_dict[
            config.TOKEN_ZERO_PAD]
        self.mention_token_id = self.token2token_id_dict[config.TOKEN_MENTION]
        self.unknown_token_id = self.token2token_id_dict[config.TOKEN_UNK]
        self.embedding_layer = nn.Embedding.from_pretrained(
            torch.from_numpy(self.token_vecs))
        self.embedding_layer.padding_idx = self.token2token_id_dict[
            config.TOKEN_ZERO_PAD]
        self.embedding_layer.weight.requires_grad = False
        self.embedding_layer.share_memory()

        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese',
                                                       do_lower_case=False)
Exemple #21
0
def check_block_emb():
    emb_file = os.path.join(
        config.DATA_DIR,
        'ultrafine/rlm_fet/enwiki-20151002-type-sents-2m-emb.pkl')
    block_emb = datautils.load_pickle_data(emb_file)
    print(block_emb.shape)
Exemple #22
0
def train_model():
    if config.dataset == 'cufe':
        data_prefix = config.CUFE_FILES['training_data_prefix']
        if config.train_on_crowd:
            data_prefix = config.CUFE_FILES['crowd_training_data_prefix']

    elif config.dataset == 'ufet':
        data_prefix = config.UFET_FILES['training_data_prefix']

    data_prefix += f'-{config.seq_tokenizer_name}'
    if config.dir_suffix:
        data_prefix += '-' + config.dir_suffix
    if not os.path.isdir(data_prefix): os.mkdir(data_prefix)

    str_today = datetime.datetime.now().strftime('%m_%d_%H%M')
    if not os.path.isdir(join(data_prefix, 'log')):
        os.mkdir(join(data_prefix, 'log'))
    if not config.test:
        log_file = os.path.join(
            join(data_prefix, 'log'),
            '{}-{}.log'.format(str_today, config.model_name))
        print(log_file)
    else:
        log_file = os.path.join(
            config.LOG_DIR, '{}-{}-test.log'.format(str_today,
                                                    config.model_name))
    # if not os.path.isdir(log_file) and not config.test: os.mkdir(log_file)
    init_universal_logging(log_file, mode='a', to_stdout=True)

    save_model_dir = join(data_prefix, 'models')
    if not os.path.isdir(save_model_dir): os.mkdir(save_model_dir)

    gres = exp_utils.GlobalRes(config)

    run_name = f'{config.dataset}-{config.model_name}-{config.seq_tokenizer_name}-{str_today}'
    logging.info(f'run_name: {run_name}')
    logging.info(
        f'/data/cleeag/word_embeddings/{config.mention_tokenizer_name}/{config.mention_tokenizer_name}_tokenizer&vecs.pkl -- loaded'
    )
    logging.info(f'training on {config.dataset}')
    logging.info(
        f'total type count: {len(gres.type2type_id_dict)}, '
        f'general type count: {0 if config.without_general_types else len(gres.general_type_set)}'
    )

    if config.dataset == 'ufet':
        crowd_training_samples = f'{config.CROWD_TRAIN_DATA_PREFIX}-{config.seq_tokenizer_name}.pkl'
        if config.test:

            train_data_pkl = join(data_prefix, 'dev.pkl')
            training_samples = datautils.load_pickle_data(train_data_pkl)
            crowd_training_samples = datautils.load_pickle_data(
                crowd_training_samples)

        else:
            train_data_pkl = join(data_prefix, 'train.pkl')
            print('loading training data {} ...'.format(train_data_pkl),
                  end=' ',
                  flush=True)
            training_samples = datautils.load_pickle_data(train_data_pkl)
            print('done', flush=True)
            logging.info('training data {} -- loaded'.format(train_data_pkl))

            crowd_training_samples = datautils.load_pickle_data(
                crowd_training_samples)

            if config.fine_tune and config.use_bert:
                # training_samples = random.choices(len(training_samples) // 10, training_samples)
                random.shuffle(training_samples)
                training_samples = training_samples[:len(training_samples) //
                                                    10]
                logging.info(
                    f'fining tuning with {len(training_samples)} samples')

        # dev_data_pkl = join(data_prefix, 'dev.pkl')
        # dev_samples = datautils.load_pickle_data(dev_data_pkl)
        dev_json_path = join(data_prefix, 'dev.json')
        dev_samples = exp_utils.model_samples_from_json(dev_json_path)
        dev_true_labels_dict = {
            s['mention_id']: [
                gres.type2type_id_dict.get(x) for x in s['types']
                if x in gres.type2type_id_dict
            ]
            for s in dev_samples
        }

        test_data_pkl = join(data_prefix, 'test.pkl')
        test_samples = datautils.load_pickle_data(test_data_pkl)
        test_true_labels_dict = {
            s['mention_id']: [
                gres.type2type_id_dict.get(x) for x in s['types']
                if x in gres.type2type_id_dict
            ]
            for s in test_samples
        }

    else:
        dev_data_pkl = join(data_prefix, 'dev.pkl')
        test_data_pkl = join(data_prefix, 'test.pkl')
        if config.test:
            train_data_pkl = join(data_prefix, 'dev.pkl')
        else:
            train_data_pkl = join(data_prefix, 'train.pkl')
        # test_data_pkl = config.CUFE_FILES['test_data_file_prefix'] + f'-{config.seq_tokenizer_name}/test.pkl'

        print('loading training data {} ...'.format(train_data_pkl),
              end=' ',
              flush=True)
        training_samples = datautils.load_pickle_data(train_data_pkl)
        print('done', flush=True)
        logging.info('training data {} -- loaded'.format(train_data_pkl))

        if not config.train_on_crowd:
            # crowd_training_samples = f'{config.CROWD_TRAIN_DATA_PREFIX}-{config.seq_tokenizer_name}/train.pkl'
            crowd_training_samples = datautils.load_pickle_data(
                join(data_prefix, 'crowd-train.pkl'))

        print('loading dev data {} ...'.format(dev_data_pkl),
              end=' ',
              flush=True)
        dev_samples = datautils.load_pickle_data(dev_data_pkl)
        print('done', flush=True)
        dev_true_labels_dict = {
            s['mention_id']: [
                gres.type2type_id_dict.get(x)
                for x in exp_utils.general_mapping(s['types'], gres)
            ]
            for s in dev_samples
        }

        test_samples = datautils.load_pickle_data(test_data_pkl)
        test_true_labels_dict = {
            s['mention_id']: [
                gres.type2type_id_dict.get(x)
                for x in exp_utils.general_mapping(s['types'], gres)
            ]
            for s in test_samples
        }

    logging.info(
        f'total training samples: {len(training_samples)}, '
        f'dev samples: {len(dev_samples)}, testing samples: {len(test_samples)}'
    )

    if not config.test:
        result_dir = join(data_prefix, f'{str_today}-results')
        if config.dataset == 'cufe':
            type_scope = 'general_types' if config.only_general_types else 'all_types'
        else:
            type_scope = config.dataset_type
        dev_results_file = join(
            result_dir,
            f'dev-{config.model_name}-{type_scope}-results-{config.inference_threshhold}.txt'
        )
        dev_incorrect_results_file = join(
            result_dir,
            f'dev-{config.model_name}-{type_scope}-incorrect_results-{config.inference_threshhold}.txt'
        )
        test_results_file = join(
            result_dir,
            f'test-{config.model_name}-{type_scope}-results-{config.inference_threshhold}.txt'
        )
        test_incorrect_results_file = join(
            result_dir,
            f'test-{config.model_name}-{type_scope}-incorrect_results-{config.inference_threshhold}.txt'
        )
    else:
        result_dir = join(data_prefix, f'test-results')
        dev_results_file = join(result_dir, f'dev-results.txt')
        dev_incorrect_results_file = join(result_dir,
                                          f'dev-incorrect_results.txt')
        test_results_file = join(result_dir, f'test-results.txt')
        test_incorrect_results_file = join(result_dir,
                                           f'test-incorrect_results.txt')

    if not os.path.isdir(result_dir): os.mkdir(result_dir)

    logging.info(
        'use_bert = {}, use_lstm = {}, use_mlp={}, bert_param_frozen={}, bert_fine_tune={}'
        .format(config.use_bert, config.use_lstm, config.use_mlp,
                config.freeze_bert, config.fine_tune))
    logging.info(
        'type_embed_dim={} contextt_lstm_hidden_dim={} pmlp_hdim={}'.format(
            config.type_embed_dim, config.lstm_hidden_dim,
            config.pred_mlp_hdim))

    # setup training
    device = torch.device(
        f'cuda:{config.gpu_ids[0]}') if torch.cuda.device_count() > 0 else None
    device_name = torch.cuda.get_device_name(config.gpu_ids[0])

    logging.info(f'running on device: {device_name}')
    logging.info('building model...')

    model = fet_model(config, device, gres)
    logging.info(f'transfer={config.transfer}')

    if config.continue_train:
        model_path = config.CONTINUE_TRAINING_PATH[config.continue_train]
        logging.info(f'loading checkpoint from {model_path}')
        trained_weights = torch.load(model_path)
        trained_weights = {
            '.'.join(k.split('.')[1:]): v
            for k, v in trained_weights.items()
        }
        cur_model_dict = model.state_dict()
        cur_model_dict.update(trained_weights)
        model.load_state_dict(cur_model_dict)

    if config.transfer and config.use_lstm:
        logging.info(f'loading checkpoint from {config.TRANSFER_MODEL_PATH}')
        cur_model_dict = model.state_dict()
        trained_weights = torch.load(config.TRANSFER_MODEL_PATH)
        trained_weights_bilstm = {
            '.'.join(k.split('.')[1:]): v
            for k, v in trained_weights.items() if 'bi_lstm' in k
        }
        cur_model_dict.update(trained_weights_bilstm)
        model.load_state_dict(cur_model_dict)

    model.to(device)
    model = torch.nn.DataParallel(model, device_ids=config.gpu_ids)

    batch_size = 32 if config.dataset == 'cufe' and config.train_on_crowd else config.batch_size
    n_iter = 150 if config.dataset == 'cufe' and config.train_on_crowd else config.n_iter
    n_batches = (len(training_samples) + batch_size - 1) // batch_size
    n_steps = n_iter * n_batches
    eval_cycle = config.eval_cycle

    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
    losses = list()
    best_dev_acc = -1
    best_maf1_v = -1
    step = 0
    steps_since_last_best = 0

    # start training
    logging.info('{}'.format(model.__class__.__name__))
    logging.info('training batch size: {}'.format(batch_size))
    logging.info(
        '{} epochs, {} steps, {} steps per iter, learning rate={}, lr_decay={}, start training ...'
        .format(n_iter, n_steps, n_batches, config.learning_rate,
                config.lr_gamma))

    while True:
        if step == n_steps:
            break

        batch_idx = step % n_batches
        batch_beg, batch_end = batch_idx * batch_size, min(
            (batch_idx + 1) * batch_size, len(training_samples))
        if config.dataset == 'ufet':
            batch_samples = training_samples[batch_beg:batch_end - batch_size * 2 // 3] \
                            + random.choices(crowd_training_samples, k=batch_size * 1 // 3)
        elif config.dataset == 'cufe':
            if not config.train_on_crowd:
                batch_samples = training_samples[batch_beg:batch_end - batch_size * 2 // 3] \
                                + random.choices(crowd_training_samples, k=batch_size * 1 // 3)
            else:
                batch_samples = training_samples[batch_beg:batch_end]

        try:
            input_dataset, type_vecs = exp_utils.samples_to_tensor(
                config, gres, batch_samples)

            input_dataset = tuple(x.to(device) for x in input_dataset)
            type_vecs = type_vecs.to(device)
            model.module.train()
            logits = model(input_dataset, gres)
        except:
            step += 1
            continue

        if config.dataset == 'ufet':
            loss = model.module.define_loss(logits, type_vecs,
                                            config.dataset_type)
        elif config.GENERAL_TYPES_MAPPING and not config.only_general_types:
            loss = model.module.get_uw_loss(logits, type_vecs, gres)
        else:
            loss = model.module.get_loss(logits, type_vecs)
        optimizer.zero_grad()

        loss.backward()
        if config.use_lstm:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 10.0,
                                           float('inf'))
        optimizer.step()
        losses.append(loss.data.cpu().numpy())
        step += 1
        if step % eval_cycle == 0 and step > 0:
            print('\nevaluating...')
            l_v, acc_v, pacc_v, maf1_v, ma_p_v, ma_r_v, mif1_v, dev_results, incorrect_dev_result_objs = \
                model_utils.eval_fetel(config, gres, model, dev_samples, dev_true_labels_dict)

            l_t, acc_t, pacc_t, maf1_t, ma_p_t, ma_r_t, mif1_t, test_results, incorrect_test_result_objs = \
                model_utils.eval_fetel(config, gres, model, test_samples, test_true_labels_dict)

            if maf1_v > best_maf1_v:
                steps_since_last_best = 0
            best_tag = '*' if maf1_v > best_maf1_v else ''
            logging.info(
                'run name={}, step={}/{}, learning rate={}, losses={:.4f}, steps_since_last_best={}'
                .format(run_name,
                        step, n_steps, optimizer.param_groups[0]['lr'],
                        sum(losses), steps_since_last_best))
            logging.info(
                'dev evaluation result: '
                'l_v={:.4f} acc_v={:.4f} pacc_v={:.4f} macro_f1_v={:.4f} micro_f1_v={:.4f}{}'
                .format(l_v, acc_v, pacc_v, maf1_v, mif1_v, best_tag))
            logging.info(
                f'dev evaluation result: macro_p={ma_p_v:.4f}, macro_r={ma_r_v:.4f}'
            )

            logging.info(
                'test evaluation result: '
                'l_v={:.4f} acc_t={:.4f} pacc_t={:.4f} macro_f1_t={:.4f} micro_f1_t={:.4f}{}'
                .format(l_t, acc_t, pacc_t, maf1_t, mif1_t, best_tag))
            logging.info(
                f'test evaluation result: macro_p={ma_p_t:.4f}, macro_r={ma_r_t:.4f}'
            )

            if maf1_v > best_maf1_v:
                if save_model_dir and not config.test:
                    save_model_file = join(
                        save_model_dir, f'{config.model_name}-{str_today}.pt')
                    torch.save(model.state_dict(), save_model_file)
                    logging.info('model saved to {}'.format(save_model_file))

                logging.info(
                    'prediction result saved to {}'.format(result_dir))
                datautils.save_json_objs(dev_results, dev_results_file)
                datautils.save_json_objs(incorrect_dev_result_objs,
                                         dev_incorrect_results_file)
                datautils.save_json_objs(test_results, test_results_file)
                datautils.save_json_objs(incorrect_test_result_objs,
                                         test_incorrect_results_file)
                # best_dev_acc = acc_v
                best_maf1_v = maf1_v

            losses = list()

        steps_since_last_best += 1
Exemple #23
0
def train_srlfet(device,
                 gres: expdata.ResData,
                 train_pkl,
                 dev_pkl,
                 manual_val_file_tup,
                 test_file_tup,
                 lstm_dim,
                 mlp_hidden_dim,
                 type_embed_dim,
                 train_config: TrainConfig,
                 single_type_path,
                 save_model_file_prefix=None):
    train_samples = datautils.load_pickle_data(train_pkl)
    dev_samples = datautils.load_pickle_data(dev_pkl)
    print(len(train_samples))

    if manual_val_file_tup is not None:
        val_mentions_file, val_sents_file, val_dep_file, val_srl_file, val_manual_label_file = manual_val_file_tup
        dev_samples = samples_from_man_labeled(
            gres.token_id_dict, gres.unknown_token_id, gres.type_id_dict,
            val_mentions_file, val_sents_file, val_dep_file, val_srl_file,
            val_manual_label_file)
        print('{} dev from manual'.format(len(dev_samples)))
    # exit()

    # loss_obj = exputils.BinMaxMarginLoss()
    if train_config.loss_name == 'mm':
        loss_obj = exputils.BinMaxMarginLoss(
            pos_margin=train_config.pos_margin,
            neg_margin=train_config.neg_margin,
            pos_scale=train_config.pos_scale,
            neg_scale=train_config.neg_scale)
    else:
        loss_obj = exputils.FocalLoss(gamma=4.0)

    batch_size = train_config.batch_size
    learning_rate = train_config.learning_rate
    n_iter = train_config.n_iter

    train_samples_list = __split_samples_by_arg_idx(train_samples)
    print([len(samples) for samples in train_samples_list], 'train samples')
    print([len(samples) // batch_size for samples in train_samples_list],
          'batchs per iter')

    dev_samples_list = __split_samples_by_arg_idx(dev_samples)
    dev_sample_type_ids_list = __get_full_type_ids_of_samples(
        gres.parent_type_ids_dict, dev_samples)
    dev_true_labels_dict = {
        s[0]: [gres.type_vocab[tid] for tid in type_ids]
        for type_ids, s in zip(dev_sample_type_ids_list, dev_samples)
    }
    # dev_true_labels_dict = {s.mention_id: [gres.type_vocab[l] for l in s.labels] for s in dev_samples}
    print([len(samples) for samples in dev_samples_list], 'validation samples')
    logging.info(' '.join(
        ['{}={}'.format(k, v) for k, v in vars(train_config).items()]))

    test_samples_list, test_true_labels_dict = None, None
    if test_file_tup is not None:
        mentions_file, sents_file, dep_tags_file, srl_results_file = test_file_tup
        all_test_samples = samples_from_txt(gres.token_id_dict,
                                            gres.unknown_token_id,
                                            gres.type_id_dict,
                                            mentions_file,
                                            sents_file,
                                            dep_tags_file,
                                            srl_results_file,
                                            use_all=True)
        test_samples_list = __split_samples_by_arg_idx(all_test_samples)
        test_sample_type_ids_list = __get_full_type_ids_of_samples(
            gres.parent_type_ids_dict, all_test_samples)
        test_true_labels_dict = {
            s[0]: [gres.type_vocab[tid] for tid in type_ids]
            for type_ids, s in zip(test_sample_type_ids_list, all_test_samples)
        }
        print([len(samples) for samples in test_samples_list], 'test samples')

    word_vec_dim = gres.token_vecs.shape[1]
    models, optimizers = list(), list()
    lr_schedulers = list() if train_config.schedule_lr else None
    for i in range(3):
        model = SRLFET(device, gres.type_vocab, gres.type_id_dict,
                       word_vec_dim, lstm_dim, mlp_hidden_dim, type_embed_dim)
        if device.type == 'cuda':
            model = model.cuda(device.index)
        models.append(model)
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        optimizers.append(optimizer)
        if lr_schedulers is not None:
            lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                           step_size=1000,
                                                           gamma=0.7)
            lr_schedulers.append(lr_scheduler)
    print('start training ...')
    losses = list()
    best_dev_acc = -1
    n_steps_per_iter = len(train_samples) // batch_size
    n_steps = n_iter * n_steps_per_iter
    if train_config.n_steps > -1:
        n_steps = train_config.n_steps
    for i in range(n_steps):
        # print(i)
        for mention_arg_idx, samples in enumerate(train_samples_list):
            model, optimizer = models[mention_arg_idx], optimizers[
                mention_arg_idx]
            model.train()
            lr_scheduler = None if lr_schedulers is None else lr_schedulers[
                mention_arg_idx]
            loss_val = __train_step(model, gres.parent_type_ids_dict,
                                    gres.token_vecs, samples, mention_arg_idx,
                                    i, batch_size, loss_obj, optimizer,
                                    lr_scheduler)
            losses.append(loss_val)

        if (i + 1) % 1000 == 0:
            acc_v, maf1, mif1, _ = __eval(gres, models, dev_samples_list,
                                          dev_true_labels_dict)
            # print(i + 1, sum(losses), acc, maf1, mif1)
            # logging.info('{} {:.4f} {:.4f} {:.4f} {:.4f}'.format(i + 1, sum(losses), acc_v, maf1, mif1))
            # losses = list()

            if test_samples_list is not None:
                acc_t, maf1_t, mif1_t, _ = __eval(
                    gres,
                    models,
                    test_samples_list,
                    test_true_labels_dict,
                    single_type_path=single_type_path)
                # print(i + 1, sum(losses), acc, maf1, mif1)
                logging.info(
                    '{} {:.4f} {:.4f} {:.4f} {:.4f} acct={:.4f} maf1t={:.4f} mif1t={:.4f}'
                    .format(i + 1, sum(losses), acc_v, maf1, mif1, acc_t,
                            maf1_t, mif1_t))
            else:
                logging.info('{} {:.4f} {:.4f} {:.4f} {:.4f}'.format(
                    i + 1, sum(losses), acc_v, maf1, mif1))
            losses = list()

            if acc_v > best_dev_acc and save_model_file_prefix:
                __save_srl_models(models, save_model_file_prefix)
                best_dev_acc = acc_v
Exemple #24
0
def train_fetel(device, gres: exputils.GlobalRes, el_entityvec: ELDirectEntityVec, train_samples_pkl,
                dev_samples_pkl, test_mentions_file, test_sents_file, test_noel_preds_file, type_embed_dim,
                context_lstm_hidden_dim, learning_rate, batch_size, n_iter, dropout, rand_per, per_penalty,
                use_mlp=False, pred_mlp_hdim=None, save_model_file=None, nil_rate=0.5,
                single_type_path=False, stack_lstm=False, concat_lstm=False, results_file=None):
    logging.info('result_file={}'.format(results_file))
    logging.info(
        'type_embed_dim={} cxt_lstm_hidden_dim={} pmlp_hdim={} nil_rate={} single_type_path={}'.format(
            type_embed_dim, context_lstm_hidden_dim, pred_mlp_hdim, nil_rate, single_type_path))
    logging.info('rand_per={} per_pen={}'.format(rand_per, per_penalty))
    logging.info('stack_lstm={} cat_lstm={}'.format(stack_lstm, concat_lstm))

    if stack_lstm:
        model = FETELStack(
            device, gres.type_vocab, gres.type_id_dict, gres.embedding_layer, context_lstm_hidden_dim,
            type_embed_dim=type_embed_dim, dropout=dropout, use_mlp=use_mlp, mlp_hidden_dim=pred_mlp_hdim,
            concat_lstm=concat_lstm)
    else:
        model = None
    if device.type == 'cuda':
        model = model.cuda(device.index)

    train_samples = datautils.load_pickle_data(train_samples_pkl)

    dev_samples = datautils.load_pickle_data(dev_samples_pkl)
    dev_samples = anchor_samples_to_model_samples(dev_samples, gres.mention_token_id, gres.parent_type_ids_dict)

    lr_gamma = 0.7
    eval_batch_size = 32
    logging.info('{}'.format(model.__class__.__name__))
    dev_true_labels_dict = {s.mention_id: [gres.type_vocab[l] for l in s.labels] for s in dev_samples}
    dev_entity_vecs, dev_el_sgns, dev_el_probs = __get_entity_vecs_for_samples(el_entityvec, dev_samples, None)

    test_samples = model_samples_from_json(gres.token_id_dict, gres.unknown_token_id, gres.mention_token_id,
                                           gres.type_id_dict, test_mentions_file, test_sents_file)
    test_noel_pred_results = datautils.read_pred_results_file(test_noel_preds_file)

    test_mentions = datautils.read_json_objs(test_mentions_file)
    test_entity_vecs, test_el_sgns, test_el_probs = __get_entity_vecs_for_mentions(
        el_entityvec, test_mentions, test_noel_pred_results, gres.n_types)

    test_true_labels_dict = {m['mention_id']: m['labels'] for m in test_mentions} if (
            'labels' in next(iter(test_mentions))) else None

    person_type_id = gres.type_id_dict.get('/person')
    l2_person_type_ids, person_loss_vec = None, None
    if person_type_id is not None:
        l2_person_type_ids = __get_l2_person_type_ids(gres.type_vocab)
        person_loss_vec = exputils.get_person_type_loss_vec(
            l2_person_type_ids, gres.n_types, per_penalty, model.device)

    dev_results_file = None
    n_batches = (len(train_samples) + batch_size - 1) // batch_size
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=n_batches, gamma=lr_gamma)
    losses = list()
    best_dev_acc = -1
    logging.info('{} steps, {} steps per iter, lr_decay={}, start training ...'.format(
        n_iter * n_batches, n_batches, lr_gamma))
    step = 0
    n_steps = n_iter * n_batches
    while step < n_steps:
        batch_idx = step % n_batches
        batch_beg, batch_end = batch_idx * batch_size, min((batch_idx + 1) * batch_size, len(train_samples))
        batch_samples = anchor_samples_to_model_samples(
            train_samples[batch_beg:batch_end], gres.mention_token_id, gres.parent_type_ids_dict)
        if rand_per:
            entity_vecs, el_sgns, el_probs = __get_entity_vecs_for_samples(
                el_entityvec, batch_samples, None, True, person_type_id, l2_person_type_ids, gres.type_vocab)
        else:
            entity_vecs, el_sgns, el_probs = __get_entity_vecs_for_samples(el_entityvec, batch_samples, None, True)

        use_entity_vecs = True
        model.train()

        (context_token_seqs, mention_token_idxs, mstrs, mstr_token_seqs, y_true
         ) = exputils.get_mstr_cxt_label_batch_input(model.device, gres.n_types, batch_samples)

        if use_entity_vecs:
            for i in range(entity_vecs.shape[0]):
                if np.random.uniform() < nil_rate:
                    entity_vecs[i] = np.zeros(entity_vecs.shape[1], np.float32)
            el_probs = torch.tensor(el_probs, dtype=torch.float32, device=model.device)
            entity_vecs = torch.tensor(entity_vecs, dtype=torch.float32, device=model.device)
        else:
            entity_vecs = None
        logits = model(context_token_seqs, mention_token_idxs, mstr_token_seqs, entity_vecs, el_probs)
        loss = model.get_loss(y_true, logits, person_loss_vec=person_loss_vec)
        scheduler.step()
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 10.0, float('inf'))
        optimizer.step()
        losses.append(loss.data.cpu().numpy())

        step += 1
        if step % 1000 == 0:
            # logging.info('i={} l={:.4f}'.format(step + 1, sum(losses)))
            acc_v, pacc_v, _, _, dev_results = eval_fetel(
                gres, model, dev_samples, dev_entity_vecs, dev_el_probs, eval_batch_size,
                use_entity_vecs=use_entity_vecs, single_type_path=single_type_path,
                true_labels_dict=dev_true_labels_dict)
            acc_t, _, maf1, mif1, test_results = eval_fetel(
                gres, model, test_samples, test_entity_vecs, test_el_probs, eval_batch_size,
                use_entity_vecs=use_entity_vecs, single_type_path=single_type_path,
                true_labels_dict=test_true_labels_dict)

            best_tag = '*' if acc_v > best_dev_acc else ''
            logging.info(
                'i={} l={:.4f} accv={:.4f} paccv={:.4f} acct={:.4f} maf1={:.4f} mif1={:.4f}{}'.format(
                    step, sum(losses), acc_v, pacc_v, acc_t, maf1, mif1, best_tag))
            if acc_v > best_dev_acc and save_model_file:
                torch.save(model.state_dict(), save_model_file)
                logging.info('model saved to {}'.format(save_model_file))

            if dev_results_file is not None and acc_v > best_dev_acc:
                datautils.save_json_objs(dev_results, dev_results_file)
                logging.info('dev reuslts saved {}'.format(dev_results_file))
            if results_file is not None and acc_v > best_dev_acc:
                datautils.save_json_objs(test_results, results_file)
                logging.info('test reuslts saved {}'.format(results_file))

            if acc_v > best_dev_acc:
                best_dev_acc = acc_v
            losses = list()