예제 #1
0
def parse_data_for_seq2seq(data_file='data/ifttt.freq3.bin'):
    train_data, dev_data, test_data = deserialize_from_file(data_file)
    prefix = 'data/seq2seq/'

    for dataset, output in [(train_data, prefix + 'ifttt.train'),
                            (dev_data, prefix + 'ifttt.dev'),
                            (test_data, prefix + 'ifttt.test')]:
        f_source = open(output + '.desc', 'w')
        f_target = open(output + '.code', 'w')

        if 'test' in output:
            raw_ids = [int(i.strip()) for i in open('data/ifff.test_data.gold.id')]
            eids = [i for i, e in enumerate(test_data.examples) if e.raw_id in raw_ids]
            dataset = test_data.get_dataset_by_ids(eids, test_data.name + '.subset')

        for e in dataset.examples:
            query_tokens = e.query
            trigger = e.parse_tree['TRIGGER'].children[0].type + ' . ' + e.parse_tree['TRIGGER'].children[0].children[0].type
            action = e.parse_tree['ACTION'].children[0].type + ' . ' + e.parse_tree['ACTION'].children[0].children[0].type
            code = 'IF ' + trigger + ' THEN ' + action

            f_source.write(' '.join(query_tokens) + '\n')
            f_target.write(code + '\n')

        f_source.close()
        f_target.close()
예제 #2
0
def dump_data_for_evaluation(data_type='django', data_file='', max_query_length=70):
    train_data, dev_data, test_data = deserialize_from_file(data_file)
    prefix = '/Users/yinpengcheng/Projects/dl4mt-tutorial/codegen_data/'
    for dataset, output in [(train_data, prefix + '%s.train' % data_type),
                            (dev_data, prefix + '%s.dev' % data_type),
                            (test_data, prefix + '%s.test' % data_type)]:
        f_source = open(output + '.desc', 'w')
        f_target = open(output + '.code', 'w')

        for e in dataset.examples:
            query_tokens = e.query[:max_query_length]
            code = e.code
            if data_type == 'django':
                target_code = de_canonicalize_code_for_seq2seq(code, e.meta_data['raw_code'])
            else:
                target_code = code

            # tokenize code
            target_code = target_code.strip()
            tokenized_target = tokenize_code_adv(target_code, breakCamelStr=False if data_type=='django' else True)
            tokenized_target = [tk.replace('\n', '#NEWLINE#') for tk in tokenized_target]
            tokenized_target = [tk for tk in tokenized_target if tk is not None]

            while tokenized_target[-1] == '#INDENT#':
                tokenized_target = tokenized_target[:-1]

            f_source.write(' '.join(query_tokens) + '\n')
            f_target.write(' '.join(tokenized_target) + '\n')

        f_source.close()
        f_target.close()
예제 #3
0
def dump_data_for_evaluation(data_type='django',
                             data_file='',
                             max_query_length=70):
    train_data, dev_data, test_data = deserialize_from_file(data_file)
    prefix = '/Users/yinpengcheng/Projects/dl4mt-tutorial/codegen_data/'
    for dataset, output in [(train_data, prefix + '%s.train' % data_type),
                            (dev_data, prefix + '%s.dev' % data_type),
                            (test_data, prefix + '%s.test' % data_type)]:
        f_source = open(output + '.desc', 'w')
        f_target = open(output + '.code', 'w')

        for e in dataset.examples:
            query_tokens = e.query[:max_query_length]
            code = e.code
            if data_type == 'django':
                target_code = de_canonicalize_code_for_seq2seq(
                    code, e.meta_data['raw_code'])
            else:
                target_code = code

            # tokenize code
            target_code = target_code.strip()
            tokenized_target = tokenize_code_adv(
                target_code,
                breakCamelStr=False if data_type == 'django' else True)
            tokenized_target = [
                tk.replace('\n', '#NEWLINE#') for tk in tokenized_target
            ]
            tokenized_target = [
                tk for tk in tokenized_target if tk is not None
            ]

            while tokenized_target[-1] == '#INDENT#':
                tokenized_target = tokenized_target[:-1]

            f_source.write(' '.join(query_tokens) + '\n')
            f_target.write(' '.join(tokenized_target) + '\n')

        f_source.close()
        f_target.close()
