コード例 #1
0
 def train_sac(ctxt=None):
     trainer = Trainer(ctxt)
     env = MyGymEnv(gym_env, max_episode_length=100)
     policy = CategoricalGRUPolicy(name='policy',
                                   env_spec=env.spec,
                                   state_include_action=False).to(
                                       global_device())
     qf1 = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(8, 5))
     qf2 = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(8, 5))
     replay_buffer = PathBuffer(capacity_in_transitions=int(1e6))
     sampler = LocalSampler(
         agents=policy,
         envs=env,
         max_episode_length=env.spec.max_episode_length,
         worker_class=FragmentWorker)
     self.algo = LoggedSAC(env=env,
                           env_spec=env.spec,
                           policy=policy,
                           qf1=qf1,
                           qf2=qf2,
                           sampler=sampler,
                           gradient_steps_per_itr=1000,
                           max_episode_length_eval=100,
                           replay_buffer=replay_buffer,
                           min_buffer_size=1e4,
                           target_update_tau=5e-3,
                           discount=0.99,
                           buffer_batch_size=256,
                           reward_scale=1.,
                           steps_per_epoch=1)
     trainer.setup(self.algo, env)
     trainer.train(n_epochs=n_eps, batch_size=4000)
     return self.algo.rew_chkpts
コード例 #2
0
def setup():
    set_seed(24)
    n_epochs = 11
    steps_per_epoch = 10
    sampler_batch_size = 512
    num_timesteps = 100 * steps_per_epoch * sampler_batch_size

    env = GymEnv('CartPole-v0')

    replay_buffer = PathBuffer(capacity_in_transitions=int(1e6))

    qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(8, 5))

    policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf)
    exploration_policy = EpsilonGreedyPolicy(env_spec=env.spec,
                                             policy=policy,
                                             total_timesteps=num_timesteps,
                                             max_epsilon=1.0,
                                             min_epsilon=0.01,
                                             decay_ratio=0.4)
    algo = DQN(env_spec=env.spec,
               policy=policy,
               qf=qf,
               exploration_policy=exploration_policy,
               replay_buffer=replay_buffer,
               steps_per_epoch=steps_per_epoch,
               qf_lr=5e-5,
               discount=0.9,
               min_buffer_size=int(1e4),
               n_train_steps=500,
               target_update_freq=30,
               buffer_batch_size=64)

    return algo, env, replay_buffer, n_epochs, sampler_batch_size
コード例 #3
0
def test_forward(batch_size):
    env_spec = GymEnv(DummyBoxEnv()).spec
    obs_dim = env_spec.observation_space.flat_dim
    obs = torch.ones([batch_size, obs_dim], dtype=torch.float32)
    qf = DiscreteMLPQFunction(env_spec=env_spec,
                              hidden_nonlinearity=None,
                              hidden_sizes=(2, 2))
    qvals = qf(obs)
    policy = DiscreteQFArgmaxPolicy(qf, env_spec)
    assert (policy(obs) == torch.argmax(qvals, dim=1)).all()
    assert policy(obs).shape == (batch_size, )
コード例 #4
0
def test_output_shape(batch_size, hidden_sizes):
    env_spec = GymEnv(DummyBoxEnv()).spec
    obs_dim = env_spec.observation_space.flat_dim
    obs = torch.ones(batch_size, obs_dim, dtype=torch.float32)

    qf = DiscreteMLPQFunction(env_spec=env_spec,
                              hidden_nonlinearity=None,
                              hidden_sizes=hidden_sizes,
                              hidden_w_init=nn.init.ones_,
                              output_w_init=nn.init.ones_)
    output = qf(obs)

    assert output.shape == (batch_size, env_spec.action_space.flat_dim)
コード例 #5
0
def test_get_action():
    env_spec = GymEnv(DummyBoxEnv()).spec
    obs_dim = env_spec.observation_space.flat_dim
    obs = torch.ones([
        obs_dim,
    ], dtype=torch.float32)
    qf = DiscreteMLPQFunction(env_spec=env_spec,
                              hidden_nonlinearity=None,
                              hidden_sizes=(2, 2))
    qvals = qf(obs.unsqueeze(0))
    policy = DiscreteQFArgmaxPolicy(qf, env_spec)
    action, _ = policy.get_action(obs.numpy())
    assert action == torch.argmax(qvals, dim=1).numpy()
    assert action.shape == ()
コード例 #6
0
def test_is_pickleable(batch_size):
    env_spec = GymEnv(DummyBoxEnv())
    obs_dim = env_spec.observation_space.flat_dim
    obs = torch.ones([batch_size, obs_dim], dtype=torch.float32)
    qf = DiscreteMLPQFunction(env_spec=env_spec,
                              hidden_nonlinearity=None,
                              hidden_sizes=(2, 2))
    policy = DiscreteQFArgmaxPolicy(qf, env_spec)

    output1 = policy.get_actions(obs.numpy())[0]

    p = pickle.dumps(policy)
    policy_pickled = pickle.loads(p)
    output2 = policy_pickled.get_actions(obs.numpy())[0]
    assert np.array_equal(output1, output2)
