コード例 #1
0
    def save_to_buffer(self, examples, replay_buffer):
        filtered_examples = []

        # for the examples that already obtained annotation, screen out the ones that do not fit
        for example in examples:
            annotation = self.env_annotation_dict.get(example.traj.env_name,
                                                      None)
            if annotation is not None and len(annotation) > 0:
                explored_program = agent_factory.traj_to_program(
                    example.traj, self.decode_vocab)
                if not annotation.verify_program(explored_program):
                    continue
            filtered_examples.append(example)

        replay_buffer.save(filtered_examples)
コード例 #2
0
    def annotate_example(self, envs):
        # first create envs from jsons
        envs = json_to_envs(envs)

        if FLAGS.executor == 'wtq':
            oracle_envs, oracle_trajs = get_wtq_annotations(envs)
        else:
            oracle_envs, oracle_trajs = get_env_trajs(envs)
        oracle_env_programs = [(env.name,
                                traj_to_program(traj, envs[0].de_vocab))
                               for env, traj in zip(oracle_envs, oracle_trajs)]
        env_name_program_dict = dict(oracle_env_programs)

        env_name_annotation_dict = dict()
        for env in envs:
            program = env_name_program_dict.get(env.name, None)
            if program is not None:
                annotation = OracleAnnotation(env, program)
                env_name_annotation_dict[env.name] = annotation
            else:
                env_name_annotation_dict[env.name] = None

        return env_name_annotation_dict
コード例 #3
0
    def annotate_example_exploration(self, envs):
        # first create envs from jsons
        envs = json_to_envs(envs)

        if FLAGS.executor == 'wtq':
            oracle_envs, oracle_trajs = get_wtq_annotations(envs)
        else:
            oracle_envs, oracle_trajs = get_env_trajs(envs)

        oracle_env_programs = [(env.name,
                                traj_to_program(traj, envs[0].de_vocab))
                               for env, traj in zip(oracle_envs, oracle_trajs)]
        env_name_program_dict = dict(oracle_env_programs)
        env_name_annotation_dict = dict()

        for env in envs:
            program = env_name_program_dict.get(env.name, None)

            if program is None:
                env_name_annotation_dict[env.name] = None
            else:
                sketch = get_sketch(program)
                explored_programs = SketchAnnotator.explore_sketch_programs(
                    env, sketch)
                oracle_trajs = [
                    collect_traj_for_program(env, explored_program)
                    for explored_program in explored_programs
                ]
                samples = [
                    agent_factory.Sample(oracle_traj, prob=1.0)
                    for oracle_traj in oracle_trajs
                ]

                annotation = SketchAnnotation(env, sketch, samples)
                env_name_annotation_dict[env.name] = annotation

        return env_name_annotation_dict
