def __init__(self, env_spec, kernel_sizes, hidden_channels, strides, hidden_sizes=(32, 32), cnn_hidden_nonlinearity=torch.nn.ReLU, mlp_hidden_nonlinearity=torch.nn.ReLU, hidden_w_init=nn.init.xavier_uniform_, hidden_b_init=nn.init.zeros_, paddings=0, padding_mode='zeros', max_pool=False, pool_shape=None, pool_stride=1, output_nonlinearity=None, output_w_init=nn.init.xavier_uniform_, output_b_init=nn.init.zeros_, layer_normalization=False, name='DiscreteCNNPolicy'): super().__init__(env_spec, name) self._env_spec = env_spec self._input_shape = env_spec.observation_space.shape self._output_dim = env_spec.action_space.flat_dim self._is_image = isinstance(self._env_spec.observation_space, akro.Image) self._cnn_module = DiscreteCNNModule( self._input_shape, self._output_dim, kernel_sizes, hidden_channels, strides, hidden_sizes, cnn_hidden_nonlinearity, mlp_hidden_nonlinearity, hidden_w_init, hidden_b_init, paddings, padding_mode, max_pool, pool_shape, pool_stride, output_nonlinearity, output_w_init, output_b_init, layer_normalization, self._is_image)
def __init__(self, env_spec, image_format, *, kernel_sizes, hidden_channels, strides, hidden_sizes=(32, 32), cnn_hidden_nonlinearity=torch.nn.ReLU, mlp_hidden_nonlinearity=torch.nn.ReLU, hidden_w_init=nn.init.xavier_uniform_, hidden_b_init=nn.init.zeros_, paddings=0, padding_mode='zeros', max_pool=False, pool_shape=None, pool_stride=1, output_nonlinearity=None, output_w_init=nn.init.xavier_uniform_, output_b_init=nn.init.zeros_, layer_normalization=False): super().__init__() self._env_spec = env_spec self._cnn_module = DiscreteCNNModule( spec=InOutSpec(input_space=env_spec.observation_space, output_space=env_spec.action_space), image_format=image_format, kernel_sizes=kernel_sizes, hidden_channels=hidden_channels, strides=strides, hidden_sizes=hidden_sizes, cnn_hidden_nonlinearity=cnn_hidden_nonlinearity, mlp_hidden_nonlinearity=mlp_hidden_nonlinearity, hidden_w_init=hidden_w_init, hidden_b_init=hidden_b_init, paddings=paddings, padding_mode=padding_mode, max_pool=max_pool, pool_shape=pool_shape, pool_stride=pool_stride, output_nonlinearity=output_nonlinearity, output_w_init=output_w_init, output_b_init=output_b_init, layer_normalization=layer_normalization)
def test_output_values(output_dim, kernel_sizes, hidden_channels, strides, paddings): input_width = 32 input_height = 32 in_channel = 3 input_shape = (in_channel, input_height, input_width) spec = InOutSpec(akro.Box(shape=input_shape, low=-np.inf, high=np.inf), akro.Box(shape=(output_dim, ), low=-np.inf, high=np.inf)) obs = torch.rand(input_shape) module = DiscreteCNNModule(spec=spec, image_format='NCHW', hidden_channels=hidden_channels, hidden_sizes=hidden_channels, kernel_sizes=kernel_sizes, strides=strides, paddings=paddings, padding_mode='zeros', hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_) cnn = CNNModule(spec=InOutSpec( akro.Box(shape=input_shape, low=-np.inf, high=np.inf), None), image_format='NCHW', hidden_channels=hidden_channels, kernel_sizes=kernel_sizes, strides=strides, paddings=paddings, padding_mode='zeros', hidden_w_init=nn.init.ones_) flat_dim = torch.flatten(cnn(obs).detach(), start_dim=1).shape[1] mlp = MLPModule( flat_dim, output_dim, hidden_channels, hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_, ) cnn_out = cnn(obs) output = mlp(torch.flatten(cnn_out, start_dim=1)) assert torch.all(torch.eq(output.detach(), module(obs).detach()))
def test_output_values(output_dim, kernel_sizes, hidden_channels, strides, paddings): batch_size = 64 input_width = 32 input_height = 32 in_channel = 3 input_shape = (batch_size, in_channel, input_height, input_width) obs = torch.rand(input_shape) module = DiscreteCNNModule(input_shape=input_shape, output_dim=output_dim, hidden_channels=hidden_channels, hidden_sizes=hidden_channels, kernel_sizes=kernel_sizes, strides=strides, paddings=paddings, padding_mode='zeros', hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_, is_image=False) cnn = CNNModule(input_var=obs, hidden_channels=hidden_channels, kernel_sizes=kernel_sizes, strides=strides, paddings=paddings, padding_mode='zeros', hidden_w_init=nn.init.ones_, is_image=False) flat_dim = torch.flatten(cnn(obs).detach(), start_dim=1).shape[1] mlp = MLPModule( flat_dim, output_dim, hidden_channels, hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_, ) cnn_out = cnn(obs) output = mlp(torch.flatten(cnn_out, start_dim=1)) assert torch.all(torch.eq(output.detach(), module(obs).detach()))
def test_without_nonlinearity(output_dim, hidden_channels, kernel_sizes, strides): input_width = 32 input_height = 32 in_channel = 3 input_shape = (in_channel, input_height, input_width) spec = InOutSpec(akro.Box(shape=input_shape, low=-np.inf, high=np.inf), akro.Box(shape=(output_dim, ), low=-np.inf, high=np.inf)) module = DiscreteCNNModule(spec=spec, image_format='NCHW', hidden_channels=hidden_channels, hidden_sizes=hidden_channels, kernel_sizes=kernel_sizes, strides=strides, mlp_hidden_nonlinearity=None, cnn_hidden_nonlinearity=None, hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_) assert len(module._module) == 3
def test_without_nonlinearity(output_dim, hidden_channels, kernel_sizes, strides): batch_size = 64 input_width = 32 input_height = 32 in_channel = 3 input_shape = (batch_size, in_channel, input_height, input_width) module = DiscreteCNNModule(input_shape=input_shape, output_dim=output_dim, hidden_channels=hidden_channels, hidden_sizes=hidden_channels, kernel_sizes=kernel_sizes, strides=strides, mlp_hidden_nonlinearity=None, cnn_hidden_nonlinearity=None, hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_, is_image=False) assert len(module._module) == 3
def test_is_pickleable(output_dim, hidden_channels, kernel_sizes, strides): batch_size = 64 input_width = 32 input_height = 32 in_channel = 3 input_shape = (batch_size, in_channel, input_height, input_width) input_a = torch.ones(input_shape) model = DiscreteCNNModule(input_shape=input_shape, output_dim=output_dim, hidden_channels=hidden_channels, kernel_sizes=kernel_sizes, mlp_hidden_nonlinearity=nn.ReLU, cnn_hidden_nonlinearity=nn.ReLU, strides=strides) output1 = model(input_a) h = pickle.dumps(model) model_pickled = pickle.loads(h) output2 = model_pickled(input_a) assert np.array_equal(torch.all(torch.eq(output1, output2)), True)
def test_is_pickleable(output_dim, hidden_channels, kernel_sizes, strides): input_width = 32 input_height = 32 in_channel = 3 input_shape = (in_channel, input_height, input_width) input_a = torch.ones(input_shape) spec = InOutSpec(akro.Box(shape=input_shape, low=-np.inf, high=np.inf), akro.Box(shape=(output_dim, ), low=-np.inf, high=np.inf)) model = DiscreteCNNModule(spec=spec, image_format='NCHW', hidden_channels=hidden_channels, kernel_sizes=kernel_sizes, mlp_hidden_nonlinearity=nn.ReLU, cnn_hidden_nonlinearity=nn.ReLU, strides=strides) output1 = model(input_a) h = pickle.dumps(model) model_pickled = pickle.loads(h) output2 = model_pickled(input_a) assert np.array_equal(torch.all(torch.eq(output1, output2)), True)