def create_envs(table_dict, data_set, en_vocab, embedding_model): all_envs = [] t1 = time.time() if FLAGS.executor == 'wtq': score_fn = wtq_score process_answer_fn = lambda x: x executor_fn = executor_factory.WikiTableExecutor elif FLAGS.executor == 'wikisql': score_fn = wikisql_score process_answer_fn = wikisql_process_answer executor_fn = executor_factory.WikiSQLExecutor else: raise ValueError('Unknown executor {}'.format(FLAGS.executor)) for i, example in enumerate(data_set): if i % 100 == 0: tf.logging.info('creating environment #{}'.format(i)) kg_info = table_dict[example['context']] executor = executor_fn(kg_info) api = executor.get_api() type_hierarchy = api['type_hierarchy'] func_dict = api['func_dict'] constant_dict = api['constant_dict'] interpreter = computer_factory.LispInterpreter( type_hierarchy=type_hierarchy, max_mem=FLAGS.max_n_mem, max_n_exp=FLAGS.max_n_exp, assisted=True) for v in func_dict.values(): interpreter.add_function(**v) interpreter.add_constant(value=kg_info['row_ents'], type='entity_list', name='all_rows') de_vocab = interpreter.get_vocab() constant_value_embedding_fn = lambda x: get_embedding_for_constant( x, embedding_model, embedding_size=FLAGS.pretrained_embedding_size) env = env_factory.QAProgrammingEnv( en_vocab, de_vocab, question_annotation=example, answer=process_answer_fn(example['answer']), constants=constant_dict.values(), interpreter=interpreter, constant_value_embedding_fn=constant_value_embedding_fn, score_fn=score_fn, name=example['id']) all_envs.append(env) return all_envs
def run_random_exploration(shard_id): experiment_dir = get_experiment_dir() if not tf.gfile.Exists(experiment_dir): tf.gfile.MkDir(experiment_dir) if FLAGS.trigger_word_file: with open(FLAGS.trigger_word_file, 'r') as f: trigger_dict = json.load(f) print 'use trigger words in {}'.format(FLAGS.trigger_word_file) else: trigger_dict = None # Load dataset. train_set = [] with open(FLAGS.train_file_tmpl.format(shard_id), 'r') as f: for line in f: example = json.loads(line) train_set.append(example) tf.logging.info('{} examples in training set.'.format(len(train_set))) table_dict = {} with open(FLAGS.table_file) as f: for line in f: table = json.loads(line) table_dict[table['name']] = table tf.logging.info('{} tables.'.format(len(table_dict))) if FLAGS.executor == 'wtq': score_fn = utils.wtq_score process_answer_fn = lambda x: x executor_fn = executor_factory.WikiTableExecutor elif FLAGS.executor == 'wikisql': score_fn = utils.wikisql_score process_answer_fn = utils.wikisql_process_answer executor_fn = executor_factory.WikiSQLExecutor else: raise ValueError('Unknown executor {}'.format(FLAGS.executor)) all_envs = [] t1 = time.time() for i, example in enumerate(train_set): if i % 100 == 0: tf.logging.info('creating environment #{}'.format(i)) kg_info = table_dict[example['context']] executor = executor_fn(kg_info) api = executor.get_api() type_hierarchy = api['type_hierarchy'] func_dict = api['func_dict'] constant_dict = api['constant_dict'] interpreter = computer_factory.LispInterpreter( type_hierarchy=type_hierarchy, max_mem=FLAGS.max_n_mem, max_n_exp=FLAGS.max_n_exp, assisted=True) for v in func_dict.values(): interpreter.add_function(**v) interpreter.add_constant(value=kg_info['row_ents'], type='entity_list', name='all_rows') de_vocab = interpreter.get_vocab() env = env_factory.QAProgrammingEnv( data_utils.Vocab([]), de_vocab, question_annotation=example, answer=process_answer_fn(example['answer']), constants=constant_dict.values(), interpreter=interpreter, constant_value_embedding_fn=lambda x: None, score_fn=score_fn, max_cache_size=FLAGS.n_explore_samples * FLAGS.n_epoch * 10, name=example['id']) all_envs.append(env) program_dict = dict([(env.name, []) for env in all_envs]) for i in xrange(1, FLAGS.n_epoch + 1): tf.logging.info('iteration {}'.format(i)) t1 = time.time() for env in all_envs: for _ in xrange(FLAGS.n_explore_samples): program = random_explore(env, trigger_dict=trigger_dict) if program is not None: program_dict[env.name].append(program) t2 = time.time() tf.logging.info('{} sec used in iteration {}'.format(t2 - t1, i)) if i % FLAGS.save_every_n == 0: tf.logging.info( 'saving programs and cache in iteration {}'.format(i)) t1 = time.time() with open( os.path.join( get_experiment_dir(), 'program_shard_{}-{}.json'.format(shard_id, i)), 'w') as f: program_str_dict = dict([(k, [' '.join(p) for p in v]) for k, v in program_dict.iteritems()]) json.dump(program_str_dict, f, sort_keys=True, indent=2) # cache_dict = dict([(env.name, list(env.cache._set)) for env in all_envs]) t2 = time.time() tf.logging.info( '{} sec used saving programs and cache in iteration {}'.format( t2 - t1, i)) n = len(all_envs) solution_ratio = len( [env for env in all_envs if program_dict[env.name]]) * 1.0 / n tf.logging.info( 'At least one solution found ratio: {}'.format(solution_ratio)) n_programs_per_env = np.array( [len(program_dict[env.name]) for env in all_envs]) tf.logging.info( 'number of solutions found per example: max {}, min {}, avg {}, std {}' .format(n_programs_per_env.max(), n_programs_per_env.min(), n_programs_per_env.mean(), n_programs_per_env.std())) # Macro average length. mean_length = np.mean([ np.mean([len(p) for p in program_dict[env.name]]) for env in all_envs if program_dict[env.name] ]) tf.logging.info('macro average program length: {}'.format(mean_length))