Beispiel #1
0
def generator(gan):
    with tf.variable_scope('generator'):

        with tf.variable_scope('deconv0'):
            f = tf.Variable(tf.truncated_normal([3, 3, gan.num_hidden, gan.num_latent], mean=0.0,
                                                stddev=0.02, dtype=tf.float32),
                            name='filter')
            b = tf.Variable(tf.zeros([gan.num_hidden], dtype=tf.float32), name='b')
            h0 = tf.nn.bias_add(tf.nn.conv2d_transpose(gan.z, f,
                                                       [gan.channel_size, 4, 4, gan.num_hidden],
                                                       strides=[1, 4, 4, 1]), b)
            h0 = batch_norm(h0, gan.num_hidden)
            h0 = tf.nn.relu(h0)

        with tf.variable_scope('deconv1'):
            f = tf.Variable(tf.truncated_normal([5, 5, gan.num_hidden / 2, gan.num_hidden], mean=0.0,
                                                stddev=0.02, dtype=tf.float32),
                            name='filter')
            b = tf.Variable(tf.zeros([gan.num_hidden / 2], dtype=tf.float32), name='b')
            h1 = tf.nn.bias_add(tf.nn.conv2d_transpose(h0, f,
                                                       [gan.channel_size, 8, 8, gan.num_hidden / 2],
                                                       strides=[1, 2, 2, 1]), b)
            h1 = batch_norm(h1, gan.num_hidden / 2)
            h1 = tf.nn.relu(h1)

        with tf.variable_scope('deconv2'):
            f = tf.Variable(tf.truncated_normal([5, 5, gan.num_hidden / 4, gan.num_hidden / 2], mean=0.0,
                                                stddev=0.02, dtype=tf.float32),
                            name='filter')
            b = tf.Variable(tf.zeros([gan.num_hidden / 4], dtype=tf.float32), name='b')
            h2 = tf.nn.bias_add(tf.nn.conv2d_transpose(h1, f,
                                                       [gan.channel_size, 16, 16, gan.num_hidden / 4],
                                                       strides=[1, 2, 2, 1]), b)
            h2 = batch_norm(h2, gan.num_hidden / 4)
            h2 = tf.nn.relu(h2)

        with tf.variable_scope('gen_images'):
            f = tf.Variable(tf.truncated_normal([5, 5, gan.num_channels, gan.num_hidden / 4], mean=0.0,
                                                stddev=0.02, dtype=tf.float32),
                            name='filter')
            b = tf.Variable(tf.zeros([gan.num_channels], dtype=tf.float32), name='b')
            gen_image = tf.nn.tanh(
                tf.nn.bias_add(tf.nn.conv2d_transpose(h2, f,
                                                      [gan.channel_size, gan.side, gan.side,
                                                       gan.num_channels],
                                                      strides=[1, 2, 2, 1]), b))
    return gen_image
def discriminator(gan, image, reuse=False, name='Discriminator'):
    """
    Args:
        gan : instance of a generative adversarial network 
        reuse : Whether you want to reuse variables from previous 
        share_params : Whether weights are tied in initial layers

        gan.batch_size: The size of batch. Should be specified before training. [64]
        gan.output_size:  The resolution in pixels of the images. [64]
        gan.df_dim:  Dimension of gen filters in first conv layer. [64]
        gan.dfc_dim:  Dimension of gen units for for fully connected layer. [1024]
        gan.c_dim:  Dimension of image color. For grayscale input, set to 1. [3]        
        



    """

    # layers that don't share variable
    d_bn1 = batch_norm(name='d_bn1')
    d_bn2 = batch_norm(name='d_bn2')
    with tf.variable_scope(name):
        if reuse:
            tf.get_variable_scope().reuse_variables()

        h0 = prelu(conv2d(image, gan.c_dim, name='d_h0_conv', reuse=False),
                   name='d_h0_prelu',
                   reuse=False)
        h1 = prelu(d_bn1(conv2d(h0, gan.df_dim, name='d_h1_conv', reuse=False),
                         reuse=reuse),
                   name='d_h1_prelu',
                   reuse=False)
        h1 = tf.reshape(h1, [self.batch_size, -1])

        # layers that share variables
        h2 = prelu(d_bn2(linear(h1, gan.dfc_dim, 'd_h2_lin', reuse=False),
                         reuse=False),
                   name='d_h2_prelu',
                   reuse=False)
        h3 = linear(h2, 1, 'd_h3_lin', reuse=False)

        return tf.nn.sigmoid(h3), h3
