def run(v):
    np.random.seed(v['seed'])
    env_maker = EnvMaker('Pendulum-v0')
    env = env_maker.make()
    policy = GaussianMLPPolicy(
        observation_space=env.observation_space,
        action_space=env.action_space,
        env_spec=env.spec,
        hidden_sizes=(64, 64),
        hidden_nonlinearity=chainer.functions.tanh,
    )
    if v['baseline'] == 'mlp':
        baseline = MLPBaseline(
            observation_space=env.observation_space,
            action_space=env.action_space,
            env_spec=env.spec,
            hidden_sizes=(64, 64),
            hidden_nonlinearity=chainer.functions.tanh,
        )
    elif v['baseline'] == 'time_dependent':
        baseline = TimeDependentBaseline(
            observation_space=env.observation_space,
            action_space=env.action_space,
            env_spec=env.spec,
        )
    elif v['baseline'] == 'linear_feature':
        baseline = LinearFeatureBaseline(
            observation_space=env.observation_space,
            action_space=env.action_space,
            env_spec=env.spec,
        )
    else:
        raise ValueError
    trpo(
        env=env,
        env_maker=env_maker,
        n_envs=16,
        policy=policy,
        baseline=baseline,
        batch_size=10000,
        n_iters=100,
        snapshot_saver=SnapshotSaver(logger.get_dir()),
    )
def run(v):
    np.random.seed(v['seed'])
    env_maker = EnvMaker('CartPole-v0')
    env = env_maker.make()
    policy = CategoricalMLPPolicy(
        observation_space=env.observation_space,
        action_space=env.action_space,
        env_spec=env.spec
    )
    baseline = MLPBaseline(
        observation_space=env.observation_space,
        action_space=env.action_space,
        env_spec=env.spec
    )
    trpo(
        env=env,
        env_maker=env_maker,
        n_envs=16,
        policy=policy,
        baseline=baseline,
        batch_size=2000,
        n_iters=100,
        snapshot_saver=SnapshotSaver(logger.get_dir())
    )
log_dir = "data/local/trpo-cartpole"

np.random.seed(42)

# Clean up existing logs
os.system("rm -rf {}".format(log_dir))

with logger.session(log_dir):
    env_maker = EnvMaker('CartPole-v0')
    env = env_maker.make()
    policy = CategoricalMLPPolicy(
        observation_space=env.observation_space,
        action_space=env.action_space,
        env_spec=env.spec
    )
    baseline = MLPBaseline(
        observation_space=env.observation_space,
        action_space=env.action_space,
        env_spec=env.spec
    )
    trpo(
        env=env,
        env_maker=env_maker,
        n_envs=16,
        policy=policy,
        baseline=baseline,
        batch_size=2000,
        n_iters=100,
        snapshot_saver=SnapshotSaver(log_dir)
    )
log_dir = "data/local/trpo-cartpole"

np.random.seed(42)

# Clean up existing logs
os.system("rm -rf {}".format(log_dir))

with logger.session(log_dir):
    env_maker = EnvMaker('CartPole-v0')
    env = env_maker.make()
    policy = CategoricalMLPPolicy(
        observation_space=env.observation_space,
        action_space=env.action_space,
        env_spec=env.spec
    )
    baseline = MLPBaseline(
        observation_space=env.observation_space,
        action_space=env.action_space,
        env_spec=env.spec
    )
    trpo(
        env=env,
        env_maker=env_maker,
        n_envs=16,
        policy=policy,
        baseline=baseline,
        batch_size=2000,
        n_iters=100,
        snapshot_saver=SnapshotSaver(log_dir)
    )