Exemplo n.º 1
0
def train(worker_id, agent, hparams, checkpoint):
  env = get_env(hparams)

  state = env.reset()
  while hparams.global_step < hparams.train_steps:
    hparams.mode[worker_id] = ModeKeys.TRAIN

    last_state = state

    action, reward, done, state = step(
        hparams, agent, last_state, env, worker_id)

    agent.observe(last_state, action, reward, done, state, worker_id)

    if done:
      hparams.local_episode[worker_id] += 1
      log_scalar('episodes/worker_%d' % worker_id,
                 hparams.local_episode[worker_id])

    hparams.global_step += 1
    hparams.total_step += 1
    hparams.local_step[worker_id] += 1
    update_learning_rate(hparams)

    if hparams.local_step[worker_id] % hparams.eval_interval == 0:
      agent.reset(worker_id)
      evaluate(worker_id, agent, env, hparams)
      if worker_id == 0:
        checkpoint.save()
      state = env.reset()
      agent.reset(worker_id)

  env.close()
Exemplo n.º 2
0
def init_agent(sess, hparams):
  # initialize environment to update hparams
  env = get_env(hparams)
  env.close()
  agent = get_agent(sess, hparams)
  checkpoint = Checkpoint(sess, hparams)
  return agent, checkpoint
Exemplo n.º 3
0
def train(worker_id, agent, hparams, checkpoint):
  env = get_env(hparams)

  while hparams.global_step < hparams.train_steps:
    hparams.mode[worker_id] = ModeKeys.TRAIN

    state = env.reset()
    done = False

    while not done:

      last_state = state

      action, reward, done, state = steps(hparams, agent, last_state, env,
                                          hparams.n_steps, worker_id)

      agent.observe(last_state, action, reward, done, state, worker_id)

      if done:
        hparams.local_episode[worker_id] += 1
        log_scalar('episodes/worker_%d' % worker_id,
                   hparams.local_episode[worker_id])

      if hparams.local_step[worker_id] % hparams.test_interval == 0:
        test(worker_id, agent, env, hparams)
        if worker_id == 0:
          checkpoint.save()
        done = True

      hparams.global_step += 1
      hparams.total_step += 1
      hparams.local_step[worker_id] += 1

  env.close()
Exemplo n.º 4
0
def test(hparams, agent):
  hparams.mode[0] = ModeKeys.TEST
  env = get_env(hparams)

  for i in range(hparams.test_episodes):
    state = env.reset()
    done = False
    episode_reward = 0
    while not done:
      if hparams.render:
        env.render()
      last_state = state
      action, reward, done, state = step(
          hparams, agent, last_state, env, worker_id=0)
      episode_reward += reward
    print("episode %d\trewards %d" % (i, episode_reward))
Exemplo n.º 5
0
def train(worker_id, agent, hparams, checkpoint):
  env = get_env(hparams)

  recurrent_state = None
  last_recurrent_state = None
  state = env.reset()
  if type(env).__name__ == 'NavRLEnv':
    recurrent_state = np.zeros((512,))
  while hparams.global_step < hparams.train_steps:
    hparams.mode[worker_id] = ModeKeys.TRAIN

    last_state = state
    if type(env).__name__ == 'NavRLEnv':
      last_recurrent_state = recurrent_state
    action, reward, done, state, recurrent_state = step(hparams, agent,
                                                        last_state, env,
                                                        worker_id,
                                                        recurrent_state)

    agent.observe(last_state, action, reward, done, state, worker_id,
                  last_recurrent_state, recurrent_state)

    if done:
      hparams.local_episode[worker_id] += 1
      log_scalar('episodes/worker_%d' % worker_id,
                 hparams.local_episode[worker_id])

    if hparams.local_step[worker_id] % hparams.eval_interval == 0:
      agent.reset(worker_id)
      evaluate(worker_id, agent, env, hparams)
      if worker_id == 0:
        checkpoint.save()
      state = env.reset()
      agent.reset(worker_id)

    hparams.global_step += 1
    hparams.total_step += 1
    hparams.local_step[worker_id] += 1
    update_learning_rate(hparams)

  env.close()