import gym from trickster.agent import REINFORCE, A2C, PPO from trickster.rollout import Trajectory from trickster import callbacks ENV_NAME = "LunarLanderContinuous-v2" ALGO = "REINFORCE" TRAJECTORY_MAX_STEPS = 100 EPOCHS = 1000 ROLLOUTS_PER_EPOCH = 4 env = gym.make(ENV_NAME) algo = {"REINFORCE": REINFORCE, "A2C": A2C, "PPO": PPO}[ALGO] agent = algo.from_environment(env) rollout = Trajectory(agent, env, TRAJECTORY_MAX_STEPS) cbs = [callbacks.ProgressPrinter(keys=rollout.progress_keys)] rollout.fit(epochs=EPOCHS, updates_per_epoch=1, rollouts_per_update=4, callbacks=cbs) rollout.render(repeats=100)
from trickster.agent import REINFORCE from trickster.rollout import Trajectory, RolloutConfig from trickster.utility import gymic from trickster.model import mlp env = gymic.rwd_scaled_env() input_shape = env.observation_space.shape num_actions = env.action_space.n policy = mlp.wide_mlp_actor_categorical(input_shape, num_actions, adam_lr=1e-4) agent = REINFORCE(policy, action_space=num_actions) rollout = Trajectory(agent, env, config=RolloutConfig(max_steps=300)) rollout.fit(episodes=500, rollouts_per_update=1, update_batch_size=-1) rollout.render(repeats=10)