예제 #4
0
def main():
	'''
		Read file from Django data set
	'''
	train_data, dev_data, test_data = deserialize_from_file("data/django.cleaned.dataset.freq5.par_info.refact.space_only.bin")

	#uncomment below for Hearthstone data set
	#train_data, dev_data, test_data = deserialize_from_file("hs.freq3.pre_suf.unary_closure.bin")

	print("----- TRAIN -----")
	train_length = len(train_data.examples)
	print(train_length) #16000 instances for django, 533 for hs
	write_to_file("train_.txt", train_data, train_length)

	print("----- DEV -----")
	dev_length = len(dev_data.examples)
	print(dev_length) #1000 instances for django, 66 for hs
	write_to_file("dev_.txt", dev_data, dev_length)

	print("----- TEST -----")
	test_length = len(test_data.examples)
	print(test_length) #1801 instances, 66 for hs
	write_to_file("test_.txt", test_data, test_length)
예제 #5
0
def parse_data_for_seq2seq(data_file='data/ifttt.freq3.bin'):
    train_data, dev_data, test_data = deserialize_from_file(data_file)
    prefix = 'data/seq2seq/'

    for dataset, output in [(train_data, prefix + 'ifttt.train'),
                            (dev_data, prefix + 'ifttt.dev'),
                            (test_data, prefix + 'ifttt.test')]:
        f_source = open(output + '.desc', 'w')
        f_target = open(output + '.code', 'w')

        if 'test' in output:
            raw_ids = [
                int(i.strip()) for i in open('data/ifff.test_data.gold.id')
            ]
            eids = [
                i for i, e in enumerate(test_data.examples)
                if e.raw_id in raw_ids
            ]
            dataset = test_data.get_dataset_by_ids(eids,
                                                   test_data.name + '.subset')

        for e in dataset.examples:
            query_tokens = e.query
            trigger = e.parse_tree['TRIGGER'].children[
                0].type + ' . ' + e.parse_tree['TRIGGER'].children[0].children[
                    0].type
            action = e.parse_tree['ACTION'].children[
                0].type + ' . ' + e.parse_tree['ACTION'].children[0].children[
                    0].type
            code = 'IF ' + trigger + ' THEN ' + action

            f_source.write(' '.join(query_tokens) + '\n')
            f_target.write(code + '\n')

        f_source.close()
        f_target.close()
예제 #6
0
# interactive operation
interactive_parser.add_argument('-mode', default='dataset')

if __name__ == '__main__':
    args = parser.parse_args()

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    np.random.seed(args.random_seed)
    init_logging(os.path.join(args.output_dir, 'parser.log'), logging.INFO)
    logging.info('command line: %s', ' '.join(sys.argv))

    logging.info('loading dataset [%s]', args.data)
    train_data, dev_data, test_data = deserialize_from_file(args.data)

    if not args.source_vocab_size:
        args.source_vocab_size = train_data.annot_vocab.size
    if not args.target_vocab_size:
        args.target_vocab_size = train_data.terminal_vocab.size
    if not args.rule_num:
        args.rule_num = len(train_data.grammar.rules)
    if not args.node_num:
        args.node_num = len(train_data.grammar.node_type_to_id)

    logging.info('current config: %s', args)
    config_module = sys.modules['config']
    for name, value in vars(args).iteritems():
        setattr(config_module, name, value)
예제 #7
0
                if verbose:
                    print(previous_keys[i], new_value, new_flag)
                index = self.indexes_per_last_value[(previous_keys[i],
                                                     new_value, new_flag)]
                new_keys.append(index)
        except:
            while len(new_keys) <= self.max_ngrams:
                new_keys.append(None)
        return new_keys

    def __call__(self, keys):
        l = list()
        for k in reversed(keys[:-1]):
            if k is not None:
                for j in self.ngram_follows[k]:
                    l.append((self.ngrams_lastelt_id[j], self.ngrams_score[j],
                              self.ngrams_lastelt_flag[j]))
            # if len(l) > 0:
            #    break
        return l


if __name__ == "__main__":
    train_data, dev_data, test_data = deserialize_from_file(
        '../../files/aligned_hs.bin')
    for ex in test_data.examples:
        input_sentence = ex.query
        l = retrieve_translation_pieces(train_data, input_sentence)
        # print len(l[4])
        del l
예제 #8
0
# decoding
parser.add_argument('-beam_size', default=15, type=int)
parser.add_argument('-max_query_length', default=70, type=int)
parser.add_argument('-decode_max_time_step', default=100, type=int)
parser.add_argument('-head_nt_constraint', dest='head_nt_constraint', action='store_true')
parser.add_argument('-no_head_nt_constraint', dest='head_nt_constraint', action='store_false')
parser.set_defaults(head_nt_constraint=True)

