def test_dqn_doesnt_store_invalid_transitions(self): STEPS = 55 env = DummyEnv() test_env = DummyEnv() model = tf.keras.Sequential([tf.keras.layers.Dense(2, input_dim=2)]) model.compile(tf.keras.optimizers.SGD(learning_rate=0.), loss="mse") agent = DQN.from_environment(env, model, discount_gamma=0., use_target_network=False) rollout = Rolling(agent, env) test_rollout = Trajectory(agent, test_env) rollout.fit(epochs=10, updates_per_epoch=12, steps_per_update=STEPS, update_batch_size=8, testing_rollout=test_rollout, buffer_warmup=False) data = agent.memory_sampler.sample(-1) np.testing.assert_array_less(data["state"], 10) np.testing.assert_equal( (data["state_next"] - data["state"]).sum(axis=1), 1)
def test_td3_doesnt_store_invalid_transitions(self): STEPS = 55 env = DummyEnv(action_space="continuous") test_env = DummyEnv(action_space="continuous") actor_out = tf.convert_to_tensor([[0., 0.]]) critic_out = tf.convert_to_tensor([1.]) actor = arch.TestingModel(actor_out) critic = arch.TestingModel(critic_out) critic2 = arch.TestingModel(critic_out) actor_target = arch.TestingModel(actor_out) critic_target = arch.TestingModel(critic_out) critic2_target = arch.TestingModel(critic_out) actor.optimizer = tf.keras.optimizers.SGD(0) critic.optimizer = tf.keras.optimizers.SGD(0) critic2.optimizer = tf.keras.optimizers.SGD(0) agent = TD3(actor, actor_target, critic, critic_target, critic2, critic2_target, discount_gamma=0., polyak_tau=0., action_minima=-1., action_maxima=1., update_actor_every=1) rollout = Rolling(agent, env) test_rollout = Trajectory(agent, test_env) rollout.fit(epochs=10, updates_per_epoch=12, steps_per_update=STEPS, update_batch_size=8, testing_rollout=test_rollout, buffer_warmup=False) data = agent.memory_sampler.sample(-1) np.testing.assert_array_less(data["state"], 10) np.testing.assert_equal( (data["state_next"] - data["state"]).sum(axis=1), 1)
from trickster.agent import DoubleDQN from trickster.rollout import Rolling, Trajectory, RolloutConfig from trickster.experience import Experience from trickster.model import mlp from trickster.utility import gymic env = gymic.rwd_scaled_env("CartPole-v1") test_env = gymic.rwd_scaled_env("CartPole-v1") input_shape = env.observation_space.shape num_actions = env.action_space.n ann = mlp.wide_mlp_critic_network(input_shape, num_actions, adam_lr=1e-3) agent = DoubleDQN(ann, action_space=env.action_space, memory=Experience(max_length=10000), epsilon=1., epsilon_decay=0.99995, epsilon_min=0.1, discount_factor_gamma=0.98) rollout = Rolling(agent, env, config=RolloutConfig(max_steps=300)) test_rollout = Trajectory(agent, test_env) rollout.fit(episodes=500, updates_per_episode=32, step_per_update=2, update_batch_size=32, testing_rollout=test_rollout, plot_curves=True) test_rollout.render(repeats=10)
num_actions = env.action_space.shape[0] actor, critics = mlp.wide_ddpg_actor_critic(input_shape, output_dim=num_actions, action_range=2, num_critics=2, actor_lr=5e-4, critic_lr=5e-4) agent = TD3(actor, critics, action_space=spaces.CONTINUOUS, memory=Experience(max_length=int(1e4)), discount_factor_gamma=0.99, action_noise_sigma=0.1, action_noise_sigma_decay=1., action_minima=-2, action_maxima=2, target_noise_sigma=0.2, target_noise_clip=0.5) rollout = Rolling(agent, env) test_rollout = Trajectory(agent, env, RolloutConfig(testing_rollout=True)) rollout.fit(episodes=1000, updates_per_episode=64, step_per_update=1, update_batch_size=32, testing_rollout=test_rollout) test_rollout.render(repeats=10)
from trickster.agent import A2C from trickster.rollout import Rolling, Trajectory, RolloutConfig from trickster.model import mlp from trickster.utility import gymic env = gymic.rwd_scaled_env() input_shape = env.observation_space.shape num_actions = env.action_space.n actor, critic = mlp.wide_pg_actor_critic(input_shape, num_actions) agent = A2C(actor, critic, action_space=env.action_space, discount_factor_gamma=0.98, entropy_penalty_coef=0.01) rollout = Rolling(agent, env, config=RolloutConfig(max_steps=300)) test_rollout = Trajectory(agent, gymic.rwd_scaled_env()) rollout.fit(episodes=1000, updates_per_episode=64, step_per_update=1, testing_rollout=test_rollout, plot_curves=True) test_rollout.render(repeats=10)