コード例 #1
0
def train_maze(output_path):
    config = Config()
    config.set_paths(output_path)

    env = EnvMaze(n=config.maze_size, hard=config.hard)

    # exploration strategy
    exp_schedule = LinearExploration(env, config.eps_begin,
                                     config.eps_end, config.eps_nsteps, config.env_name)

    # learning rate schedule
    lr_schedule = LinearSchedule(config.lr_begin, config.lr_end,
                                 config.lr_nsteps, config.env_name)

    # train model
    print(config.output_path)
    model = NatureQN(env, config)
    model.bfs_len = env.get_bfs_length()
    evaluation_result_list, oos_evalution_result_list = model.run(exp_schedule, lr_schedule)
    return evaluation_result_list, oos_evalution_result_list
コード例 #2
0
ファイル: run.py プロジェクト: zachabarnes/slither-rl-agent
  elif FLAGS.network_type == 'deep_q':
    network = network.DeepQ(FLAGS)

  elif FLAGS.network_type == 'recurrent_q':
    network = network.RecurrentQ(FLAGS)

  elif FLAGS.network_type == 'transfer_q':
    network = network.TransferQ(FLAGS)

  elif FLAGS.network_type == 'deep_ac':
    network = network.DeepAC(FLAGS)

  else: raise NotImplementedError

  # Initialize exploration strategy
  exp_schedule = LinearExploration(env, FLAGS.epsilon, FLAGS.eps_end, FLAGS.eps_nsteps)

  # Initialize exploration rate schedule
  lr_schedule  = LinearSchedule(FLAGS.learning_rate, FLAGS.lr_end, FLAGS.lr_nsteps)

  # train model
  model = None
  if FLAGS.model_type == 'q':
    model = Model(env, record_env, network, FLAGS)
  elif FLAGS.model_type == 'ac':
    model = ModelAC(env, record_env, network, FLAGS)
  else:
    raise NotImplementedError

  if FLAGS.record_only:
    model.record_videos(FLAGS.model_path+'checkpoint')
コード例 #3
0
You'll find the results, log and video recordings of your agent every 250k under
the corresponding file in the results folder. A good way to monitor the progress
of the training is to use Tensorboard. The starter code writes summaries of different
variables.

To launch tensorboard, open a Terminal window and run 
tensorboard --logdir=results/
Then, connect remotely to 
address-ip-of-the-server:6006 
6006 is the default port used by tensorboard.
"""
if __name__ == '__main__':
    # make env
    env = gym.make(config.env_name)
    env = MaxAndSkipEnv(env, skip=config.skip_frame)
    env = PreproWrapper(env,
                        prepro=greyscale,
                        shape=(80, 80, 1),
                        overwrite_render=config.overwrite_render)

    # exploration strategy
    exp_schedule = LinearExploration(env, config.eps_begin, config.eps_end,
                                     config.eps_nsteps)

    # learning rate schedule
    lr_schedule = LinearSchedule(config.lr_begin, config.lr_end,
                                 config.lr_nsteps)

    # train model
    model = NatureQN(env, config)
    model.run(exp_schedule, lr_schedule)
コード例 #4
0
    student_config.lr_nsteps = args.nsteps_train / 2
    student_config.exp_policy = args.exp_policy

    # make env
    env = gym.make(student_config.env_name)
    if hasattr(student_config, 'skip_frame'):
        env = MaxAndSkipEnv(env, skip=student_config.skip_frame)
    if hasattr(student_config, 'preprocess_state'
               ) and student_config.preprocess_state is not None:
        env = PreproWrapper(env,
                            prepro=greyscale,
                            shape=(80, 80, 1),
                            overwrite_render=student_config.overwrite_render)

    # exploration strategy
    if student_config.exp_policy == 'egreedy':
        exp_schedule = LinearExploration(env, student_config.eps_begin,
                                         student_config.eps_end,
                                         student_config.eps_nsteps)
    else:
        exp_schedule = LinearGreedyExploration(env, student_config.eps_begin,
                                               student_config.eps_end,
                                               student_config.eps_nsteps)
    # learning rate schedule
    lr_schedule = LinearSchedule(student_config.lr_begin,
                                 student_config.lr_end,
                                 student_config.lr_nsteps)

    # train model
    model = DistilledQN(env, student_config)
    model.run(exp_schedule, lr_schedule)