Exemplo n.º 1
0
 def setUp(self):
     self.n_frames = 4
     self.env = TfEnv(DummyDiscrete2DEnv(random=False))
     self.env_s = TfEnv(
         StackFrames(
             DummyDiscrete2DEnv(random=False), n_frames=self.n_frames))
     self.width, self.height = self.env.observation_space.shape
Exemplo n.º 2
0
 def setup_method(self):
     self.width = 16
     self.height = 16
     self.env = DummyDiscrete2DEnv()
     self.env_r = Resize(DummyDiscrete2DEnv(),
                         width=self.width,
                         height=self.height)
Exemplo n.º 3
0
    def test_stack_frames_axis(self):
        env = StackFrames(DummyDiscrete2DEnv(random=False),
                          n_frames=self.n_frames,
                          axis=0)
        env.reset()
        obs, _, _, _ = env.step(1)
        assert obs.shape[0] == self.n_frames

        env = StackFrames(DummyDiscrete2DEnv(random=False),
                          n_frames=self.n_frames,
                          axis=2)
        env.reset()
        obs, _, _, _ = env.step(1)
        assert obs.shape[2] == self.n_frames
Exemplo n.º 4
0
    def test_baseline(self):
        """Test the baseline initialization."""
        box_env = TfEnv(DummyBoxEnv())
        deterministic_mlp_baseline = DeterministicMLPBaseline(env_spec=box_env)
        gaussian_mlp_baseline = GaussianMLPBaseline(env_spec=box_env)

        discrete_env = TfEnv(Resize(DummyDiscrete2DEnv(), width=64, height=64))
        gaussian_conv_baseline = GaussianConvBaseline(
            env_spec=discrete_env,
            regressor_args=dict(
                conv_filters=[32, 32],
                conv_filter_sizes=[1, 1],
                conv_strides=[1, 1],
                conv_pads=["VALID", "VALID"],
                hidden_sizes=(32, 32)))

        self.sess.run(tf.global_variables_initializer())
        deterministic_mlp_baseline.get_param_values(trainable=True)
        gaussian_mlp_baseline.get_param_values(trainable=True)
        gaussian_conv_baseline.get_param_values(trainable=True)
Exemplo n.º 5
0
 def setUp(self):
     self.width = 16
     self.height = 16
     self.env = TfEnv(DummyDiscrete2DEnv())
     self.env_r = TfEnv(
         Resize(DummyDiscrete2DEnv(), width=self.width, height=self.height))
Exemplo n.º 6
0
 def setup_method(self):
     self.n_frames = 4
     self.env = DummyDiscrete2DEnv(random=False)
     self.env_s = StackFrames(DummyDiscrete2DEnv(random=False),
                              n_frames=self.n_frames)
     self.width, self.height = self.env.observation_space.shape
Exemplo n.º 7
0
 def test_invalid_axis_raises_error(self):
     with pytest.raises(ValueError):
         StackFrames(DummyDiscrete2DEnv(random=False),
                     n_frames=self.n_frames,
                     axis=5)