def test_initializer_variance(self, num_spatial_dims, kernel_shape, in_channels, output_channels, data_format): c = conv.ConvNDTranspose(num_spatial_dims=num_spatial_dims, kernel_shape=kernel_shape, output_channels=output_channels, data_format=data_format) inputs = jnp.ones([16] + ([32] * num_spatial_dims) + [in_channels]) c(inputs) w = c.params_dict()["conv_nd_transpose/w"] actual_std = w.std() expected_std = 1 / (np.sqrt(np.prod(kernel_shape + (in_channels, )))) # This ratio of the error compared to the expected std might be somewhere # around 0.15 normally. We check it is not > 0.5, as that would indicate # something seriously wrong (ie the previous buggy initialization). rel_diff = np.abs(actual_std - expected_std) / expected_std self.assertLess(rel_diff, 0.5)
def f(): input_shape = [2, 4] + [16]*n data = jnp.zeros(input_shape) net = conv.ConvNDTranspose( n, output_channels=3, kernel_shape=3, data_format="channels_first") return net(data)
def f(): input_shape = [2] + [8]*n + [4] data = jnp.zeros(input_shape) net = conv.ConvNDTranspose( n, output_channels=3, kernel_shape=3, stride=3) return net(data)
def f(): input_shape = [2] + [16]*n + [4] data = jnp.zeros(input_shape) net = conv.ConvNDTranspose( n, output_channels=3, kernel_shape=3, padding="VALID") return net(data)