Exemplo n.º 1
0
    def test_initializer_variance(self, num_spatial_dims, kernel_shape,
                                  in_channels, output_channels, data_format):
        c = conv.ConvNDTranspose(num_spatial_dims=num_spatial_dims,
                                 kernel_shape=kernel_shape,
                                 output_channels=output_channels,
                                 data_format=data_format)

        inputs = jnp.ones([16] + ([32] * num_spatial_dims) + [in_channels])
        c(inputs)

        w = c.params_dict()["conv_nd_transpose/w"]
        actual_std = w.std()
        expected_std = 1 / (np.sqrt(np.prod(kernel_shape + (in_channels, ))))

        # This ratio of the error compared to the expected std might be somewhere
        # around 0.15 normally. We check it is not > 0.5, as that would indicate
        # something seriously wrong (ie the previous buggy initialization).
        rel_diff = np.abs(actual_std - expected_std) / expected_std
        self.assertLess(rel_diff, 0.5)
Exemplo n.º 2
0
 def f():
   input_shape = [2, 4] + [16]*n
   data = jnp.zeros(input_shape)
   net = conv.ConvNDTranspose(
       n, output_channels=3, kernel_shape=3, data_format="channels_first")
   return net(data)
Exemplo n.º 3
0
 def f():
   input_shape = [2] + [8]*n + [4]
   data = jnp.zeros(input_shape)
   net = conv.ConvNDTranspose(
       n, output_channels=3, kernel_shape=3, stride=3)
   return net(data)
Exemplo n.º 4
0
 def f():
   input_shape = [2] + [16]*n + [4]
   data = jnp.zeros(input_shape)
   net = conv.ConvNDTranspose(
       n, output_channels=3, kernel_shape=3, padding="VALID")
   return net(data)