Example #1
0
def train(worker_id, agent, hparams, checkpoint):
  env = get_env(hparams)

  state = env.reset()
  while hparams.global_step < hparams.train_steps:
    hparams.mode[worker_id] = ModeKeys.TRAIN

    last_state = state

    action, reward, done, state = step(
        hparams, agent, last_state, env, worker_id)

    agent.observe(last_state, action, reward, done, state, worker_id)

    if done:
      hparams.local_episode[worker_id] += 1
      log_scalar('episodes/worker_%d' % worker_id,
                 hparams.local_episode[worker_id])

    hparams.global_step += 1
    hparams.total_step += 1
    hparams.local_step[worker_id] += 1
    update_learning_rate(hparams)

    if hparams.local_step[worker_id] % hparams.eval_interval == 0:
      agent.reset(worker_id)
      evaluate(worker_id, agent, env, hparams)
      if worker_id == 0:
        checkpoint.save()
      state = env.reset()
      agent.reset(worker_id)

  env.close()
Example #2
0
def train(worker_id, agent, hparams, checkpoint):
  env = get_env(hparams)

  while hparams.global_step < hparams.train_steps:
    hparams.mode[worker_id] = ModeKeys.TRAIN

    state = env.reset()
    done = False

    while not done:

      last_state = state

      action, reward, done, state = steps(hparams, agent, last_state, env,
                                          hparams.n_steps, worker_id)

      agent.observe(last_state, action, reward, done, state, worker_id)

      if done:
        hparams.local_episode[worker_id] += 1
        log_scalar('episodes/worker_%d' % worker_id,
                   hparams.local_episode[worker_id])

      if hparams.local_step[worker_id] % hparams.test_interval == 0:
        test(worker_id, agent, env, hparams)
        if worker_id == 0:
          checkpoint.save()
        done = True

      hparams.global_step += 1
      hparams.total_step += 1
      hparams.local_step[worker_id] += 1

  env.close()
Example #3
0
def test(worker_id, agent, env, hparams):
  hparams.mode[worker_id] = ModeKeys.TEST
  rewards = []

  for i in range(hparams.test_episodes):
    state = env.reset()
    done = False
    episode_reward = 0
    while not done:
      last_state = state
      action, reward, done, state = steps(
          hparams, agent, last_state, env, n=1, worker_id=worker_id)
      episode_reward += reward
      hparams.total_step += 1
    rewards.append(episode_reward)

  log_scalar('rewards/worker_%d' % worker_id, np.mean(rewards))
Example #4
0
def evaluate(worker_id, agent, env, hparams):
  hparams.mode[worker_id] = ModeKeys.EVAL
  rewards = []

  for i in range(hparams.eval_episodes):
    state = env.reset()
    recurrent_state = np.zeros((512,))
    done = False
    episode_reward = 0
    while not done:
      last_recurrent_state = recurrent_state
      last_state = state
      action, reward, done, state, recurrent_state = step(
          hparams, agent, last_state, env, worker_id, recurrent_state)
      episode_reward += reward
      hparams.total_step += 1
    rewards.append(episode_reward)

  log_scalar('rewards/worker_%d' % worker_id, np.mean(rewards))
Example #5
0
def train(worker_id, agent, hparams, checkpoint):
  env = get_env(hparams)

  recurrent_state = None
  last_recurrent_state = None
  state = env.reset()
  if type(env).__name__ == 'NavRLEnv':
    recurrent_state = np.zeros((512,))
  while hparams.global_step < hparams.train_steps:
    hparams.mode[worker_id] = ModeKeys.TRAIN

    last_state = state
    if type(env).__name__ == 'NavRLEnv':
      last_recurrent_state = recurrent_state
    action, reward, done, state, recurrent_state = step(hparams, agent,
                                                        last_state, env,
                                                        worker_id,
                                                        recurrent_state)

    agent.observe(last_state, action, reward, done, state, worker_id,
                  last_recurrent_state, recurrent_state)

    if done:
      hparams.local_episode[worker_id] += 1
      log_scalar('episodes/worker_%d' % worker_id,
                 hparams.local_episode[worker_id])

    if hparams.local_step[worker_id] % hparams.eval_interval == 0:
      agent.reset(worker_id)
      evaluate(worker_id, agent, env, hparams)
      if worker_id == 0:
        checkpoint.save()
      state = env.reset()
      agent.reset(worker_id)

    hparams.global_step += 1
    hparams.total_step += 1
    hparams.local_step[worker_id] += 1
    update_learning_rate(hparams)

  env.close()