예제 #1
0
 def setup_method(self):
     super().setup_method()
     self.env = GymEnv(DummyDiscreteEnv())
     self.qf = SimpleQFunction(self.env.spec)
     self.policy = DiscreteQFArgmaxPolicy(env_spec=self.env.spec,
                                          qf=self.qf)
     self.sess.run(tf.compat.v1.global_variables_initializer())
     self.env.reset()
예제 #2
0
 def setup_method(self):
     super().setup_method()
     self.env = GymEnv(AtariEnv(DummyDiscretePixelEnvBaselines()),
                       is_image=True)
     self.qf = DiscreteCNNQFunction(env_spec=self.env.spec,
                                    filters=((1, (1, 1)), ),
                                    strides=(1, ),
                                    dueling=False)
     self.policy = DiscreteQFArgmaxPolicy(env_spec=self.env.spec,
                                          qf=self.qf)
     self.sess.run(tf.compat.v1.global_variables_initializer())
     self.env.reset()
예제 #3
0
class TestQfDerivedPolicy(TfGraphTestCase):

    def setup_method(self):
        super().setup_method()
        self.env = GymEnv(DummyDiscreteEnv())
        self.qf = SimpleQFunction(self.env.spec)
        self.policy = DiscreteQFArgmaxPolicy(env_spec=self.env.spec,
                                             qf=self.qf)
        self.sess.run(tf.compat.v1.global_variables_initializer())
        self.env.reset()

    def test_discrete_qf_argmax_policy(self):
        obs = self.env.step(1).observation
        action, _ = self.policy.get_action(obs)
        assert self.env.action_space.contains(action)
        actions, _ = self.policy.get_actions([obs])
        for action in actions:
            assert self.env.action_space.contains(action)

    def test_get_param(self):
        with tf.compat.v1.variable_scope('SimpleQFunction', reuse=True):
            return_var = tf.compat.v1.get_variable('return_var')
        assert self.policy.get_param_values() == return_var.eval()

    def test_is_pickleable(self):
        with tf.compat.v1.variable_scope('SimpleQFunction', reuse=True):
            return_var = tf.compat.v1.get_variable('return_var')
        # assign it to all one
        return_var.load(tf.ones_like(return_var).eval())
        obs = self.env.step(1).observation
        action1, _ = self.policy.get_action(obs)

        p = pickle.dumps(self.policy)
        with tf.compat.v1.Session(graph=tf.Graph()):
            policy_pickled = pickle.loads(p)
            action2, _ = policy_pickled.get_action(obs)
            assert action1 == action2

    def test_does_not_support_dict_obs_space(self):
        """Test that policy raises error if passed a dict obs space."""
        env = GymEnv(DummyDictEnv(act_space_type='discrete'))
        with pytest.raises(ValueError):
            qf = SimpleQFunction(env.spec,
                                 name='does_not_support_dict_obs_space')
            DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf)

    def test_invalid_action_spaces(self):
        """Test that policy raises error if passed a dict obs space."""
        env = GymEnv(DummyDictEnv(act_space_type='box'))
        with pytest.raises(ValueError):
            qf = SimpleQFunction(env.spec)
            DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf)
예제 #4
0
 def test_does_not_support_dict_obs_space(self):
     """Test that policy raises error if passed a dict obs space."""
     env = GymEnv(DummyDictEnv(act_space_type='discrete'))
     with pytest.raises(ValueError):
         qf = SimpleQFunction(env.spec,
                              name='does_not_support_dict_obs_space')
         DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf)
