def test_box(self): env = Gym(env_name="CartPole-v0") obs_preprocessor = env.get_obs_preprocessor() obs = env.reset() state = obs_preprocessor(obs) self.assertTrue(state.has_float_features_only) self.assertEqual(state.float_features.shape, (1, obs.shape[0])) self.assertEqual(state.float_features.dtype, torch.float32) self.assertEqual(state.float_features.device, torch.device("cpu")) npt.assert_array_almost_equal(obs, state.float_features.squeeze(0))
def test_box_cuda(self): env = Gym(env_name="CartPole-v0") device = torch.device("cuda") obs_preprocessor = env.get_obs_preprocessor(device=device) obs = env.reset() state = obs_preprocessor(obs) self.assertTrue(state.has_float_features_only) self.assertEqual(state.float_features.shape, (1, obs.shape[0])) self.assertEqual(state.float_features.dtype, torch.float32) # `device` doesn't have index. So we need this. x = torch.zeros(1, device=device) self.assertEqual(state.float_features.device, x.device) npt.assert_array_almost_equal(obs, state.float_features.cpu().squeeze(0))