def connect(self, data, generator_inputs): """Connects the components and returns the losses, outputs and debug ops. Args: data: a `tf.Tensor`: `[batch_size, ...]`. There are no constraints on the rank of this tensor, but it has to be compatible with the shapes expected by the discriminator. generator_inputs: a `tf.Tensor`: `[g_in_batch_size, ...]`. It does not have to have the same batch size as the `data` tensor. There are not constraints on the rank of this tensor, but it has to be compatible with the shapes the generator network supports as inputs. Returns: An `ModelOutputs` instance. """ samples, optimised_z = utils.optimise_and_sample(generator_inputs, self, data, is_training=True) optimisation_cost = utils.get_optimisation_cost( generator_inputs, optimised_z) # Pass in the labels to the discriminator in case we are using a # discriminator which makes use of labels. The labels can be None. disc_data_logits = self._discriminator(data) disc_sample_logits = self._discriminator(samples) disc_data_loss = utils.cross_entropy_loss( disc_data_logits, tf.ones(tf.shape(disc_data_logits[:, 0]), dtype=tf.int32)) disc_sample_loss = utils.cross_entropy_loss( disc_sample_logits, tf.zeros(tf.shape(disc_sample_logits[:, 0]), dtype=tf.int32)) disc_loss = disc_data_loss + disc_sample_loss generator_loss = utils.cross_entropy_loss( disc_sample_logits, tf.ones(tf.shape(disc_sample_logits[:, 0]), dtype=tf.int32)) optimization_components = self._build_optimization_components( discriminator_loss=disc_loss, generator_loss=generator_loss, optimisation_cost=optimisation_cost) debug_ops = {} debug_ops['z_step_size'] = self.z_step_size debug_ops['disc_data_loss'] = disc_data_loss debug_ops['disc_sample_loss'] = disc_sample_loss debug_ops['disc_loss'] = disc_loss debug_ops['gen_loss'] = generator_loss debug_ops['opt_cost'] = optimisation_cost return utils.ModelOutputs(optimization_components, debug_ops)
def connect(self, data, generator_inputs): """Connects the components and returns the losses, outputs and debug ops. Args: data: a `tf.Tensor`: `[batch_size, ...]`. There are no constraints on the rank of this tensor, but it has to be compatible with the shapes expected by the discriminator. generator_inputs: a `tf.Tensor`: `[g_in_batch_size, ...]`. It does not have to have the same batch size as the `data` tensor. There are not constraints on the rank of this tensor, but it has to be compatible with the shapes the generator network supports as inputs. Returns: An `ModelOutputs` instance. """ samples, optimised_z = utils.optimise_and_sample( generator_inputs, self, data, is_training=True) optimisation_cost = utils.get_optimisation_cost(generator_inputs, optimised_z) debug_ops = {} initial_samples = self.generator(generator_inputs, is_training=True) generator_loss = tf.reduce_mean(self.gen_loss_fn(data, samples)) # compute the RIP loss # (\sqrt{F(x_1 - x_2)^2} - \sqrt{(x_1 - x_2)^2})^2 # as a triplet loss for 3 pairs of images. r1 = self._get_rip_loss(samples, initial_samples) r2 = self._get_rip_loss(samples, data) r3 = self._get_rip_loss(initial_samples, data) rip_loss = tf.reduce_mean((r1 + r2 + r3) / 3.0) total_loss = generator_loss + rip_loss optimization_components = self._build_optimization_components( generator_loss=total_loss) debug_ops['rip_loss'] = rip_loss debug_ops['recons_loss'] = tf.reduce_mean( tf.norm(snt.BatchFlatten()(samples) - snt.BatchFlatten()(data), axis=-1)) debug_ops['z_step_size'] = self.z_step_size debug_ops['opt_cost'] = optimisation_cost debug_ops['gen_loss'] = generator_loss return utils.ModelOutputs( optimization_components, debug_ops)