Beispiel #3
0
    def forward(self, x, noise=None, name=''):
        with tf.variable_scope("Encoder", reuse=self.constructed):
            if len(self.config['data_properties']['flat']) > 0:
                n_output_size = helper.list_sum([
                    distributions.DistributionsAKA[e['dist']].num_params(
                        e['size'][-1])
                    for e in self.config['data_properties']['flat']
                ])
                x_batched_inp_flat = tf.reshape(
                    x['flat'], [-1, *x['flat'].get_shape().as_list()[2:]])

                lay1_flat = tf.layers.dense(
                    inputs=x_batched_inp_flat,
                    units=self.config['n_flat'],
                    activation=self.activation_function)
                lay2_flat = tf.layers.dense(
                    inputs=lay1_flat,
                    units=self.config['n_flat'],
                    activation=self.activation_function)
                latent_flat = tf.layers.dense(inputs=lay2_flat,
                                              units=self.n_output,
                                              activation=None)
                z_flat = tf.reshape(
                    latent_flat,
                    [-1, x['flat'].get_shape().as_list()[1], self.n_output])
                z = z_flat

            if len(self.config['data_properties']['image']) > 0:
                image_shape = (
                    self.config['data_properties']['image'][0]['size'][-3:-1])
                n_image_size = np.prod(image_shape)
                n_output_channels = helper.list_sum([
                    distributions.DistributionsAKA[e['dist']].num_params(
                        e['size'][-1])
                    for e in self.config['data_properties']['image']
                ])
                x_batched_inp_image = tf.reshape(
                    x['image'], [-1, *x['image'].get_shape().as_list()[2:]])

                if self.config[
                        'encoder_mode'] == 'Deterministic' or self.config[
                            'encoder_mode'] == 'Gaussian' or self.config[
                                'encoder_mode'] == 'UnivApproxNoSpatial':
                    image_input = x_batched_inp_image
                if self.config['encoder_mode'] == 'UnivApprox' or self.config[
                        'encoder_mode'] == 'UnivApproxSine':
                    noise_spatial = tf.tile(
                        noise[:, np.newaxis, np.newaxis, :], [
                            1, *x_batched_inp_image.get_shape().as_list()[1:3],
                            1
                        ])
                    x_and_noise_image = tf.concat(
                        [x_batched_inp_image, noise_spatial], axis=-1)
                    image_input = x_and_noise_image

                # # 28x28xn_channels
                if image_shape == (28, 28):

                    lay1_image = tf.layers.conv2d(
                        inputs=image_input,
                        filters=self.config['n_filter'],
                        kernel_size=[5, 5],
                        strides=[2, 2],
                        padding="valid",
                        use_bias=True,
                        activation=None)
                    if self.normalization_mode == 'Layer Norm':
                        lay2_image = self.activation_function(
                            helper.conv_layer_norm_layer(lay1_image))
                    elif self.normalization_mode == 'Batch Norm':
                        lay2_image = self.activation_function(
                            helper.batch_norm()(lay1_image))
                    else:
                        lay2_image = self.activation_function(lay1_image)

                    lay3_image = tf.layers.conv2d(inputs=lay2_image,
                                                  filters=1 *
                                                  self.config['n_filter'],
                                                  kernel_size=[5, 5],
                                                  strides=[1, 1],
                                                  padding="valid",
                                                  use_bias=True,
                                                  activation=None)
                    if self.normalization_mode == 'Layer Norm':
                        lay4_image = self.activation_function(
                            helper.conv_layer_norm_layer(lay3_image))
                    elif self.normalization_mode == 'Batch Norm':
                        lay4_image = self.activation_function(
                            helper.batch_norm()(lay3_image))
                    else:
                        lay4_image = self.activation_function(lay3_image)

                    lay5_image = tf.layers.conv2d(inputs=lay4_image,
                                                  filters=2 *
                                                  self.config['n_filter'],
                                                  kernel_size=[4, 4],
                                                  strides=[1, 1],
                                                  padding="valid",
                                                  use_bias=True,
                                                  activation=None)
                    if self.normalization_mode == 'Layer Norm':
                        lay6_image = self.activation_function(
                            helper.conv_layer_norm_layer(lay5_image))
                    elif self.normalization_mode == 'Batch Norm':
                        lay6_image = self.activation_function(
                            helper.batch_norm()(lay5_image))
                    else:
                        lay6_image = self.activation_function(lay5_image)

                    latent_image = tf.layers.conv2d(
                        inputs=lay6_image,
                        filters=2 * self.config['n_filter'],
                        kernel_size=[3, 3],
                        strides=[1, 1],
                        padding="valid",
                        use_bias=True,
                        activation=self.activation_function)
                    latent_image_flat = tf.reshape(
                        latent_image,
                        [-1,
                         np.prod(latent_image.get_shape().as_list()[1:])])

                # # 32x32xn_channels
                if image_shape == (32, 32):

                    lay1_image = tf.layers.conv2d(
                        inputs=image_input,
                        filters=self.config['n_filter'],
                        kernel_size=[5, 5],
                        strides=[2, 2],
                        padding="valid",
                        use_bias=True,
                        activation=None)
                    if self.normalization_mode == 'Layer Norm':
                        lay2_image = self.activation_function(
                            helper.conv_layer_norm_layer(lay1_image))
                    elif self.normalization_mode == 'Batch Norm':
                        lay2_image = self.activation_function(
                            helper.batch_norm()(lay1_image))
                    else:
                        lay2_image = self.activation_function(lay1_image)

                    lay3_image = tf.layers.conv2d(inputs=lay2_image,
                                                  filters=1 *
                                                  self.config['n_filter'],
                                                  kernel_size=[5, 5],
                                                  strides=[1, 1],
                                                  padding="valid",
                                                  use_bias=True,
                                                  activation=None)
                    if self.normalization_mode == 'Layer Norm':
                        lay4_image = self.activation_function(
                            helper.conv_layer_norm_layer(lay3_image))
                    elif self.normalization_mode == 'Batch Norm':
                        lay4_image = self.activation_function(
                            helper.batch_norm()(lay3_image))
                    else:
                        lay4_image = self.activation_function(lay3_image)

                    lay5_image = tf.layers.conv2d(inputs=lay4_image,
                                                  filters=2 *
                                                  self.config['n_filter'],
                                                  kernel_size=[5, 5],
                                                  strides=[1, 1],
                                                  padding="valid",
                                                  use_bias=True,
                                                  activation=None)
                    if self.normalization_mode == 'Layer Norm':
                        lay6_image = self.activation_function(
                            helper.conv_layer_norm_layer(lay5_image))
                    elif self.normalization_mode == 'Batch Norm':
                        lay6_image = self.activation_function(
                            helper.batch_norm()(lay5_image))
                    else:
                        lay6_image = self.activation_function(lay5_image)

                    latent_image = tf.layers.conv2d(
                        inputs=lay6_image,
                        filters=2 * self.config['n_filter'],
                        kernel_size=[3, 3],
                        strides=[1, 1],
                        padding="valid",
                        use_bias=True,
                        activation=self.activation_function)
                    latent_image_flat = tf.reshape(
                        latent_image,
                        [-1,
                         np.prod(latent_image.get_shape().as_list()[1:])])

                # 64x64xn_channels
                if image_shape == (64, 64):

                    lay1_image = tf.layers.conv2d(
                        inputs=image_input,
                        filters=self.config['n_filter'],
                        kernel_size=[5, 5],
                        strides=[2, 2],
                        padding="valid",
                        use_bias=True,
                        activation=None)
                    if self.normalization_mode == 'Layer Norm':
                        lay2_image = self.activation_function(
                            helper.conv_layer_norm_layer(lay1_image))
                    elif self.normalization_mode == 'Batch Norm':
                        lay2_image = self.activation_function(
                            helper.batch_norm()(lay1_image))
                    else:
                        lay2_image = self.activation_function(lay1_image)

                    lay3_image = tf.layers.conv2d(
                        inputs=lay2_image,
                        filters=self.config['n_filter'],
                        kernel_size=[5, 5],
                        strides=[2, 2],
                        padding="valid",
                        use_bias=True,
                        activation=None)
                    if self.normalization_mode == 'Layer Norm':
                        lay4_image = self.activation_function(
                            helper.conv_layer_norm_layer(lay3_image))
                    elif self.normalization_mode == 'Batch Norm':
                        lay4_image = self.activation_function(
                            helper.batch_norm()(lay3_image))
                    else:
                        lay4_image = self.activation_function(lay3_image)

                    lay5_image = tf.layers.conv2d(inputs=lay4_image,
                                                  filters=2 *
                                                  self.config['n_filter'],
                                                  kernel_size=[5, 5],
                                                  strides=[1, 1],
                                                  padding="valid",
                                                  use_bias=True,
                                                  activation=None)
                    if self.normalization_mode == 'Layer Norm':
                        lay6_image = self.activation_function(
                            helper.conv_layer_norm_layer(lay5_image))
                    elif self.normalization_mode == 'Batch Norm':
                        lay6_image = self.activation_function(
                            helper.batch_norm()(lay5_image))
                    else:
                        lay6_image = self.activation_function(lay5_image)

                    lay7_image = tf.layers.conv2d(inputs=lay6_image,
                                                  filters=3 *
                                                  self.config['n_filter'],
                                                  kernel_size=[5, 5],
                                                  strides=[1, 1],
                                                  padding="valid",
                                                  use_bias=True,
                                                  activation=None)
                    if self.normalization_mode == 'Layer Norm':
                        lay8_image = self.activation_function(
                            helper.conv_layer_norm_layer(lay7_image))
                    elif self.normalization_mode == 'Batch Norm':
                        lay8_image = self.activation_function(
                            helper.batch_norm()(lay7_image))
                    else:
                        lay8_image = self.activation_function(lay7_image)

                    latent_image = tf.layers.conv2d(
                        inputs=lay8_image,
                        filters=3 * self.config['n_filter'],
                        kernel_size=[3, 3],
                        strides=[1, 1],
                        padding="valid",
                        use_bias=True,
                        activation=self.activation_function)
                    latent_image_flat = tf.reshape(
                        latent_image,
                        [-1,
                         np.prod(latent_image.get_shape().as_list()[1:])])

                if self.config['encoder_mode'] == 'Deterministic':
                    lay1_flat = tf.layers.dense(
                        inputs=latent_image_flat,
                        units=self.config['n_flat'],
                        use_bias=True,
                        activation=self.activation_function)
                    latent_flat = tf.layers.dense(inputs=lay1_flat,
                                                  units=self.n_output,
                                                  use_bias=True,
                                                  activation=None)
                if self.config['encoder_mode'] == 'Gaussian':
                    lay1_flat = tf.layers.dense(
                        inputs=latent_image_flat,
                        units=self.config['n_flat'],
                        use_bias=True,
                        activation=self.activation_function)
                    latent_mu = tf.layers.dense(inputs=lay1_flat,
                                                units=self.n_output,
                                                use_bias=True,
                                                activation=None)
                    latent_log_sig = tf.layers.dense(inputs=lay1_flat,
                                                     units=self.n_output,
                                                     use_bias=True,
                                                     activation=None)
                    latent_flat = latent_mu + tf.nn.softplus(
                        latent_log_sig) * noise
                if self.config['encoder_mode'] == 'UnivApprox' or self.config[
                        'encoder_mode'] == 'UnivApproxNoSpatial':
                    lay1_concat = tf.layers.dense(
                        inputs=tf.concat([latent_image_flat, noise], axis=-1),
                        units=self.config['n_flat'],
                        use_bias=True,
                        activation=self.activation_function)
                    latent_flat = tf.layers.dense(inputs=lay1_concat,
                                                  units=self.n_output,
                                                  use_bias=True,
                                                  activation=None)
                if self.config['encoder_mode'] == 'UnivApproxSine':
                    lay1_concat = tf.layers.dense(
                        inputs=tf.concat([latent_image_flat, noise], axis=-1),
                        units=self.config['n_flat'],
                        use_bias=True,
                        activation=self.activation_function)
                    latent_correction = tf.layers.dense(inputs=lay1_concat,
                                                        units=self.n_output,
                                                        use_bias=True,
                                                        activation=None)
                    latent_output = tf.layers.dense(inputs=lay1_concat,
                                                    units=self.n_output,
                                                    use_bias=True,
                                                    activation=None)
                    latent_flat = latent_output + tf.sin(
                        self.config['enc_sine_freq'] *
                        noise) - latent_correction

                z_flat = tf.reshape(
                    latent_flat,
                    [-1, x['image'].get_shape().as_list()[1], self.n_output])
            self.constructed = True
            return z_flat