コード例 #7
0
def dqn_cartpole(ctxt=None, seed=24):
    """Train DQN with CartPole-v0 environment.

    Args:
        ctxt (garage.experiment.ExperimentContext): The experiment
            configuration used by LocalRunner to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.
    """
    set_seed(seed)
    runner = Trainer(ctxt)

    n_epochs = 100
    steps_per_epoch = 10
    sampler_batch_size = 512
    num_timesteps = n_epochs * steps_per_epoch * sampler_batch_size
    env = GymEnv('CartPole-v0')
    replay_buffer = PathBuffer(capacity_in_transitions=int(1e6))
    qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(8, 5))
    policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf)
    exploration_policy = EpsilonGreedyPolicy(env_spec=env.spec,
                                             policy=policy,
                                             total_timesteps=num_timesteps,
                                             max_epsilon=1.0,
                                             min_epsilon=0.01,
                                             decay_ratio=0.4)
    sampler = LocalSampler(agents=exploration_policy,
                           envs=env,
                           max_episode_length=env.spec.max_episode_length,
                           worker_class=FragmentWorker)
    algo = DQN(env_spec=env.spec,
               policy=policy,
               qf=qf,
               exploration_policy=exploration_policy,
               replay_buffer=replay_buffer,
               sampler=sampler,
               steps_per_epoch=steps_per_epoch,
               qf_lr=5e-5,
               discount=0.9,
               min_buffer_size=int(1e4),
               n_train_steps=500,
               target_update_freq=30,
               buffer_batch_size=64)

    runner.setup(algo, env)
    runner.train(n_epochs=n_epochs, batch_size=sampler_batch_size)

    env.close()
コード例 #8
0
def test_forward(hidden_sizes):
    env_spec = GymEnv(DummyBoxEnv()).spec
    obs_dim = env_spec.observation_space.flat_dim
    obs = torch.ones(obs_dim, dtype=torch.float32).unsqueeze(0)

    qf = DiscreteMLPQFunction(env_spec=env_spec,
                              hidden_nonlinearity=None,
                              hidden_sizes=hidden_sizes,
                              hidden_w_init=nn.init.ones_,
                              output_w_init=nn.init.ones_)

    output = qf(obs)

    expected_output = torch.full([1, 1],
                                 fill_value=(obs_dim) * np.prod(hidden_sizes),
                                 dtype=torch.float32)
    assert torch.eq(output, expected_output).all()
コード例 #9
0
def test_is_pickleable(hidden_sizes):
    env_spec = GymEnv(DummyBoxEnv()).spec
    obs_dim = env_spec.observation_space.flat_dim
    obs = torch.ones(obs_dim, dtype=torch.float32).unsqueeze(0)

    qf = DiscreteMLPQFunction(env_spec=env_spec,
                              hidden_nonlinearity=None,
                              hidden_sizes=hidden_sizes,
                              hidden_w_init=nn.init.ones_,
                              output_w_init=nn.init.ones_)

    output1 = qf(obs)

    p = pickle.dumps(qf)
    qf_pickled = pickle.loads(p)
    output2 = qf_pickled(obs)

    assert torch.eq(output1, output2).all()
コード例 #10
0
ファイル: DQNTutor.py プロジェクト: ManavR123/cs_285_project
        def train_dqn(ctxt=None):
            set_seed(seed)
            trainer = Trainer(ctxt)
            env = MyGymEnv(gym_env, max_episode_length=100)
            steps_per_epoch = 10
            sampler_batch_size = 4000
            num_timesteps = n_eps * steps_per_epoch * sampler_batch_size
            replay_buffer = PathBuffer(capacity_in_transitions=int(1e6))
            qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(8, 5))
            policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf)
            exploration_policy = EpsilonGreedyPolicy(
                env_spec=env.spec,
                policy=policy,
                total_timesteps=num_timesteps,
                max_epsilon=1.0,
                min_epsilon=0.01,
                decay_ratio=0.4,
            )
            sampler = LocalSampler(
                agents=exploration_policy,
                envs=env,
                max_episode_length=env.spec.max_episode_length,
                worker_class=FragmentWorker,
            )
            self.algo = LoggedDQN(
                env=env,
                env_spec=env.spec,
                policy=policy,
                qf=qf,
                exploration_policy=exploration_policy,
                replay_buffer=replay_buffer,
                sampler=sampler,
                steps_per_epoch=steps_per_epoch,
                qf_lr=5e-5,
                discount=0.99,
                min_buffer_size=int(1e4),
                n_train_steps=500,
                target_update_freq=30,
                buffer_batch_size=64,
            )
            trainer.setup(self.algo, env)
            trainer.train(n_epochs=n_eps, batch_size=sampler_batch_size)

            return self.algo.rew_chkpts