예제 #5
0
    def test_dqn_cartpole_pickle(self):
        """Test DQN with CartPole environment."""
        deterministic.set_seed(100)
        with TFTrainer(snapshot_config, sess=self.sess) as trainer:
            n_epochs = 10
            steps_per_epoch = 10
            sampler_batch_size = 500
            num_timesteps = n_epochs * steps_per_epoch * sampler_batch_size
            env = GymEnv('CartPole-v0')
            replay_buffer = PathBuffer(capacity_in_transitions=int(1e4))
            qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(64, 64))
            policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf)
            epilson_greedy_policy = EpsilonGreedyPolicy(
                env_spec=env.spec,
                policy=policy,
                total_timesteps=num_timesteps,
                max_epsilon=1.0,
                min_epsilon=0.02,
                decay_ratio=0.1)
            sampler = LocalSampler(
                agents=epilson_greedy_policy,
                envs=env,
                max_episode_length=env.spec.max_episode_length,
                is_tf_worker=True,
                worker_class=FragmentWorker)
            algo = DQN(env_spec=env.spec,
                       policy=policy,
                       qf=qf,
                       exploration_policy=epilson_greedy_policy,
                       replay_buffer=replay_buffer,
                       sampler=sampler,
                       qf_lr=1e-4,
                       discount=1.0,
                       min_buffer_size=int(1e3),
                       double_q=False,
                       n_train_steps=500,
                       grad_norm_clipping=5.0,
                       steps_per_epoch=steps_per_epoch,
                       target_network_update_freq=1,
                       buffer_batch_size=32)
            trainer.setup(algo, env)
            with tf.compat.v1.variable_scope(
                    'DiscreteMLPQFunction/mlp/hidden_0', reuse=True):
                bias = tf.compat.v1.get_variable('bias')
                # assign it to all one
                old_bias = tf.ones_like(bias).eval()
                bias.load(old_bias)
                h = pickle.dumps(algo)

            with tf.compat.v1.Session(graph=tf.Graph()):
                pickle.loads(h)
                with tf.compat.v1.variable_scope(
                        'DiscreteMLPQFunction/mlp/hidden_0', reuse=True):
                    new_bias = tf.compat.v1.get_variable('bias')
                    new_bias = new_bias.eval()
                    assert np.array_equal(old_bias, new_bias)

            env.close()
예제 #6
0
def dqn_cartpole(ctxt=None, seed=1):
    """Train TRPO with CubeCrash-v0 environment.

    Args:
        ctxt (garage.experiment.ExperimentContext): The experiment
            configuration used by Trainer to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.

    """
    set_seed(seed)
    with TFTrainer(ctxt) as trainer:
        n_epochs = 10
        steps_per_epoch = 10
        sampler_batch_size = 500
        num_timesteps = n_epochs * steps_per_epoch * sampler_batch_size
        env = GymEnv('CartPole-v0')
        replay_buffer = PathBuffer(capacity_in_transitions=int(1e4))
        qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(64, 64))
        policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf)
        exploration_policy = EpsilonGreedyPolicy(env_spec=env.spec,
                                                 policy=policy,
                                                 total_timesteps=num_timesteps,
                                                 max_epsilon=1.0,
                                                 min_epsilon=0.02,
                                                 decay_ratio=0.1)

        sampler = LocalSampler(agents=exploration_policy,
                               envs=env,
                               max_episode_length=env.spec.max_episode_length,
                               is_tf_worker=True,
                               worker_class=FragmentWorker)

        algo = DQN(env_spec=env.spec,
                   policy=policy,
                   qf=qf,
                   exploration_policy=exploration_policy,
                   replay_buffer=replay_buffer,
                   sampler=sampler,
                   steps_per_epoch=steps_per_epoch,
                   qf_lr=1e-4,
                   discount=1.0,
                   min_buffer_size=int(1e3),
                   double_q=True,
                   n_train_steps=500,
                   target_network_update_freq=1,
                   buffer_batch_size=32)

        trainer.setup(algo, env)
        trainer.train(n_epochs=n_epochs, batch_size=sampler_batch_size)
예제 #7
0
    def test_dqn_cartpole_grad_clip(self):
        """Test DQN with CartPole environment."""
        deterministic.set_seed(100)
        with TFTrainer(snapshot_config, sess=self.sess) as trainer:
            n_epochs = 10
            steps_per_epoch = 10
            sampler_batch_size = 500
            num_timesteps = n_epochs * steps_per_epoch * sampler_batch_size
            env = GymEnv('CartPole-v0')
            replay_buffer = PathBuffer(capacity_in_transitions=int(1e4))
            qf = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(64, 64))
            policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf)
            epilson_greedy_policy = EpsilonGreedyPolicy(
                env_spec=env.spec,
                policy=policy,
                total_timesteps=num_timesteps,
                max_epsilon=1.0,
                min_epsilon=0.02,
                decay_ratio=0.1)
            sampler = LocalSampler(
                agents=epilson_greedy_policy,
                envs=env,
                max_episode_length=env.spec.max_episode_length,
                is_tf_worker=True,
                worker_class=FragmentWorker)
            algo = DQN(env_spec=env.spec,
                       policy=policy,
                       qf=qf,
                       exploration_policy=epilson_greedy_policy,
                       replay_buffer=replay_buffer,
                       sampler=sampler,
                       qf_lr=1e-4,
                       discount=1.0,
                       min_buffer_size=int(1e3),
                       double_q=False,
                       n_train_steps=500,
                       grad_norm_clipping=5.0,
                       steps_per_epoch=steps_per_epoch,
                       target_network_update_freq=1,
                       buffer_batch_size=32)

            trainer.setup(algo, env)
            last_avg_ret = trainer.train(n_epochs=n_epochs,
                                         batch_size=sampler_batch_size)
            assert last_avg_ret > 8.8

            env.close()
