def test_output_values(self, kernel_sizes, hidden_channels, strides, paddings): """Test output values from CNNBaseModule. Args: kernel_sizes (tuple[int]): Kernel sizes. hidden_channels (tuple[int]): hidden channels. strides (tuple[int]): strides. paddings (tuple[int]): value of zero-padding. """ module_with_nonlinear_function_and_module = CNNModule( self.input_spec, image_format='NCHW', hidden_channels=hidden_channels, kernel_sizes=kernel_sizes, strides=strides, paddings=paddings, padding_mode='zeros', hidden_nonlinearity=torch.relu, hidden_w_init=nn.init.xavier_uniform_) module_with_nonlinear_module_instance_and_function = CNNModule( self.input_spec, image_format='NCHW', hidden_channels=hidden_channels, kernel_sizes=kernel_sizes, strides=strides, paddings=paddings, padding_mode='zeros', hidden_nonlinearity=nn.ReLU(), hidden_w_init=nn.init.xavier_uniform_) output1 = module_with_nonlinear_function_and_module(self.input) output2 = module_with_nonlinear_module_instance_and_function( self.input) current_size = self.input_width for (filter_size, stride, padding) in zip(kernel_sizes, strides, paddings): # padding = float((filter_size - 1) / 2) # P = (F - 1) /2 current_size = int( (current_size - filter_size + padding * 2) / stride) + 1 # conv formula = ((W - F + 2P) / S) + 1 flatten_shape = current_size * current_size * hidden_channels[-1] expected_output = torch.zeros((self.batch_size, flatten_shape)) assert np.array_equal(torch.all(torch.eq(output1, expected_output)), True) assert np.array_equal(torch.all(torch.eq(output2, expected_output)), True)
def test_output_values_with_unequal_stride_with_padding( self, hidden_channels, kernel_sizes, strides, paddings): """Test output values with unequal stride and padding from CNNModule. Args: kernel_sizes (tuple[int]): Kernel sizes. hidden_channels (tuple[int]): hidden channels. strides (tuple[int]): strides. paddings (tuple[int]): value of zero-padding. """ model = CNNModule(input_var=self.input, hidden_channels=hidden_channels, kernel_sizes=kernel_sizes, strides=strides, paddings=paddings, padding_mode='zeros', hidden_nonlinearity=torch.relu, hidden_w_init=nn.init.xavier_uniform_) output = model(self.input) current_size = self.input_width for (filter_size, stride, padding) in zip(kernel_sizes, strides, paddings): # padding = float((filter_size - 1) / 2) # P = (F - 1) /2 current_size = int( (current_size - filter_size + padding * 2) / stride) + 1 # conv formula = ((W - F + 2P) / S) + 1 flatten_shape = current_size * current_size * hidden_channels[-1] expected_output = torch.zeros((self.batch_size, flatten_shape)) assert np.array_equal(torch.all(torch.eq(output, expected_output)), True)
def __init__(self, input_shape, output_dim, kernel_sizes, hidden_channels, strides, hidden_sizes=(32, 32), cnn_hidden_nonlinearity=nn.ReLU, mlp_hidden_nonlinearity=nn.ReLU, hidden_w_init=nn.init.xavier_uniform_, hidden_b_init=nn.init.zeros_, paddings=0, padding_mode='zeros', max_pool=False, pool_shape=None, pool_stride=1, output_nonlinearity=None, output_w_init=nn.init.xavier_uniform_, output_b_init=nn.init.zeros_, layer_normalization=False, is_image=True): super().__init__() input_var = torch.zeros(input_shape) cnn_module = CNNModule(input_var=input_var, kernel_sizes=kernel_sizes, strides=strides, hidden_w_init=hidden_w_init, hidden_b_init=hidden_b_init, hidden_channels=hidden_channels, hidden_nonlinearity=cnn_hidden_nonlinearity, paddings=paddings, padding_mode=padding_mode, max_pool=max_pool, layer_normalization=layer_normalization, pool_shape=pool_shape, pool_stride=pool_stride, is_image=is_image) with torch.no_grad(): cnn_out = cnn_module(input_var) flat_dim = torch.flatten(cnn_out, start_dim=1).shape[1] mlp_module = MLPModule(flat_dim, output_dim, hidden_sizes, hidden_nonlinearity=mlp_hidden_nonlinearity, hidden_w_init=hidden_w_init, hidden_b_init=hidden_b_init, output_nonlinearity=output_nonlinearity, output_w_init=output_w_init, output_b_init=output_b_init, layer_normalization=layer_normalization) if mlp_hidden_nonlinearity is None: self._module = nn.Sequential(cnn_module, nn.Flatten(), mlp_module) else: self._module = nn.Sequential(cnn_module, mlp_hidden_nonlinearity(), nn.Flatten(), mlp_module)
def __init__(self, spec, image_format, *, kernel_sizes, hidden_channels, strides, hidden_sizes=(32, 32), cnn_hidden_nonlinearity=nn.ReLU, mlp_hidden_nonlinearity=nn.ReLU, hidden_w_init=nn.init.xavier_uniform_, hidden_b_init=nn.init.zeros_, paddings=0, padding_mode='zeros', max_pool=False, pool_shape=None, pool_stride=1, output_nonlinearity=None, output_w_init=nn.init.xavier_uniform_, output_b_init=nn.init.zeros_, layer_normalization=False): super().__init__() cnn_spec = InOutSpec(input_space=spec.input_space, output_space=None) cnn_module = CNNModule(spec=cnn_spec, image_format=image_format, kernel_sizes=kernel_sizes, strides=strides, hidden_w_init=hidden_w_init, hidden_b_init=hidden_b_init, hidden_channels=hidden_channels, hidden_nonlinearity=cnn_hidden_nonlinearity, paddings=paddings, padding_mode=padding_mode, max_pool=max_pool, layer_normalization=layer_normalization, pool_shape=pool_shape, pool_stride=pool_stride) flat_dim = cnn_module.spec.output_space.flat_dim output_dim = spec.output_space.flat_dim mlp_module = MLPModule(flat_dim, output_dim, hidden_sizes, hidden_nonlinearity=mlp_hidden_nonlinearity, hidden_w_init=hidden_w_init, hidden_b_init=hidden_b_init, output_nonlinearity=output_nonlinearity, output_w_init=output_w_init, output_b_init=output_b_init, layer_normalization=layer_normalization) if mlp_hidden_nonlinearity is None: self._module = nn.Sequential(cnn_module, nn.Flatten(), mlp_module) else: self._module = nn.Sequential(cnn_module, mlp_hidden_nonlinearity(), nn.Flatten(), mlp_module)
def __init__(self, env_spec, image_format, kernel_sizes, *, hidden_channels, strides=1, hidden_sizes=(32, 32), hidden_nonlinearity=torch.tanh, hidden_w_init=nn.init.xavier_uniform_, hidden_b_init=nn.init.zeros_, paddings=0, padding_mode='zeros', max_pool=False, pool_shape=None, pool_stride=1, output_w_init=nn.init.xavier_uniform_, output_b_init=nn.init.zeros_, layer_normalization=False, name='CategoricalCNNPolicy'): if not isinstance(env_spec.action_space, akro.Discrete): raise ValueError('CategoricalMLPPolicy only works ' 'with akro.Discrete action space.') if isinstance(env_spec.observation_space, akro.Dict): raise ValueError('CNN policies do not support ' 'with akro.Dict observation spaces.') super().__init__(env_spec, name) self._cnn_module = CNNModule(InOutSpec( self._env_spec.observation_space, None), image_format=image_format, kernel_sizes=kernel_sizes, strides=strides, hidden_channels=hidden_channels, hidden_w_init=hidden_w_init, hidden_b_init=hidden_b_init, hidden_nonlinearity=hidden_nonlinearity, paddings=paddings, padding_mode=padding_mode, max_pool=max_pool, pool_shape=pool_shape, pool_stride=pool_stride, layer_normalization=layer_normalization) self._mlp_module = MultiHeadedMLPModule( n_heads=1, input_dim=self._cnn_module.spec.output_space.flat_dim, output_dims=[self._env_spec.action_space.flat_dim], hidden_sizes=hidden_sizes, hidden_w_init=hidden_w_init, hidden_b_init=hidden_b_init, hidden_nonlinearity=hidden_nonlinearity, output_w_inits=output_w_init, output_b_inits=output_b_init)
def test_dueling_output_values(output_dim, kernel_sizes, hidden_channels, strides, paddings): batch_size = 64 input_width = 32 input_height = 32 in_channel = 3 input_shape = (batch_size, in_channel, input_height, input_width) obs = torch.rand(input_shape) module = DiscreteDuelingCNNModule(input_shape=input_shape, output_dim=output_dim, hidden_channels=hidden_channels, hidden_sizes=hidden_channels, kernel_sizes=kernel_sizes, strides=strides, paddings=paddings, padding_mode='zeros', hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_, is_image=False) cnn = CNNModule(input_var=obs, hidden_channels=hidden_channels, kernel_sizes=kernel_sizes, strides=strides, paddings=paddings, padding_mode='zeros', hidden_w_init=nn.init.ones_, is_image=False) flat_dim = torch.flatten(cnn(obs).detach(), start_dim=1).shape[1] mlp_adv = MLPModule( flat_dim, output_dim, hidden_channels, hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_, ) mlp_val = MLPModule( flat_dim, 1, hidden_channels, hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_, ) cnn_out = cnn(obs) val = mlp_val(torch.flatten(cnn_out, start_dim=1)) adv = mlp_adv(torch.flatten(cnn_out, start_dim=1)) output = val + (adv - adv.mean(1).unsqueeze(1)) assert torch.all(torch.eq(output.detach(), module(obs).detach()))
def test_output_with_max_pooling(self, kernel_sizes, hidden_channels, strides, pool_shape, pool_stride): model = CNNModule(input_var=self.input, hidden_channels=hidden_channels, kernel_sizes=kernel_sizes, strides=strides, max_pool=True, pool_shape=(pool_shape, pool_shape), pool_stride=(pool_stride, pool_stride)) x = model(self.input) fc_w = torch.zeros((x.shape[1], 10)) fc_b = torch.zeros(10) result = x.mm(fc_w) + fc_b assert result.size() == torch.Size([64, 10])
def test_set_output_size(kernel_sizes, hidden_channels, strides, pool_shape, pool_stride): spec = InOutSpec(akro.Box(shape=[3, 19, 15], high=np.inf, low=-np.inf), akro.Box(shape=[200], high=np.inf, low=-np.inf)) model = CNNModule(spec, image_format='NCHW', hidden_channels=hidden_channels, kernel_sizes=kernel_sizes, strides=strides, pool_shape=[(pool_shape, pool_shape)], pool_stride=[(pool_stride, pool_stride)], layer_normalization=True) images = torch.ones(10, 3, 19, 15) x = model(images) assert x.shape == (10, 200)
def test_no_head_invalid_settings(self, hidden_nonlinear): """Check CNNModule throws exception with invalid non-linear functions. Args: hidden_nonlinear (callable or torch.nn.Module): Non-linear functions for hidden layers. """ expected_msg = 'Non linear function .* is not supported' with pytest.raises(ValueError, match=expected_msg): CNNModule(input_var=self.input, hidden_channels=(32, ), kernel_sizes=(3, ), strides=(1, ), hidden_nonlinearity=hidden_nonlinear)
def test_output_values(output_dim, kernel_sizes, hidden_channels, strides, paddings): input_width = 32 input_height = 32 in_channel = 3 input_shape = (in_channel, input_height, input_width) spec = InOutSpec(akro.Box(shape=input_shape, low=-np.inf, high=np.inf), akro.Box(shape=(output_dim, ), low=-np.inf, high=np.inf)) obs = torch.rand(input_shape) module = DiscreteCNNModule(spec=spec, image_format='NCHW', hidden_channels=hidden_channels, hidden_sizes=hidden_channels, kernel_sizes=kernel_sizes, strides=strides, paddings=paddings, padding_mode='zeros', hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_) cnn = CNNModule(spec=InOutSpec( akro.Box(shape=input_shape, low=-np.inf, high=np.inf), None), image_format='NCHW', hidden_channels=hidden_channels, kernel_sizes=kernel_sizes, strides=strides, paddings=paddings, padding_mode='zeros', hidden_w_init=nn.init.ones_) flat_dim = torch.flatten(cnn(obs).detach(), start_dim=1).shape[1] mlp = MLPModule( flat_dim, output_dim, hidden_channels, hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_, ) cnn_out = cnn(obs) output = mlp(torch.flatten(cnn_out, start_dim=1)) assert torch.all(torch.eq(output.detach(), module(obs).detach()))
def test_is_pickleable(self, hidden_channels, kernel_sizes, strides): """Check CNNModule is pickeable. Args: hidden_channels (tuple[int]): hidden channels. kernel_sizes (tuple[int]): Kernel sizes. strides (tuple[int]): strides. """ model = CNNModule(input_var=self.input, hidden_channels=hidden_channels, kernel_sizes=kernel_sizes, strides=strides) output1 = model(self.input) h = pickle.dumps(model) model_pickled = pickle.loads(h) output2 = model_pickled(self.input) assert np.array_equal(torch.all(torch.eq(output1, output2)), True)