from environments.lqr import LQR from models.simple import Simple from policies.rand import RandomPolicy from policies.normal import NormalPolicy from utils.data import SARSDataset random.seed(42) environment = LQR(-1, 1) policy_model_random = RandomPolicy(-2, 2) policy_model_normal = NormalPolicy([1, 9, 1], [1.]) policy_model = policy_model_normal value_model = Simple() agent = Agent(environment, policy_model, value_model, verbose=True) agent.run_reps(iterations=3) state_space = np.arange(-2, 2.1, 0.1) action_space = np.arange(-2, 2.1, 0.1) # loop over state action pairs observations = [] for state in state_space: for action in action_space: environment.state = state new_state, reward, _, _ = environment.step(action) observations.append({ 'prev_state': state, 'action': action, 'reward': reward, 'new_state':new_state})
import sys, os sys.path.append(os.path.join(os.path.dirname(__file__), '..')) from agents.agent import Agent from environments.atari import AtariBreakout from policies.rand import RandomPolicy from policies.normal import NormalPolicy from models.mlp import MLP environment = AtariBreakout() # policy_model = RandomPolicy(0, 3, is_discrete=True) policy_model = NormalPolicy([128, 64, 1], 1.) value_model = MLP([128, 64, 1]) agent = Agent(environment, policy_model, value_model, verbose=True) agent.run_reps(exp_render=True)
import torch from agents.agent import Agent from environments.lqr import LQR from models.simple import Simple from policies.rand import RandomPolicy from policies.normal import NormalPolicy from models.mlp import MLP from utils.data import SARSDataset import gym from torch.nn import functional as F environment = LQR(-1, 1) policy_model_normal = NormalPolicy([1, 20, 45, 1], [4 for i in range(1)], activation=F.tanh) policy_model = policy_model_normal value_model = MLP([1, 20, 45, 1]) random.seed(42) torch.manual_seed(42) np.random.seed(42) #environment.seed(42) agent = Agent(environment, policy_model, value_model, verbose=True) agent.run_reps(100, exp_timesteps=1000, exp_episodes=10, exp_render=False, val_epsilon=0.1, pol_lr=1e-2)