def test_get_actions(self, action_dim, kernel_sizes, hidden_channels, strides, paddings): """Test get_actions function.""" batch_size = 64 input_width = 32 input_height = 32 in_channel = 3 input_shape = (batch_size, in_channel, input_height, input_width) env = GymEnv( DummyDiscreteEnv(obs_dim=input_shape, action_dim=action_dim)) env = self._initialize_obs_env(env) policy = DiscreteCNNPolicy(env_spec=env.spec, hidden_channels=hidden_channels, hidden_sizes=hidden_channels, kernel_sizes=kernel_sizes, strides=strides, paddings=paddings, padding_mode='zeros', hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_) env.reset() obs = env.step(1).observation actions, _ = policy.get_actions([obs, obs, obs]) for action in actions: assert env.action_space.contains(int(action[0])) assert env.action_space.n == action_dim
def test_obs_unflattened(self, action_dim, kernel_sizes, hidden_channels, strides, paddings): """Test if a flattened image obs is passed to get_action then it is unflattened. """ batch_size = 64 input_width = 32 input_height = 32 in_channel = 3 input_shape = (batch_size, in_channel, input_height, input_width) env = GymEnv( DummyDiscreteEnv(obs_dim=input_shape, action_dim=action_dim)) env = self._initialize_obs_env(env) env.reset() policy = DiscreteCNNPolicy(env_spec=env.spec, hidden_channels=hidden_channels, hidden_sizes=hidden_channels, kernel_sizes=kernel_sizes, strides=strides, paddings=paddings, padding_mode='zeros', hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_) obs = env.observation_space.sample() action, _ = policy.get_action(env.observation_space.flatten(obs)) env.step(action)
def test_is_pickleable(self, action_dim, kernel_sizes, hidden_channels, strides, paddings): """Test if policy is pickable.""" batch_size = 64 input_width = 32 input_height = 32 in_channel = 3 input_shape = (batch_size, in_channel, input_height, input_width) env = GymEnv( DummyDiscreteEnv(obs_dim=input_shape, action_dim=action_dim)) env = self._initialize_obs_env(env) policy = DiscreteCNNPolicy(env_spec=env.spec, hidden_channels=hidden_channels, hidden_sizes=hidden_channels, kernel_sizes=kernel_sizes, strides=strides, paddings=paddings, padding_mode='zeros', hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_) env.reset() obs = env.step(1).observation output_action_1, _ = policy.get_action(obs.flatten()) p = cloudpickle.dumps(policy) policy_pickled = cloudpickle.loads(p) output_action_2, _ = policy_pickled.get_action(obs) assert env.action_space.contains(int(output_action_1[0])) assert env.action_space.contains(int(output_action_2[0])) assert output_action_1.shape == output_action_2.shape
def test_is_pickleable(self, kernel_sizes, hidden_channels, strides, paddings): """Test if policy is pickable.""" env = GymEnv(DummyDiscretePixelEnv()) policy = DiscreteCNNPolicy(env_spec=env.spec, image_format='NHWC', hidden_channels=hidden_channels, kernel_sizes=kernel_sizes, strides=strides, paddings=paddings, padding_mode='zeros', hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_) env.reset() obs = env.step(1).observation output_action_1, _ = policy.get_action(obs.flatten()) p = cloudpickle.dumps(policy) policy_pickled = cloudpickle.loads(p) output_action_2, _ = policy_pickled.get_action(obs) assert env.action_space.contains(int(output_action_1[0])) assert env.action_space.contains(int(output_action_2[0])) assert output_action_1.shape == output_action_2.shape
def test_get_action(self, kernel_sizes, hidden_channels, strides, paddings): """Test get_action function.""" env = GymEnv(DummyDiscretePixelEnv()) policy = DiscreteCNNPolicy(env_spec=env.spec, image_format='NHWC', hidden_channels=hidden_channels, kernel_sizes=kernel_sizes, strides=strides, paddings=paddings, padding_mode='zeros', hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_) env.reset() obs = env.step(1).observation action, _ = policy.get_action(obs.flatten()) assert env.action_space.contains(int(action[0]))
def test_obs_unflattened(self, kernel_sizes, hidden_channels, strides, paddings): """Test if a flattened image obs is passed to get_action then it is unflattened. """ env = GymEnv(DummyDiscretePixelEnv()) env.reset() policy = DiscreteCNNPolicy(env_spec=env.spec, image_format='NHWC', hidden_channels=hidden_channels, kernel_sizes=kernel_sizes, strides=strides, paddings=paddings, padding_mode='zeros', hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_) obs = env.observation_space.sample() action, _ = policy.get_action(env.observation_space.flatten(obs)) env.step(action[0])