Beispiel #1
0
    def test_output_values(self, kernel_sizes, hidden_channels, strides,
                           paddings):
        """Test output values from CNNBaseModule.

        Args:
            kernel_sizes (tuple[int]): Kernel sizes.
            hidden_channels (tuple[int]): hidden channels.
            strides (tuple[int]): strides.
            paddings (tuple[int]): value of zero-padding.

        """
        module_with_nonlinear_function_and_module = CNNModule(
            self.input_spec,
            image_format='NCHW',
            hidden_channels=hidden_channels,
            kernel_sizes=kernel_sizes,
            strides=strides,
            paddings=paddings,
            padding_mode='zeros',
            hidden_nonlinearity=torch.relu,
            hidden_w_init=nn.init.xavier_uniform_)

        module_with_nonlinear_module_instance_and_function = CNNModule(
            self.input_spec,
            image_format='NCHW',
            hidden_channels=hidden_channels,
            kernel_sizes=kernel_sizes,
            strides=strides,
            paddings=paddings,
            padding_mode='zeros',
            hidden_nonlinearity=nn.ReLU(),
            hidden_w_init=nn.init.xavier_uniform_)

        output1 = module_with_nonlinear_function_and_module(self.input)
        output2 = module_with_nonlinear_module_instance_and_function(
            self.input)

        current_size = self.input_width
        for (filter_size, stride, padding) in zip(kernel_sizes, strides,
                                                  paddings):
            # padding = float((filter_size - 1) / 2) # P = (F - 1) /2
            current_size = int(
                (current_size - filter_size + padding * 2) /
                stride) + 1  # conv formula = ((W - F + 2P) / S) + 1
        flatten_shape = current_size * current_size * hidden_channels[-1]

        expected_output = torch.zeros((self.batch_size, flatten_shape))

        assert np.array_equal(torch.all(torch.eq(output1, expected_output)),
                              True)
        assert np.array_equal(torch.all(torch.eq(output2, expected_output)),
                              True)
Beispiel #2
0
    def test_output_values_with_unequal_stride_with_padding(
            self, hidden_channels, kernel_sizes, strides, paddings):
        """Test output values with unequal stride and padding from CNNModule.

        Args:
            kernel_sizes (tuple[int]): Kernel sizes.
            hidden_channels (tuple[int]): hidden channels.
            strides (tuple[int]): strides.
            paddings (tuple[int]): value of zero-padding.

        """
        model = CNNModule(input_var=self.input,
                          hidden_channels=hidden_channels,
                          kernel_sizes=kernel_sizes,
                          strides=strides,
                          paddings=paddings,
                          padding_mode='zeros',
                          hidden_nonlinearity=torch.relu,
                          hidden_w_init=nn.init.xavier_uniform_)
        output = model(self.input)

        current_size = self.input_width
        for (filter_size, stride, padding) in zip(kernel_sizes, strides,
                                                  paddings):
            # padding = float((filter_size - 1) / 2) # P = (F - 1) /2
            current_size = int(
                (current_size - filter_size + padding * 2) /
                stride) + 1  # conv formula = ((W - F + 2P) / S) + 1
        flatten_shape = current_size * current_size * hidden_channels[-1]

        expected_output = torch.zeros((self.batch_size, flatten_shape))
        assert np.array_equal(torch.all(torch.eq(output, expected_output)),
                              True)