args = parser.parse_args(args=['-data_type', 'django', '-data', 'data/django.cleaned.dataset.freq5.par_info.refact.space_only.bin',
                               '-model', 'models/model.django_word128_encoder256_rule128_node64.beam15.adam.simple_trans.no_unary_closure.8e39832.run3.best_acc.npz'])
if args.data_type == 'hs':
    args.decode_max_time_step = 350

logging.info('loading dataset [%s]', args.data)
train_data, dev_data, test_data = deserialize_from_file(args.data)

if not args.source_vocab_size:
    args.source_vocab_size = train_data.annot_vocab.size
if not args.target_vocab_size:
    args.target_vocab_size = train_data.terminal_vocab.size
if not args.rule_num:
    args.rule_num = len(train_data.grammar.rules)
if not args.node_num:
    args.node_num = len(train_data.grammar.node_type_to_id)

config_module = sys.modules['config']
for name, value in vars(args).iteritems():
    setattr(config_module, name, value)

# build the model
예제 #9
0
parser.add_argument('-no_head_nt_constraint',
                    dest='head_nt_constraint',
                    action='store_false')
parser.set_defaults(head_nt_constraint=True)

args = parser.parse_args(args=[
    '-data_type', 'django', '-data',
    'data/django.cleaned.dataset.freq5.par_info.refact.space_only.bin',
    '-model',
    'models/model.django_word128_encoder256_rule128_node64.beam15.adam.simple_trans.no_unary_closure.8e39832.run3.best_acc.npz'
])
if args.data_type == 'hs':
    args.decode_max_time_step = 350

logging.info('loading dataset [%s]', args.data)
train_data, dev_data, test_data = deserialize_from_file(args.data)

if not args.source_vocab_size:
    args.source_vocab_size = train_data.annot_vocab.size
if not args.target_vocab_size:
    args.target_vocab_size = train_data.terminal_vocab.size
if not args.rule_num:
    args.rule_num = len(train_data.grammar.rules)
if not args.node_num:
    args.node_num = len(train_data.grammar.node_type_to_id)

config_module = sys.modules['config']
for name, value in vars(args).iteritems():
    setattr(config_module, name, value)

# build the model
예제 #10
0
from nn.utils.generic_utils import init_logging
from nn.utils.io_utils import deserialize_from_file, serialize_to_file
print('Hi')
print('Hello')

train_data, dev_data, test_data = deserialize_from_file(
    '/content/NL2code/data/hs.freq3.pre_suf.unary_closure.bin')

print('Total Grammar Rules: ' + str(len(train_data.grammar.rules)))
print('Total Annotation Vocabs: ' +
      str(len(train_data.annot_vocab.token_id_map)))
print('Total Terminal Vocabs: ' +
      str(len(train_data.terminal_vocab.token_id_map)))
print('Total examples: ' + str(train_data.examples))

print('-' * 100)
for i, grammar_rule in enumerate(train_data.grammar.rules):
    print(grammar_rule)
    if i == 10:
        break

print('-' * 100)
for i, annot_vocab in enumerate(train_data.annot_vocab.token_id_map):
    print(annot_vocab)
    if i == 50:
        break

print('-' * 100)
for i, annot_vocab in enumerate(train_data.annot_vocab.token_id_map):
    print(annot_vocab)
    if i == 50:
예제 #11
0
# interactive operation
interactive_parser.add_argument('-mode', default='dataset')

if __name__ == '__main__':
    args = parser.parse_args()

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    np.random.seed(args.random_seed)
    init_logging(os.path.join(args.output_dir, 'parser.log'), logging.INFO)
    logging.info('command line: %s', ' '.join(sys.argv))

    logging.info('loading dataset [%s]', args.data)
    train_data, dev_data, test_data = deserialize_from_file(args.data)

    if not args.source_vocab_size:
        args.source_vocab_size = train_data.annot_vocab.size
    if not args.target_vocab_size:
        args.target_vocab_size = train_data.terminal_vocab.size
    if not args.rule_num:
        args.rule_num = len(train_data.grammar.rules)
    if not args.node_num:
        args.node_num = len(train_data.grammar.node_type_to_id)

    logging.info('current config: %s', args)
    config_module = sys.modules['config']
    for name, value in vars(args).iteritems():
        setattr(config_module, name, value)