Beispiel #1
0
def main(unused_arg):
    env = catch.Catch(seed=FLAGS.seed)
    epsilon_cfg = dict(init_value=FLAGS.epsilon_begin,
                       end_value=FLAGS.epsilon_end,
                       transition_steps=FLAGS.epsilon_steps,
                       power=1.)
    agent = DQN(
        observation_spec=env.observation_spec(),
        action_spec=env.action_spec(),
        epsilon_cfg=epsilon_cfg,
        target_period=FLAGS.target_period,
        learning_rate=FLAGS.learning_rate,
    )

    accumulator = ReplayBuffer(FLAGS.replay_capacity)
    experiment.run_loop(
        agent=agent,
        environment=env,
        accumulator=accumulator,
        seed=FLAGS.seed,
        batch_size=FLAGS.batch_size,
        train_episodes=FLAGS.train_episodes,
        evaluate_every=FLAGS.evaluate_every,
        eval_episodes=FLAGS.eval_episodes,
    )
Beispiel #2
0
def main(unused_arg):
    env = catch.Catch(seed=FLAGS.seed)
    agent = OnlineQ(env.observation_spec(), env.action_spec(),
                    FLAGS.learning_rate, FLAGS.epsilon)

    accumulator = TransitionAccumulator()
    experiment.run_loop(agent, env, accumulator, FLAGS.seed, 1,
                        FLAGS.train_episodes, FLAGS.evaluate_every,
                        FLAGS.eval_episodes)
Beispiel #3
0
def main(unused_arg):
    env = catch.Catch(seed=FLAGS.seed)
    agent = OnlineQLambda(observation_spec=env.observation_spec(),
                          action_spec=env.action_spec(),
                          num_hidden_units=FLAGS.num_hidden_units,
                          epsilon=FLAGS.epsilon,
                          lambda_=FLAGS.lambda_,
                          learning_rate=FLAGS.learning_rate)

    accumulator = SequenceAccumulator(length=FLAGS.sequence_length)
    experiment.run_loop(
        agent=agent,
        environment=env,
        accumulator=accumulator,
        seed=FLAGS.seed,
        batch_size=1,
        train_episodes=FLAGS.train_episodes,
        evaluate_every=FLAGS.evaluate_every,
        eval_episodes=FLAGS.eval_episodes,
    )
Beispiel #4
0
def main(unused_arg):
    env = catch.Catch(seed=FLAGS.seed)
    env = wrappers.RewardScale(env, reward_scale=FLAGS.reward_scale)
    agent = PopArtAgent(
        observation_spec=env.observation_spec(),
        action_spec=env.action_spec(),
        num_hidden_units=FLAGS.num_hidden_units,
        epsilon=FLAGS.epsilon,
        learning_rate=FLAGS.learning_rate,
        pop_art_step_size=FLAGS.pop_art_step_size,
    )

    accumulator = TransitionAccumulator()
    experiment.run_loop(
        agent=agent,
        environment=env,
        accumulator=accumulator,
        seed=FLAGS.seed,
        batch_size=1,
        train_episodes=FLAGS.train_episodes,
        evaluate_every=FLAGS.evaluate_every,
        eval_episodes=FLAGS.eval_episodes,
    )