def build(self, input_shape): tfd = tfp.distributions with summary.SummaryScope('latent_distributions') as scope: categorical = self.add_weight( name='categorical_distribution', shape=[FLAGS.hidden_dims, FLAGS.categorical_dims], trainable=True, ) categorical = tf.nn.softmax(categorical) loc = self.add_weight( name='logistic_loc_variables', shape=[FLAGS.hidden_dims, FLAGS.categorical_dims], trainable=True, ) scale = tf.nn.softplus( self.add_weight( name='logistic_scale_variables', shape=[FLAGS.hidden_dims, FLAGS.categorical_dims], trainable=True, )) scope['categorical'] = categorical scope['loc'] = loc scope['scale'] = scale self.vars = [categorical, loc, scale] self._distribution = tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical(probs=categorical), components_distribution=tfd.Normal( loc=loc, scale=scale, )) self.built = True
def call(self, input): with summary.SummaryScope(self.name) as scope: scope['input'] = input layers = scope.sequential( input, self.model_layers, interior_layers=True, ) output = layers[-1] internal = layers[-2] upsize = tf.image.resize_bilinear( output, size=[2 * int(shape) for shape in output.shape[1:-1]], ) scope['upsize'] = upsize # with summary.SummaryScope(self.name + "_upsize") as scope: # upsize = scope.sequential(internal, self.model_upsize) return output, upsize
def build(self, input_shape): if FLAGS.gaussian_downsample: with summary.SummaryScope(self.name) as scope: self.std = self.add_weight( 'std', shape=[int(input_shape[-1])], trainable=True, ) scale = tf.nn.relu(self.std) + 1e-6 dist = tfp.distributions.Normal( loc=tf.zeros_like(scale), scale=scale, ) vals = dist.prob( tf.range( start=-self.size, limit=self.size + 1, dtype=tf.float32, )[:, tf.newaxis]) gauss_kernel = vals[:, tf.newaxis, :] * vals[tf.newaxis, :, :] gauss_kernel /= tf.reduce_sum(gauss_kernel, axis=[0, 1]) self.kernel = gauss_kernel[:, :, :, tf.newaxis] scope['std'] = self.std else: self.kernel = self.add_weight( 'kernel', shape=[ 2 * self.size + 1, 2 * self.size + 1, int(input_shape[-1]), 1, ], trainable=True, )
def call(self, input): with summary.SummaryScope(self.name) as scope: scope['input'] = input output = scope.sequential(input, self.model_layers) return output
def log_sampling_information(latent, distribution): with summary.SummaryScope('samples') as scope: samples = tf.layers.flatten(latent) scope['latent_samples'] = samples scope['distribution_samples'] = distribution.sample( samples.get_shape()[-1])
def build_losses(self): num_pixels = np.prod(self.original_dim[:2]) with summary.SummaryScope('losses') as scope: expected_bits_per_image = tf.reduce_sum( [ tf.reduce_sum( layer['likelihood'], axis=[ 1, 2, ], ) for layer in self.outputs ], axis=[0], ) print(expected_bits_per_image.shape) self.output_expected_bits = expected_bits_per_image / (-np.log(2) * num_pixels) train_bpp = tf.reduce_mean(self.output_expected_bits) self.output_image = self.outputs[-1]['output'] self.output_original = self.outputs[-1]['input'] train_ssim = tf.reduce_mean(1 - tf.image.ssim( self.outputs[-1]['output'], self.outputs[-1]['input'], 1.0, )) / 2.0 train_mse = tf.reduce_sum([ tf.reduce_mean( tf.squared_difference( layer['input'], layer['output'], )) for idx, layer in enumerate(self.outputs[-1:] if not FLAGS. helper_mse_loss else self.outputs) ]) train_mse *= 255**2 / num_pixels if FLAGS.use_ssim: train_loss = train_ssim * 0.05 + train_bpp else: train_loss = train_mse * 0.05 + train_bpp train_psnr = tf.reduce_mean( tf.image.psnr( self.outputs[-1]['output'], self.outputs[-1]['input'], 1.0, )) scope['bpp'] = train_bpp scope['mse'] = train_mse scope['loss'] = train_loss scope['psnr'] = train_psnr scope['ssim'] = train_ssim train_op = tf.train.AdamOptimizer(FLAGS.learning_rate) \ .minimize(train_loss) merged = tf.summary.merge_all() images_summary = tf.summary.merge([ tf.summary.image( f'comparison_{idx}', image, max_outputs=FLAGS.batch_size, ) for idx, image in enumerate(self.images) ]) self.train_loss = train_loss self.train_op = train_op self.merged = merged self.images_summary = images_summary self.train_bpp = train_bpp self.train_mse = train_mse self.train_psnr = train_psnr self.train_ssim = train_ssim