コード例 #4
0
  def run(self):
    agent, envs = init_experiment(self.fns, FLAGS.eval_use_gpu, gpu_id=str(FLAGS.eval_gpu_id))
    for env in envs:
      env.punish_extra_work = False
    graph = agent.model.graph
    dev_writer = tf.summary.FileWriter(os.path.join(
      get_experiment_dir(), FLAGS.tb_log_dir, 'dev'))
    best_dev_avg_return = 0.0
    best_model_path = ''
    best_model_dir = os.path.join(get_experiment_dir(), FLAGS.best_model_dir)
    if not tf.gfile.Exists(best_model_dir):
      tf.gfile.MkDir(best_model_dir)
    i = 0
    current_ckpt = get_init_model_path()
    env_dict = dict([(env.name, env) for env in envs])
    while True:
      t1 = time.time()
      tf.logging.info('dev: iteration {}, evaluating {}.'.format(i, current_ckpt))
      env_batch_size = FLAGS.eval_batch_size
      env_iterator = data_utils.BatchIterator(
        dict(envs=envs), shuffle=False,
        batch_size=env_batch_size)
      dev_samples = []
      dev_samples_in_beam = []
      for j, batch_dict in enumerate(env_iterator):
        t3 = time.time()
        batch_envs = batch_dict['envs']
        tf.logging.info('=' * 50)
        tf.logging.info('{} iteration {}, batch {}: {} envs'.format(
          self.name, i, j, len(batch_envs)))
        new_samples_in_beam = agent.beam_search(
          batch_envs, beam_size=FLAGS.eval_beam_size)
        dev_samples_in_beam += new_samples_in_beam
        tf.logging.info('{} samples in beam, batch {}.'.format(
          len(new_samples_in_beam), j))
        t4 = time.time()
        tf.logging.info('{} sec used in evaluator batch {}.'.format(t4 - t3, j))

      # Account for beam search where the beam doesn't
      # contain any examples without error, which will make
      # len(dev_samples) smaller than len(envs).
      dev_samples = self.select_top(dev_samples_in_beam)
      dev_avg_return, dev_avg_len = agent.evaluate(
        dev_samples, writer=dev_writer, true_n=len(envs))
      tf.logging.info('{} samples in non-empty beam.'.format(len(dev_samples)))
      tf.logging.info('true n is {}'.format(len(envs)))
      tf.logging.info('{} questions in dev set.'.format(len(envs)))
      tf.logging.info('{} dev avg return.'.format(dev_avg_return))

      tf.logging.info('dev: avg return: {}, avg length: {}.'.format(
        dev_avg_return, dev_avg_len))
      if dev_avg_return > best_dev_avg_return:
        best_model_path = graph.save(
          os.path.join(best_model_dir, 'model'),
          agent.model.get_global_step())
        best_dev_avg_return = dev_avg_return
        tf.logging.info('New best dev avg returns is {}'.format(best_dev_avg_return))
        tf.logging.info('New best model is saved in {}'.format(best_model_path))
        with open(os.path.join(get_experiment_dir(), 'best_model_info.json'), 'w') as f:
          result = {'best_model_path': best_model_path}
          if FLAGS.eval_only:
            result['best_eval_avg_return'] = best_dev_avg_return
          else:
            result['best_dev_avg_return'] = best_dev_avg_return
          json.dump(result, f)

      if FLAGS.eval_only:
        # Save the decoding results for further. 
        dev_programs_in_beam_dict = {}
        for sample in dev_samples_in_beam:
          name = sample.traj.env_name
          program = agent_factory.traj_to_program(sample.traj, batch_envs[0].de_vocab)
          answer = sample.traj.answer
          if name in dev_programs_in_beam_dict:
            dev_programs_in_beam_dict[name].append((program, answer, sample.prob))
          else:
            dev_programs_in_beam_dict[name] = [(program, answer, sample.prob)]

        t3 = time.time()
        with open(
            os.path.join(get_experiment_dir(), 'dev_programs_in_beam_{}.json'.format(i)),
            'w') as f:
          json.dump(dev_programs_in_beam_dict, f)
        t4 = time.time()
        tf.logging.info('{} sec used dumping programs in beam in eval iteration {}.'.format(
          t4 - t3, i))

        t3 = time.time()
        with codecs.open(
            os.path.join(
              get_experiment_dir(), 'dev_samples_{}.txt'.format(i)),
            'w', encoding='utf-8') as f:
          for sample in dev_samples:
            f.write(show_samples([sample], envs[0].de_vocab, env_dict))
        t4 = time.time()
        tf.logging.info('{} sec used logging dev samples in eval iteration {}.'.format(
          t4 - t3, i))

      t2 = time.time()
      tf.logging.info('{} sec used in eval iteration {}.'.format(
        t2 - t1, i))

      if FLAGS.eval_only or agent.model.get_global_step() >= FLAGS.n_steps:
        tf.logging.info('{} finished'.format(self.name))
        if FLAGS.eval_only:
          print('Eval average return (accuracy) of the best model is {}'.format(
            best_dev_avg_return))
        else:
          print('Best dev average return (accuracy) is {}'.format(best_dev_avg_return))
          print('Best model is saved in {}'.format(best_model_path))
        return

      # Reload on the latest model.
      new_ckpt = None
      t1 = time.time()
      while new_ckpt is None or new_ckpt == current_ckpt:
        time.sleep(1)
        new_ckpt = tf.train.latest_checkpoint(
          os.path.join(get_experiment_dir(), FLAGS.saved_model_dir))
      t2 = time.time()
      tf.logging.info('{} sec used waiting for new checkpoint in evaluator.'.format(
        t2-t1))
      
      tf.logging.info('lastest ckpt to evaluate is {}.'.format(new_ckpt))
      tf.logging.info('{} loading ckpt {}'.format(self.name, new_ckpt))
      t1 = time.time()
      graph.restore(new_ckpt)
      t2 = time.time()
      tf.logging.info('{} sec used {} loading ckpt {}'.format(
        t2-t1, self.name, new_ckpt))
      current_ckpt = new_ckpt
