def __call__(self, z, train: bool = True): # Common arguments conv_kwargs = { 'kernel_size': (4, 4), 'strides': (2, 2), 'padding': 'SAME', 'use_bias': False, 'kernel_init': he_normal() } norm_kwargs = { 'use_running_average': not train, 'momentum': 0.99, 'epsilon': 0.001, 'use_scale': True, 'use_bias': True } z = np.reshape(z, (1, 1, self.zdim)) # Layer 1 z = nn.ConvTranspose(features=512, kernel_size=(4, 4), strides=(1, 1), padding='VALID', use_bias=False, kernel_init=he_normal())(z) z = nn.BatchNorm(**norm_kwargs)(z) z = nn.leaky_relu(z, 0.2) # Layer 2 z = nn.ConvTranspose(features=256, **conv_kwargs)(z) z = nn.BatchNorm(**norm_kwargs)(z) z = nn.leaky_relu(z, 0.2) # Layer 3 z = nn.ConvTranspose(features=128, **conv_kwargs)(z) z = nn.BatchNorm(**norm_kwargs)(z) z = nn.leaky_relu(z, 0.2) # Layer 4 z = nn.ConvTranspose(features=64, **conv_kwargs)(z) z = nn.BatchNorm(**norm_kwargs)(z) z = nn.leaky_relu(z, 0.2) # Layer 5 z = nn.ConvTranspose(features=1, kernel_size=(4, 4), strides=(2, 2), padding='SAME', use_bias=False, kernel_init=nn.initializers.xavier_normal())(z) # x = nn.sigmoid(z) x = nn.softplus(z) return jnp.rot90(np.squeeze(x), k=2) # Rotate to match TF output
def __call__(self, z): shape_before_flattening, flatten_out_size = self.flatten_enc_shape() x = nn.Dense(flatten_out_size, name='fc1')(z) x = x.reshape((x.shape[0], *shape_before_flattening[1:])) hidden_dims = self.hidden_dims[::-1] # Build Decoder for h_dim in range(len(hidden_dims)-1): x = nn.ConvTranspose(features=hidden_dims[h_dim], kernel_size=(3, 3), strides=(2,2))(x) x = nn.GroupNorm()(x) x = nn.gelu(x) x = nn.ConvTranspose(features=3, kernel_size=(3, 3), strides=(2,2))(x) x = nn.sigmoid(x) return x
def __call__(self, z): shape_before_flattening, flatten_out_size = self.flatten_enc_shape() #print(shape_before_flattening, flatten_out_size) x = nn.Dense(flatten_out_size, name='fc1')(z) x = nn.gelu(x) x = x.reshape((x.shape[0], *shape_before_flattening[1:])) x = nn.ConvTranspose(features=32, kernel_size=(3, 3), strides=(2, 2))(x) x = nn.GroupNorm(32)(x) x = nn.gelu(x) x = nn.ConvTranspose(features=28, kernel_size=(3, 3), strides=(2, 2))(x) x = nn.GroupNorm(28)(x) x = nn.gelu(x) x = nn.ConvTranspose(features=1, kernel_size=(3, 3), strides=(2, 2))(x) return x
def basic_module(self, x): x = nn.Conv(features=90, kernel_size=(9, 9), padding='VALID', dtype=jp.float64)(x) x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.ConvTranspose(features=1, kernel_size=(2, 2), strides=(2, 2), dtype=jp.float64)(x) x = x.reshape(x.shape[0], -1) x = jp.prod(x, 1) return x
def __call__(self, x): """Forward function. Args: x: Input 4D tensor of shape `(N, H, W, in_chans)`. Returns: jnp.array: Output tensor of shape `(N, H*2, W*2, out_chans)`. """ x = nn.ConvTranspose( self.out_chans, kernel_size=(2, 2), strides=(2, 2), use_bias=False)( x) x = _simple_instance_norm2d(x, (1, 2)) x = jax.nn.leaky_relu(x, negative_slope=0.2) return x
def __call__(self, x, train): del train encoder_keys = [ 'filter_sizes', 'kernel_sizes', 'kernel_paddings', 'window_sizes', 'window_paddings', 'strides', 'activations', ] if len(set(len(self.encoder[k]) for k in encoder_keys)) > 1: raise ValueError( 'The elements in encoder dict do not have the same length.') decoder_keys = [ 'filter_sizes', 'kernel_sizes', 'window_sizes', 'paddings', 'activations', ] if len(set(len(self.decoder[k]) for k in decoder_keys)) > 1: raise ValueError( 'The elements in decoder dict do not have the same length.') # encoder for i in range(len(self.encoder['filter_sizes'])): x = nn.Conv(self.encoder['filter_sizes'][i], self.encoder['kernel_sizes'][i], padding=self.encoder['kernel_paddings'][i])(x) x = model_utils.ACTIVATIONS[self.encoder['activations'][i]](x) x = nn.max_pool(x, self.encoder['window_sizes'][i], strides=self.encoder['strides'][i], padding=self.encoder['window_paddings'][i]) # decoder for i in range(len(self.decoder['filter_sizes'])): x = nn.ConvTranspose(self.decoder['filter_sizes'][i], self.decoder['kernel_sizes'][i], self.decoder['window_sizes'][i], padding=self.decoder['paddings'][i])(x) x = model_utils.ACTIVATIONS[self.decoder['activations'][i]](x) return x
def test_single_input_conv_transpose(self): rng = dict(params=random.PRNGKey(0)) x = jnp.ones((8, 3)) conv_transpose_module = nn.ConvTranspose( features=4, kernel_size=(3, ), padding='VALID', kernel_init=initializers.ones, bias_init=initializers.ones, ) y, initial_params = conv_transpose_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (3, 3, 4)) correct_ans = np.array([[4., 4., 4., 4.], [7., 7., 7., 7.], [10., 10., 10., 10.], [10., 10., 10., 10.], [10., 10., 10., 10.], [10., 10., 10., 10.], [10., 10., 10., 10.], [10., 10., 10., 10.], [7., 7., 7., 7.], [4., 4., 4., 4.]]) np.testing.assert_allclose(y, correct_ans)
def setup(self): # if bilinear, use the normal convolutions to reduce the number of channels if self.bilinear: self.up = lambda x: jax.image.resize( x, [x.shape[0], x.shape[1] * 2, x.shape[2] * 2, x.shape[3]], method='bilinear') self.conv = DoubleConv(self.in_channels, self.out_channels, self.in_channels // 2, self.test, self.group_norm, self.num_groups, self.activation) else: self.up = nn.ConvTranspose(self.in_channels // 2, kernel_size=(2, 2), stride=(2, 2)) self.conv = DoubleConv(self.in_channels, self.out_channels, self.out_channels, self.test, self.group_norm, self.num_groups, self.activation)