Beispiel #3
0
    def __init__(self,
                 input_shape,
                 output_dim,
                 kernel_sizes,
                 hidden_channels,
                 strides,
                 hidden_sizes=(32, 32),
                 cnn_hidden_nonlinearity=nn.ReLU,
                 mlp_hidden_nonlinearity=nn.ReLU,
                 hidden_w_init=nn.init.xavier_uniform_,
                 hidden_b_init=nn.init.zeros_,
                 paddings=0,
                 padding_mode='zeros',
                 max_pool=False,
                 pool_shape=None,
                 pool_stride=1,
                 output_nonlinearity=None,
                 output_w_init=nn.init.xavier_uniform_,
                 output_b_init=nn.init.zeros_,
                 layer_normalization=False,
                 is_image=True):

        super().__init__()

        input_var = torch.zeros(input_shape)
        cnn_module = CNNModule(input_var=input_var,
                               kernel_sizes=kernel_sizes,
                               strides=strides,
                               hidden_w_init=hidden_w_init,
                               hidden_b_init=hidden_b_init,
                               hidden_channels=hidden_channels,
                               hidden_nonlinearity=cnn_hidden_nonlinearity,
                               paddings=paddings,
                               padding_mode=padding_mode,
                               max_pool=max_pool,
                               layer_normalization=layer_normalization,
                               pool_shape=pool_shape,
                               pool_stride=pool_stride,
                               is_image=is_image)

        with torch.no_grad():
            cnn_out = cnn_module(input_var)
        flat_dim = torch.flatten(cnn_out, start_dim=1).shape[1]

        mlp_module = MLPModule(flat_dim,
                               output_dim,
                               hidden_sizes,
                               hidden_nonlinearity=mlp_hidden_nonlinearity,
                               hidden_w_init=hidden_w_init,
                               hidden_b_init=hidden_b_init,
                               output_nonlinearity=output_nonlinearity,
                               output_w_init=output_w_init,
                               output_b_init=output_b_init,
                               layer_normalization=layer_normalization)

        if mlp_hidden_nonlinearity is None:
            self._module = nn.Sequential(cnn_module, nn.Flatten(), mlp_module)
        else:
            self._module = nn.Sequential(cnn_module, mlp_hidden_nonlinearity(),
                                         nn.Flatten(), mlp_module)
    def __init__(self,
                 spec,
                 image_format,
                 *,
                 kernel_sizes,
                 hidden_channels,
                 strides,
                 hidden_sizes=(32, 32),
                 cnn_hidden_nonlinearity=nn.ReLU,
                 mlp_hidden_nonlinearity=nn.ReLU,
                 hidden_w_init=nn.init.xavier_uniform_,
                 hidden_b_init=nn.init.zeros_,
                 paddings=0,
                 padding_mode='zeros',
                 max_pool=False,
                 pool_shape=None,
                 pool_stride=1,
                 output_nonlinearity=None,
                 output_w_init=nn.init.xavier_uniform_,
                 output_b_init=nn.init.zeros_,
                 layer_normalization=False):

        super().__init__()

        cnn_spec = InOutSpec(input_space=spec.input_space, output_space=None)
        cnn_module = CNNModule(spec=cnn_spec,
                               image_format=image_format,
                               kernel_sizes=kernel_sizes,
                               strides=strides,
                               hidden_w_init=hidden_w_init,
                               hidden_b_init=hidden_b_init,
                               hidden_channels=hidden_channels,
                               hidden_nonlinearity=cnn_hidden_nonlinearity,
                               paddings=paddings,
                               padding_mode=padding_mode,
                               max_pool=max_pool,
                               layer_normalization=layer_normalization,
                               pool_shape=pool_shape,
                               pool_stride=pool_stride)
        flat_dim = cnn_module.spec.output_space.flat_dim

        output_dim = spec.output_space.flat_dim
        mlp_module = MLPModule(flat_dim,
                               output_dim,
                               hidden_sizes,
                               hidden_nonlinearity=mlp_hidden_nonlinearity,
                               hidden_w_init=hidden_w_init,
                               hidden_b_init=hidden_b_init,
                               output_nonlinearity=output_nonlinearity,
                               output_w_init=output_w_init,
                               output_b_init=output_b_init,
                               layer_normalization=layer_normalization)

        if mlp_hidden_nonlinearity is None:
            self._module = nn.Sequential(cnn_module, nn.Flatten(), mlp_module)
        else:
            self._module = nn.Sequential(cnn_module, mlp_hidden_nonlinearity(),
                                         nn.Flatten(), mlp_module)
    def __init__(self,
                 env_spec,
                 image_format,
                 kernel_sizes,
                 *,
                 hidden_channels,
                 strides=1,
                 hidden_sizes=(32, 32),
                 hidden_nonlinearity=torch.tanh,
                 hidden_w_init=nn.init.xavier_uniform_,
                 hidden_b_init=nn.init.zeros_,
                 paddings=0,
                 padding_mode='zeros',
                 max_pool=False,
                 pool_shape=None,
                 pool_stride=1,
                 output_w_init=nn.init.xavier_uniform_,
                 output_b_init=nn.init.zeros_,
                 layer_normalization=False,
                 name='CategoricalCNNPolicy'):

        if not isinstance(env_spec.action_space, akro.Discrete):
            raise ValueError('CategoricalMLPPolicy only works '
                             'with akro.Discrete action space.')
        if isinstance(env_spec.observation_space, akro.Dict):
            raise ValueError('CNN policies do not support '
                             'with akro.Dict observation spaces.')

        super().__init__(env_spec, name)

        self._cnn_module = CNNModule(InOutSpec(
            self._env_spec.observation_space, None),
                                     image_format=image_format,
                                     kernel_sizes=kernel_sizes,
                                     strides=strides,
                                     hidden_channels=hidden_channels,
                                     hidden_w_init=hidden_w_init,
                                     hidden_b_init=hidden_b_init,
                                     hidden_nonlinearity=hidden_nonlinearity,
                                     paddings=paddings,
                                     padding_mode=padding_mode,
                                     max_pool=max_pool,
                                     pool_shape=pool_shape,
                                     pool_stride=pool_stride,
                                     layer_normalization=layer_normalization)
        self._mlp_module = MultiHeadedMLPModule(
            n_heads=1,
            input_dim=self._cnn_module.spec.output_space.flat_dim,
            output_dims=[self._env_spec.action_space.flat_dim],
            hidden_sizes=hidden_sizes,
            hidden_w_init=hidden_w_init,
            hidden_b_init=hidden_b_init,
            hidden_nonlinearity=hidden_nonlinearity,
            output_w_inits=output_w_init,
            output_b_inits=output_b_init)
