def test_get_actions(self, action_dim, kernel_sizes, hidden_channels,
                         strides, paddings):
        """Test get_actions function."""
        batch_size = 64
        input_width = 32
        input_height = 32
        in_channel = 3
        input_shape = (batch_size, in_channel, input_height, input_width)
        env = GymEnv(
            DummyDiscreteEnv(obs_dim=input_shape, action_dim=action_dim))

        env = self._initialize_obs_env(env)
        policy = DiscreteCNNPolicy(env_spec=env.spec,
                                   hidden_channels=hidden_channels,
                                   hidden_sizes=hidden_channels,
                                   kernel_sizes=kernel_sizes,
                                   strides=strides,
                                   paddings=paddings,
                                   padding_mode='zeros',
                                   hidden_w_init=nn.init.ones_,
                                   output_w_init=nn.init.ones_)

        env.reset()
        obs = env.step(1).observation

        actions, _ = policy.get_actions([obs, obs, obs])
        for action in actions:
            assert env.action_space.contains(int(action[0]))
            assert env.action_space.n == action_dim
    def test_obs_unflattened(self, action_dim, kernel_sizes, hidden_channels,
                             strides, paddings):
        """Test if a flattened image obs is passed to get_action
           then it is unflattened.
        """
        batch_size = 64
        input_width = 32
        input_height = 32
        in_channel = 3
        input_shape = (batch_size, in_channel, input_height, input_width)
        env = GymEnv(
            DummyDiscreteEnv(obs_dim=input_shape, action_dim=action_dim))
        env = self._initialize_obs_env(env)

        env.reset()
        policy = DiscreteCNNPolicy(env_spec=env.spec,
                                   hidden_channels=hidden_channels,
                                   hidden_sizes=hidden_channels,
                                   kernel_sizes=kernel_sizes,
                                   strides=strides,
                                   paddings=paddings,
                                   padding_mode='zeros',
                                   hidden_w_init=nn.init.ones_,
                                   output_w_init=nn.init.ones_)

        obs = env.observation_space.sample()
        action, _ = policy.get_action(env.observation_space.flatten(obs))
        env.step(action)
    def test_is_pickleable(self, action_dim, kernel_sizes, hidden_channels,
                           strides, paddings):
        """Test if policy is pickable."""
        batch_size = 64
        input_width = 32
        input_height = 32
        in_channel = 3
        input_shape = (batch_size, in_channel, input_height, input_width)
        env = GymEnv(
            DummyDiscreteEnv(obs_dim=input_shape, action_dim=action_dim))

        env = self._initialize_obs_env(env)
        policy = DiscreteCNNPolicy(env_spec=env.spec,
                                   hidden_channels=hidden_channels,
                                   hidden_sizes=hidden_channels,
                                   kernel_sizes=kernel_sizes,
                                   strides=strides,
                                   paddings=paddings,
                                   padding_mode='zeros',
                                   hidden_w_init=nn.init.ones_,
                                   output_w_init=nn.init.ones_)
        env.reset()
        obs = env.step(1).observation

        output_action_1, _ = policy.get_action(obs.flatten())

        p = cloudpickle.dumps(policy)
        policy_pickled = cloudpickle.loads(p)
        output_action_2, _ = policy_pickled.get_action(obs)

        assert env.action_space.contains(int(output_action_1[0]))
        assert env.action_space.contains(int(output_action_2[0]))
        assert output_action_1.shape == output_action_2.shape
    def test_is_pickleable(self, kernel_sizes, hidden_channels, strides,
                           paddings):
        """Test if policy is pickable."""
        env = GymEnv(DummyDiscretePixelEnv())
        policy = DiscreteCNNPolicy(env_spec=env.spec,
                                   image_format='NHWC',
                                   hidden_channels=hidden_channels,
                                   kernel_sizes=kernel_sizes,
                                   strides=strides,
                                   paddings=paddings,
                                   padding_mode='zeros',
                                   hidden_w_init=nn.init.ones_,
                                   output_w_init=nn.init.ones_)
        env.reset()
        obs = env.step(1).observation

        output_action_1, _ = policy.get_action(obs.flatten())

        p = cloudpickle.dumps(policy)
        policy_pickled = cloudpickle.loads(p)
        output_action_2, _ = policy_pickled.get_action(obs)

        assert env.action_space.contains(int(output_action_1[0]))
        assert env.action_space.contains(int(output_action_2[0]))
        assert output_action_1.shape == output_action_2.shape
    def test_get_action(self, kernel_sizes, hidden_channels, strides,
                        paddings):
        """Test get_action function."""
        env = GymEnv(DummyDiscretePixelEnv())
        policy = DiscreteCNNPolicy(env_spec=env.spec,
                                   image_format='NHWC',
                                   hidden_channels=hidden_channels,
                                   kernel_sizes=kernel_sizes,
                                   strides=strides,
                                   paddings=paddings,
                                   padding_mode='zeros',
                                   hidden_w_init=nn.init.ones_,
                                   output_w_init=nn.init.ones_)
        env.reset()
        obs = env.step(1).observation

        action, _ = policy.get_action(obs.flatten())
        assert env.action_space.contains(int(action[0]))
    def test_obs_unflattened(self, kernel_sizes, hidden_channels, strides,
                             paddings):
        """Test if a flattened image obs is passed to get_action
           then it is unflattened.
        """
        env = GymEnv(DummyDiscretePixelEnv())
        env.reset()
        policy = DiscreteCNNPolicy(env_spec=env.spec,
                                   image_format='NHWC',
                                   hidden_channels=hidden_channels,
                                   kernel_sizes=kernel_sizes,
                                   strides=strides,
                                   paddings=paddings,
                                   padding_mode='zeros',
                                   hidden_w_init=nn.init.ones_,
                                   output_w_init=nn.init.ones_)

        obs = env.observation_space.sample()
        action, _ = policy.get_action(env.observation_space.flatten(obs))
        env.step(action[0])