コード例 #5
0
  def run(self):
    agent, envs = init_experiment(self.fns, FLAGS.eval_use_gpu, gpu_id=str(FLAGS.eval_gpu_id))
    for env in envs:
      env.punish_extra_work = False
    graph = agent.model.graph
    dev_writer = tf.summary.FileWriter(os.path.join(
      get_experiment_dir(), FLAGS.tb_log_dir, 'dev'))
    best_dev_avg_return = 0.0
    best_model_path = ''
    best_model_dir = os.path.join(get_experiment_dir(), FLAGS.best_model_dir)
    if not tf.gfile.Exists(best_model_dir):
      tf.gfile.MkDir(best_model_dir)
    i = 0
    current_ckpt = get_init_model_path()
    env_dict = dict([(env.name, env) for env in envs])
    while True:
      t1 = time.time()
      tf.logging.info('dev: iteration {}, evaluating {}.'.format(i, current_ckpt))

      dev_avg_return, dev_samples, dev_samples_in_beam = beam_search_eval(
        agent, envs, writer=dev_writer)
      
      if dev_avg_return > best_dev_avg_return:
        best_model_path = graph.save(
          os.path.join(best_model_dir, 'model'),
          agent.model.get_global_step())
        best_dev_avg_return = dev_avg_return
        tf.logging.info('New best dev avg returns is {}'.format(best_dev_avg_return))
        tf.logging.info('New best model is saved in {}'.format(best_model_path))
        with open(os.path.join(get_experiment_dir(), 'best_model_info.json'), 'w') as f:
          result = {'best_model_path': compress_home_path(best_model_path)}
          if FLAGS.eval_only:
            result['best_eval_avg_return'] = best_dev_avg_return
          else:
            result['best_dev_avg_return'] = best_dev_avg_return
          json.dump(result, f)

      if FLAGS.eval_only:
        # Save the decoding results for further. 
        dev_programs_in_beam_dict = {}
        for sample in dev_samples_in_beam:
          name = sample.traj.env_name
          program = agent_factory.traj_to_program(sample.traj, envs[0].de_vocab)
          answer = sample.traj.answer
          if name in dev_programs_in_beam_dict:
            dev_programs_in_beam_dict[name].append((program, answer, sample.prob))
          else:
            dev_programs_in_beam_dict[name] = [(program, answer, sample.prob)]

        t3 = time.time()
        with open(
            os.path.join(get_experiment_dir(), 'dev_programs_in_beam_{}.json'.format(i)),
            'w') as f:
          json.dump(dev_programs_in_beam_dict, f)
        t4 = time.time()
        tf.logging.info('{} sec used dumping programs in beam in eval iteration {}.'.format(
          t4 - t3, i))

        t3 = time.time()
        with codecs.open(
            os.path.join(
              get_experiment_dir(), 'dev_samples_{}.txt'.format(i)),
            'w', encoding='utf-8') as f:
          for sample in dev_samples:
            f.write(show_samples([sample], envs[0].de_vocab, env_dict))
        t4 = time.time()
        tf.logging.info('{} sec used logging dev samples in eval iteration {}.'.format(
          t4 - t3, i))

      t2 = time.time()
      tf.logging.info('{} sec used in eval iteration {}.'.format(
        t2 - t1, i))

      if FLAGS.eval_only or agent.model.get_global_step() >= FLAGS.n_steps:
        tf.logging.info('{} finished'.format(self.name))
        if FLAGS.eval_only:
          print('Eval average return (accuracy) of the best model is {}'.format(
            best_dev_avg_return))
        else:
          print('Best dev average return (accuracy) is {}'.format(best_dev_avg_return))
          print('Best model is saved in {}'.format(best_model_path))
        return

      # Reload on the latest model.
      new_ckpt = None
      t1 = time.time()
      while new_ckpt is None or new_ckpt == current_ckpt:
        time.sleep(1)
        new_ckpt = tf.train.latest_checkpoint(
          os.path.join(get_experiment_dir(), FLAGS.saved_model_dir))
      t2 = time.time()
      tf.logging.info('{} sec used waiting for new checkpoint in evaluator.'.format(
        t2-t1))
      
      tf.logging.info('lastest ckpt to evaluate is {}.'.format(new_ckpt))
      tf.logging.info('{} loading ckpt {}'.format(self.name, new_ckpt))
      t1 = time.time()
      graph.restore(new_ckpt)
      t2 = time.time()
      tf.logging.info('{} sec used {} loading ckpt {}'.format(
        t2-t1, self.name, new_ckpt))
      current_ckpt = new_ckpt
コード例 #6
0
ファイル: error_analysis.py プロジェクト: niansong1996/wassp
def wikisql_error_analysis():
    env_file = '/Users/ansongni/projects/data/wikisql/processed_input/preprocess_4/test_split.jsonl'
    #train_env_file_prefix = '/Users/ansongni/projects/data/wikisql/processed_input/preprocess_4/train_split_shard_30-'
    decoded_beam_file = '/Users/ansongni/projects/data/wikisql/output/eval_imp_baseline/dev_programs_in_beam_0.json'
    #decoded_beam_file = '/Users/ansongni/projects/data/wikisql/output/train_eval_imp_baseline/dev_programs_in_beam_0.json'

    # first load the test environments and get oracle programs
    #test_envs = get_envs([(train_env_file_prefix+str(i)+'.jsonl') for i in range(0, 30)])
    test_envs = get_envs([env_file])
    envs, trajs = get_env_trajs(test_envs)

    oracle_env_programs = [(env, traj_to_program(traj, envs[0].de_vocab))
                           for env, traj in zip(envs, trajs)]

    # then load decoded results in the beam
    with open(decoded_beam_file) as f:
        decoded_beam = json.load(f)

    example_results = []

    # generate the detailed example result for each env that got oracle
    for env, oracle_program in oracle_env_programs:
        id = env.name
        table = env.question_annotation['context']
        question = env.question_annotation['question']
        oracle_answer = env.question_annotation['answer']

        # take care of missing example
        hyps = decoded_beam.get(id, None)
        if hyps is None:
            continue

        beam = []
        for hyp in hyps:
            prob = hyp[2]
            predicted_program = hyp[0]
            predicted_answer = hyp[1]

            program_hyp = program_result(prob, predicted_program,
                                         predicted_answer)
            beam.append(program_hyp)

        result = detailed_example_result(id, table, question, oracle_program,
                                         oracle_answer, beam)
        example_results.append(result)

    # 1. now we do overall error analysis
    print('%d test examples, %d have oracle program, %d gets evaluated' %
          (len(test_envs), len(envs), len(example_results)))
    overall_analysis(example_results)
    confidence_analysis(example_results)

    # 2. now we analyze the failed cases
    failed_example_results = filter(
        lambda result: 1.0 - wikisql_score(
            result.hyps_in_beam[0].predicted_answer, result.oracle_answer),
        example_results)
    print('%d failed examples' % len(failed_example_results))
    overall_analysis(failed_example_results)
    confidence_analysis(failed_example_results)
    failed_words_analysis(example_results, failed_example_results)
コード例 #7
0
    def decode_sketch_program(envs):
        # first create the real envs from the jsons and add constraints to them
        envs = json_to_envs(envs)
        env_name_dict = dict(map(lambda env: (env.name, env), envs))

        if FLAGS.executor == 'wtq':
            oracle_envs, oracle_trajs = get_wtq_annotations(envs)
        else:
            oracle_envs, oracle_trajs = get_env_trajs(envs)

        env_sketch_dict = dict([
            (env.name, get_sketch(traj_to_program(traj, envs[0].de_vocab)))
            for env, traj in zip(oracle_envs, oracle_trajs)
        ])
        for env in envs:
            sketch = env_sketch_dict.get(env.name, None)
            if sketch is not None:
                env.set_sketch_constraint(sketch[:])

        # create the agent
        graph_config = get_saved_graph_config()
        graph_config['use_gpu'] = False
        graph_config['gpu_id'] = '0'
        init_model_path = get_init_model_path()
        agent = create_agent(graph_config, init_model_path)

        # beam search
        beam_samples = []
        env_iterator = data_utils.BatchIterator(
            dict(envs=envs), shuffle=False, batch_size=FLAGS.eval_batch_size)
        for j, batch_dict in tqdm(enumerate(env_iterator)):
            batch_envs = batch_dict['envs']
            beam_samples += agent.beam_search(batch_envs, beam_size=50)

        # group the samples into beams (because the impl is so bad)
        env_beam_dict = dict()
        for sample in beam_samples:
            env_beam_dict[sample.traj.env_name] = env_beam_dict.get(
                sample.traj.env_name, []) + [sample]

        # get the trajs with 1.0 reward for each example and re-weight the prob
        env_name_annotation_dict = dict()
        for env_name, env in env_name_dict.iteritems():
            beam = env_beam_dict.get(env_name, [])
            success_beam = filter(lambda x: x.traj.rewards[-1] == 1.0, beam)
            if len(success_beam) > 0:
                # retrieve the sketch result from previous steps
                sketch = env_sketch_dict.get(env_name, None)

                if sketch is None:
                    env_name_annotation_dict[env_name] = None
                else:
                    # re-weight the examples in the beam
                    prob_sum = sum(
                        map(lambda sample: sample.prob, success_beam))
                    success_beam = map(
                        lambda sample: agent_factory.Sample(
                            traj=sample.traj, prob=sample.prob / prob_sum),
                        success_beam)
                    if len(success_beam) > 10:
                        success_beam = sorted(success_beam,
                                              key=lambda sample: sample.prob,
                                              reverse=True)
                        success_beam = success_beam[:10]

                    annotation = SketchAnnotation(env, sketch, success_beam)
                    env_name_annotation_dict[env_name] = annotation
            else:
                env_name_annotation_dict[env_name] = None

        return env_name_annotation_dict
コード例 #8
0
    def explore_sketch_programs(env, sketch, max_hyp=1000):
        '''provide a sketch, find all the executable programs fit that sketch'''
        env = env.clone()
        env.use_cache = False

        all_hyps = [(env, env.start_ob)]

        for cmd in sketch:
            # first add the left bracket and the cmd head
            new_all_hyps = []
            for env, ob in all_hyps:
                try:
                    action_index = list(ob[0].valid_indices).index(
                        env.de_vocab.lookup('('))
                    ob, _, _, _ = env.step(action_index)
                    action_index = list(ob[0].valid_indices).index(
                        env.de_vocab.lookup(cmd))
                    ob, _, _, _ = env.step(action_index)

                    new_all_hyps.append((env, ob))
                except ValueError:
                    raise ValueError('This should not happen')
            all_hyps = new_all_hyps

            # then for every possible action, we use a cloned env and step that until the closure of this stmt
            stmt_done = False
            while not stmt_done:
                new_all_hyps = []
                for env, ob in all_hyps:
                    valid_indices = list(ob[0].valid_indices)
                    if len(valid_indices
                           ) == 0:  # this is a dead end for current env
                        continue
                    for action_index in range(len(valid_indices)):
                        new_env = env.clone()
                        new_env.use_cache = False

                        ob, _, _, _ = new_env.step(action_index)
                        new_all_hyps.append((new_env, ob))
                    if valid_indices == [env.de_vocab.lookup(')')
                                         ]:  # maximum stmt length is reached
                        stmt_done = True
                all_hyps = list(np.random.permutation(new_all_hyps))[:max_hyp]

        # add then end token
        new_all_hyps = []
        for env, ob in all_hyps:
            try:
                action_index = list(ob[0].valid_indices).index(
                    env.de_vocab.lookup('<END>'))
                ob, _, _, _ = env.step(action_index)

                new_all_hyps.append((env, ob))
            except ValueError:
                raise ValueError('This should not happen')
        all_hyps = new_all_hyps

        # pruning based on the result
        explored_programs = []
        for env, _ in all_hyps:
            if env.rewards[-1] == 1.0:
                traj = agent_factory.Traj(obs=env.obs,
                                          actions=env.actions,
                                          rewards=env.rewards,
                                          context=env.get_context(),
                                          env_name=env.name,
                                          answer=env.interpreter.result)
                explored_programs.append(traj_to_program(traj, env.de_vocab))

        explored_programs = list(np.random.permutation(explored_programs))

        return explored_programs
コード例 #9
0
 def __init__(self, env, sketch, samples):
     Annotation.__init__(self, env.name)
     self.sketch = sketch
     self.sketch_programs = [(traj_to_program(sample.traj,
                                              env.de_vocab), sample.prob)
                             for sample in samples]
コード例 #10
0
    def run(self):
        agent, envs = init_experiment(
            [get_train_shard_path(i) for i in self.shard_ids],
            use_gpu=FLAGS.actor_use_gpu,
            gpu_id=str(self.actor_id + FLAGS.actor_gpu_start_id))
        self.decode_vocab = envs[0].de_vocab

        graph = agent.model.graph
        current_ckpt = get_init_model_path()

        env_dict = dict([(env.name, env) for env in envs])
        replay_buffer = agent_factory.AllGoodReplayBuffer(
            agent, envs[0].de_vocab)

        # Load saved programs to warm start the replay buffer.
        if FLAGS.load_saved_programs:
            load_programs(envs, replay_buffer, FLAGS.saved_program_file)

        if FLAGS.save_replay_buffer_at_end:
            replay_buffer_copy = agent_factory.AllGoodReplayBuffer(
                de_vocab=envs[0].de_vocab)
            replay_buffer_copy.program_prob_dict = copy.deepcopy(
                replay_buffer.program_prob_dict)

        # shrink the annotation dict to the envs needed
        small_env_annotation_dict = dict()
        for env in envs:
            annotation = self.env_annotation_dict.get(env.name, None)
            if annotation is not None:
                small_env_annotation_dict[env.name] = annotation
        self.env_annotation_dict = small_env_annotation_dict
        print('Actor %d, total %d envs, %d has been annotated.' %
              (self.actor_id, len(envs), len(self.env_annotation_dict)))

        # get samples from the annotations and put them into the buffer
        env_name_dict = dict([(env.name, env) for env in envs])
        if len(self.env_annotation_dict) > 0:
            annotated_samples = []
            for env_name, annotation in self.env_annotation_dict.items():
                samples_from_annotation = annotation.get_samples(
                    env_name_dict[env_name])
                annotated_samples += samples_from_annotation
            self.save_to_buffer(annotated_samples, replay_buffer)

        i = 0
        while True:
            # Create the logging files.
            if FLAGS.log_samples_every_n_epoch > 0 and i % FLAGS.log_samples_every_n_epoch == 0:
                f_replay = codecs.open(os.path.join(
                    get_experiment_dir(),
                    'replay_samples_{}_{}.txt'.format(self.name, i)),
                                       'w',
                                       encoding='utf-8')
                f_policy = codecs.open(os.path.join(
                    get_experiment_dir(),
                    'policy_samples_{}_{}.txt'.format(self.name, i)),
                                       'w',
                                       encoding='utf-8')
                f_train = codecs.open(os.path.join(
                    get_experiment_dir(),
                    'train_samples_{}_{}.txt'.format(self.name, i)),
                                      'w',
                                      encoding='utf-8')

            n_train_samples = 0
            if FLAGS.use_replay_samples_in_train:
                n_train_samples += FLAGS.n_replay_samples

            if FLAGS.use_policy_samples_in_train and FLAGS.use_nonreplay_samples_in_train:
                raise ValueError(
                    'Cannot use both on-policy samples and nonreplay samples for training!'
                )

            if FLAGS.use_policy_samples_in_train or FLAGS.use_nonreplay_samples_in_train:
                # Note that nonreplay samples are drawn by rejection
                # sampling from on-policy samples.
                n_train_samples += FLAGS.n_policy_samples

            # Make sure that all the samples from the env batch
            # fits into one batch for training.
            if FLAGS.batch_size < n_train_samples:
                raise ValueError(
                    'One batch have to at least contain samples from one environment.'
                )

            env_batch_size = FLAGS.batch_size / n_train_samples

            env_iterator = data_utils.BatchIterator(dict(envs=envs),
                                                    shuffle=True,
                                                    batch_size=env_batch_size)

            for j, batch_dict in enumerate(env_iterator):
                batch_envs = batch_dict['envs']
                tf.logging.info('=' * 50)
                tf.logging.info('{} iteration {}, batch {}: {} envs'.format(
                    self.name, i, j, len(batch_envs)))

                t1 = time.time()
                # Generate samples with cache and save to replay buffer.
                t3 = time.time()
                n_explore = 0
                for _ in xrange(FLAGS.n_explore_samples):
                    explore_samples = agent.generate_samples(
                        batch_envs,
                        n_samples=1,
                        use_cache=FLAGS.use_cache,
                        greedy=FLAGS.greedy_exploration)
                    self.save_to_buffer(explore_samples, replay_buffer)
                    n_explore += len(explore_samples)

                if FLAGS.n_extra_explore_for_hard > 0:
                    hard_envs = [
                        env for env in batch_envs
                        if not replay_buffer.has_found_solution(env.name)
                    ]
                    if hard_envs:
                        for _ in xrange(FLAGS.n_extra_explore_for_hard):
                            explore_samples = agent.generate_samples(
                                hard_envs,
                                n_samples=1,
                                use_cache=FLAGS.use_cache,
                                greedy=FLAGS.greedy_exploration)
                            self.save_to_buffer(explore_samples, replay_buffer)
                            n_explore += len(explore_samples)

                t4 = time.time()
                tf.logging.info(
                    '{} sec used generating {} exploration samples.'.format(
                        t4 - t3, n_explore))

                tf.logging.info(
                    '{} samples saved in the replay buffer.'.format(
                        replay_buffer.size))

                t3 = time.time()
                replay_samples = replay_buffer.replay(
                    batch_envs,
                    FLAGS.n_replay_samples,
                    use_top_k=FLAGS.use_top_k_replay_samples,
                    agent=None if FLAGS.random_replay_samples else agent,
                    truncate_at_n=FLAGS.truncate_replay_buffer_at_n)
                t4 = time.time()
                tf.logging.info(
                    '{} sec used selecting {} replay samples.'.format(
                        t4 - t3, len(replay_samples)))

                t3 = time.time()
                if FLAGS.use_top_k_policy_samples:
                    if FLAGS.n_policy_samples == 1:
                        policy_samples = agent.generate_samples(
                            batch_envs,
                            n_samples=FLAGS.n_policy_samples,
                            greedy=True)
                    else:
                        policy_samples = agent.beam_search(
                            batch_envs, beam_size=FLAGS.n_policy_samples)
                else:
                    policy_samples = agent.generate_samples(
                        batch_envs,
                        n_samples=FLAGS.n_policy_samples,
                        greedy=False)
                t4 = time.time()
                tf.logging.info(
                    '{} sec used generating {} on-policy samples'.format(
                        t4 - t3, len(policy_samples)))

                t2 = time.time()
                tf.logging.info(
                    ('{} sec used generating replay and on-policy samples,'
                     ' {} iteration {}, batch {}: {} envs').format(
                         t2 - t1, self.name, i, j, len(batch_envs)))

                t1 = time.time()
                self.eval_queue.put((policy_samples, len(batch_envs)))
                self.replay_queue.put((replay_samples, len(batch_envs)))

                assert (FLAGS.fixed_replay_weight >= 0.0
                        and FLAGS.fixed_replay_weight <= 1.0)

                if FLAGS.use_replay_prob_as_weight:
                    new_samples = []
                    for sample in replay_samples:
                        name = sample.traj.env_name
                        if name in replay_buffer.prob_sum_dict:
                            replay_prob = max(
                                replay_buffer.prob_sum_dict[name],
                                FLAGS.min_replay_weight)
                        else:
                            replay_prob = 0.0
                        scale = replay_prob
                        new_samples.append(
                            agent_factory.Sample(traj=sample.traj,
                                                 prob=sample.prob * scale))
                    replay_samples = new_samples
                else:
                    replay_samples = agent_factory.scale_probs(
                        replay_samples, FLAGS.fixed_replay_weight)

                replay_samples = sorted(replay_samples,
                                        key=lambda x: x.traj.env_name)

                policy_samples = sorted(policy_samples,
                                        key=lambda x: x.traj.env_name)

                if FLAGS.use_nonreplay_samples_in_train:
                    nonreplay_samples = []
                    for sample in policy_samples:
                        if not replay_buffer.contain(sample.traj):
                            nonreplay_samples.append(sample)

                self.save_to_buffer(policy_samples, replay_buffer)

                def weight_samples(samples):
                    if FLAGS.use_replay_prob_as_weight:
                        new_samples = []
                        for sample in samples:
                            name = sample.traj.env_name
                            if name in replay_buffer.prob_sum_dict:
                                replay_prob = max(
                                    replay_buffer.prob_sum_dict[name],
                                    FLAGS.min_replay_weight)
                            else:
                                replay_prob = 0.0
                            scale = 1.0 - replay_prob
                            new_samples.append(
                                agent_factory.Sample(traj=sample.traj,
                                                     prob=sample.prob * scale))
                    else:
                        new_samples = agent_factory.scale_probs(
                            samples, 1 - FLAGS.fixed_replay_weight)
                    return new_samples

                train_samples = []
                if FLAGS.use_replay_samples_in_train:
                    if FLAGS.use_trainer_prob:
                        replay_samples = [
                            sample._replace(prob=None)
                            for sample in replay_samples
                        ]
                    train_samples += replay_samples

                if FLAGS.use_policy_samples_in_train:
                    train_samples += weight_samples(policy_samples)

                if FLAGS.use_nonreplay_samples_in_train:
                    train_samples += weight_samples(nonreplay_samples)

                train_samples = sorted(train_samples,
                                       key=lambda x: x.traj.env_name)
                tf.logging.info('{} train samples'.format(len(train_samples)))

                if FLAGS.use_importance_sampling:
                    step_logprobs = agent.compute_step_logprobs(
                        [s.traj for s in train_samples])
                else:
                    step_logprobs = None

                if FLAGS.use_replay_prob_as_weight:
                    n_clip = 0
                    for env in batch_envs:
                        name = env.name
                        if (name in replay_buffer.prob_sum_dict
                                and replay_buffer.prob_sum_dict[name] <
                                FLAGS.min_replay_weight):
                            n_clip += 1
                    clip_frac = float(n_clip) / len(batch_envs)
                else:
                    clip_frac = 0.0

                # put all weight on the annotated ones at first, then gradually increase explored examples
                al_scale_factor = min(
                    1.0,
                    (agent.model.get_global_step() - FLAGS.active_start_step) /
                    float(FLAGS.active_scale_steps))
                assert (al_scale_factor >= 0.0 and al_scale_factor <= 1.0)
                for i, sample in enumerate(train_samples):
                    if sample.traj.env_name not in self.env_annotation_dict:
                        train_samples[i] = agent_factory.Sample(
                            traj=sample.traj,
                            prob=sample.prob * al_scale_factor)
                    else:
                        annotation = self.env_annotation_dict[
                            sample.traj.env_name]
                        explored_program = agent_factory.traj_to_program(
                            sample.traj, self.decode_vocab)
                        if not annotation.verify_program(explored_program):
                            train_samples[i] = agent_factory.Sample(
                                traj=sample.traj,
                                prob=sample.prob * al_scale_factor)

                self.train_queue.put((train_samples, step_logprobs, clip_frac))
                t2 = time.time()
                tf.logging.info(
                    ('{} sec used preparing and enqueuing samples, {}'
                     ' iteration {}, batch {}: {} envs').format(
                         t2 - t1, self.name, i, j, len(batch_envs)))

                t1 = time.time()
                # Wait for a ckpt that still exist or it is the same
                # ckpt (no need to load anything).
                while True:
                    new_ckpt = self.ckpt_queue.get()
                    new_ckpt_file = new_ckpt + '.meta'
                    if new_ckpt == current_ckpt or tf.gfile.Exists(
                            new_ckpt_file):
                        break
                t2 = time.time()
                tf.logging.info(
                    '{} sec waiting {} iteration {}, batch {}'.format(
                        t2 - t1, self.name, i, j))

                if new_ckpt != current_ckpt:
                    # If the ckpt is not the same, then restore the new
                    # ckpt.
                    tf.logging.info('{} loading ckpt {}'.format(
                        self.name, new_ckpt))
                    t1 = time.time()
                    graph.restore(new_ckpt)
                    t2 = time.time()
                    tf.logging.info('{} sec used {} restoring ckpt {}'.format(
                        t2 - t1, self.name, new_ckpt))
                    current_ckpt = new_ckpt

                if FLAGS.log_samples_every_n_epoch > 0 and i % FLAGS.log_samples_every_n_epoch == 0:
                    f_replay.write(
                        show_samples(replay_samples, envs[0].de_vocab,
                                     env_dict))
                    f_policy.write(
                        show_samples(policy_samples, envs[0].de_vocab,
                                     env_dict))
                    f_train.write(
                        show_samples(train_samples, envs[0].de_vocab,
                                     env_dict))

            if FLAGS.log_samples_every_n_epoch > 0 and i % FLAGS.log_samples_every_n_epoch == 0:
                f_replay.close()
                f_policy.close()
                f_train.close()

            if agent.model.get_global_step() >= FLAGS.n_steps:
                if FLAGS.save_replay_buffer_at_end:
                    all_replay = os.path.join(
                        get_experiment_dir(),
                        'all_replay_samples_{}.txt'.format(self.name))
                with codecs.open(all_replay, 'w', encoding='utf-8') as f:
                    samples = replay_buffer.all_samples(envs, agent=None)
                    samples = [
                        s for s in samples
                        if not replay_buffer_copy.contain(s.traj)
                    ]
                    f.write(show_samples(samples, envs[0].de_vocab, None))

                tf.logging.info('{} finished'.format(self.name))
                return
            i += 1