Example #1
0
def beam_search_eval(agent, envs, writer=None):
    env_batch_size = FLAGS.eval_batch_size
    env_iterator = data_utils.BatchIterator(
      dict(envs=envs), shuffle=False,
      batch_size=env_batch_size)
    dev_samples = []
    dev_samples_in_beam = []
    for j, batch_dict in enumerate(env_iterator):
      t1 = time.time()
      batch_envs = batch_dict['envs']
      tf.logging.info('=' * 50)
      tf.logging.info('eval, batch {}: {} envs'.format(j, len(batch_envs)))
      new_samples_in_beam = agent.beam_search(
        batch_envs, beam_size=FLAGS.eval_beam_size)
      dev_samples_in_beam += new_samples_in_beam
      tf.logging.info('{} samples in beam, batch {}.'.format(
        len(new_samples_in_beam), j))
      t2 = time.time()
      tf.logging.info('{} sec used in evaluator batch {}.'.format(t2 - t1, j))

    # Account for beam search where the beam doesn't
    # contain any examples without error, which will make
    # len(dev_samples) smaller than len(envs).
    dev_samples = select_top(dev_samples_in_beam)
    dev_avg_return, dev_avg_len = agent.evaluate(
      dev_samples, writer=writer, true_n=len(envs))
    tf.logging.info('{} samples in non-empty beam.'.format(len(dev_samples)))
    tf.logging.info('true n is {}'.format(len(envs)))
    tf.logging.info('{} questions in dev set.'.format(len(envs)))
    tf.logging.info('{} dev avg return.'.format(dev_avg_return))
    tf.logging.info('dev: avg return: {}, avg length: {}.'.format(
      dev_avg_return, dev_avg_len))

    return dev_avg_return, dev_samples, dev_samples_in_beam
Example #2
0
    def eval(envs):
        # first create the real envs from the jsons
        envs = json_to_envs(envs)

        # create the agent
        graph_config = get_saved_graph_config()
        graph_config['use_gpu'] = False
        graph_config['gpu_id'] = '0'
        init_model_path = get_init_model_path()
        agent = create_agent(graph_config, init_model_path)

        # greedy decode
        greedy_samples = []
        env_iterator = data_utils.BatchIterator(
            dict(envs=envs), shuffle=False, batch_size=FLAGS.eval_batch_size)
        for j, batch_dict in tqdm(enumerate(env_iterator)):
            batch_envs = batch_dict['envs']
            greedy_samples += agent.generate_samples(batch_envs,
                                                     n_samples=1,
                                                     greedy=True,
                                                     use_cache=False,
                                                     filter_error=False)

        env_sample_list = zip(envs, greedy_samples)
        failed_env_sample_list = filter(lambda x: x[1].traj.rewards[-1] < 1.0,
                                        env_sample_list)

        failed_envs = [env_sample[0] for env_sample in failed_env_sample_list]
        failed_envs = list(np.random.permutation(failed_envs))
        failed_envs = [(1.0, env.name) for env in failed_envs]

        return failed_envs
Example #3
0
    def eval(envs):
        # first create the real envs from the jsons
        envs = json_to_envs(envs)

        # create the agent
        graph_config = get_saved_graph_config()
        graph_config['use_gpu'] = False
        graph_config['gpu_id'] = '0'
        init_model_path = get_init_model_path()
        agent = create_agent(graph_config, init_model_path)

        # greedy decode
        beam_samples = []
        env_iterator = data_utils.BatchIterator(
            dict(envs=envs), shuffle=False, batch_size=FLAGS.eval_batch_size)
        for j, batch_dict in tqdm(enumerate(env_iterator)):
            batch_envs = batch_dict['envs']
            beam_samples += agent.beam_search(batch_envs, beam_size=5)

        # group the samples into beams (because the impl is so bad)
        env_beam_dict = dict()
        for sample in beam_samples:
            env_beam_dict[sample.traj.env_name] = env_beam_dict.get(
                sample.traj.env_name, []) + [sample]

        # get the highest confidence from the beam of each example
        conf_envs = [(1.0 - max(map(lambda x: x.prob, beam)), env_name)
                     for env_name, beam in env_beam_dict.items()]

        return conf_envs
