Example #1
0
def main(_):
    """Run td3/ddpg training."""
    contrib_eager_python_tfe.enable_eager_execution()

    if FLAGS.use_gpu:
        tf.device('/device:GPU:0').__enter__()

    tf.gfile.MakeDirs(FLAGS.log_dir)
    summary_writer = contrib_summary.create_file_writer(FLAGS.log_dir,
                                                        flush_millis=10000)

    tf.set_random_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)
    random.seed(FLAGS.seed)

    env = gym.make(FLAGS.env)
    env.seed(FLAGS.seed)

    if FLAGS.env in ['HalfCheetah-v2', 'Ant-v1']:
        rand_actions = int(1e4)
    else:
        rand_actions = int(1e3)

    obs_shape = env.observation_space.shape
    act_shape = env.action_space.shape

    if FLAGS.algo == 'td3':
        model = ddpg_td3.DDPG(obs_shape[0],
                              act_shape[0],
                              use_td3=True,
                              policy_update_freq=2,
                              actor_lr=1e-3)
    else:
        model = ddpg_td3.DDPG(obs_shape[0],
                              act_shape[0],
                              use_td3=False,
                              policy_update_freq=1,
                              actor_lr=1e-4)

    replay_buffer_var = contrib_eager_python_tfe.Variable('',
                                                          name='replay_buffer')
    gym_random_state_var = contrib_eager_python_tfe.Variable(
        '', name='gym_random_state')
    np_random_state_var = contrib_eager_python_tfe.Variable(
        '', name='np_random_state')
    py_random_state_var = contrib_eager_python_tfe.Variable(
        '', name='py_random_state')

    saver = contrib_eager_python_tfe.Saver(
        model.variables + [replay_buffer_var] +
        [gym_random_state_var, np_random_state_var, py_random_state_var])
    tf.gfile.MakeDirs(FLAGS.save_dir)

    reward_scale = contrib_eager_python_tfe.Variable(1, name='reward_scale')
    eval_saver = contrib_eager_python_tfe.Saver(model.actor.variables +
                                                [reward_scale])
    tf.gfile.MakeDirs(FLAGS.eval_save_dir)

    last_checkpoint = tf.train.latest_checkpoint(FLAGS.save_dir)
    if last_checkpoint is None:
        replay_buffer = ReplayBuffer()
        total_numsteps = 0
        prev_save_timestep = 0
        prev_eval_save_timestep = 0
    else:
        saver.restore(last_checkpoint)
        replay_buffer = pickle.loads(zlib.decompress(
            replay_buffer_var.numpy()))
        total_numsteps = int(last_checkpoint.split('-')[-1])
        assert len(replay_buffer) == total_numsteps
        prev_save_timestep = total_numsteps
        prev_eval_save_timestep = total_numsteps
        env.unwrapped.np_random.set_state(
            pickle.loads(gym_random_state_var.numpy()))
        np.random.set_state(pickle.loads(np_random_state_var.numpy()))
        random.setstate(pickle.loads(py_random_state_var.numpy()))

    with summary_writer.as_default():
        while total_numsteps < FLAGS.training_steps:
            rollout_reward, rollout_timesteps = do_rollout(
                env,
                model.actor,
                replay_buffer,
                noise_scale=FLAGS.exploration_noise,
                rand_actions=rand_actions)
            total_numsteps += rollout_timesteps

            logging.info('Training: total timesteps %d, episode reward %f',
                         total_numsteps, rollout_reward)

            print('Training: total timesteps {}, episode reward {}'.format(
                total_numsteps, rollout_reward))

            with contrib_summary.always_record_summaries():
                contrib_summary.scalar('reward',
                                       rollout_reward,
                                       step=total_numsteps)
                contrib_summary.scalar('length',
                                       rollout_timesteps,
                                       step=total_numsteps)

            if len(replay_buffer) >= FLAGS.min_samples_to_start:
                for _ in range(rollout_timesteps):
                    time_step = replay_buffer.sample(
                        batch_size=FLAGS.batch_size)
                    batch = TimeStep(*zip(*time_step))
                    model.update(batch)

                if total_numsteps - prev_save_timestep >= FLAGS.save_interval:
                    replay_buffer_var.assign(
                        zlib.compress(pickle.dumps(replay_buffer)))
                    gym_random_state_var.assign(
                        pickle.dumps(env.unwrapped.np_random.get_state()))
                    np_random_state_var.assign(
                        pickle.dumps(np.random.get_state()))
                    py_random_state_var.assign(pickle.dumps(random.getstate()))

                    saver.save(os.path.join(FLAGS.save_dir, 'checkpoint'),
                               global_step=total_numsteps)
                    prev_save_timestep = total_numsteps

                if total_numsteps - prev_eval_save_timestep >= FLAGS.eval_save_interval:
                    eval_saver.save(os.path.join(FLAGS.eval_save_dir,
                                                 'checkpoint'),
                                    global_step=total_numsteps)
                    prev_eval_save_timestep = total_numsteps
