Exemplo n.º 1
0
def main():
    # output_file = 'data/cooked/reverb_zpar.pt'
    output_file = 'data/cooked/reverb.pt1'

    date_dict = {}
    # train_data = read_reverb_zpar_data('data/cooked/reverb_zpar_train.txt')
    # dev_data = read_reverb_zpar_data('data/cooked/reverb_zpar_dev.txt')
    # test_data = read_reverb_zpar_data('data/cooked/reverb_zpar_test.txt')
    train_data = read_reverb_data('data/cooked/reverb_train.txt')
    dev_data = read_reverb_data('data/cooked/reverb_dev.txt')
    test_data = read_reverb_data('data/cooked/reverb_test.txt')
    date_dict.update(train_data)
    date_dict.update(dev_data)
    date_dict.update(test_data)

    glove = Glove(
        '/users4/bwchen/CommonsenseERL_EMNLP_2019/data/glove.6B.100d.ext.txt')
    id_data, subj_oov_count, verb_oov_count, obj_oov_count, event_d = event_to_id(
        date_dict, glove)

    print('subj oov:', subj_oov_count)
    print('verb oov:', verb_oov_count)
    print('obj oov :', obj_oov_count)

    id2word = glove.reverse_dict()
    date = '2006-10-20'
    subj_id, subj_w, verb_id, verb_w, obj_id, obj_w = id_data[date]
    print('subj:', ' '.join([id2word[int(i)] for i in subj_id[0]]))
    print('verb:', ' '.join([id2word[int(i)] for i in verb_id[0]]))
    print('obj: ', ' '.join([id2word[int(i)] for i in obj_id[0]]))
    #exit()
    print(event_d)
    torch.save(id_data, output_file)
    f = open('event_data', 'wb')
    pickle.dump(event_d, f)
    f.close()
import sys
sys.path.insert(0, '.')
from event_tensors.train_utils import RandomizedQueuedInstances
from event_tensors.glove_utils import Glove

if __name__ == '__main__':
    svo_file = 'data/svo_small.txt'
    output_file = 'data/word_prediction_small.txt'

    emb_file = 'data/glove.6B.100d.ext.txt'
    num_queues = 256
    batch_size = 128
    max_phrase_size = 10
    embeddings = Glove(emb_file)
    id2word = embeddings.reverse_dict()
    # remove None at the last
    data = list(
        iter(
            RandomizedQueuedInstances(svo_file, embeddings, num_queues,
                                      batch_size, max_phrase_size)))[:-1]

    output_file = open(output_file, 'w')
    for subj, verb, obj, word_id in data:
        subj_id, _ = subj
        verb_id, _ = verb
        obj_id, _ = obj
        subj = [id2word[i] for i in subj_id if i != 1]
        verb = [id2word[i] for i in verb_id if i != 1]
        obj = [id2word[i] for i in obj_id if i != 1]
        word = id2word[word_id]
        if len(subj_id) == 0: