Exemple #1
0
def prepare_dataset(data_path):
    grammar = ASDLGrammar.from_text(
        open('../../asdl/lang/sql/sql_asdl.txt').read())
    transition_system = SqlTransitionSystem(grammar)

    datasets = []
    for file in ['dev', 'test', 'train']:
        print('processing %s' % file, file=sys.stderr)
        dataset_path = os.path.join(data_path, file + '.jsonl')
        table_path = os.path.join(data_path, file + '.tables.jsonl')
        dataset = load_dataset(transition_system, dataset_path, table_path)
        pickle.dump(dataset, open('../../data/wikisql1/%s.bin' % file, 'wb'))

        datasets.append(dataset)

    train_set = datasets[2]
    dev_set = datasets[0]
    test_set = datasets[1]
    # generate vocabulary
    src_vocab = VocabEntry.from_corpus([e.src_sent for e in train_set],
                                       size=100000,
                                       freq_cutoff=2)
    primitive_vocab = VocabEntry()
    primitive_vocab.add('</primitive>')

    vocab = Vocab(source=src_vocab, primitive=primitive_vocab)
    print('generated vocabulary %s' % repr(vocab), file=sys.stderr)

    pickle.dump(vocab, open('../../data/wikisql1/vocab.bin', 'wb'))

    action_len = [
        len(e.tgt_actions) for e in chain(train_set, dev_set, test_set)
    ]
    print('Max action len: %d' % max(action_len), file=sys.stderr)
    print('Avg action len: %d' % np.average(action_len), file=sys.stderr)
    print('Actions larger than 100: %d' %
          len(list(filter(lambda x: x > 100, action_len))),
          file=sys.stderr)
Exemple #2
0
def generate_vocab_for_paraphrase_model(vocab_path, save_path):
    from components.vocab import VocabEntry, Vocab

    vocab = pickle.load(open(vocab_path))
    para_vocab = VocabEntry()
    for i in range(0, 10):
        para_vocab.add('<unk_%d>' % i)
    for word in vocab.source.word2id:
        para_vocab.add(word)
    for word in vocab.code.word2id:
        para_vocab.add(word)

    pickle.dump(para_vocab, open(save_path, 'w'))