Example #4
0
    def eval(envs):
        # first create the real envs from the jsons
        envs = json_to_envs(envs)

        # create the agent
        graph_config = get_saved_graph_config()
        graph_config['use_gpu'] = False
        graph_config['gpu_id'] = '0'
        init_model_path = get_init_model_path()
        agent = create_agent(graph_config, init_model_path)

        # greedy decode
        beam_samples = []
        env_iterator = data_utils.BatchIterator(
            dict(envs=envs), shuffle=False, batch_size=FLAGS.eval_batch_size)
        for j, batch_dict in tqdm(enumerate(env_iterator)):
            batch_envs = batch_dict['envs']
            beam_samples += agent.beam_search(batch_envs, beam_size=5)

        # group the samples into beams (because the impl is so bad)
        env_beam_dict = dict()
        for sample in beam_samples:
            env_beam_dict[sample.traj.env_name] = env_beam_dict.get(
                sample.traj.env_name, []) + [sample]

        # get the top hyps and find those failed ones
        top_hyps = map(
            lambda (env_name, beam):
            (env_name,
             reduce(lambda s1, s2: s1 if s1.prob > s2.prob else s2, beam)),
            env_beam_dict.items())
        failed_top_hyps = filter(
            lambda (env_name, sample): sample.traj.rewards[-1] == 0.0,
            top_hyps)
        conf_envs = map(lambda (env_name, sample): (sample.prob, env_name),
                        failed_top_hyps)

        return conf_envs
  def run(self):
    agent, envs = init_experiment(self.fns, FLAGS.eval_use_gpu, gpu_id=str(FLAGS.eval_gpu_id))
    for env in envs:
      env.punish_extra_work = False
    graph = agent.model.graph
    dev_writer = tf.summary.FileWriter(os.path.join(
      get_experiment_dir(), FLAGS.tb_log_dir, 'dev'))
    best_dev_avg_return = 0.0
    best_model_path = ''
    best_model_dir = os.path.join(get_experiment_dir(), FLAGS.best_model_dir)
    if not tf.gfile.Exists(best_model_dir):
      tf.gfile.MkDir(best_model_dir)
    i = 0
    current_ckpt = get_init_model_path()
    env_dict = dict([(env.name, env) for env in envs])
    while True:
      t1 = time.time()
      tf.logging.info('dev: iteration {}, evaluating {}.'.format(i, current_ckpt))
      env_batch_size = FLAGS.eval_batch_size
      env_iterator = data_utils.BatchIterator(
        dict(envs=envs), shuffle=False,
        batch_size=env_batch_size)
      dev_samples = []
      dev_samples_in_beam = []
      for j, batch_dict in enumerate(env_iterator):
        t3 = time.time()
        batch_envs = batch_dict['envs']
        tf.logging.info('=' * 50)
        tf.logging.info('{} iteration {}, batch {}: {} envs'.format(
          self.name, i, j, len(batch_envs)))
        new_samples_in_beam = agent.beam_search(
          batch_envs, beam_size=FLAGS.eval_beam_size)
        dev_samples_in_beam += new_samples_in_beam
        tf.logging.info('{} samples in beam, batch {}.'.format(
          len(new_samples_in_beam), j))
        t4 = time.time()
        tf.logging.info('{} sec used in evaluator batch {}.'.format(t4 - t3, j))

      # Account for beam search where the beam doesn't
      # contain any examples without error, which will make
      # len(dev_samples) smaller than len(envs).
      dev_samples = self.select_top(dev_samples_in_beam)
      dev_avg_return, dev_avg_len = agent.evaluate(
        dev_samples, writer=dev_writer, true_n=len(envs))
      tf.logging.info('{} samples in non-empty beam.'.format(len(dev_samples)))
      tf.logging.info('true n is {}'.format(len(envs)))
      tf.logging.info('{} questions in dev set.'.format(len(envs)))
      tf.logging.info('{} dev avg return.'.format(dev_avg_return))

      tf.logging.info('dev: avg return: {}, avg length: {}.'.format(
        dev_avg_return, dev_avg_len))
      if dev_avg_return > best_dev_avg_return:
        best_model_path = graph.save(
          os.path.join(best_model_dir, 'model'),
          agent.model.get_global_step())
        best_dev_avg_return = dev_avg_return
        tf.logging.info('New best dev avg returns is {}'.format(best_dev_avg_return))
        tf.logging.info('New best model is saved in {}'.format(best_model_path))
        with open(os.path.join(get_experiment_dir(), 'best_model_info.json'), 'w') as f:
          result = {'best_model_path': best_model_path}
          if FLAGS.eval_only:
            result['best_eval_avg_return'] = best_dev_avg_return
          else:
            result['best_dev_avg_return'] = best_dev_avg_return
          json.dump(result, f)

      if FLAGS.eval_only:
        # Save the decoding results for further. 
        dev_programs_in_beam_dict = {}
        for sample in dev_samples_in_beam:
          name = sample.traj.env_name
          program = agent_factory.traj_to_program(sample.traj, batch_envs[0].de_vocab)
          answer = sample.traj.answer
          if name in dev_programs_in_beam_dict:
            dev_programs_in_beam_dict[name].append((program, answer, sample.prob))
          else:
            dev_programs_in_beam_dict[name] = [(program, answer, sample.prob)]

        t3 = time.time()
        with open(
            os.path.join(get_experiment_dir(), 'dev_programs_in_beam_{}.json'.format(i)),
            'w') as f:
          json.dump(dev_programs_in_beam_dict, f)
        t4 = time.time()
        tf.logging.info('{} sec used dumping programs in beam in eval iteration {}.'.format(
          t4 - t3, i))

        t3 = time.time()
        with codecs.open(
            os.path.join(
              get_experiment_dir(), 'dev_samples_{}.txt'.format(i)),
            'w', encoding='utf-8') as f:
          for sample in dev_samples:
            f.write(show_samples([sample], envs[0].de_vocab, env_dict))
        t4 = time.time()
        tf.logging.info('{} sec used logging dev samples in eval iteration {}.'.format(
          t4 - t3, i))

      t2 = time.time()
      tf.logging.info('{} sec used in eval iteration {}.'.format(
        t2 - t1, i))

      if FLAGS.eval_only or agent.model.get_global_step() >= FLAGS.n_steps:
        tf.logging.info('{} finished'.format(self.name))
        if FLAGS.eval_only:
          print('Eval average return (accuracy) of the best model is {}'.format(
            best_dev_avg_return))
        else:
          print('Best dev average return (accuracy) is {}'.format(best_dev_avg_return))
          print('Best model is saved in {}'.format(best_model_path))
        return

      # Reload on the latest model.
      new_ckpt = None
      t1 = time.time()
      while new_ckpt is None or new_ckpt == current_ckpt:
        time.sleep(1)
        new_ckpt = tf.train.latest_checkpoint(
          os.path.join(get_experiment_dir(), FLAGS.saved_model_dir))
      t2 = time.time()
      tf.logging.info('{} sec used waiting for new checkpoint in evaluator.'.format(
        t2-t1))
      
      tf.logging.info('lastest ckpt to evaluate is {}.'.format(new_ckpt))
      tf.logging.info('{} loading ckpt {}'.format(self.name, new_ckpt))
      t1 = time.time()
      graph.restore(new_ckpt)
      t2 = time.time()
      tf.logging.info('{} sec used {} loading ckpt {}'.format(
        t2-t1, self.name, new_ckpt))
      current_ckpt = new_ckpt
  def run(self):
    agent, envs = init_experiment(
      [get_train_shard_path(i) for i in self.shard_ids],
      use_gpu=FLAGS.actor_use_gpu,
      gpu_id=str(self.actor_id + FLAGS.actor_gpu_start_id))
    graph = agent.model.graph
    current_ckpt = get_init_model_path()

    env_dict = dict([(env.name, env) for env in envs])
    replay_buffer = agent_factory.AllGoodReplayBuffer(agent, envs[0].de_vocab)

    # Load saved programs to warm start the replay buffer. 
    if FLAGS.load_saved_programs:
      load_programs(
        envs, replay_buffer, FLAGS.saved_program_file)

    i = 0
    while True:
      # Create the logging files. 
      if FLAGS.log_samples_every_n_epoch > 0 and i % FLAGS.log_samples_every_n_epoch == 0:
        f_replay = codecs.open(os.path.join(
          get_experiment_dir(), 'replay_samples_{}_{}.txt'.format(self.name, i)),
                               'w', encoding='utf-8')
        f_policy = codecs.open(os.path.join(
          get_experiment_dir(), 'policy_samples_{}_{}.txt'.format(self.name, i)),
                               'w', encoding='utf-8')
        f_train = codecs.open(os.path.join(
          get_experiment_dir(), 'train_samples_{}_{}.txt'.format(self.name, i)),
                              'w', encoding='utf-8')

      n_train_samples = 0
      if FLAGS.use_replay_samples_in_train:
        n_train_samples += FLAGS.n_replay_samples

      if FLAGS.use_policy_samples_in_train and FLAGS.use_nonreplay_samples_in_train:
        raise ValueError(
          'Cannot use both on-policy samples and nonreplay samples for training!')
        
      if FLAGS.use_policy_samples_in_train or FLAGS.use_nonreplay_samples_in_train:
        # Note that nonreplay samples are drawn by rejection
        # sampling from on-policy samples.
        n_train_samples += FLAGS.n_policy_samples

      # Make sure that all the samples from the env batch
      # fits into one batch for training.
      if FLAGS.batch_size < n_train_samples:
        raise ValueError(
            'One batch have to at least contain samples from one environment.')

      env_batch_size = FLAGS.batch_size / n_train_samples
      
      env_iterator = data_utils.BatchIterator(
        dict(envs=envs), shuffle=True,
        batch_size=env_batch_size)

      for j, batch_dict in enumerate(env_iterator):
        batch_envs = batch_dict['envs']
        tf.logging.info('=' * 50)
        tf.logging.info('{} iteration {}, batch {}: {} envs'.format(
            self.name, i, j, len(batch_envs)))
  
        t1 = time.time()
        # Generate samples with cache and save to replay buffer.
        t3 = time.time()
        n_explore = 0
        for _ in xrange(FLAGS.n_explore_samples):
          explore_samples = agent.generate_samples(
            batch_envs, n_samples=1, use_cache=FLAGS.use_cache,
            greedy=FLAGS.greedy_exploration)
          replay_buffer.save(explore_samples)
          n_explore += len(explore_samples)

        if FLAGS.n_extra_explore_for_hard > 0:
          hard_envs = [env for env in batch_envs
                       if not replay_buffer.has_found_solution(env.name)]
          if hard_envs:
            for _ in xrange(FLAGS.n_extra_explore_for_hard):
              explore_samples = agent.generate_samples(
                hard_envs, n_samples=1, use_cache=FLAGS.use_cache,
                greedy=FLAGS.greedy_exploration)
              replay_buffer.save(explore_samples)
              n_explore += len(explore_samples)

        t4 = time.time()
        tf.logging.info('{} sec used generating {} exploration samples.'.format(
          t4 - t3, n_explore))

        tf.logging.info('{} samples saved in the replay buffer.'.format(
          replay_buffer.size))
        
        t3 = time.time()
        replay_samples = replay_buffer.replay(
          batch_envs, FLAGS.n_replay_samples,
          use_top_k=FLAGS.use_top_k_replay_samples,
          agent=None if FLAGS.random_replay_samples else agent,
          truncate_at_n=FLAGS.truncate_replay_buffer_at_n)
        t4 = time.time()
        tf.logging.info('{} sec used selecting {} replay samples.'.format(
          t4 - t3, len(replay_samples)))
          
        t3 = time.time()
        if FLAGS.use_top_k_policy_samples:
          if FLAGS.n_policy_samples == 1:
            policy_samples = agent.generate_samples(
              batch_envs, n_samples=FLAGS.n_policy_samples,
              greedy=True)
          else:
            policy_samples = agent.beam_search(
              batch_envs, beam_size=FLAGS.n_policy_samples)
        else:
          policy_samples = agent.generate_samples(
            batch_envs, n_samples=FLAGS.n_policy_samples,
            greedy=False)
        t4 = time.time()
        tf.logging.info('{} sec used generating {} on-policy samples'.format(
          t4-t3, len(policy_samples)))

        t2 = time.time()
        tf.logging.info(
          ('{} sec used generating replay and on-policy samples,'
           ' {} iteration {}, batch {}: {} envs').format(
            t2-t1, self.name, i, j, len(batch_envs)))

        t1 = time.time()
        self.eval_queue.put((policy_samples, len(batch_envs)))
        self.replay_queue.put((replay_samples, len(batch_envs)))

        assert (FLAGS.fixed_replay_weight >= 0.0 and FLAGS.fixed_replay_weight <= 1.0)

        if FLAGS.use_replay_prob_as_weight:
          new_samples = []
          for sample in replay_samples:
            name = sample.traj.env_name
            if name in replay_buffer.prob_sum_dict:
              replay_prob = max(
                replay_buffer.prob_sum_dict[name], FLAGS.min_replay_weight)
            else:
              replay_prob = 0.0
            scale = replay_prob
            new_samples.append(
              agent_factory.Sample(
                traj=sample.traj,
                prob=sample.prob * scale))
          replay_samples = new_samples
        else:
          replay_samples = agent_factory.scale_probs(
            replay_samples, FLAGS.fixed_replay_weight)

        replay_samples = sorted(
          replay_samples, key=lambda x: x.traj.env_name)

        policy_samples = sorted(
          policy_samples, key=lambda x: x.traj.env_name)

        if FLAGS.use_nonreplay_samples_in_train:
          nonreplay_samples = []
          for sample in policy_samples:
            if not replay_buffer.contain(sample.traj):
              nonreplay_samples.append(sample)

        replay_buffer.save(policy_samples)

        def weight_samples(samples):
          if FLAGS.use_replay_prob_as_weight:
            new_samples = []
            for sample in samples:
              name = sample.traj.env_name
              if name in replay_buffer.prob_sum_dict:
                replay_prob = max(
                  replay_buffer.prob_sum_dict[name],
                  FLAGS.min_replay_weight)
              else:
                replay_prob = 0.0
              scale = 1.0 - replay_prob
              new_samples.append(
                agent_factory.Sample(
                  traj=sample.traj,
                  prob=sample.prob * scale))
          else:
            new_samples = agent_factory.scale_probs(
              samples, 1 - FLAGS.fixed_replay_weight)
          return new_samples

        train_samples = []
        if FLAGS.use_replay_samples_in_train:
          if FLAGS.use_trainer_prob:
            replay_samples = [
              sample._replace(prob=None) for sample in replay_samples]
          train_samples += replay_samples

        if FLAGS.use_policy_samples_in_train:
          train_samples += weight_samples(policy_samples)

        if FLAGS.use_nonreplay_samples_in_train:
          train_samples += weight_samples(nonreplay_samples)
        
        train_samples = sorted(train_samples, key=lambda x: x.traj.env_name)
        tf.logging.info('{} train samples'.format(len(train_samples)))

        if FLAGS.use_importance_sampling:
          step_logprobs = agent.compute_step_logprobs(
            [s.traj for s in train_samples])
        else:
          step_logprobs = None

        if FLAGS.use_replay_prob_as_weight:
          n_clip = 0
          for env in batch_envs:
            name = env.name
            if (name in replay_buffer.prob_sum_dict and
                replay_buffer.prob_sum_dict[name] < FLAGS.min_replay_weight):
              n_clip += 1
          clip_frac = float(n_clip) / len(batch_envs)
        else:
          clip_frac = 0.0
  
        self.train_queue.put((train_samples, step_logprobs, clip_frac))
        t2 = time.time()
        tf.logging.info(
          ('{} sec used preparing and enqueuing samples, {}'
           ' iteration {}, batch {}: {} envs').format(
             t2-t1, self.name, i, j, len(batch_envs)))

        t1 = time.time()
        # Wait for a ckpt that still exist or it is the same
        # ckpt (no need to load anything).
        while True:
          new_ckpt = self.ckpt_queue.get()
          new_ckpt_file = new_ckpt + '.meta'
          if new_ckpt == current_ckpt or tf.gfile.Exists(new_ckpt_file):
            break
        t2 = time.time()
        tf.logging.info('{} sec waiting {} iteration {}, batch {}'.format(
          t2-t1, self.name, i, j))

        if new_ckpt != current_ckpt:
          # If the ckpt is not the same, then restore the new
          # ckpt.
          tf.logging.info('{} loading ckpt {}'.format(self.name, new_ckpt))
          t1 = time.time()          
          graph.restore(new_ckpt)
          t2 = time.time()
          tf.logging.info('{} sec used {} restoring ckpt {}'.format(
            t2-t1, self.name, new_ckpt))
          current_ckpt = new_ckpt

        if FLAGS.log_samples_every_n_epoch > 0 and i % FLAGS.log_samples_every_n_epoch == 0:
          f_replay.write(show_samples(replay_samples, envs[0].de_vocab, env_dict))
          f_policy.write(show_samples(policy_samples, envs[0].de_vocab, env_dict))
          f_train.write(show_samples(train_samples, envs[0].de_vocab, env_dict))

      if FLAGS.log_samples_every_n_epoch > 0 and i % FLAGS.log_samples_every_n_epoch == 0:
        f_replay.close()
        f_policy.close()
        f_train.close()

      if agent.model.get_global_step() >= FLAGS.n_steps:
        tf.logging.info('{} finished'.format(self.name))
        return
      i += 1
