def main(unused_arg): env = catch.Catch(seed=FLAGS.seed) epsilon_cfg = dict(init_value=FLAGS.epsilon_begin, end_value=FLAGS.epsilon_end, transition_steps=FLAGS.epsilon_steps, power=1.) agent = DQN( observation_spec=env.observation_spec(), action_spec=env.action_spec(), epsilon_cfg=epsilon_cfg, target_period=FLAGS.target_period, learning_rate=FLAGS.learning_rate, ) accumulator = ReplayBuffer(FLAGS.replay_capacity) experiment.run_loop( agent=agent, environment=env, accumulator=accumulator, seed=FLAGS.seed, batch_size=FLAGS.batch_size, train_episodes=FLAGS.train_episodes, evaluate_every=FLAGS.evaluate_every, eval_episodes=FLAGS.eval_episodes, )
def main(unused_arg): env = catch.Catch(seed=FLAGS.seed) agent = OnlineQ(env.observation_spec(), env.action_spec(), FLAGS.learning_rate, FLAGS.epsilon) accumulator = TransitionAccumulator() experiment.run_loop(agent, env, accumulator, FLAGS.seed, 1, FLAGS.train_episodes, FLAGS.evaluate_every, FLAGS.eval_episodes)
def main(unused_arg): env = catch.Catch(seed=FLAGS.seed) agent = OnlineQLambda(observation_spec=env.observation_spec(), action_spec=env.action_spec(), num_hidden_units=FLAGS.num_hidden_units, epsilon=FLAGS.epsilon, lambda_=FLAGS.lambda_, learning_rate=FLAGS.learning_rate) accumulator = SequenceAccumulator(length=FLAGS.sequence_length) experiment.run_loop( agent=agent, environment=env, accumulator=accumulator, seed=FLAGS.seed, batch_size=1, train_episodes=FLAGS.train_episodes, evaluate_every=FLAGS.evaluate_every, eval_episodes=FLAGS.eval_episodes, )
def main(unused_arg): env = catch.Catch(seed=FLAGS.seed) env = wrappers.RewardScale(env, reward_scale=FLAGS.reward_scale) agent = PopArtAgent( observation_spec=env.observation_spec(), action_spec=env.action_spec(), num_hidden_units=FLAGS.num_hidden_units, epsilon=FLAGS.epsilon, learning_rate=FLAGS.learning_rate, pop_art_step_size=FLAGS.pop_art_step_size, ) accumulator = TransitionAccumulator() experiment.run_loop( agent=agent, environment=env, accumulator=accumulator, seed=FLAGS.seed, batch_size=1, train_episodes=FLAGS.train_episodes, evaluate_every=FLAGS.evaluate_every, eval_episodes=FLAGS.eval_episodes, )