Exemplo n.º 1
0
    def __call__(self, z, train: bool = True):
        # Common arguments
        conv_kwargs = {
            'kernel_size': (4, 4),
            'strides': (2, 2),
            'padding': 'SAME',
            'use_bias': False,
            'kernel_init': he_normal()
        }
        norm_kwargs = {
            'use_running_average': not train,
            'momentum': 0.99,
            'epsilon': 0.001,
            'use_scale': True,
            'use_bias': True
        }

        z = np.reshape(z, (1, 1, self.zdim))

        # Layer 1
        z = nn.ConvTranspose(features=512,
                             kernel_size=(4, 4),
                             strides=(1, 1),
                             padding='VALID',
                             use_bias=False,
                             kernel_init=he_normal())(z)
        z = nn.BatchNorm(**norm_kwargs)(z)
        z = nn.leaky_relu(z, 0.2)

        # Layer 2
        z = nn.ConvTranspose(features=256, **conv_kwargs)(z)
        z = nn.BatchNorm(**norm_kwargs)(z)
        z = nn.leaky_relu(z, 0.2)

        # Layer 3
        z = nn.ConvTranspose(features=128, **conv_kwargs)(z)
        z = nn.BatchNorm(**norm_kwargs)(z)
        z = nn.leaky_relu(z, 0.2)

        # Layer 4
        z = nn.ConvTranspose(features=64, **conv_kwargs)(z)
        z = nn.BatchNorm(**norm_kwargs)(z)
        z = nn.leaky_relu(z, 0.2)

        # Layer 5
        z = nn.ConvTranspose(features=1,
                             kernel_size=(4, 4),
                             strides=(2, 2),
                             padding='SAME',
                             use_bias=False,
                             kernel_init=nn.initializers.xavier_normal())(z)
        # x = nn.sigmoid(z)
        x = nn.softplus(z)

        return jnp.rot90(np.squeeze(x), k=2)  # Rotate to match TF output
Exemplo n.º 2
0
  def __call__(self, z):
    shape_before_flattening, flatten_out_size = self.flatten_enc_shape()

    x = nn.Dense(flatten_out_size, name='fc1')(z)
    x = x.reshape((x.shape[0], *shape_before_flattening[1:]))
    
    hidden_dims = self.hidden_dims[::-1]
    # Build Decoder
    for h_dim in range(len(hidden_dims)-1):
      x = nn.ConvTranspose(features=hidden_dims[h_dim], kernel_size=(3, 3), strides=(2,2))(x)
      x = nn.GroupNorm()(x)
      x = nn.gelu(x)
    
    x = nn.ConvTranspose(features=3, kernel_size=(3, 3), strides=(2,2))(x)
    x = nn.sigmoid(x)
    return x
 def __call__(self, z):
     shape_before_flattening, flatten_out_size = self.flatten_enc_shape()
     #print(shape_before_flattening, flatten_out_size)
     x = nn.Dense(flatten_out_size, name='fc1')(z)
     x = nn.gelu(x)
     x = x.reshape((x.shape[0], *shape_before_flattening[1:]))
     x = nn.ConvTranspose(features=32, kernel_size=(3, 3),
                          strides=(2, 2))(x)
     x = nn.GroupNorm(32)(x)
     x = nn.gelu(x)
     x = nn.ConvTranspose(features=28, kernel_size=(3, 3),
                          strides=(2, 2))(x)
     x = nn.GroupNorm(28)(x)
     x = nn.gelu(x)
     x = nn.ConvTranspose(features=1, kernel_size=(3, 3), strides=(2, 2))(x)
     return x
Exemplo n.º 4
0
 def basic_module(self, x):
     x = nn.Conv(features=90,
                 kernel_size=(9, 9),
                 padding='VALID',
                 dtype=jp.float64)(x)
     x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
     x = nn.ConvTranspose(features=1,
                          kernel_size=(2, 2),
                          strides=(2, 2),
                          dtype=jp.float64)(x)
     x = x.reshape(x.shape[0], -1)
     x = jp.prod(x, 1)
     return x