Example #7
0
    def decode_sketch_program(envs):
        # first create the real envs from the jsons and add constraints to them
        envs = json_to_envs(envs)
        env_name_dict = dict(map(lambda env: (env.name, env), envs))

        if FLAGS.executor == 'wtq':
            oracle_envs, oracle_trajs = get_wtq_annotations(envs)
        else:
            oracle_envs, oracle_trajs = get_env_trajs(envs)

        env_sketch_dict = dict([
            (env.name, get_sketch(traj_to_program(traj, envs[0].de_vocab)))
            for env, traj in zip(oracle_envs, oracle_trajs)
        ])
        for env in envs:
            sketch = env_sketch_dict.get(env.name, None)
            if sketch is not None:
                env.set_sketch_constraint(sketch[:])

        # create the agent
        graph_config = get_saved_graph_config()
        graph_config['use_gpu'] = False
        graph_config['gpu_id'] = '0'
        init_model_path = get_init_model_path()
        agent = create_agent(graph_config, init_model_path)

        # beam search
        beam_samples = []
        env_iterator = data_utils.BatchIterator(
            dict(envs=envs), shuffle=False, batch_size=FLAGS.eval_batch_size)
        for j, batch_dict in tqdm(enumerate(env_iterator)):
            batch_envs = batch_dict['envs']
            beam_samples += agent.beam_search(batch_envs, beam_size=50)

        # group the samples into beams (because the impl is so bad)
        env_beam_dict = dict()
        for sample in beam_samples:
            env_beam_dict[sample.traj.env_name] = env_beam_dict.get(
                sample.traj.env_name, []) + [sample]

        # get the trajs with 1.0 reward for each example and re-weight the prob
        env_name_annotation_dict = dict()
        for env_name, env in env_name_dict.iteritems():
            beam = env_beam_dict.get(env_name, [])
            success_beam = filter(lambda x: x.traj.rewards[-1] == 1.0, beam)
            if len(success_beam) > 0:
                # retrieve the sketch result from previous steps
                sketch = env_sketch_dict.get(env_name, None)

                if sketch is None:
                    env_name_annotation_dict[env_name] = None
                else:
                    # re-weight the examples in the beam
                    prob_sum = sum(
                        map(lambda sample: sample.prob, success_beam))
                    success_beam = map(
                        lambda sample: agent_factory.Sample(
                            traj=sample.traj, prob=sample.prob / prob_sum),
                        success_beam)
                    if len(success_beam) > 10:
                        success_beam = sorted(success_beam,
                                              key=lambda sample: sample.prob,
                                              reverse=True)
                        success_beam = success_beam[:10]

                    annotation = SketchAnnotation(env, sketch, success_beam)
                    env_name_annotation_dict[env_name] = annotation
            else:
                env_name_annotation_dict[env_name] = None

        return env_name_annotation_dict
