示例#1
0
    def __init__(self,
                 env_spec,
                 kernel_sizes,
                 hidden_channels,
                 strides,
                 hidden_sizes=(32, 32),
                 cnn_hidden_nonlinearity=torch.nn.ReLU,
                 mlp_hidden_nonlinearity=torch.nn.ReLU,
                 hidden_w_init=nn.init.xavier_uniform_,
                 hidden_b_init=nn.init.zeros_,
                 paddings=0,
                 padding_mode='zeros',
                 max_pool=False,
                 pool_shape=None,
                 pool_stride=1,
                 output_nonlinearity=None,
                 output_w_init=nn.init.xavier_uniform_,
                 output_b_init=nn.init.zeros_,
                 layer_normalization=False,
                 name='DiscreteCNNPolicy'):

        super().__init__(env_spec, name)
        self._env_spec = env_spec
        self._input_shape = env_spec.observation_space.shape
        self._output_dim = env_spec.action_space.flat_dim
        self._is_image = isinstance(self._env_spec.observation_space,
                                    akro.Image)

        self._cnn_module = DiscreteCNNModule(
            self._input_shape, self._output_dim, kernel_sizes, hidden_channels,
            strides, hidden_sizes, cnn_hidden_nonlinearity,
            mlp_hidden_nonlinearity, hidden_w_init, hidden_b_init, paddings,
            padding_mode, max_pool, pool_shape, pool_stride,
            output_nonlinearity, output_w_init, output_b_init,
            layer_normalization, self._is_image)
示例#2
0
    def __init__(self,
                 env_spec,
                 image_format,
                 *,
                 kernel_sizes,
                 hidden_channels,
                 strides,
                 hidden_sizes=(32, 32),
                 cnn_hidden_nonlinearity=torch.nn.ReLU,
                 mlp_hidden_nonlinearity=torch.nn.ReLU,
                 hidden_w_init=nn.init.xavier_uniform_,
                 hidden_b_init=nn.init.zeros_,
                 paddings=0,
                 padding_mode='zeros',
                 max_pool=False,
                 pool_shape=None,
                 pool_stride=1,
                 output_nonlinearity=None,
                 output_w_init=nn.init.xavier_uniform_,
                 output_b_init=nn.init.zeros_,
                 layer_normalization=False):
        super().__init__()

        self._env_spec = env_spec

        self._cnn_module = DiscreteCNNModule(
            spec=InOutSpec(input_space=env_spec.observation_space,
                           output_space=env_spec.action_space),
            image_format=image_format,
            kernel_sizes=kernel_sizes,
            hidden_channels=hidden_channels,
            strides=strides,
            hidden_sizes=hidden_sizes,
            cnn_hidden_nonlinearity=cnn_hidden_nonlinearity,
            mlp_hidden_nonlinearity=mlp_hidden_nonlinearity,
            hidden_w_init=hidden_w_init,
            hidden_b_init=hidden_b_init,
            paddings=paddings,
            padding_mode=padding_mode,
            max_pool=max_pool,
            pool_shape=pool_shape,
            pool_stride=pool_stride,
            output_nonlinearity=output_nonlinearity,
            output_w_init=output_w_init,
            output_b_init=output_b_init,
            layer_normalization=layer_normalization)
示例#3
0
def test_output_values(output_dim, kernel_sizes, hidden_channels, strides,
                       paddings):

    input_width = 32
    input_height = 32
    in_channel = 3
    input_shape = (in_channel, input_height, input_width)
    spec = InOutSpec(akro.Box(shape=input_shape, low=-np.inf, high=np.inf),
                     akro.Box(shape=(output_dim, ), low=-np.inf, high=np.inf))
    obs = torch.rand(input_shape)

    module = DiscreteCNNModule(spec=spec,
                               image_format='NCHW',
                               hidden_channels=hidden_channels,
                               hidden_sizes=hidden_channels,
                               kernel_sizes=kernel_sizes,
                               strides=strides,
                               paddings=paddings,
                               padding_mode='zeros',
                               hidden_w_init=nn.init.ones_,
                               output_w_init=nn.init.ones_)

    cnn = CNNModule(spec=InOutSpec(
        akro.Box(shape=input_shape, low=-np.inf, high=np.inf), None),
                    image_format='NCHW',
                    hidden_channels=hidden_channels,
                    kernel_sizes=kernel_sizes,
                    strides=strides,
                    paddings=paddings,
                    padding_mode='zeros',
                    hidden_w_init=nn.init.ones_)
    flat_dim = torch.flatten(cnn(obs).detach(), start_dim=1).shape[1]

    mlp = MLPModule(
        flat_dim,
        output_dim,
        hidden_channels,
        hidden_w_init=nn.init.ones_,
        output_w_init=nn.init.ones_,
    )

    cnn_out = cnn(obs)
    output = mlp(torch.flatten(cnn_out, start_dim=1))

    assert torch.all(torch.eq(output.detach(), module(obs).detach()))