Exemplo n.º 5
0
  def __call__(self, x):
    """Forward function.

    Args:
        x: Input 4D tensor of shape `(N, H, W, in_chans)`.

    Returns:
        jnp.array: Output tensor of shape `(N, H*2, W*2, out_chans)`.
    """
    x = nn.ConvTranspose(
        self.out_chans, kernel_size=(2, 2), strides=(2, 2), use_bias=False)(
            x)
    x = _simple_instance_norm2d(x, (1, 2))
    x = jax.nn.leaky_relu(x, negative_slope=0.2)

    return x
Exemplo n.º 6
0
    def __call__(self, x, train):
        del train
        encoder_keys = [
            'filter_sizes',
            'kernel_sizes',
            'kernel_paddings',
            'window_sizes',
            'window_paddings',
            'strides',
            'activations',
        ]
        if len(set(len(self.encoder[k]) for k in encoder_keys)) > 1:
            raise ValueError(
                'The elements in encoder dict do not have the same length.')

        decoder_keys = [
            'filter_sizes',
            'kernel_sizes',
            'window_sizes',
            'paddings',
            'activations',
        ]
        if len(set(len(self.decoder[k]) for k in decoder_keys)) > 1:
            raise ValueError(
                'The elements in decoder dict do not have the same length.')

        # encoder
        for i in range(len(self.encoder['filter_sizes'])):
            x = nn.Conv(self.encoder['filter_sizes'][i],
                        self.encoder['kernel_sizes'][i],
                        padding=self.encoder['kernel_paddings'][i])(x)
            x = model_utils.ACTIVATIONS[self.encoder['activations'][i]](x)
            x = nn.max_pool(x,
                            self.encoder['window_sizes'][i],
                            strides=self.encoder['strides'][i],
                            padding=self.encoder['window_paddings'][i])

        # decoder
        for i in range(len(self.decoder['filter_sizes'])):
            x = nn.ConvTranspose(self.decoder['filter_sizes'][i],
                                 self.decoder['kernel_sizes'][i],
                                 self.decoder['window_sizes'][i],
                                 padding=self.decoder['paddings'][i])(x)
            x = model_utils.ACTIVATIONS[self.decoder['activations'][i]](x)
        return x
Exemplo n.º 7
0
 def test_single_input_conv_transpose(self):
     rng = dict(params=random.PRNGKey(0))
     x = jnp.ones((8, 3))
     conv_transpose_module = nn.ConvTranspose(
         features=4,
         kernel_size=(3, ),
         padding='VALID',
         kernel_init=initializers.ones,
         bias_init=initializers.ones,
     )
     y, initial_params = conv_transpose_module.init_with_output(rng, x)
     self.assertEqual(initial_params['params']['kernel'].shape, (3, 3, 4))
     correct_ans = np.array([[4., 4., 4., 4.], [7., 7., 7., 7.],
                             [10., 10., 10., 10.], [10., 10., 10., 10.],
                             [10., 10., 10., 10.], [10., 10., 10., 10.],
                             [10., 10., 10., 10.], [10., 10., 10., 10.],
                             [7., 7., 7., 7.], [4., 4., 4., 4.]])
     np.testing.assert_allclose(y, correct_ans)
Exemplo n.º 8
0
 def setup(self):
     # if bilinear, use the normal convolutions to reduce the number of channels
     if self.bilinear:
         self.up = lambda x: jax.image.resize(
             x, [x.shape[0], x.shape[1] * 2, x.shape[2] * 2, x.shape[3]],
             method='bilinear')
         self.conv = DoubleConv(self.in_channels, self.out_channels,
                                self.in_channels // 2, self.test,
                                self.group_norm, self.num_groups,
                                self.activation)
     else:
         self.up = nn.ConvTranspose(self.in_channels // 2,
                                    kernel_size=(2, 2),
                                    stride=(2, 2))
         self.conv = DoubleConv(self.in_channels, self.out_channels,
                                self.out_channels, self.test,
                                self.group_norm, self.num_groups,
                                self.activation)