Beispiel #4
0
	def forward(self, x, name = ''):
		with tf.variable_scope("Critic", reuse=self.constructed):
			outputs = []
			if x['flat'] is not None:
				x_batched_inp_flat = tf.reshape(x['flat'], [-1,  *x['flat'].get_shape().as_list()[2:]])
				lay1_flat = tf.layers.dense(inputs = x_batched_inp_flat, units = self.config['n_flat'], activation = self.activation_function)
				lay2_flat = tf.layers.dense(inputs = lay1_flat, units = self.config['n_flat'], activation = self.activation_function)
				lay3_flat = tf.layers.dense(inputs = lay2_flat, units = self.config['n_flat'], activation = self.activation_function)
				lay4_flat = tf.layers.dense(inputs = lay3_flat, units = 1, activation = None)
				outputs.append(tf.reshape(lay4_flat, [-1, 1, 1]))

			if x['image'] is not None: 
				image_shape = (self.config['data_properties']['image'][0]['size'][-3:-1])
				x_batched_inp_image = tf.reshape(x['image'], [-1,  *x['image'].get_shape().as_list()[2:]])

				# # 28x28xn_channels
				if image_shape == (28, 28):

					lay1_image = tf.layers.conv2d(inputs=x_batched_inp_image, filters=self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], padding="same", use_bias=True, activation=self.activation_function)
					lay2_image = tf.layers.conv2d(inputs=lay1_image, filters=2*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], padding="same", use_bias=True, activation=None)
					if self.normalization_mode == 'Layer Norm': 
						lay3_image = self.activation_function(helper.conv_layer_norm_layer(lay2_image))
					elif self.normalization_mode == 'Batch Norm': 
						lay3_image = self.activation_function(helper.batch_norm()(lay2_image))
					else: lay3_image = self.activation_function(lay2_image)

					lay4_image = tf.layers.conv2d(inputs=lay3_image, filters=4*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], padding="same", use_bias=True, activation=None)
					if self.normalization_mode == 'Layer Norm': 
						lay5_image = self.activation_function(helper.conv_layer_norm_layer(lay4_image))
					elif self.normalization_mode == 'Batch Norm': 
						lay5_image = self.activation_function(helper.batch_norm()(lay4_image))
					else: lay5_image = self.activation_function(lay4_image)

					critic_image = tf.layers.conv2d(inputs=lay5_image, filters=1, kernel_size=[4, 4], strides=[1, 1], padding="valid", use_bias=True, activation=None)

				# # 32x32xn_channels
				if image_shape == (32, 32):

					lay1_image = tf.layers.conv2d(inputs=x_batched_inp_image, filters=self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], padding="same", use_bias=True, activation=self.activation_function)
					lay2_image = tf.layers.conv2d(inputs=lay1_image, filters=2*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], padding="same", use_bias=True, activation=None)
					if self.normalization_mode == 'Layer Norm': 
						lay3_image = self.activation_function(helper.conv_layer_norm_layer(lay2_image))
					elif self.normalization_mode == 'Batch Norm': 
						lay3_image = self.activation_function(helper.batch_norm()(lay2_image))
					else: lay3_image = self.activation_function(lay2_image)
					
					lay4_image = tf.layers.conv2d(inputs=lay3_image, filters=4*self.config['n_filter'], kernel_size=[5, 5], strides=[1, 1], padding="valid", use_bias=True, activation=None)
					if self.normalization_mode == 'Layer Norm': 
						lay5_image = self.activation_function(helper.conv_layer_norm_layer(lay4_image))
					elif self.normalization_mode == 'Batch Norm': 
						lay5_image = self.activation_function(helper.batch_norm()(lay4_image))
					else: lay5_image = self.activation_function(lay4_image)
					
					critic_image = tf.layers.conv2d(inputs=lay5_image, filters=1, kernel_size=[4, 4], strides=[1, 1], padding="valid", use_bias=True, activation=None)
					
				# 64x64xn_channels
				if image_shape == (64, 64):

					lay1_image = tf.layers.conv2d(inputs=x_batched_inp_image, filters=self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], padding="same", use_bias=True, activation=self.activation_function)
					lay2_image = tf.layers.conv2d(inputs=lay1_image, filters=2*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], padding="same", use_bias=True, activation=None)
					if self.normalization_mode == 'Layer Norm': 
						lay3_image = self.activation_function(helper.conv_layer_norm_layer(lay2_image))
					elif self.normalization_mode == 'Batch Norm': 
						lay3_image = self.activation_function(helper.batch_norm()(lay2_image))
					else: lay3_image = self.activation_function(lay2_image)

					lay4_image = tf.layers.conv2d(inputs=lay3_image, filters=4*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], padding="same", use_bias=True, activation=None)
					if self.normalization_mode == 'Layer Norm': 
						lay5_image = self.activation_function(helper.conv_layer_norm_layer(lay4_image))
					elif self.normalization_mode == 'Batch Norm': 
						lay5_image = self.activation_function(helper.batch_norm()(lay4_image))
					else: lay5_image = self.activation_function(lay4_image)

					lay6_image = tf.layers.conv2d(inputs=lay5_image, filters=8*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], padding="same", use_bias=True, activation=None)
					if self.normalization_mode == 'Layer Norm': 
						lay7_image = self.activation_function(helper.conv_layer_norm_layer(lay6_image))
					elif self.normalization_mode == 'Batch Norm': 
						lay7_image = self.activation_function(helper.batch_norm()(lay6_image))
					else: lay7_image = self.activation_function(lay6_image)

					critic_image = tf.layers.conv2d(inputs=lay7_image, filters=1, kernel_size=[4, 4], strides=[1, 1], padding="valid", use_bias=True, activation=None)

				critic = tf.reshape(critic_image, [-1, x['image'].get_shape().as_list()[1], 1])
				outputs.append(critic)

			if len(outputs) > 1: 
				pdb.set_trace()
				merged_input = tf.concat(outputs, axis=-1)
				x_batched_inp_image = tf.reshape(x['image'], [-1,  *x['image'].get_shape().as_list()[2:]])
				input_merged = tf.reshape(merged_input, [-1, merged_input.get_shape().as_list()[-1]])
				lay1_merged = tf.layers.dense(inputs = input_merged, units = 1, activation = None)
				enc = tf.reshape(lay1_merged, [-1, x['flat'].get_shape().as_list()[1], 1])
			else: enc = outputs[0]
			self.constructed = True
			return enc
