コード例 #1
0
def run_cartpole_reinforce(args, log_dir="./logs/reinforce"):
    os.makedirs(log_dir, exist_ok=True)
    env = CartPoleEnv()
    agent = CartPoleReinforceAgent(env.observation_space.shape[0], env.action_space.n)
    env.seed(args.seed)
    torch.manual_seed(args.seed)
    env = BenchMonitor(env, log_dir, allow_early_resets=True)
    env = CartPoleEnvSelfReset(env)

    exp_mem = build_experience_memory(agent, env, args.num_rollout_steps)
    w = World(env, agent, exp_mem)

    with torch.no_grad():
        w.agent.eval()
        gather_exp_via_rollout(w.env, w.agent, w.exp_mem, args.num_rollout_steps)

    optimizer = torch.optim.Adam(agent.parameters(), args.lr)

    for k in tqdm(range(args.num_batches)):
        with torch.no_grad():
            agent.eval()
            batch = do_rollout(w, args)
        train_batch(agent, batch, optimizer)

    return agent, env
コード例 #2
0
def run_cartpole_reinforce(args, log_dir="./logs/reinforce"):
    os.makedirs(log_dir, exist_ok=True)
    env = CartPoleEnv()
    agent: PolicyAgent = PolicyAgent(env.observation_space.shape[0],
                                     env.action_space.n)
    env.seed(args.seed)
    torch.manual_seed(args.seed)
    env = BenchMonitor(env, log_dir, allow_early_resets=True)
    train(env, agent, args)
    return agent, env
コード例 #3
0
def run_cartpole_dqn(num_batches=1000,
                     batch_size=32,
                     log_dir="./logs/dqn",
                     seed=0):
    os.makedirs(log_dir, exist_ok=True)
    env = CartPoleEnv()
    env.seed(seed)
    torch.manual_seed(seed)
    agent = CartPoleAgent(env.observation_space, env.action_space)
    from baselines.bench import Monitor as BenchMonitor

    env = BenchMonitor(env, log_dir, allow_early_resets=True)
    train(agent, env, num_batches=num_batches, batch_size=batch_size)
    return agent, env
コード例 #4
0
class CartPoleDictEnvWrapper(gym.Env):
    def __init__(self, max_angle=12, max_num_steps=1000):
        self.env = CartPoleEnv()
        # self.env.theta_threshold_radians = max_angle * 2 * math.pi / 360
        self.observation_space = self.env.observation_space
        self.action_space = self.env.action_space
        self.step_counter = 0
        self.max_num_steps = max_num_steps

    def step(self, action):
        if isinstance(action, numpy.ndarray):
            action = action[0]
        assert isinstance(action, numpy.int64)
        obs, _, done, _ = self.env.step(action)
        self.step_counter += 1
        if self.step_counter % self.max_num_steps == 0:
            done = True
        if done:
            reward = -10.0
            obs = self.env.reset()
        else:
            reward = 0.0
        return {"observation": obs, "reward": reward, "done": int(done)}

    def reset(self):
        obs = self.env.reset()
        return {"observation": obs, "reward": 0.0, "done": int(False)}

    def render(self, mode="human"):
        return self.env.render(mode)

    def close(self):
        self.env.close()

    def seed(self, seed=None):
        return self.env.seed(seed)
コード例 #5
0
ファイル: train_cartpole.py プロジェクト: phate09/SafeDRL
from collections import deque

import matplotlib.pyplot as plt
import numpy as np
from gym.envs.classic_control import CartPoleEnv
from tensorboardX import SummaryWriter

from training.dqn.dqn_agent import Agent
from utility.Scheduler import Scheduler

currentDT = datetime.datetime.now()
print(f'Start at {currentDT.strftime("%Y-%m-%d %H:%M:%S")}')
seed = 5
# np.random.seed(seed)
env = CartPoleEnv()  # gym.make("CartPole-v0")
env.seed(seed)
np.random.seed(seed)
state_size = 4
action_size = 2
STARTING_BETA = 0.6  # the higher the more it decreases the influence of high TD transitions
ALPHA = 0.6  # the higher the more aggressive the sampling towards high TD transitions
EPS_DECAY = 0.2
MIN_EPS = 0.01

current_time = currentDT.strftime('%b%d_%H-%M-%S')
comment = f"alpha={ALPHA}, min_eps={MIN_EPS}, eps_decay={EPS_DECAY}"
log_dir = os.path.join('../runs', current_time + '_' + comment)
os.mkdir(log_dir)
print(f"logging to {log_dir}")
writer = SummaryWriter(log_dir=log_dir)
agent = Agent(state_size=state_size, action_size=action_size, alpha=ALPHA)