def test_dueling_output_values(output_dim, kernel_sizes, hidden_channels,
                               strides, paddings):

    batch_size = 64
    input_width = 32
    input_height = 32
    in_channel = 3
    input_shape = (batch_size, in_channel, input_height, input_width)
    obs = torch.rand(input_shape)

    module = DiscreteDuelingCNNModule(input_shape=input_shape,
                                      output_dim=output_dim,
                                      hidden_channels=hidden_channels,
                                      hidden_sizes=hidden_channels,
                                      kernel_sizes=kernel_sizes,
                                      strides=strides,
                                      paddings=paddings,
                                      padding_mode='zeros',
                                      hidden_w_init=nn.init.ones_,
                                      output_w_init=nn.init.ones_,
                                      is_image=False)

    cnn = CNNModule(input_var=obs,
                    hidden_channels=hidden_channels,
                    kernel_sizes=kernel_sizes,
                    strides=strides,
                    paddings=paddings,
                    padding_mode='zeros',
                    hidden_w_init=nn.init.ones_,
                    is_image=False)
    flat_dim = torch.flatten(cnn(obs).detach(), start_dim=1).shape[1]

    mlp_adv = MLPModule(
        flat_dim,
        output_dim,
        hidden_channels,
        hidden_w_init=nn.init.ones_,
        output_w_init=nn.init.ones_,
    )

    mlp_val = MLPModule(
        flat_dim,
        1,
        hidden_channels,
        hidden_w_init=nn.init.ones_,
        output_w_init=nn.init.ones_,
    )

    cnn_out = cnn(obs)
    val = mlp_val(torch.flatten(cnn_out, start_dim=1))
    adv = mlp_adv(torch.flatten(cnn_out, start_dim=1))
    output = val + (adv - adv.mean(1).unsqueeze(1))

    assert torch.all(torch.eq(output.detach(), module(obs).detach()))
Beispiel #7
0
 def test_output_with_max_pooling(self, kernel_sizes, hidden_channels,
                                  strides, pool_shape, pool_stride):
     model = CNNModule(input_var=self.input,
                       hidden_channels=hidden_channels,
                       kernel_sizes=kernel_sizes,
                       strides=strides,
                       max_pool=True,
                       pool_shape=(pool_shape, pool_shape),
                       pool_stride=(pool_stride, pool_stride))
     x = model(self.input)
     fc_w = torch.zeros((x.shape[1], 10))
     fc_b = torch.zeros(10)
     result = x.mm(fc_w) + fc_b
     assert result.size() == torch.Size([64, 10])