Example #8
0
    def run(self):
        agent, all_envs = init_experiment(
            [get_train_shard_path(i) for i in self.shard_ids],
            use_gpu=FLAGS.actor_use_gpu,
            gpu_id=str(self.actor_id + FLAGS.actor_gpu_start_id))
        graph = agent.model.graph
        current_ckpt = get_init_model_path()

        # obtain the oracle of the examples and delete the examples that can not obtain oracle
        envs, env_trajs = get_env_trajs(all_envs)

        # build a dict to store the oracle trajs
        env_oracle_trajs_dict = dict()
        for env, env_traj in zip(envs, env_trajs):
            env_oracle_trajs_dict[env.name] = env_traj
        tf.logging.info(
            'Found oracle for {} envs out of total of {} for actor_{}'.format(
                len(all_envs), len(envs), self.actor_id))

        i = 0
        while True:
            n_train_samples = 0

            n_train_samples += 1

            # Make sure that all the samples from the env batch
            # fits into one batch for training.
            if FLAGS.batch_size < n_train_samples:
                raise ValueError(
                    'One batch have to at least contain samples from one environment.'
                )

            env_batch_size = FLAGS.batch_size / n_train_samples

            env_iterator = data_utils.BatchIterator(dict(envs=envs),
                                                    shuffle=True,
                                                    batch_size=env_batch_size)

            for j, batch_dict in enumerate(env_iterator):
                batch_envs = batch_dict['envs']
                tf.logging.info('=' * 50)
                tf.logging.info('{} iteration {}, batch {}: {} envs'.format(
                    self.name, i, j, len(batch_envs)))

                t1 = time.time()

                # get the oracle samples
                oracle_samples = []
                for batch_env in batch_envs:
                    oracle_samples.append(
                        agent_factory.Sample(
                            traj=env_oracle_trajs_dict[batch_env.name],
                            prob=1.0))

                self.eval_queue.put((oracle_samples, len(batch_envs)))
                self.replay_queue.put((oracle_samples, len(batch_envs)))

                assert (FLAGS.fixed_replay_weight >= 0.0
                        and FLAGS.fixed_replay_weight <= 1.0)

                train_samples = []

                train_samples += oracle_samples

                train_samples = sorted(train_samples,
                                       key=lambda x: x.traj.env_name)
                tf.logging.info('{} train samples'.format(len(train_samples)))

                if FLAGS.use_importance_sampling:
                    step_logprobs = agent.compute_step_logprobs(
                        [s.traj for s in train_samples])
                else:
                    step_logprobs = None

                # TODO: the clip_factor may be wrong
                self.train_queue.put((train_samples, step_logprobs, 0.0))
                t2 = time.time()
                tf.logging.info(
                    ('{} sec used preparing and enqueuing samples, {}'
                     ' iteration {}, batch {}: {} envs').format(
                         t2 - t1, self.name, i, j, len(batch_envs)))

                t1 = time.time()
                # Wait for a ckpt that still exist or it is the same
                # ckpt (no need to load anything).
                while True:
                    new_ckpt = self.ckpt_queue.get()
                    new_ckpt_file = new_ckpt + '.meta'
                    if new_ckpt == current_ckpt or tf.gfile.Exists(
                            new_ckpt_file):
                        break
                t2 = time.time()
                tf.logging.info(
                    '{} sec waiting {} iteration {}, batch {}'.format(
                        t2 - t1, self.name, i, j))

                if new_ckpt != current_ckpt:
                    # If the ckpt is not the same, then restore the new
                    # ckpt.
                    tf.logging.info('{} loading ckpt {}'.format(
                        self.name, new_ckpt))
                    t1 = time.time()
                    graph.restore(new_ckpt)
                    t2 = time.time()
                    tf.logging.info('{} sec used {} restoring ckpt {}'.format(
                        t2 - t1, self.name, new_ckpt))
                    current_ckpt = new_ckpt

            if agent.model.get_global_step() >= FLAGS.n_steps:
                tf.logging.info('{} finished'.format(self.name))
                return
            i += 1