예제 #8
0
class TestQfDerivedPolicyImageObs(TfGraphTestCase):
    def setup_method(self):
        super().setup_method()
        self.env = GymEnv(AtariEnv(DummyDiscretePixelEnvBaselines()),
                          is_image=True)
        self.qf = DiscreteCNNQFunction(env_spec=self.env.spec,
                                       filters=((1, (1, 1)), ),
                                       strides=(1, ),
                                       dueling=False)
        self.policy = DiscreteQFArgmaxPolicy(env_spec=self.env.spec,
                                             qf=self.qf)
        self.sess.run(tf.compat.v1.global_variables_initializer())
        self.env.reset()

    def test_obs_unflattened(self):
        """Test if a flattened image obs is passed to get_action
           then it is unflattened.
        """
        obs = self.env.observation_space.sample()
        action, _ = self.policy.get_action(
            self.env.observation_space.flatten(obs))
        self.env.step(action)
예제 #9
0
 def test_invalid_action_spaces(self):
     """Test that policy raises error if passed a dict obs space."""
     env = GymEnv(DummyDictEnv(act_space_type='box'))
     with pytest.raises(ValueError):
         qf = SimpleQFunction(env.spec)
         DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf)
예제 #10
0
def dqn_pong(ctxt=None, seed=1, buffer_size=int(5e4), max_episode_length=500):
    """Train DQN on PongNoFrameskip-v4 environment.

    Args:
        ctxt (garage.experiment.ExperimentContext): The experiment
            configuration used by Trainer to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.
        buffer_size (int): Number of timesteps to store in replay buffer.
        max_episode_length (int): Maximum length of an episode, after which an
            episode is considered complete. This is used during testing to
            minimize the memory required to store a single episode.

    """
    set_seed(seed)
    with TFTrainer(ctxt) as trainer:
        n_epochs = 100
        steps_per_epoch = 20
        sampler_batch_size = 500
        num_timesteps = n_epochs * steps_per_epoch * sampler_batch_size

        env = gym.make('PongNoFrameskip-v4')
        env = env.unwrapped
        env = Noop(env, noop_max=30)
        env = MaxAndSkip(env, skip=4)
        env = EpisodicLife(env)
        if 'FIRE' in env.unwrapped.get_action_meanings():
            env = FireReset(env)
        env = Grayscale(env)
        env = Resize(env, 84, 84)
        env = ClipReward(env)
        env = StackFrames(env, 4)

        env = GymEnv(env, is_image=True, max_episode_length=max_episode_length)

        replay_buffer = PathBuffer(capacity_in_transitions=buffer_size)

        qf = DiscreteCNNQFunction(env_spec=env.spec,
                                  filters=(
                                              (32, (8, 8)),
                                              (64, (4, 4)),
                                              (64, (3, 3)),
                                          ),
                                  strides=(4, 2, 1),
                                  dueling=False)  # yapf: disable

        policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf)
        exploration_policy = EpsilonGreedyPolicy(env_spec=env.spec,
                                                 policy=policy,
                                                 total_timesteps=num_timesteps,
                                                 max_epsilon=1.0,
                                                 min_epsilon=0.02,
                                                 decay_ratio=0.1)

        algo = DQN(env_spec=env.spec,
                   policy=policy,
                   qf=qf,
                   exploration_policy=exploration_policy,
                   replay_buffer=replay_buffer,
                   qf_lr=1e-4,
                   discount=0.99,
                   min_buffer_size=int(1e4),
                   double_q=False,
                   n_train_steps=500,
                   steps_per_epoch=steps_per_epoch,
                   target_network_update_freq=2,
                   buffer_batch_size=32)

        trainer.setup(algo, env)
        trainer.train(n_epochs=n_epochs, batch_size=sampler_batch_size)