示例#1
0
 def setup_method(self):
     self.width = 16
     self.height = 16
     self.env = DummyDiscrete2DEnv()
     self.env_r = Resize(DummyDiscrete2DEnv(),
                         width=self.width,
                         height=self.height)
示例#2
0
 def test_resize_invalid_environment_shape(self):
     with self.assertRaises(ValueError):
         self.env.observation_space = Box(low=0,
                                          high=255,
                                          shape=(4, ),
                                          dtype=np.uint8)
         Resize(self.env, width=self.width, height=self.height)
示例#3
0
    def setUp(self):
        self.shape = (50, 50)
        self.env = mock.Mock()
        self.env.observation_space = Box(low=0,
                                         high=255,
                                         shape=self.shape,
                                         dtype=np.uint8)
        self.env.reset.return_value = np.zeros(self.shape)
        self.env.step.side_effect = self._step

        self._width = 16
        self._height = 16
        self.env_r = Resize(self.env, width=self._width, height=self._height)

        self.obs = self.env.reset()
        self.obs_r = self.env_r.reset()
示例#4
0
class TestResize(unittest.TestCase):
    @overrides
    def setUp(self):
        self.shape = (50, 50)
        self.env = mock.Mock()
        self.env.observation_space = Box(low=0,
                                         high=255,
                                         shape=self.shape,
                                         dtype=np.uint8)
        self.env.reset.return_value = np.zeros(self.shape)
        self.env.step.side_effect = self._step

        self._width = 16
        self._height = 16
        self.env_r = Resize(self.env, width=self._width, height=self._height)

        self.obs = self.env.reset()
        self.obs_r = self.env_r.reset()

    def _step(self, action):
        return np.full(self.shape, 125), 0, False, dict()

    def test_resize_invalid_environment_type(self):
        with self.assertRaises(ValueError):
            self.env.observation_space = Discrete(64)
            Resize(self.env, width=self._width, height=self._height)

    def test_resize_invalid_environment_shape(self):
        with self.assertRaises(ValueError):
            self.env.observation_space = Box(low=0,
                                             high=255,
                                             shape=(4, ),
                                             dtype=np.uint8)
            Resize(self.env, width=self._width, height=self._height)

    def test_resize_output_observation_space(self):
        assert self.env_r.observation_space.shape == (self._width,
                                                      self._height)

    def test_resize_output_reset(self):
        assert self.obs_r.shape == (self._width, self._height)

    def test_resize_output_step(self):
        obs_r, _, _, _ = self.env_r.step(0)

        assert obs_r.shape == (self._width, self._height)
示例#5
0
class TestResize:
    def setup_method(self):
        self.width = 16
        self.height = 16
        self.env = DummyDiscrete2DEnv()
        self.env_r = Resize(
            DummyDiscrete2DEnv(), width=self.width, height=self.height)

    def teardown_method(self):
        self.env.close()
        self.env_r.close()

    def test_resize_invalid_environment_type(self):
        with pytest.raises(ValueError):
            self.env.observation_space = Discrete(64)
            Resize(self.env, width=self.width, height=self.height)

    def test_resize_invalid_environment_shape(self):
        with pytest.raises(ValueError):
            self.env.observation_space = Box(
                low=0, high=255, shape=(4, ), dtype=np.uint8)
            Resize(self.env, width=self.width, height=self.height)

    def test_resize_output_observation_space(self):
        assert self.env_r.observation_space.shape == (self.width, self.height)

    def test_resize_output_reset(self):
        assert self.env_r.reset().shape == (self.width, self.height)

    def test_resize_output_step(self):
        self.env_r.reset()
        obs_r, _, _, _ = self.env_r.step(1)
        assert obs_r.shape == (self.width, self.height)
示例#6
0
    def test_baseline(self):
        """Test the baseline initialization."""
        box_env = TfEnv(DummyBoxEnv())
        deterministic_mlp_baseline = DeterministicMLPBaseline(env_spec=box_env)
        gaussian_mlp_baseline = GaussianMLPBaseline(env_spec=box_env)

        discrete_env = TfEnv(Resize(DummyDiscrete2DEnv(), width=64, height=64))
        gaussian_conv_baseline = GaussianConvBaseline(
            env_spec=discrete_env,
            regressor_args=dict(
                conv_filters=[32, 32],
                conv_filter_sizes=[1, 1],
                conv_strides=[1, 1],
                conv_pads=["VALID", "VALID"],
                hidden_sizes=(32, 32)))

        self.sess.run(tf.global_variables_initializer())
        deterministic_mlp_baseline.get_param_values(trainable=True)
        gaussian_mlp_baseline.get_param_values(trainable=True)
        gaussian_conv_baseline.get_param_values(trainable=True)
示例#7
0
 def test_resize_invalid_environment_type(self):
     with self.assertRaises(ValueError):
         self.env.observation_space = Discrete(64)
         Resize(self.env, width=self.width, height=self.height)
示例#8
0
 def setUp(self):
     self.width = 16
     self.height = 16
     self.env = TfEnv(DummyDiscrete2DEnv())
     self.env_r = TfEnv(
         Resize(DummyDiscrete2DEnv(), width=self.width, height=self.height))