Beispiel #5
0
	def forward(self, x, name = ''):
		with tf.variable_scope("Generator", reuse=self.constructed):
			out_dict = {'flat': None, 'image': None}
			if len(self.config['data_properties']['flat']) > 0:
				n_output_size = helper.list_sum([distributions.DistributionsAKA[e['dist']].num_params(e['size'][-1]) for e in self.config['data_properties']['flat']])
				x_batched_inp_flat = tf.reshape(x, [-1,  x.get_shape().as_list()[-1]])
				
				lay1_flat = tf.layers.dense(inputs = x_batched_inp_flat, units = self.config['n_flat'], activation = self.activation_function)
				lay2_flat = tf.layers.dense(inputs = lay1_flat, units = self.config['n_flat'], activation = self.activation_function)
				lay3_flat = tf.layers.dense(inputs = lay2_flat, units = self.config['n_flat'], activation = self.activation_function)
				flat_param = tf.layers.dense(inputs = lay3_flat, units = n_output_size, activation = None)
				out_dict['flat'] = tf.reshape(flat_param, [-1, x.get_shape().as_list()[1], n_output_size])

			if len(self.config['data_properties']['image']) > 0:
				image_shape = (self.config['data_properties']['image'][0]['size'][-3:-1])
				n_image_size = np.prod(image_shape)
				n_output_channels = helper.list_sum([distributions.DistributionsAKA[e['dist']].num_params(e['size'][-1]) for e in self.config['data_properties']['image']])
				x_batched_inp_flat = tf.reshape(x, [-1,  x.get_shape().as_list()[-1]])
				
				# # 28x28xn_channels
				if image_shape == (28, 28):
					print('CHECK ME I AM FULLY CONNECTED~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')				
					lay1_image = tf.layers.dense(inputs = x_batched_inp_flat, units = self.config['n_flat'], activation = self.activation_function)
					lay2_image = tf.layers.dense(inputs = lay1_image, units = self.config['n_flat'], activation = self.activation_function)
					lay3_image = tf.layers.dense(inputs = lay2_image, units = n_image_size*n_output_channels, activation = self.activation_function)
					lay4_image = tf.nn.sigmoid(lay3_image)
					image_param = tf.reshape(lay4_image, [-1, *image_shape, n_output_channels])

					# lay1_image = tf.layers.dense(inputs = x_batched_inp_flat, units = 8*self.config['n_filter']*4*4, activation = None)
					# lay2_image = tf.reshape(lay1_image, [-1, 4, 4, 8*self.config['n_filter']])
					# if self.normalization_mode == 'Layer Norm': 
					# 	lay3_image = self.activation_function(helper.conv_layer_norm_layer(lay2_image))
					# elif self.normalization_mode == 'Batch Norm': 
					# 	lay3_image = self.activation_function(helper.batch_norm()(lay2_image))
					# else: lay3_image = self.activation_function(lay2_image) #h0

					# lay4_image = tf.layers.conv2d_transpose(inputs=lay3_image, filters=4*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], activation=None)
					# lay5_image = helper.tf_center_crop_image(lay4_image, resize_ratios=[8,8])
					# if self.normalization_mode == 'Layer Norm': 
					# 	lay6_image = self.activation_function(helper.conv_layer_norm_layer(lay5_image))
					# elif self.normalization_mode == 'Batch Norm': 
					# 	lay6_image = self.activation_function(helper.batch_norm()(lay5_image))
					# else: lay6_image = self.activation_function(lay5_image) #h1

					# lay7_image = tf.layers.conv2d_transpose(inputs=lay6_image, filters=2*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], activation=None)
					# lay8_image = helper.tf_center_crop_image(lay7_image, resize_ratios=[16,16])
					# if self.normalization_mode == 'Layer Norm': 
					# 	lay9_image = self.activation_function(helper.conv_layer_norm_layer(lay8_image))
					# elif self.normalization_mode == 'Batch Norm': 
					# 	lay9_image = self.activation_function(helper.batch_norm()(lay8_image))
					# else: lay9_image = self.activation_function(lay8_image) #h2

					# lay10_image = tf.layers.conv2d_transpose(inputs=lay9_image, filters=1*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], activation=None)
					# lay11_image = helper.tf_center_crop_image(lay10_image, resize_ratios=[28,28])
					# if self.normalization_mode == 'Layer Norm': 
					# 	lay12_image = self.activation_function(helper.conv_layer_norm_layer(lay11_image))
					# elif self.normalization_mode == 'Batch Norm': 
					# 	lay12_image = self.activation_function(helper.batch_norm()(lay11_image))
					# else: lay12_image = self.activation_function(lay11_image) #h3

					# lay13_image = tf.layers.conv2d_transpose(inputs=lay12_image, filters=n_output_channels, kernel_size=[3, 3], strides=[1, 1], activation=tf.nn.sigmoid)
					# image_param = helper.tf_center_crop_image(lay13_image, resize_ratios=[28,28])

				# # 32x32xn_channels
				if image_shape == (32, 32):

					lay1_image = tf.layers.dense(inputs = x_batched_inp_flat, units = 8*self.config['n_filter']*4*4, activation = None)
					lay2_image = tf.reshape(lay1_image, [-1, 4, 4, 8*self.config['n_filter']])
					if self.normalization_mode == 'Layer Norm': 
						lay3_image = self.activation_function(helper.conv_layer_norm_layer(lay2_image))
					elif self.normalization_mode == 'Batch Norm': 
						lay3_image = self.activation_function(helper.batch_norm()(lay2_image))
					else: lay3_image = self.activation_function(lay2_image) #h0

					lay4_image = tf.layers.conv2d_transpose(inputs=lay3_image, filters=4*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], activation=None)
					lay5_image = helper.tf_center_crop_image(lay4_image, resize_ratios=[8,8])
					if self.normalization_mode == 'Layer Norm': 
						lay6_image = self.activation_function(helper.conv_layer_norm_layer(lay5_image))
					elif self.normalization_mode == 'Batch Norm': 
						lay6_image = self.activation_function(helper.batch_norm()(lay5_image))
					else: lay6_image = self.activation_function(lay5_image) #h1

					lay7_image = tf.layers.conv2d_transpose(inputs=lay6_image, filters=2*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], activation=None)
					lay8_image = helper.tf_center_crop_image(lay7_image, resize_ratios=[16,16])
					if self.normalization_mode == 'Layer Norm': 
						lay9_image = self.activation_function(helper.conv_layer_norm_layer(lay8_image))
					elif self.normalization_mode == 'Batch Norm': 
						lay9_image = self.activation_function(helper.batch_norm()(lay8_image))
					else: lay9_image = self.activation_function(lay8_image) #h2

					lay10_image = tf.layers.conv2d_transpose(inputs=lay9_image, filters=1*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], activation=None)
					lay11_image = helper.tf_center_crop_image(lay10_image, resize_ratios=[32,32])
					if self.normalization_mode == 'Layer Norm': 
						lay12_image = self.activation_function(helper.conv_layer_norm_layer(lay11_image))
					elif self.normalization_mode == 'Batch Norm': 
						lay12_image = self.activation_function(helper.batch_norm()(lay11_image))
					else: lay12_image = self.activation_function(lay11_image) #h3

					lay13_image = tf.layers.conv2d_transpose(inputs=lay12_image, filters=n_output_channels, kernel_size=[3, 3], strides=[1, 1], activation=tf.nn.sigmoid)
					image_param = helper.tf_center_crop_image(lay13_image, resize_ratios=[32,32])

				# 64x64xn_channels
				if image_shape == (64, 64):

					lay1_image = tf.layers.dense(inputs = x_batched_inp_flat, units = 8*self.config['n_filter']*4*4, activation = None)
					lay2_image = tf.reshape(lay1_image, [-1, 4, 4, 8*self.config['n_filter']])
					if self.normalization_mode == 'Layer Norm': 
						lay3_image = self.activation_function(helper.conv_layer_norm_layer(lay2_image))
					elif self.normalization_mode == 'Batch Norm': 
						lay3_image = self.activation_function(helper.batch_norm()(lay2_image))
					else: lay3_image = self.activation_function(lay2_image) #h0

					lay4_image = tf.layers.conv2d_transpose(inputs=lay3_image, filters=4*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], activation=None)
					lay5_image = helper.tf_center_crop_image(lay4_image, resize_ratios=[8,8])
					if self.normalization_mode == 'Layer Norm': 
						lay6_image = self.activation_function(helper.conv_layer_norm_layer(lay5_image))
					elif self.normalization_mode == 'Batch Norm': 
						lay6_image = self.activation_function(helper.batch_norm()(lay5_image))
					else: lay6_image = self.activation_function(lay5_image) #h1

					lay7_image = tf.layers.conv2d_transpose(inputs=lay6_image, filters=2*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], activation=None)
					lay8_image = helper.tf_center_crop_image(lay7_image, resize_ratios=[16,16])
					if self.normalization_mode == 'Layer Norm': 
						lay9_image = self.activation_function(helper.conv_layer_norm_layer(lay8_image))
					elif self.normalization_mode == 'Batch Norm': 
						lay9_image = self.activation_function(helper.batch_norm()(lay8_image))
					else: lay9_image = self.activation_function(lay8_image) #h2
					
					lay10_image = tf.layers.conv2d_transpose(inputs=lay9_image, filters=1*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], activation=None)
					lay11_image = helper.tf_center_crop_image(lay10_image, resize_ratios=[32,32])
					if self.normalization_mode == 'Layer Norm': 
						lay12_image = self.activation_function(helper.conv_layer_norm_layer(lay11_image))
					elif self.normalization_mode == 'Batch Norm': 
						lay12_image = self.activation_function(helper.batch_norm()(lay11_image))
					else: lay12_image = self.activation_function(lay11_image) #h3

					lay13_image = tf.layers.conv2d_transpose(inputs=lay12_image, filters=n_output_channels, kernel_size=[5, 5], strides=[2, 2], activation=tf.nn.sigmoid)
					image_param = helper.tf_center_crop_image(lay13_image, resize_ratios=[64,64])

				# 128x128xn_channels
				if image_shape == (128, 128):

					lay1_image = tf.layers.dense(inputs = x_batched_inp_flat, units = 8*self.config['n_filter']*4*4, activation = None)
					lay2_image = tf.reshape(lay1_image, [-1, 4, 4, 8*self.config['n_filter']])
					if self.normalization_mode == 'Layer Norm': 
						lay3_image = self.activation_function(helper.conv_layer_norm_layer(lay2_image))
					elif self.normalization_mode == 'Batch Norm': 
						lay3_image = self.activation_function(helper.batch_norm()(lay2_image))
					else: lay3_image = self.activation_function(lay2_image) #h0

					lay4_image = tf.layers.conv2d_transpose(inputs=lay3_image, filters=4*self.config['n_filter'], kernel_size=[5, 5], strides=[1, 1], activation=None)
					if self.normalization_mode == 'Layer Norm': 
						lay5_image = self.activation_function(helper.conv_layer_norm_layer(lay4_image))
					elif self.normalization_mode == 'Batch Norm': 
						lay5_image = self.activation_function(helper.batch_norm()(lay4_image))
					else: lay5_image = self.activation_function(lay4_image) #h1

					lay6_image = tf.layers.conv2d_transpose(inputs=lay5_image, filters=2*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], activation=None)
					lay7_image = helper.tf_center_crop_image(lay6_image, resize_ratios=[16,16])
					if self.normalization_mode == 'Layer Norm': 
						lay8_image = self.activation_function(helper.conv_layer_norm_layer(lay7_image))
					elif self.normalization_mode == 'Batch Norm': 
						lay8_image = self.activation_function(helper.batch_norm()(lay7_image))
					else: lay8_image = self.activation_function(lay7_image) #h2

					lay9_image = tf.layers.conv2d_transpose(inputs=lay8_image, filters=2*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], activation=None)
					lay10_image = helper.tf_center_crop_image(lay9_image, resize_ratios=[32,32])
					if self.normalization_mode == 'Layer Norm': 
						lay11_image = self.activation_function(helper.conv_layer_norm_layer(lay10_image))
					elif self.normalization_mode == 'Batch Norm': 
						lay11_image = self.activation_function(helper.batch_norm()(lay10_image))
					else: lay11_image = self.activation_function(lay10_image) #h3

					lay12_image = tf.layers.conv2d_transpose(inputs=lay11_image, filters=1*self.config['n_filter'], kernel_size=[5, 5], strides=[2, 2], activation=None)
					lay13_image = helper.tf_center_crop_image(lay12_image, resize_ratios=[64,64])
					if self.normalization_mode == 'Layer Norm': 
						lay14_image = self.activation_function(helper.conv_layer_norm_layer(lay13_image))
					elif self.normalization_mode == 'Batch Norm': 
						lay14_image = self.activation_function(helper.batch_norm()(lay13_image))
					else: lay14_image = self.activation_function(lay13_image) #h3

					lay15_image = tf.layers.conv2d_transpose(inputs=lay14_image, filters=n_output_channels, kernel_size=[5, 5], strides=[2, 2], activation=tf.nn.sigmoid)
					image_param = helper.tf_center_crop_image(lay15_image, resize_ratios=[128,128])

				out_dict['image'] = tf.reshape(image_param, [-1, x.get_shape().as_list()[1], *image_shape, n_output_channels])

			self.constructed = True
			return out_dict