def main(_):
  """Run td3/ddpg training."""
  contrib_eager_python_tfe.enable_eager_execution()

  if FLAGS.use_gpu:
    tf.device('/device:GPU:0').__enter__()

  tf.gfile.MakeDirs(FLAGS.log_dir)
  summary_writer = contrib_summary.create_file_writer(
      FLAGS.log_dir, flush_millis=10000)

  tf.set_random_seed(FLAGS.seed)
  np.random.seed(FLAGS.seed)
  random.seed(FLAGS.seed)

  env = gym.make(FLAGS.env)
  env.seed(FLAGS.seed)
  if FLAGS.learn_absorbing:
    env = lfd_envs.AbsorbingWrapper(env)

  if FLAGS.env in ['HalfCheetah-v2', 'Ant-v1']:
    rand_actions = int(1e4)
  else:
    rand_actions = int(1e3)

  obs_shape = env.observation_space.shape
  act_shape = env.action_space.shape

  subsampling_rate = env._max_episode_steps // FLAGS.trajectory_size  # pylint: disable=protected-access
  lfd = gail.GAIL(
      obs_shape[0] + act_shape[0],
      subsampling_rate=subsampling_rate,
      gail_loss=FLAGS.gail_loss)

  if FLAGS.algo == 'td3':
    model = ddpg_td3.DDPG(
        obs_shape[0],
        act_shape[0],
        use_td3=True,
        policy_update_freq=2,
        actor_lr=FLAGS.actor_lr,
        get_reward=lfd.get_reward,
        use_absorbing_state=FLAGS.learn_absorbing)
  else:
    model = ddpg_td3.DDPG(
        obs_shape[0],
        act_shape[0],
        use_td3=False,
        policy_update_freq=1,
        actor_lr=FLAGS.actor_lr,
        get_reward=lfd.get_reward,
        use_absorbing_state=FLAGS.learn_absorbing)

  random_reward, _ = do_rollout(
      env, model.actor, None, num_trajectories=10, sample_random=True)

  replay_buffer_var = contrib_eager_python_tfe.Variable(
      '', name='replay_buffer')
  expert_replay_buffer_var = contrib_eager_python_tfe.Variable(
      '', name='expert_replay_buffer')

  # Save and restore random states of gym/numpy/python.
  # If the job is preempted, it guarantees that it won't affect the results.
  # And the results will be deterministic (on CPU) and reproducible.
  gym_random_state_var = contrib_eager_python_tfe.Variable(
      '', name='gym_random_state')
  np_random_state_var = contrib_eager_python_tfe.Variable(
      '', name='np_random_state')
  py_random_state_var = contrib_eager_python_tfe.Variable(
      '', name='py_random_state')

  reward_scale = contrib_eager_python_tfe.Variable(1, name='reward_scale')

  saver = contrib_eager_python_tfe.Saver(
      model.variables + lfd.variables +
      [replay_buffer_var, expert_replay_buffer_var, reward_scale] +
      [gym_random_state_var, np_random_state_var, py_random_state_var])

  tf.gfile.MakeDirs(FLAGS.save_dir)

  eval_saver = contrib_eager_python_tfe.Saver(model.actor.variables +
                                              [reward_scale])
  tf.gfile.MakeDirs(FLAGS.eval_save_dir)

  last_checkpoint = tf.train.latest_checkpoint(FLAGS.save_dir)
  if last_checkpoint is None:
    expert_saver = contrib_eager_python_tfe.Saver([expert_replay_buffer_var])
    last_checkpoint = os.path.join(FLAGS.expert_dir, 'expert_replay_buffer')
    expert_saver.restore(last_checkpoint)
    expert_replay_buffer = pickle.loads(expert_replay_buffer_var.numpy())
    expert_reward = expert_replay_buffer.get_average_reward()

    logging.info('Expert reward %f', expert_reward)
    print('Expert reward {}'.format(expert_reward))

    reward_scale.assign(expert_reward)
    expert_replay_buffer.subsample_trajectories(FLAGS.num_expert_trajectories)
    if FLAGS.learn_absorbing:
      expert_replay_buffer.add_absorbing_states(env)

    # Subsample after adding absorbing states, because otherwise we can lose
    # final states.

    print('Original dataset size {}'.format(len(expert_replay_buffer)))
    expert_replay_buffer.subsample_transitions(subsampling_rate)
    print('Subsampled dataset size {}'.format(len(expert_replay_buffer)))
    replay_buffer = ReplayBuffer()
    total_numsteps = 0
    prev_save_timestep = 0
    prev_eval_save_timestep = 0
  else:
    saver.restore(last_checkpoint)
    replay_buffer = pickle.loads(zlib.decompress(replay_buffer_var.numpy()))
    expert_replay_buffer = pickle.loads(
        zlib.decompress(expert_replay_buffer_var.numpy()))
    total_numsteps = int(last_checkpoint.split('-')[-1])
    prev_save_timestep = total_numsteps
    prev_eval_save_timestep = total_numsteps
    env.unwrapped.np_random.set_state(
        pickle.loads(gym_random_state_var.numpy()))
    np.random.set_state(pickle.loads(np_random_state_var.numpy()))
    random.setstate(pickle.loads(py_random_state_var.numpy()))

  with summary_writer.as_default():
    while total_numsteps < FLAGS.training_steps:
      # Decay helps to make the model more stable.
      # TODO(agrawalk): Use tf.train.exponential_decay
      model.actor_lr.assign(
          model.initial_actor_lr * pow(0.5, total_numsteps // 100000))
      logging.info('Learning rate %f', model.actor_lr.numpy())
      rollout_reward, rollout_timesteps = do_rollout(
          env,
          model.actor,
          replay_buffer,
          noise_scale=FLAGS.exploration_noise,
          rand_actions=rand_actions,
          sample_random=(model.actor_step.numpy() == 0),
          add_absorbing_state=FLAGS.learn_absorbing)
      total_numsteps += rollout_timesteps

      logging.info('Training: total timesteps %d, episode reward %f',
                   total_numsteps, rollout_reward)

      print('Training: total timesteps {}, episode reward {}'.format(
          total_numsteps, rollout_reward))

      with contrib_summary.always_record_summaries():
        contrib_summary.scalar(
            'reward/scaled', (rollout_reward - random_reward) /
            (reward_scale.numpy() - random_reward),
            step=total_numsteps)
        contrib_summary.scalar('reward', rollout_reward, step=total_numsteps)
        contrib_summary.scalar('length', rollout_timesteps, step=total_numsteps)

      if len(replay_buffer) >= FLAGS.min_samples_to_start:
        for _ in range(rollout_timesteps):
          time_step = replay_buffer.sample(batch_size=FLAGS.batch_size)
          batch = TimeStep(*zip(*time_step))

          time_step = expert_replay_buffer.sample(batch_size=FLAGS.batch_size)
          expert_batch = TimeStep(*zip(*time_step))

          lfd.update(batch, expert_batch)

        for _ in range(FLAGS.updates_per_step * rollout_timesteps):
          time_step = replay_buffer.sample(batch_size=FLAGS.batch_size)
          batch = TimeStep(*zip(*time_step))
          model.update(
              batch,
              update_actor=model.critic_step.numpy() >=
              FLAGS.policy_updates_delay)

        if total_numsteps - prev_save_timestep >= FLAGS.save_interval:
          replay_buffer_var.assign(zlib.compress(pickle.dumps(replay_buffer)))
          expert_replay_buffer_var.assign(
              zlib.compress(pickle.dumps(expert_replay_buffer)))
          gym_random_state_var.assign(
              pickle.dumps(env.unwrapped.np_random.get_state()))
          np_random_state_var.assign(pickle.dumps(np.random.get_state()))
          py_random_state_var.assign(pickle.dumps(random.getstate()))
          saver.save(
              os.path.join(FLAGS.save_dir, 'checkpoint'),
              global_step=total_numsteps)
          prev_save_timestep = total_numsteps

        if total_numsteps - prev_eval_save_timestep >= FLAGS.eval_save_interval:
          eval_saver.save(
              os.path.join(FLAGS.eval_save_dir, 'checkpoint'),
              global_step=total_numsteps)
          prev_eval_save_timestep = total_numsteps