def setup_method(self): self.width = 16 self.height = 16 self.env = DummyDiscrete2DEnv() self.env_r = Resize(DummyDiscrete2DEnv(), width=self.width, height=self.height)
def test_resize_invalid_environment_shape(self): with self.assertRaises(ValueError): self.env.observation_space = Box(low=0, high=255, shape=(4, ), dtype=np.uint8) Resize(self.env, width=self.width, height=self.height)
def setUp(self): self.shape = (50, 50) self.env = mock.Mock() self.env.observation_space = Box(low=0, high=255, shape=self.shape, dtype=np.uint8) self.env.reset.return_value = np.zeros(self.shape) self.env.step.side_effect = self._step self._width = 16 self._height = 16 self.env_r = Resize(self.env, width=self._width, height=self._height) self.obs = self.env.reset() self.obs_r = self.env_r.reset()
class TestResize(unittest.TestCase): @overrides def setUp(self): self.shape = (50, 50) self.env = mock.Mock() self.env.observation_space = Box(low=0, high=255, shape=self.shape, dtype=np.uint8) self.env.reset.return_value = np.zeros(self.shape) self.env.step.side_effect = self._step self._width = 16 self._height = 16 self.env_r = Resize(self.env, width=self._width, height=self._height) self.obs = self.env.reset() self.obs_r = self.env_r.reset() def _step(self, action): return np.full(self.shape, 125), 0, False, dict() def test_resize_invalid_environment_type(self): with self.assertRaises(ValueError): self.env.observation_space = Discrete(64) Resize(self.env, width=self._width, height=self._height) def test_resize_invalid_environment_shape(self): with self.assertRaises(ValueError): self.env.observation_space = Box(low=0, high=255, shape=(4, ), dtype=np.uint8) Resize(self.env, width=self._width, height=self._height) def test_resize_output_observation_space(self): assert self.env_r.observation_space.shape == (self._width, self._height) def test_resize_output_reset(self): assert self.obs_r.shape == (self._width, self._height) def test_resize_output_step(self): obs_r, _, _, _ = self.env_r.step(0) assert obs_r.shape == (self._width, self._height)
class TestResize: def setup_method(self): self.width = 16 self.height = 16 self.env = DummyDiscrete2DEnv() self.env_r = Resize( DummyDiscrete2DEnv(), width=self.width, height=self.height) def teardown_method(self): self.env.close() self.env_r.close() def test_resize_invalid_environment_type(self): with pytest.raises(ValueError): self.env.observation_space = Discrete(64) Resize(self.env, width=self.width, height=self.height) def test_resize_invalid_environment_shape(self): with pytest.raises(ValueError): self.env.observation_space = Box( low=0, high=255, shape=(4, ), dtype=np.uint8) Resize(self.env, width=self.width, height=self.height) def test_resize_output_observation_space(self): assert self.env_r.observation_space.shape == (self.width, self.height) def test_resize_output_reset(self): assert self.env_r.reset().shape == (self.width, self.height) def test_resize_output_step(self): self.env_r.reset() obs_r, _, _, _ = self.env_r.step(1) assert obs_r.shape == (self.width, self.height)
def test_baseline(self): """Test the baseline initialization.""" box_env = TfEnv(DummyBoxEnv()) deterministic_mlp_baseline = DeterministicMLPBaseline(env_spec=box_env) gaussian_mlp_baseline = GaussianMLPBaseline(env_spec=box_env) discrete_env = TfEnv(Resize(DummyDiscrete2DEnv(), width=64, height=64)) gaussian_conv_baseline = GaussianConvBaseline( env_spec=discrete_env, regressor_args=dict( conv_filters=[32, 32], conv_filter_sizes=[1, 1], conv_strides=[1, 1], conv_pads=["VALID", "VALID"], hidden_sizes=(32, 32))) self.sess.run(tf.global_variables_initializer()) deterministic_mlp_baseline.get_param_values(trainable=True) gaussian_mlp_baseline.get_param_values(trainable=True) gaussian_conv_baseline.get_param_values(trainable=True)
def test_resize_invalid_environment_type(self): with self.assertRaises(ValueError): self.env.observation_space = Discrete(64) Resize(self.env, width=self.width, height=self.height)
def setUp(self): self.width = 16 self.height = 16 self.env = TfEnv(DummyDiscrete2DEnv()) self.env_r = TfEnv( Resize(DummyDiscrete2DEnv(), width=self.width, height=self.height))