def test_output_values(output_dim, kernel_sizes, hidden_channels, strides,
                       paddings):

    batch_size = 64
    input_width = 32
    input_height = 32
    in_channel = 3
    input_shape = (batch_size, in_channel, input_height, input_width)
    obs = torch.rand(input_shape)

    module = DiscreteCNNModule(input_shape=input_shape,
                               output_dim=output_dim,
                               hidden_channels=hidden_channels,
                               hidden_sizes=hidden_channels,
                               kernel_sizes=kernel_sizes,
                               strides=strides,
                               paddings=paddings,
                               padding_mode='zeros',
                               hidden_w_init=nn.init.ones_,
                               output_w_init=nn.init.ones_,
                               is_image=False)

    cnn = CNNModule(input_var=obs,
                    hidden_channels=hidden_channels,
                    kernel_sizes=kernel_sizes,
                    strides=strides,
                    paddings=paddings,
                    padding_mode='zeros',
                    hidden_w_init=nn.init.ones_,
                    is_image=False)
    flat_dim = torch.flatten(cnn(obs).detach(), start_dim=1).shape[1]

    mlp = MLPModule(
        flat_dim,
        output_dim,
        hidden_channels,
        hidden_w_init=nn.init.ones_,
        output_w_init=nn.init.ones_,
    )

    cnn_out = cnn(obs)
    output = mlp(torch.flatten(cnn_out, start_dim=1))

    assert torch.all(torch.eq(output.detach(), module(obs).detach()))
示例#5
0
def test_without_nonlinearity(output_dim, hidden_channels, kernel_sizes,
                              strides):
    input_width = 32
    input_height = 32
    in_channel = 3
    input_shape = (in_channel, input_height, input_width)
    spec = InOutSpec(akro.Box(shape=input_shape, low=-np.inf, high=np.inf),
                     akro.Box(shape=(output_dim, ), low=-np.inf, high=np.inf))

    module = DiscreteCNNModule(spec=spec,
                               image_format='NCHW',
                               hidden_channels=hidden_channels,
                               hidden_sizes=hidden_channels,
                               kernel_sizes=kernel_sizes,
                               strides=strides,
                               mlp_hidden_nonlinearity=None,
                               cnn_hidden_nonlinearity=None,
                               hidden_w_init=nn.init.ones_,
                               output_w_init=nn.init.ones_)

    assert len(module._module) == 3
def test_without_nonlinearity(output_dim, hidden_channels, kernel_sizes,
                              strides):
    batch_size = 64
    input_width = 32
    input_height = 32
    in_channel = 3
    input_shape = (batch_size, in_channel, input_height, input_width)

    module = DiscreteCNNModule(input_shape=input_shape,
                               output_dim=output_dim,
                               hidden_channels=hidden_channels,
                               hidden_sizes=hidden_channels,
                               kernel_sizes=kernel_sizes,
                               strides=strides,
                               mlp_hidden_nonlinearity=None,
                               cnn_hidden_nonlinearity=None,
                               hidden_w_init=nn.init.ones_,
                               output_w_init=nn.init.ones_,
                               is_image=False)

    assert len(module._module) == 3
def test_is_pickleable(output_dim, hidden_channels, kernel_sizes, strides):
    batch_size = 64
    input_width = 32
    input_height = 32
    in_channel = 3
    input_shape = (batch_size, in_channel, input_height, input_width)
    input_a = torch.ones(input_shape)

    model = DiscreteCNNModule(input_shape=input_shape,
                              output_dim=output_dim,
                              hidden_channels=hidden_channels,
                              kernel_sizes=kernel_sizes,
                              mlp_hidden_nonlinearity=nn.ReLU,
                              cnn_hidden_nonlinearity=nn.ReLU,
                              strides=strides)
    output1 = model(input_a)

    h = pickle.dumps(model)
    model_pickled = pickle.loads(h)
    output2 = model_pickled(input_a)

    assert np.array_equal(torch.all(torch.eq(output1, output2)), True)
示例#8
0
def test_is_pickleable(output_dim, hidden_channels, kernel_sizes, strides):
    input_width = 32
    input_height = 32
    in_channel = 3
    input_shape = (in_channel, input_height, input_width)
    input_a = torch.ones(input_shape)
    spec = InOutSpec(akro.Box(shape=input_shape, low=-np.inf, high=np.inf),
                     akro.Box(shape=(output_dim, ), low=-np.inf, high=np.inf))

    model = DiscreteCNNModule(spec=spec,
                              image_format='NCHW',
                              hidden_channels=hidden_channels,
                              kernel_sizes=kernel_sizes,
                              mlp_hidden_nonlinearity=nn.ReLU,
                              cnn_hidden_nonlinearity=nn.ReLU,
                              strides=strides)
    output1 = model(input_a)

    h = pickle.dumps(model)
    model_pickled = pickle.loads(h)
    output2 = model_pickled(input_a)

    assert np.array_equal(torch.all(torch.eq(output1, output2)), True)