def generator(gan): with tf.variable_scope('generator'): with tf.variable_scope('deconv0'): f = tf.Variable(tf.truncated_normal([3, 3, gan.num_hidden, gan.num_latent], mean=0.0, stddev=0.02, dtype=tf.float32), name='filter') b = tf.Variable(tf.zeros([gan.num_hidden], dtype=tf.float32), name='b') h0 = tf.nn.bias_add(tf.nn.conv2d_transpose(gan.z, f, [gan.channel_size, 4, 4, gan.num_hidden], strides=[1, 4, 4, 1]), b) h0 = batch_norm(h0, gan.num_hidden) h0 = tf.nn.relu(h0) with tf.variable_scope('deconv1'): f = tf.Variable(tf.truncated_normal([5, 5, gan.num_hidden / 2, gan.num_hidden], mean=0.0, stddev=0.02, dtype=tf.float32), name='filter') b = tf.Variable(tf.zeros([gan.num_hidden / 2], dtype=tf.float32), name='b') h1 = tf.nn.bias_add(tf.nn.conv2d_transpose(h0, f, [gan.channel_size, 8, 8, gan.num_hidden / 2], strides=[1, 2, 2, 1]), b) h1 = batch_norm(h1, gan.num_hidden / 2) h1 = tf.nn.relu(h1) with tf.variable_scope('deconv2'): f = tf.Variable(tf.truncated_normal([5, 5, gan.num_hidden / 4, gan.num_hidden / 2], mean=0.0, stddev=0.02, dtype=tf.float32), name='filter') b = tf.Variable(tf.zeros([gan.num_hidden / 4], dtype=tf.float32), name='b') h2 = tf.nn.bias_add(tf.nn.conv2d_transpose(h1, f, [gan.channel_size, 16, 16, gan.num_hidden / 4], strides=[1, 2, 2, 1]), b) h2 = batch_norm(h2, gan.num_hidden / 4) h2 = tf.nn.relu(h2) with tf.variable_scope('gen_images'): f = tf.Variable(tf.truncated_normal([5, 5, gan.num_channels, gan.num_hidden / 4], mean=0.0, stddev=0.02, dtype=tf.float32), name='filter') b = tf.Variable(tf.zeros([gan.num_channels], dtype=tf.float32), name='b') gen_image = tf.nn.tanh( tf.nn.bias_add(tf.nn.conv2d_transpose(h2, f, [gan.channel_size, gan.side, gan.side, gan.num_channels], strides=[1, 2, 2, 1]), b)) return gen_image
def discriminator(gan, image, reuse=False, name='Discriminator'): """ Args: gan : instance of a generative adversarial network reuse : Whether you want to reuse variables from previous share_params : Whether weights are tied in initial layers gan.batch_size: The size of batch. Should be specified before training. [64] gan.output_size: The resolution in pixels of the images. [64] gan.df_dim: Dimension of gen filters in first conv layer. [64] gan.dfc_dim: Dimension of gen units for for fully connected layer. [1024] gan.c_dim: Dimension of image color. For grayscale input, set to 1. [3] """ # layers that don't share variable d_bn1 = batch_norm(name='d_bn1') d_bn2 = batch_norm(name='d_bn2') with tf.variable_scope(name): if reuse: tf.get_variable_scope().reuse_variables() h0 = prelu(conv2d(image, gan.c_dim, name='d_h0_conv', reuse=False), name='d_h0_prelu', reuse=False) h1 = prelu(d_bn1(conv2d(h0, gan.df_dim, name='d_h1_conv', reuse=False), reuse=reuse), name='d_h1_prelu', reuse=False) h1 = tf.reshape(h1, [self.batch_size, -1]) # layers that share variables h2 = prelu(d_bn2(linear(h1, gan.dfc_dim, 'd_h2_lin', reuse=False), reuse=False), name='d_h2_prelu', reuse=False) h3 = linear(h2, 1, 'd_h3_lin', reuse=False) return tf.nn.sigmoid(h3), h3
def forward(self, x, noise=None, name=''): with tf.variable_scope("Encoder", reuse=self.constructed): if len(self.config['data_properties']['flat']) > 0: n_output_size = helper.list_sum([ distributions.DistributionsAKA[e['dist']].num_params( e['size'][-1]) for e in self.config['data_properties']['flat'] ]) x_batched_inp_flat = tf.reshape( x['flat'], [-1, *x['flat'].get_shape().as_list()[2:]]) lay1_flat = tf.layers.dense( inputs=x_batched_inp_flat, units=self.config['n_flat'], activation=self.activation_function) lay2_flat = tf.layers.dense( inputs=lay1_flat, units=self.config['n_flat'], activation=self.activation_function) latent_flat = tf.layers.dense(inputs=lay2_flat, units=self.n_output, activation=None) z_flat = tf.reshape( latent_flat, [-1, x['flat'].get_shape().as_list()[1], self.n_output]) z = z_flat if len(self.config['data_properties']['image']) > 0: image_shape = ( self.config['data_properties']['image'][0]['size'][-3:-1]) n_image_size = np.prod(image_shape) n_output_channels = helper.list_sum([ distributions.DistributionsAKA[e['dist']].num_params( e['size'][-1]) for e in self.config['data_properties']['image'] ]) x_batched_inp_image = tf.reshape( x['image'], [-1, *x['image'].get_shape().as_list()[2:]]) if self.config[ 'encoder_mode'] == 'Deterministic' or self.config[ 'encoder_mode'] == 'Gaussian' or self.config[ 'encoder_mode'] == 'UnivApproxNoSpatial': image_input = x_batched_inp_image if self.config['encoder_mode'] == 'UnivApprox' or self.config[ 'encoder_mode'] == 'UnivApproxSine': noise_spatial = tf.tile( noise[:, np.newaxis, np.newaxis, :], [ 1, *x_batched_inp_image.get_shape().as_list()[1:3], 1 ]) x_and_noise_image = tf.concat( [x_batched_inp_image, noise_spatial], axis=-1) image_input = x_and_noise_image # # 28x28xn_channels if image_shape == (28, 28): lay1_image = tf.layers.conv2d( inputs=image_input, filters=self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], padding="valid", use_bias=True, activation=None) if self.normalization_mode == 'Layer Norm': lay2_image = self.activation_function( helper.conv_layer_norm_layer(lay1_image)) elif self.normalization_mode == 'Batch Norm': lay2_image = self.activation_function( helper.batch_norm()(lay1_image)) else: lay2_image = self.activation_function(lay1_image) lay3_image = tf.layers.conv2d(inputs=lay2_image, filters=1 * self.config['n_filter'], kernel_size=[5, 5], strides=[1, 1], padding="valid", use_bias=True, activation=None) if self.normalization_mode == 'Layer Norm': lay4_image = self.activation_function( helper.conv_layer_norm_layer(lay3_image)) elif self.normalization_mode == 'Batch Norm': lay4_image = self.activation_function( helper.batch_norm()(lay3_image)) else: lay4_image = self.activation_function(lay3_image) lay5_image = tf.layers.conv2d(inputs=lay4_image, filters=2 * self.config['n_filter'], kernel_size=[4, 4], strides=[1, 1], padding="valid", use_bias=True, activation=None) if self.normalization_mode == 'Layer Norm': lay6_image = self.activation_function( helper.conv_layer_norm_layer(lay5_image)) elif self.normalization_mode == 'Batch Norm': lay6_image = self.activation_function( helper.batch_norm()(lay5_image)) else: lay6_image = self.activation_function(lay5_image) latent_image = tf.layers.conv2d( inputs=lay6_image, filters=2 * self.config['n_filter'], kernel_size=[3, 3], strides=[1, 1], padding="valid", use_bias=True, activation=self.activation_function) latent_image_flat = tf.reshape( latent_image, [-1, np.prod(latent_image.get_shape().as_list()[1:])]) # # 32x32xn_channels if image_shape == (32, 32): lay1_image = tf.layers.conv2d( inputs=image_input, filters=self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], padding="valid", use_bias=True, activation=None) if self.normalization_mode == 'Layer Norm': lay2_image = self.activation_function( helper.conv_layer_norm_layer(lay1_image)) elif self.normalization_mode == 'Batch Norm': lay2_image = self.activation_function( helper.batch_norm()(lay1_image)) else: lay2_image = self.activation_function(lay1_image) lay3_image = tf.layers.conv2d(inputs=lay2_image, filters=1 * self.config['n_filter'], kernel_size=[5, 5], strides=[1, 1], padding="valid", use_bias=True, activation=None) if self.normalization_mode == 'Layer Norm': lay4_image = self.activation_function( helper.conv_layer_norm_layer(lay3_image)) elif self.normalization_mode == 'Batch Norm': lay4_image = self.activation_function( helper.batch_norm()(lay3_image)) else: lay4_image = self.activation_function(lay3_image) lay5_image = tf.layers.conv2d(inputs=lay4_image, filters=2 * self.config['n_filter'], kernel_size=[5, 5], strides=[1, 1], padding="valid", use_bias=True, activation=None) if self.normalization_mode == 'Layer Norm': lay6_image = self.activation_function( helper.conv_layer_norm_layer(lay5_image)) elif self.normalization_mode == 'Batch Norm': lay6_image = self.activation_function( helper.batch_norm()(lay5_image)) else: lay6_image = self.activation_function(lay5_image) latent_image = tf.layers.conv2d( inputs=lay6_image, filters=2 * self.config['n_filter'], kernel_size=[3, 3], strides=[1, 1], padding="valid", use_bias=True, activation=self.activation_function) latent_image_flat = tf.reshape( latent_image, [-1, np.prod(latent_image.get_shape().as_list()[1:])]) # 64x64xn_channels if image_shape == (64, 64): lay1_image = tf.layers.conv2d( inputs=image_input, filters=self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], padding="valid", use_bias=True, activation=None) if self.normalization_mode == 'Layer Norm': lay2_image = self.activation_function( helper.conv_layer_norm_layer(lay1_image)) elif self.normalization_mode == 'Batch Norm': lay2_image = self.activation_function( helper.batch_norm()(lay1_image)) else: lay2_image = self.activation_function(lay1_image) lay3_image = tf.layers.conv2d( inputs=lay2_image, filters=self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], padding="valid", use_bias=True, activation=None) if self.normalization_mode == 'Layer Norm': lay4_image = self.activation_function( helper.conv_layer_norm_layer(lay3_image)) elif self.normalization_mode == 'Batch Norm': lay4_image = self.activation_function( helper.batch_norm()(lay3_image)) else: lay4_image = self.activation_function(lay3_image) lay5_image = tf.layers.conv2d(inputs=lay4_image, filters=2 * self.config['n_filter'], kernel_size=[5, 5], strides=[1, 1], padding="valid", use_bias=True, activation=None) if self.normalization_mode == 'Layer Norm': lay6_image = self.activation_function( helper.conv_layer_norm_layer(lay5_image)) elif self.normalization_mode == 'Batch Norm': lay6_image = self.activation_function( helper.batch_norm()(lay5_image)) else: lay6_image = self.activation_function(lay5_image) lay7_image = tf.layers.conv2d(inputs=lay6_image, filters=3 * self.config['n_filter'], kernel_size=[5, 5], strides=[1, 1], padding="valid", use_bias=True, activation=None) if self.normalization_mode == 'Layer Norm': lay8_image = self.activation_function( helper.conv_layer_norm_layer(lay7_image)) elif self.normalization_mode == 'Batch Norm': lay8_image = self.activation_function( helper.batch_norm()(lay7_image)) else: lay8_image = self.activation_function(lay7_image) latent_image = tf.layers.conv2d( inputs=lay8_image, filters=3 * self.config['n_filter'], kernel_size=[3, 3], strides=[1, 1], padding="valid", use_bias=True, activation=self.activation_function) latent_image_flat = tf.reshape( latent_image, [-1, np.prod(latent_image.get_shape().as_list()[1:])]) if self.config['encoder_mode'] == 'Deterministic': lay1_flat = tf.layers.dense( inputs=latent_image_flat, units=self.config['n_flat'], use_bias=True, activation=self.activation_function) latent_flat = tf.layers.dense(inputs=lay1_flat, units=self.n_output, use_bias=True, activation=None) if self.config['encoder_mode'] == 'Gaussian': lay1_flat = tf.layers.dense( inputs=latent_image_flat, units=self.config['n_flat'], use_bias=True, activation=self.activation_function) latent_mu = tf.layers.dense(inputs=lay1_flat, units=self.n_output, use_bias=True, activation=None) latent_log_sig = tf.layers.dense(inputs=lay1_flat, units=self.n_output, use_bias=True, activation=None) latent_flat = latent_mu + tf.nn.softplus( latent_log_sig) * noise if self.config['encoder_mode'] == 'UnivApprox' or self.config[ 'encoder_mode'] == 'UnivApproxNoSpatial': lay1_concat = tf.layers.dense( inputs=tf.concat([latent_image_flat, noise], axis=-1), units=self.config['n_flat'], use_bias=True, activation=self.activation_function) latent_flat = tf.layers.dense(inputs=lay1_concat, units=self.n_output, use_bias=True, activation=None) if self.config['encoder_mode'] == 'UnivApproxSine': lay1_concat = tf.layers.dense( inputs=tf.concat([latent_image_flat, noise], axis=-1), units=self.config['n_flat'], use_bias=True, activation=self.activation_function) latent_correction = tf.layers.dense(inputs=lay1_concat, units=self.n_output, use_bias=True, activation=None) latent_output = tf.layers.dense(inputs=lay1_concat, units=self.n_output, use_bias=True, activation=None) latent_flat = latent_output + tf.sin( self.config['enc_sine_freq'] * noise) - latent_correction z_flat = tf.reshape( latent_flat, [-1, x['image'].get_shape().as_list()[1], self.n_output]) self.constructed = True return z_flat
def forward(self, x, name = ''): with tf.variable_scope("Critic", reuse=self.constructed): outputs = [] if x['flat'] is not None: x_batched_inp_flat = tf.reshape(x['flat'], [-1, *x['flat'].get_shape().as_list()[2:]]) lay1_flat = tf.layers.dense(inputs = x_batched_inp_flat, units = self.config['n_flat'], activation = self.activation_function) lay2_flat = tf.layers.dense(inputs = lay1_flat, units = self.config['n_flat'], activation = self.activation_function) lay3_flat = tf.layers.dense(inputs = lay2_flat, units = self.config['n_flat'], activation = self.activation_function) lay4_flat = tf.layers.dense(inputs = lay3_flat, units = 1, activation = None) outputs.append(tf.reshape(lay4_flat, [-1, 1, 1])) if x['image'] is not None: image_shape = (self.config['data_properties']['image'][0]['size'][-3:-1]) x_batched_inp_image = tf.reshape(x['image'], [-1, *x['image'].get_shape().as_list()[2:]]) # # 28x28xn_channels if image_shape == (28, 28): lay1_image = tf.layers.conv2d(inputs=x_batched_inp_image, filters=self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], padding="same", use_bias=True, activation=self.activation_function) lay2_image = tf.layers.conv2d(inputs=lay1_image, filters=2*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], padding="same", use_bias=True, activation=None) if self.normalization_mode == 'Layer Norm': lay3_image = self.activation_function(helper.conv_layer_norm_layer(lay2_image)) elif self.normalization_mode == 'Batch Norm': lay3_image = self.activation_function(helper.batch_norm()(lay2_image)) else: lay3_image = self.activation_function(lay2_image) lay4_image = tf.layers.conv2d(inputs=lay3_image, filters=4*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], padding="same", use_bias=True, activation=None) if self.normalization_mode == 'Layer Norm': lay5_image = self.activation_function(helper.conv_layer_norm_layer(lay4_image)) elif self.normalization_mode == 'Batch Norm': lay5_image = self.activation_function(helper.batch_norm()(lay4_image)) else: lay5_image = self.activation_function(lay4_image) critic_image = tf.layers.conv2d(inputs=lay5_image, filters=1, kernel_size=[4, 4], strides=[1, 1], padding="valid", use_bias=True, activation=None) # # 32x32xn_channels if image_shape == (32, 32): lay1_image = tf.layers.conv2d(inputs=x_batched_inp_image, filters=self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], padding="same", use_bias=True, activation=self.activation_function) lay2_image = tf.layers.conv2d(inputs=lay1_image, filters=2*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], padding="same", use_bias=True, activation=None) if self.normalization_mode == 'Layer Norm': lay3_image = self.activation_function(helper.conv_layer_norm_layer(lay2_image)) elif self.normalization_mode == 'Batch Norm': lay3_image = self.activation_function(helper.batch_norm()(lay2_image)) else: lay3_image = self.activation_function(lay2_image) lay4_image = tf.layers.conv2d(inputs=lay3_image, filters=4*self.config['n_filter'], kernel_size=[5, 5], strides=[1, 1], padding="valid", use_bias=True, activation=None) if self.normalization_mode == 'Layer Norm': lay5_image = self.activation_function(helper.conv_layer_norm_layer(lay4_image)) elif self.normalization_mode == 'Batch Norm': lay5_image = self.activation_function(helper.batch_norm()(lay4_image)) else: lay5_image = self.activation_function(lay4_image) critic_image = tf.layers.conv2d(inputs=lay5_image, filters=1, kernel_size=[4, 4], strides=[1, 1], padding="valid", use_bias=True, activation=None) # 64x64xn_channels if image_shape == (64, 64): lay1_image = tf.layers.conv2d(inputs=x_batched_inp_image, filters=self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], padding="same", use_bias=True, activation=self.activation_function) lay2_image = tf.layers.conv2d(inputs=lay1_image, filters=2*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], padding="same", use_bias=True, activation=None) if self.normalization_mode == 'Layer Norm': lay3_image = self.activation_function(helper.conv_layer_norm_layer(lay2_image)) elif self.normalization_mode == 'Batch Norm': lay3_image = self.activation_function(helper.batch_norm()(lay2_image)) else: lay3_image = self.activation_function(lay2_image) lay4_image = tf.layers.conv2d(inputs=lay3_image, filters=4*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], padding="same", use_bias=True, activation=None) if self.normalization_mode == 'Layer Norm': lay5_image = self.activation_function(helper.conv_layer_norm_layer(lay4_image)) elif self.normalization_mode == 'Batch Norm': lay5_image = self.activation_function(helper.batch_norm()(lay4_image)) else: lay5_image = self.activation_function(lay4_image) lay6_image = tf.layers.conv2d(inputs=lay5_image, filters=8*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], padding="same", use_bias=True, activation=None) if self.normalization_mode == 'Layer Norm': lay7_image = self.activation_function(helper.conv_layer_norm_layer(lay6_image)) elif self.normalization_mode == 'Batch Norm': lay7_image = self.activation_function(helper.batch_norm()(lay6_image)) else: lay7_image = self.activation_function(lay6_image) critic_image = tf.layers.conv2d(inputs=lay7_image, filters=1, kernel_size=[4, 4], strides=[1, 1], padding="valid", use_bias=True, activation=None) critic = tf.reshape(critic_image, [-1, x['image'].get_shape().as_list()[1], 1]) outputs.append(critic) if len(outputs) > 1: pdb.set_trace() merged_input = tf.concat(outputs, axis=-1) x_batched_inp_image = tf.reshape(x['image'], [-1, *x['image'].get_shape().as_list()[2:]]) input_merged = tf.reshape(merged_input, [-1, merged_input.get_shape().as_list()[-1]]) lay1_merged = tf.layers.dense(inputs = input_merged, units = 1, activation = None) enc = tf.reshape(lay1_merged, [-1, x['flat'].get_shape().as_list()[1], 1]) else: enc = outputs[0] self.constructed = True return enc
def forward(self, x, name = ''): with tf.variable_scope("Generator", reuse=self.constructed): out_dict = {'flat': None, 'image': None} if len(self.config['data_properties']['flat']) > 0: n_output_size = helper.list_sum([distributions.DistributionsAKA[e['dist']].num_params(e['size'][-1]) for e in self.config['data_properties']['flat']]) x_batched_inp_flat = tf.reshape(x, [-1, x.get_shape().as_list()[-1]]) lay1_flat = tf.layers.dense(inputs = x_batched_inp_flat, units = self.config['n_flat'], activation = self.activation_function) lay2_flat = tf.layers.dense(inputs = lay1_flat, units = self.config['n_flat'], activation = self.activation_function) lay3_flat = tf.layers.dense(inputs = lay2_flat, units = self.config['n_flat'], activation = self.activation_function) flat_param = tf.layers.dense(inputs = lay3_flat, units = n_output_size, activation = None) out_dict['flat'] = tf.reshape(flat_param, [-1, x.get_shape().as_list()[1], n_output_size]) if len(self.config['data_properties']['image']) > 0: image_shape = (self.config['data_properties']['image'][0]['size'][-3:-1]) n_image_size = np.prod(image_shape) n_output_channels = helper.list_sum([distributions.DistributionsAKA[e['dist']].num_params(e['size'][-1]) for e in self.config['data_properties']['image']]) x_batched_inp_flat = tf.reshape(x, [-1, x.get_shape().as_list()[-1]]) # # 28x28xn_channels if image_shape == (28, 28): print('CHECK ME I AM FULLY CONNECTED~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~') lay1_image = tf.layers.dense(inputs = x_batched_inp_flat, units = self.config['n_flat'], activation = self.activation_function) lay2_image = tf.layers.dense(inputs = lay1_image, units = self.config['n_flat'], activation = self.activation_function) lay3_image = tf.layers.dense(inputs = lay2_image, units = n_image_size*n_output_channels, activation = self.activation_function) lay4_image = tf.nn.sigmoid(lay3_image) image_param = tf.reshape(lay4_image, [-1, *image_shape, n_output_channels]) # lay1_image = tf.layers.dense(inputs = x_batched_inp_flat, units = 8*self.config['n_filter']*4*4, activation = None) # lay2_image = tf.reshape(lay1_image, [-1, 4, 4, 8*self.config['n_filter']]) # if self.normalization_mode == 'Layer Norm': # lay3_image = self.activation_function(helper.conv_layer_norm_layer(lay2_image)) # elif self.normalization_mode == 'Batch Norm': # lay3_image = self.activation_function(helper.batch_norm()(lay2_image)) # else: lay3_image = self.activation_function(lay2_image) #h0 # lay4_image = tf.layers.conv2d_transpose(inputs=lay3_image, filters=4*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], activation=None) # lay5_image = helper.tf_center_crop_image(lay4_image, resize_ratios=[8,8]) # if self.normalization_mode == 'Layer Norm': # lay6_image = self.activation_function(helper.conv_layer_norm_layer(lay5_image)) # elif self.normalization_mode == 'Batch Norm': # lay6_image = self.activation_function(helper.batch_norm()(lay5_image)) # else: lay6_image = self.activation_function(lay5_image) #h1 # lay7_image = tf.layers.conv2d_transpose(inputs=lay6_image, filters=2*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], activation=None) # lay8_image = helper.tf_center_crop_image(lay7_image, resize_ratios=[16,16]) # if self.normalization_mode == 'Layer Norm': # lay9_image = self.activation_function(helper.conv_layer_norm_layer(lay8_image)) # elif self.normalization_mode == 'Batch Norm': # lay9_image = self.activation_function(helper.batch_norm()(lay8_image)) # else: lay9_image = self.activation_function(lay8_image) #h2 # lay10_image = tf.layers.conv2d_transpose(inputs=lay9_image, filters=1*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], activation=None) # lay11_image = helper.tf_center_crop_image(lay10_image, resize_ratios=[28,28]) # if self.normalization_mode == 'Layer Norm': # lay12_image = self.activation_function(helper.conv_layer_norm_layer(lay11_image)) # elif self.normalization_mode == 'Batch Norm': # lay12_image = self.activation_function(helper.batch_norm()(lay11_image)) # else: lay12_image = self.activation_function(lay11_image) #h3 # lay13_image = tf.layers.conv2d_transpose(inputs=lay12_image, filters=n_output_channels, kernel_size=[3, 3], strides=[1, 1], activation=tf.nn.sigmoid) # image_param = helper.tf_center_crop_image(lay13_image, resize_ratios=[28,28]) # # 32x32xn_channels if image_shape == (32, 32): lay1_image = tf.layers.dense(inputs = x_batched_inp_flat, units = 8*self.config['n_filter']*4*4, activation = None) lay2_image = tf.reshape(lay1_image, [-1, 4, 4, 8*self.config['n_filter']]) if self.normalization_mode == 'Layer Norm': lay3_image = self.activation_function(helper.conv_layer_norm_layer(lay2_image)) elif self.normalization_mode == 'Batch Norm': lay3_image = self.activation_function(helper.batch_norm()(lay2_image)) else: lay3_image = self.activation_function(lay2_image) #h0 lay4_image = tf.layers.conv2d_transpose(inputs=lay3_image, filters=4*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], activation=None) lay5_image = helper.tf_center_crop_image(lay4_image, resize_ratios=[8,8]) if self.normalization_mode == 'Layer Norm': lay6_image = self.activation_function(helper.conv_layer_norm_layer(lay5_image)) elif self.normalization_mode == 'Batch Norm': lay6_image = self.activation_function(helper.batch_norm()(lay5_image)) else: lay6_image = self.activation_function(lay5_image) #h1 lay7_image = tf.layers.conv2d_transpose(inputs=lay6_image, filters=2*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], activation=None) lay8_image = helper.tf_center_crop_image(lay7_image, resize_ratios=[16,16]) if self.normalization_mode == 'Layer Norm': lay9_image = self.activation_function(helper.conv_layer_norm_layer(lay8_image)) elif self.normalization_mode == 'Batch Norm': lay9_image = self.activation_function(helper.batch_norm()(lay8_image)) else: lay9_image = self.activation_function(lay8_image) #h2 lay10_image = tf.layers.conv2d_transpose(inputs=lay9_image, filters=1*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], activation=None) lay11_image = helper.tf_center_crop_image(lay10_image, resize_ratios=[32,32]) if self.normalization_mode == 'Layer Norm': lay12_image = self.activation_function(helper.conv_layer_norm_layer(lay11_image)) elif self.normalization_mode == 'Batch Norm': lay12_image = self.activation_function(helper.batch_norm()(lay11_image)) else: lay12_image = self.activation_function(lay11_image) #h3 lay13_image = tf.layers.conv2d_transpose(inputs=lay12_image, filters=n_output_channels, kernel_size=[3, 3], strides=[1, 1], activation=tf.nn.sigmoid) image_param = helper.tf_center_crop_image(lay13_image, resize_ratios=[32,32]) # 64x64xn_channels if image_shape == (64, 64): lay1_image = tf.layers.dense(inputs = x_batched_inp_flat, units = 8*self.config['n_filter']*4*4, activation = None) lay2_image = tf.reshape(lay1_image, [-1, 4, 4, 8*self.config['n_filter']]) if self.normalization_mode == 'Layer Norm': lay3_image = self.activation_function(helper.conv_layer_norm_layer(lay2_image)) elif self.normalization_mode == 'Batch Norm': lay3_image = self.activation_function(helper.batch_norm()(lay2_image)) else: lay3_image = self.activation_function(lay2_image) #h0 lay4_image = tf.layers.conv2d_transpose(inputs=lay3_image, filters=4*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], activation=None) lay5_image = helper.tf_center_crop_image(lay4_image, resize_ratios=[8,8]) if self.normalization_mode == 'Layer Norm': lay6_image = self.activation_function(helper.conv_layer_norm_layer(lay5_image)) elif self.normalization_mode == 'Batch Norm': lay6_image = self.activation_function(helper.batch_norm()(lay5_image)) else: lay6_image = self.activation_function(lay5_image) #h1 lay7_image = tf.layers.conv2d_transpose(inputs=lay6_image, filters=2*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], activation=None) lay8_image = helper.tf_center_crop_image(lay7_image, resize_ratios=[16,16]) if self.normalization_mode == 'Layer Norm': lay9_image = self.activation_function(helper.conv_layer_norm_layer(lay8_image)) elif self.normalization_mode == 'Batch Norm': lay9_image = self.activation_function(helper.batch_norm()(lay8_image)) else: lay9_image = self.activation_function(lay8_image) #h2 lay10_image = tf.layers.conv2d_transpose(inputs=lay9_image, filters=1*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], activation=None) lay11_image = helper.tf_center_crop_image(lay10_image, resize_ratios=[32,32]) if self.normalization_mode == 'Layer Norm': lay12_image = self.activation_function(helper.conv_layer_norm_layer(lay11_image)) elif self.normalization_mode == 'Batch Norm': lay12_image = self.activation_function(helper.batch_norm()(lay11_image)) else: lay12_image = self.activation_function(lay11_image) #h3 lay13_image = tf.layers.conv2d_transpose(inputs=lay12_image, filters=n_output_channels, kernel_size=[5, 5], strides=[2, 2], activation=tf.nn.sigmoid) image_param = helper.tf_center_crop_image(lay13_image, resize_ratios=[64,64]) # 128x128xn_channels if image_shape == (128, 128): lay1_image = tf.layers.dense(inputs = x_batched_inp_flat, units = 8*self.config['n_filter']*4*4, activation = None) lay2_image = tf.reshape(lay1_image, [-1, 4, 4, 8*self.config['n_filter']]) if self.normalization_mode == 'Layer Norm': lay3_image = self.activation_function(helper.conv_layer_norm_layer(lay2_image)) elif self.normalization_mode == 'Batch Norm': lay3_image = self.activation_function(helper.batch_norm()(lay2_image)) else: lay3_image = self.activation_function(lay2_image) #h0 lay4_image = tf.layers.conv2d_transpose(inputs=lay3_image, filters=4*self.config['n_filter'], kernel_size=[5, 5], strides=[1, 1], activation=None) if self.normalization_mode == 'Layer Norm': lay5_image = self.activation_function(helper.conv_layer_norm_layer(lay4_image)) elif self.normalization_mode == 'Batch Norm': lay5_image = self.activation_function(helper.batch_norm()(lay4_image)) else: lay5_image = self.activation_function(lay4_image) #h1 lay6_image = tf.layers.conv2d_transpose(inputs=lay5_image, filters=2*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], activation=None) lay7_image = helper.tf_center_crop_image(lay6_image, resize_ratios=[16,16]) if self.normalization_mode == 'Layer Norm': lay8_image = self.activation_function(helper.conv_layer_norm_layer(lay7_image)) elif self.normalization_mode == 'Batch Norm': lay8_image = self.activation_function(helper.batch_norm()(lay7_image)) else: lay8_image = self.activation_function(lay7_image) #h2 lay9_image = tf.layers.conv2d_transpose(inputs=lay8_image, filters=2*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], activation=None) lay10_image = helper.tf_center_crop_image(lay9_image, resize_ratios=[32,32]) if self.normalization_mode == 'Layer Norm': lay11_image = self.activation_function(helper.conv_layer_norm_layer(lay10_image)) elif self.normalization_mode == 'Batch Norm': lay11_image = self.activation_function(helper.batch_norm()(lay10_image)) else: lay11_image = self.activation_function(lay10_image) #h3 lay12_image = tf.layers.conv2d_transpose(inputs=lay11_image, filters=1*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], activation=None) lay13_image = helper.tf_center_crop_image(lay12_image, resize_ratios=[64,64]) if self.normalization_mode == 'Layer Norm': lay14_image = self.activation_function(helper.conv_layer_norm_layer(lay13_image)) elif self.normalization_mode == 'Batch Norm': lay14_image = self.activation_function(helper.batch_norm()(lay13_image)) else: lay14_image = self.activation_function(lay13_image) #h3 lay15_image = tf.layers.conv2d_transpose(inputs=lay14_image, filters=n_output_channels, kernel_size=[5, 5], strides=[2, 2], activation=tf.nn.sigmoid) image_param = helper.tf_center_crop_image(lay15_image, resize_ratios=[128,128]) out_dict['image'] = tf.reshape(image_param, [-1, x.get_shape().as_list()[1], *image_shape, n_output_channels]) self.constructed = True return out_dict