def get_vocab(self):
     mem_tokens = []
     for i in range(self.max_mem):
         mem_tokens.append('v{}'.format(i))
     vocab = data_utils.Vocab(self.namespace.get_all_names() + SPECIAL_TKS +
                              mem_tokens)
     return vocab
def run_random_exploration(shard_id):
  experiment_dir = get_experiment_dir()
  if not tf.gfile.Exists(experiment_dir):
    tf.gfile.MkDir(experiment_dir)

  if FLAGS.trigger_word_file:
    with gfile.Open(FLAGS.trigger_word_file, 'r') as f:
      trigger_dict = json.load(f)
      print 'use trigger words in {}'.format(FLAGS.trigger_word_file)
  else:
    trigger_dict = None

  # Load dataset.
  train_set = []
  with gfile.Open(FLAGS.train_file_tmpl.format(shard_id), 'r') as f:
    for line in f:
      example = json.loads(line)
      train_set.append(example)
  tf.logging.info('{} examples in training set.'.format(len(train_set)))

  table_dict = {}
  with gfile.Open(FLAGS.table_file) as f:
    for line in f:
      table = json.loads(line)
      table_dict[table['name']] = table
  tf.logging.info('{} tables.'.format(len(table_dict)))

  if FLAGS.executor == 'wtq':
    score_fn = utils.wtq_score
    process_answer_fn = lambda x: x
    executor_fn = executor_factory.WikiTableExecutor
  elif FLAGS.executor == 'wikisql':
    score_fn = utils.wikisql_score
    process_answer_fn = utils.wikisql_process_answer
    executor_fn = executor_factory.WikiSQLExecutor
  else:
    raise ValueError('Unknown executor {}'.format(FLAGS.executor))

  all_envs = []
  t1 = time.time()
  for i, example in enumerate(train_set):
    if i % 100 == 0:
      tf.logging.info('creating environment #{}'.format(i))
    kg_info = table_dict[example['context']]
    executor = executor_fn(kg_info)
    api = executor.get_api()
    type_hierarchy = api['type_hierarchy']
    func_dict = api['func_dict']
    constant_dict = api['constant_dict']
    interpreter = computer_factory.LispInterpreter(
        type_hierarchy=type_hierarchy,
        max_mem=FLAGS.max_n_mem,
        max_n_exp=FLAGS.max_n_exp,
        assisted=True)
    for v in func_dict.values():
      interpreter.add_function(**v)

    interpreter.add_constant(
        value=kg_info['row_ents'], type='entity_list', name='all_rows')

    de_vocab = interpreter.get_vocab()
    env = env_factory.QAProgrammingEnv(
        data_utils.Vocab([]),
        de_vocab,
        question_annotation=example,
        answer=process_answer_fn(example['answer']),
        constants=constant_dict.values(),
        interpreter=interpreter,
        constant_value_embedding_fn=lambda x: None,
        score_fn=score_fn,
        max_cache_size=FLAGS.n_explore_samples * FLAGS.n_epoch * 10,
        name=example['id'])
    all_envs.append(env)

  program_dict = dict([(env.name, []) for env in all_envs])
  for i in xrange(1, FLAGS.n_epoch + 1):
    tf.logging.info('iteration {}'.format(i))
    t1 = time.time()
    for env in all_envs:
      for _ in xrange(FLAGS.n_explore_samples):
        program = random_explore(env, trigger_dict=trigger_dict)
        if program is not None:
          program_dict[env.name].append(program)
    t2 = time.time()
    tf.logging.info('{} sec used in iteration {}'.format(t2 - t1, i))

    if i % FLAGS.save_every_n == 0:
      tf.logging.info('saving programs and cache in iteration {}'.format(i))
      t1 = time.time()
      with gfile.Open(
          os.path.join(get_experiment_dir(), 'program_shard_{}-{}.json'.format(
              shard_id, i)), 'w') as f:
        program_str_dict = dict([
            (k, [' '.join(p) for p in v]) for k, v in program_dict.iteritems()
        ])
        json.dump(program_str_dict, f, sort_keys=True, indent=2)

      # cache_dict = dict([(env.name, list(env.cache._set)) for env in all_envs])
      t2 = time.time()
      tf.logging.info(
          '{} sec used saving programs and cache in iteration {}'.format(
              t2 - t1, i))

    n = len(all_envs)
    solution_ratio = len([env for env in all_envs if program_dict[env.name]
                         ]) * 1.0 / n
    tf.logging.info(
        'At least one solution found ratio: {}'.format(solution_ratio))
    n_programs_per_env = np.array(
        [len(program_dict[env.name]) for env in all_envs])
    tf.logging.info(
        'number of solutions found per example: max {}, min {}, avg {}, std {}'
        .format(n_programs_per_env.max(), n_programs_per_env.min(),
                n_programs_per_env.mean(), n_programs_per_env.std()))

    # Macro average length.
    mean_length = np.mean([
        np.mean([len(p)
                 for p in program_dict[env.name]])
        for env in all_envs
        if program_dict[env.name]
    ])
    tf.logging.info('macro average program length: {}'.format(mean_length))