Beispiel #8
0
def test_set_output_size(kernel_sizes, hidden_channels, strides, pool_shape,
                         pool_stride):
    spec = InOutSpec(akro.Box(shape=[3, 19, 15], high=np.inf, low=-np.inf),
                     akro.Box(shape=[200], high=np.inf, low=-np.inf))
    model = CNNModule(spec,
                      image_format='NCHW',
                      hidden_channels=hidden_channels,
                      kernel_sizes=kernel_sizes,
                      strides=strides,
                      pool_shape=[(pool_shape, pool_shape)],
                      pool_stride=[(pool_stride, pool_stride)],
                      layer_normalization=True)
    images = torch.ones(10, 3, 19, 15)
    x = model(images)
    assert x.shape == (10, 200)
Beispiel #9
0
    def test_no_head_invalid_settings(self, hidden_nonlinear):
        """Check CNNModule throws exception with invalid non-linear functions.

        Args:
            hidden_nonlinear (callable or torch.nn.Module): Non-linear
                functions for hidden layers.

        """
        expected_msg = 'Non linear function .* is not supported'
        with pytest.raises(ValueError, match=expected_msg):
            CNNModule(input_var=self.input,
                      hidden_channels=(32, ),
                      kernel_sizes=(3, ),
                      strides=(1, ),
                      hidden_nonlinearity=hidden_nonlinear)
Beispiel #10
0
def test_output_values(output_dim, kernel_sizes, hidden_channels, strides,
                       paddings):

    input_width = 32
    input_height = 32
    in_channel = 3
    input_shape = (in_channel, input_height, input_width)
    spec = InOutSpec(akro.Box(shape=input_shape, low=-np.inf, high=np.inf),
                     akro.Box(shape=(output_dim, ), low=-np.inf, high=np.inf))
    obs = torch.rand(input_shape)

    module = DiscreteCNNModule(spec=spec,
                               image_format='NCHW',
                               hidden_channels=hidden_channels,
                               hidden_sizes=hidden_channels,
                               kernel_sizes=kernel_sizes,
                               strides=strides,
                               paddings=paddings,
                               padding_mode='zeros',
                               hidden_w_init=nn.init.ones_,
                               output_w_init=nn.init.ones_)

    cnn = CNNModule(spec=InOutSpec(
        akro.Box(shape=input_shape, low=-np.inf, high=np.inf), None),
                    image_format='NCHW',
                    hidden_channels=hidden_channels,
                    kernel_sizes=kernel_sizes,
                    strides=strides,
                    paddings=paddings,
                    padding_mode='zeros',
                    hidden_w_init=nn.init.ones_)
    flat_dim = torch.flatten(cnn(obs).detach(), start_dim=1).shape[1]

    mlp = MLPModule(
        flat_dim,
        output_dim,
        hidden_channels,
        hidden_w_init=nn.init.ones_,
        output_w_init=nn.init.ones_,
    )

    cnn_out = cnn(obs)
    output = mlp(torch.flatten(cnn_out, start_dim=1))

    assert torch.all(torch.eq(output.detach(), module(obs).detach()))
Beispiel #11
0
    def test_is_pickleable(self, hidden_channels, kernel_sizes, strides):
        """Check CNNModule is pickeable.

        Args:
            hidden_channels (tuple[int]): hidden channels.
            kernel_sizes (tuple[int]): Kernel sizes.
            strides (tuple[int]): strides.

        """
        model = CNNModule(input_var=self.input,
                          hidden_channels=hidden_channels,
                          kernel_sizes=kernel_sizes,
                          strides=strides)
        output1 = model(self.input)

        h = pickle.dumps(model)
        model_pickled = pickle.loads(h)
        output2 = model_pickled(self.input)

        assert np.array_equal(torch.all(torch.eq(output1, output2)), True)