Example #1
0
def create_envs(table_dict, data_set, en_vocab, embedding_model):
    all_envs = []
    t1 = time.time()
    if FLAGS.executor == 'wtq':
        score_fn = wtq_score
        process_answer_fn = lambda x: x
        executor_fn = executor_factory.WikiTableExecutor
    elif FLAGS.executor == 'wikisql':
        score_fn = wikisql_score
        process_answer_fn = wikisql_process_answer
        executor_fn = executor_factory.WikiSQLExecutor
    else:
        raise ValueError('Unknown executor {}'.format(FLAGS.executor))

    for i, example in enumerate(data_set):
        if i % 100 == 0:
            tf.logging.info('creating environment #{}'.format(i))
        kg_info = table_dict[example['context']]
        executor = executor_fn(kg_info)
        api = executor.get_api()
        type_hierarchy = api['type_hierarchy']
        func_dict = api['func_dict']
        constant_dict = api['constant_dict']
        interpreter = computer_factory.LispInterpreter(
            type_hierarchy=type_hierarchy,
            max_mem=FLAGS.max_n_mem,
            max_n_exp=FLAGS.max_n_exp,
            assisted=True)
        for v in func_dict.values():
            interpreter.add_function(**v)

        interpreter.add_constant(value=kg_info['row_ents'],
                                 type='entity_list',
                                 name='all_rows')

        de_vocab = interpreter.get_vocab()

        constant_value_embedding_fn = lambda x: get_embedding_for_constant(
            x, embedding_model, embedding_size=FLAGS.pretrained_embedding_size)
        env = env_factory.QAProgrammingEnv(
            en_vocab,
            de_vocab,
            question_annotation=example,
            answer=process_answer_fn(example['answer']),
            constants=constant_dict.values(),
            interpreter=interpreter,
            constant_value_embedding_fn=constant_value_embedding_fn,
            score_fn=score_fn,
            name=example['id'])
        all_envs.append(env)
    return all_envs
def run_random_exploration(shard_id):
    experiment_dir = get_experiment_dir()
    if not tf.gfile.Exists(experiment_dir):
        tf.gfile.MkDir(experiment_dir)

    if FLAGS.trigger_word_file:
        with open(FLAGS.trigger_word_file, 'r') as f:
            trigger_dict = json.load(f)
            print 'use trigger words in {}'.format(FLAGS.trigger_word_file)
    else:
        trigger_dict = None

    # Load dataset.
    train_set = []
    with open(FLAGS.train_file_tmpl.format(shard_id), 'r') as f:
        for line in f:
            example = json.loads(line)
            train_set.append(example)
    tf.logging.info('{} examples in training set.'.format(len(train_set)))

    table_dict = {}
    with open(FLAGS.table_file) as f:
        for line in f:
            table = json.loads(line)
            table_dict[table['name']] = table
    tf.logging.info('{} tables.'.format(len(table_dict)))

    if FLAGS.executor == 'wtq':
        score_fn = utils.wtq_score
        process_answer_fn = lambda x: x
        executor_fn = executor_factory.WikiTableExecutor
    elif FLAGS.executor == 'wikisql':
        score_fn = utils.wikisql_score
        process_answer_fn = utils.wikisql_process_answer
        executor_fn = executor_factory.WikiSQLExecutor
    else:
        raise ValueError('Unknown executor {}'.format(FLAGS.executor))

    all_envs = []
    t1 = time.time()
    for i, example in enumerate(train_set):
        if i % 100 == 0:
            tf.logging.info('creating environment #{}'.format(i))
        kg_info = table_dict[example['context']]
        executor = executor_fn(kg_info)
        api = executor.get_api()
        type_hierarchy = api['type_hierarchy']
        func_dict = api['func_dict']
        constant_dict = api['constant_dict']
        interpreter = computer_factory.LispInterpreter(
            type_hierarchy=type_hierarchy,
            max_mem=FLAGS.max_n_mem,
            max_n_exp=FLAGS.max_n_exp,
            assisted=True)
        for v in func_dict.values():
            interpreter.add_function(**v)

        interpreter.add_constant(value=kg_info['row_ents'],
                                 type='entity_list',
                                 name='all_rows')

        de_vocab = interpreter.get_vocab()
        env = env_factory.QAProgrammingEnv(
            data_utils.Vocab([]),
            de_vocab,
            question_annotation=example,
            answer=process_answer_fn(example['answer']),
            constants=constant_dict.values(),
            interpreter=interpreter,
            constant_value_embedding_fn=lambda x: None,
            score_fn=score_fn,
            max_cache_size=FLAGS.n_explore_samples * FLAGS.n_epoch * 10,
            name=example['id'])
        all_envs.append(env)

    program_dict = dict([(env.name, []) for env in all_envs])
    for i in xrange(1, FLAGS.n_epoch + 1):
        tf.logging.info('iteration {}'.format(i))
        t1 = time.time()
        for env in all_envs:
            for _ in xrange(FLAGS.n_explore_samples):
                program = random_explore(env, trigger_dict=trigger_dict)
                if program is not None:
                    program_dict[env.name].append(program)
        t2 = time.time()
        tf.logging.info('{} sec used in iteration {}'.format(t2 - t1, i))

        if i % FLAGS.save_every_n == 0:
            tf.logging.info(
                'saving programs and cache in iteration {}'.format(i))
            t1 = time.time()
            with open(
                    os.path.join(
                        get_experiment_dir(),
                        'program_shard_{}-{}.json'.format(shard_id, i)),
                    'w') as f:
                program_str_dict = dict([(k, [' '.join(p) for p in v])
                                         for k, v in program_dict.iteritems()])
                json.dump(program_str_dict, f, sort_keys=True, indent=2)

            # cache_dict = dict([(env.name, list(env.cache._set)) for env in all_envs])
            t2 = time.time()
            tf.logging.info(
                '{} sec used saving programs and cache in iteration {}'.format(
                    t2 - t1, i))

        n = len(all_envs)
        solution_ratio = len(
            [env for env in all_envs if program_dict[env.name]]) * 1.0 / n
        tf.logging.info(
            'At least one solution found ratio: {}'.format(solution_ratio))
        n_programs_per_env = np.array(
            [len(program_dict[env.name]) for env in all_envs])
        tf.logging.info(
            'number of solutions found per example: max {}, min {}, avg {}, std {}'
            .format(n_programs_per_env.max(), n_programs_per_env.min(),
                    n_programs_per_env.mean(), n_programs_per_env.std()))

        # Macro average length.
        mean_length = np.mean([
            np.mean([len(p) for p in program_dict[env.name]])
            for env in all_envs if program_dict[env.name]
        ])
        tf.logging.info('macro average program length: {}'.format(mean_length))