Пример #1
0
 def weight_samples(samples):
     if FLAGS.use_replay_prob_as_weight:
         new_samples = []
         for sample in samples:
             name = sample.traj.env_name
             if name in replay_buffer.prob_sum_dict:
                 replay_prob = max(
                     replay_buffer.prob_sum_dict[name],
                     FLAGS.min_replay_weight)
             else:
                 replay_prob = 0.0
             scale = 1.0 - replay_prob
             new_samples.append(
                 agent_factory.Sample(traj=sample.traj,
                                      prob=sample.prob * scale))
     else:
         new_samples = agent_factory.scale_probs(
             samples, 1 - FLAGS.fixed_replay_weight)
     return new_samples
  def run(self):
    agent, envs = init_experiment(
      [get_train_shard_path(i) for i in self.shard_ids],
      use_gpu=FLAGS.actor_use_gpu,
      gpu_id=str(self.actor_id + FLAGS.actor_gpu_start_id))
    graph = agent.model.graph
    current_ckpt = get_init_model_path()

    env_dict = dict([(env.name, env) for env in envs])
    replay_buffer = agent_factory.AllGoodReplayBuffer(agent, envs[0].de_vocab)

    # Load saved programs to warm start the replay buffer. 
    if FLAGS.load_saved_programs:
      load_programs(
        envs, replay_buffer, FLAGS.saved_program_file)

    i = 0
    while True:
      # Create the logging files. 
      if FLAGS.log_samples_every_n_epoch > 0 and i % FLAGS.log_samples_every_n_epoch == 0:
        f_replay = codecs.open(os.path.join(
          get_experiment_dir(), 'replay_samples_{}_{}.txt'.format(self.name, i)),
                               'w', encoding='utf-8')
        f_policy = codecs.open(os.path.join(
          get_experiment_dir(), 'policy_samples_{}_{}.txt'.format(self.name, i)),
                               'w', encoding='utf-8')
        f_train = codecs.open(os.path.join(
          get_experiment_dir(), 'train_samples_{}_{}.txt'.format(self.name, i)),
                              'w', encoding='utf-8')

      n_train_samples = 0
      if FLAGS.use_replay_samples_in_train:
        n_train_samples += FLAGS.n_replay_samples

      if FLAGS.use_policy_samples_in_train and FLAGS.use_nonreplay_samples_in_train:
        raise ValueError(
          'Cannot use both on-policy samples and nonreplay samples for training!')
        
      if FLAGS.use_policy_samples_in_train or FLAGS.use_nonreplay_samples_in_train:
        # Note that nonreplay samples are drawn by rejection
        # sampling from on-policy samples.
        n_train_samples += FLAGS.n_policy_samples

      # Make sure that all the samples from the env batch
      # fits into one batch for training.
      if FLAGS.batch_size < n_train_samples:
        raise ValueError(
            'One batch have to at least contain samples from one environment.')

      env_batch_size = FLAGS.batch_size / n_train_samples
      
      env_iterator = data_utils.BatchIterator(
        dict(envs=envs), shuffle=True,
        batch_size=env_batch_size)

      for j, batch_dict in enumerate(env_iterator):
        batch_envs = batch_dict['envs']
        tf.logging.info('=' * 50)
        tf.logging.info('{} iteration {}, batch {}: {} envs'.format(
            self.name, i, j, len(batch_envs)))
  
        t1 = time.time()
        # Generate samples with cache and save to replay buffer.
        t3 = time.time()
        n_explore = 0
        for _ in xrange(FLAGS.n_explore_samples):
          explore_samples = agent.generate_samples(
            batch_envs, n_samples=1, use_cache=FLAGS.use_cache,
            greedy=FLAGS.greedy_exploration)
          replay_buffer.save(explore_samples)
          n_explore += len(explore_samples)

        if FLAGS.n_extra_explore_for_hard > 0:
          hard_envs = [env for env in batch_envs
                       if not replay_buffer.has_found_solution(env.name)]
          if hard_envs:
            for _ in xrange(FLAGS.n_extra_explore_for_hard):
              explore_samples = agent.generate_samples(
                hard_envs, n_samples=1, use_cache=FLAGS.use_cache,
                greedy=FLAGS.greedy_exploration)
              replay_buffer.save(explore_samples)
              n_explore += len(explore_samples)

        t4 = time.time()
        tf.logging.info('{} sec used generating {} exploration samples.'.format(
          t4 - t3, n_explore))

        tf.logging.info('{} samples saved in the replay buffer.'.format(
          replay_buffer.size))
        
        t3 = time.time()
        replay_samples = replay_buffer.replay(
          batch_envs, FLAGS.n_replay_samples,
          use_top_k=FLAGS.use_top_k_replay_samples,
          agent=None if FLAGS.random_replay_samples else agent,
          truncate_at_n=FLAGS.truncate_replay_buffer_at_n)
        t4 = time.time()
        tf.logging.info('{} sec used selecting {} replay samples.'.format(
          t4 - t3, len(replay_samples)))
          
        t3 = time.time()
        if FLAGS.use_top_k_policy_samples:
          if FLAGS.n_policy_samples == 1:
            policy_samples = agent.generate_samples(
              batch_envs, n_samples=FLAGS.n_policy_samples,
              greedy=True)
          else:
            policy_samples = agent.beam_search(
              batch_envs, beam_size=FLAGS.n_policy_samples)
        else:
          policy_samples = agent.generate_samples(
            batch_envs, n_samples=FLAGS.n_policy_samples,
            greedy=False)
        t4 = time.time()
        tf.logging.info('{} sec used generating {} on-policy samples'.format(
          t4-t3, len(policy_samples)))

        t2 = time.time()
        tf.logging.info(
          ('{} sec used generating replay and on-policy samples,'
           ' {} iteration {}, batch {}: {} envs').format(
            t2-t1, self.name, i, j, len(batch_envs)))

        t1 = time.time()
        self.eval_queue.put((policy_samples, len(batch_envs)))
        self.replay_queue.put((replay_samples, len(batch_envs)))

        assert (FLAGS.fixed_replay_weight >= 0.0 and FLAGS.fixed_replay_weight <= 1.0)

        if FLAGS.use_replay_prob_as_weight:
          new_samples = []
          for sample in replay_samples:
            name = sample.traj.env_name
            if name in replay_buffer.prob_sum_dict:
              replay_prob = max(
                replay_buffer.prob_sum_dict[name], FLAGS.min_replay_weight)
            else:
              replay_prob = 0.0
            scale = replay_prob
            new_samples.append(
              agent_factory.Sample(
                traj=sample.traj,
                prob=sample.prob * scale))
          replay_samples = new_samples
        else:
          replay_samples = agent_factory.scale_probs(
            replay_samples, FLAGS.fixed_replay_weight)

        replay_samples = sorted(
          replay_samples, key=lambda x: x.traj.env_name)

        policy_samples = sorted(
          policy_samples, key=lambda x: x.traj.env_name)

        if FLAGS.use_nonreplay_samples_in_train:
          nonreplay_samples = []
          for sample in policy_samples:
            if not replay_buffer.contain(sample.traj):
              nonreplay_samples.append(sample)

        replay_buffer.save(policy_samples)

        def weight_samples(samples):
          if FLAGS.use_replay_prob_as_weight:
            new_samples = []
            for sample in samples:
              name = sample.traj.env_name
              if name in replay_buffer.prob_sum_dict:
                replay_prob = max(
                  replay_buffer.prob_sum_dict[name],
                  FLAGS.min_replay_weight)
              else:
                replay_prob = 0.0
              scale = 1.0 - replay_prob
              new_samples.append(
                agent_factory.Sample(
                  traj=sample.traj,
                  prob=sample.prob * scale))
          else:
            new_samples = agent_factory.scale_probs(
              samples, 1 - FLAGS.fixed_replay_weight)
          return new_samples

        train_samples = []
        if FLAGS.use_replay_samples_in_train:
          if FLAGS.use_trainer_prob:
            replay_samples = [
              sample._replace(prob=None) for sample in replay_samples]
          train_samples += replay_samples

        if FLAGS.use_policy_samples_in_train:
          train_samples += weight_samples(policy_samples)

        if FLAGS.use_nonreplay_samples_in_train:
          train_samples += weight_samples(nonreplay_samples)
        
        train_samples = sorted(train_samples, key=lambda x: x.traj.env_name)
        tf.logging.info('{} train samples'.format(len(train_samples)))

        if FLAGS.use_importance_sampling:
          step_logprobs = agent.compute_step_logprobs(
            [s.traj for s in train_samples])
        else:
          step_logprobs = None

        if FLAGS.use_replay_prob_as_weight:
          n_clip = 0
          for env in batch_envs:
            name = env.name
            if (name in replay_buffer.prob_sum_dict and
                replay_buffer.prob_sum_dict[name] < FLAGS.min_replay_weight):
              n_clip += 1
          clip_frac = float(n_clip) / len(batch_envs)
        else:
          clip_frac = 0.0
  
        self.train_queue.put((train_samples, step_logprobs, clip_frac))
        t2 = time.time()
        tf.logging.info(
          ('{} sec used preparing and enqueuing samples, {}'
           ' iteration {}, batch {}: {} envs').format(
             t2-t1, self.name, i, j, len(batch_envs)))

        t1 = time.time()
        # Wait for a ckpt that still exist or it is the same
        # ckpt (no need to load anything).
        while True:
          new_ckpt = self.ckpt_queue.get()
          new_ckpt_file = new_ckpt + '.meta'
          if new_ckpt == current_ckpt or tf.gfile.Exists(new_ckpt_file):
            break
        t2 = time.time()
        tf.logging.info('{} sec waiting {} iteration {}, batch {}'.format(
          t2-t1, self.name, i, j))

        if new_ckpt != current_ckpt:
          # If the ckpt is not the same, then restore the new
          # ckpt.
          tf.logging.info('{} loading ckpt {}'.format(self.name, new_ckpt))
          t1 = time.time()          
          graph.restore(new_ckpt)
          t2 = time.time()
          tf.logging.info('{} sec used {} restoring ckpt {}'.format(
            t2-t1, self.name, new_ckpt))
          current_ckpt = new_ckpt

        if FLAGS.log_samples_every_n_epoch > 0 and i % FLAGS.log_samples_every_n_epoch == 0:
          f_replay.write(show_samples(replay_samples, envs[0].de_vocab, env_dict))
          f_policy.write(show_samples(policy_samples, envs[0].de_vocab, env_dict))
          f_train.write(show_samples(train_samples, envs[0].de_vocab, env_dict))

      if FLAGS.log_samples_every_n_epoch > 0 and i % FLAGS.log_samples_every_n_epoch == 0:
        f_replay.close()
        f_policy.close()
        f_train.close()

      if agent.model.get_global_step() >= FLAGS.n_steps:
        tf.logging.info('{} finished'.format(self.name))
        return
      i += 1