Exemplo n.º 1
0
def dagger_lstm(env, num_rollouts=1, epochs=1):
  data = helper.run_expert(env, num_rollouts=num_rollouts)
  input_dim = len(data['observations'][0])
  output_dim = len(data['actions'][0])
  model = helper.build_lstm_model(input_dim, output_dim)

  sess = tf.get_default_session()
  sess.run(tf.global_variables_initializer())

  data = [data]
  rewards = []

  os.makedirs('checkpoints', exist_ok=True)

  for epoch in range(epochs):
    checkpoint_path = None
    if epoch == epochs-1:
      checkpoint_path = helper.checkpoint_path(env, 'dagger-lstm-')
    policy_fn, initial_state, mean = train(sess, data, model=model, curr_epoch=epoch, checkpoint_path=checkpoint_path)

    _data = run(sess, env, policy_fn, initial_state, mean,
                num_rollouts=num_rollouts, stats=False)
    _data['actions'] = helper.ask_expert_actions(env, _data['observations'])
    rewards.append(_data['returns'])
    data.append(_data)

  return policy_fn, rewards
Exemplo n.º 2
0
def main():
  parser = argparse.ArgumentParser()
  parser.add_argument('env', type=str)
  parser.add_argument('--model_checkpoint', type=str)
  parser.add_argument('--render', type=bool, default=True)
  parser.add_argument('--max_timesteps', type=int)
  parser.add_argument('--num_rollouts', type=int, default=10)
  args = parser.parse_args()

  with tf.Session() as sess:
    with tf.variable_scope(args.env):
      input_dim, output_dim = helper.input_output_shape(args.env)
      model = helper.build_model(input_dim, output_dim)
      input_ph, output_pred = model['input_ph'], model['output_pred']

      policy_fn = tf_util.function([input_ph], output_pred)

      if args.model_checkpoint:
        checkpoint_path = args.model_checkpoint
      else:
        checkpoint_path = helper.checkpoint_path(args.env)

      saver = tf.train.Saver()
      saver.restore(sess, checkpoint_path)

      env = gym.make(helper.envname(args.env))
      max_steps = args.max_timesteps or env.spec.timestep_limit

      returns = []
      observations = []
      actions = []
      for i in range(args.num_rollouts):
        print('iter', i)
        obs = env.reset()
        done = False
        totalr = 0
        steps = 0
        while not done:
          action = policy_fn(obs[None, :])
          observations.append(obs)
          actions.append(action)
          obs, r, done, _ = env.step(action)
          totalr += r
          steps += 1
          if args.render:
            env.render()
          if steps >= max_steps:
            break
        returns.append(totalr)

      helper.print_returns_stats(returns)
Exemplo n.º 3
0
def compare_bc_on_multiple_envs(epochs=200, num_rollouts=10):
    envs = ["ant", "half_cheetah", "hopper", "humanoid", "reacher", "walker"]

    os.makedirs('checkpoints', exist_ok=True)

    for env in envs:
        print(colored("ENV: %s" % env, 'green'))
        data = helper.run_expert(env, num_rollouts=num_rollouts)
        with tf.Session():
            with tf.variable_scope(env):
                sess = tf.get_default_session()
                policy_fn = train_bc(sess,
                                     data,
                                     epochs=epochs,
                                     checkpoint_path=helper.checkpoint_path(
                                         env, 'bc-'))

                run_bc(sess, env, policy_fn, num_rollouts=num_rollouts)

                print(
                    "=============================================================="
                )
Exemplo n.º 4
0
def test_checkpoint_path():
  assert helper.checkpoint_path("ant") == "checkpoints/Ant-v2.ckpt"