class CVAEGAN(Image2ImageGenerativeModel): def __init__(self, x_dim, y_dim, z_dim, enc_architecture, gen_architecture, adversarial_architecture, folder="./CVAEGAN", is_patchgan=False, is_wasserstein=False): super(CVAEGAN, self).__init__( x_dim, y_dim, [enc_architecture, gen_architecture, adversarial_architecture], folder) self._z_dim = z_dim with tf.name_scope("Inputs"): self._Z_input = tf.placeholder(tf.float32, shape=[None, z_dim], name="z") self._enc_architecture = self._architectures[0] self._gen_architecture = self._architectures[1] self._adv_architecture = self._architectures[2] self._is_patchgan = is_patchgan self._is_wasserstein = is_wasserstein self._is_feature_matching = False ################# Define architecture if self._is_patchgan: f_xy = self._adv_architecture[-1][-1]["filters"] assert f_xy == 1, "If is PatchGAN, last layer of adversarial_XY needs 1 filter. Given: {}.".format( f_xy) a_xy = self._adv_architecture[-1][-1]["activation"] if self._is_wasserstein: assert a_xy == tf.identity, "If is PatchGAN, last layer of adversarial needs tf.identity. Given: {}.".format( a_xy) else: assert a_xy == tf.nn.sigmoid, "If is PatchGAN, last layer of adversarial needs tf.nn.sigmoid. Given: {}.".format( a_xy) else: self._adv_architecture.append( [tf.layers.flatten, { "name": "Flatten" }]) if self._is_wasserstein: self._adv_architecture.append([ logged_dense, { "units": 1, "activation": tf.identity, "name": "Output" } ]) else: self._adv_architecture.append([ logged_dense, { "units": 1, "activation": tf.nn.sigmoid, "name": "Output" } ]) last_layers_mean = [[tf.layers.flatten, { "name": "flatten" }], [ logged_dense, { "units": z_dim, "activation": tf.identity, "name": "Mean" } ]] self._encoder_mean = Encoder(self._enc_architecture + last_layers_mean, name="Encoder") last_layers_std = [[tf.layers.flatten, { "name": "flatten" }], [ logged_dense, { "units": z_dim, "activation": tf.identity, "name": "Std" } ]] self._encoder_std = Encoder(self._enc_architecture + last_layers_std, name="Encoder") self._gen_architecture[-1][1]["name"] = "Output" self._generator = Generator(self._gen_architecture, name="Generator") self._adversarial = Discriminator(self._adv_architecture, name="Adversarial") self._nets = [self._encoder_mean, self._generator, self._adversarial] ################# Connect inputs and networks self._mean_layer = self._encoder_mean.generate_net(self._Y_input) self._std_layer = self._encoder_std.generate_net(self._Y_input) self._output_enc_with_noise = self._mean_layer + tf.exp( 0.5 * self._std_layer) * self._Z_input with tf.name_scope("Inputs"): self._gen_input = image_condition_concat( inputs=self._X_input, condition=self._output_enc_with_noise, name="mod_z_real") self._gen_input_from_encoding = image_condition_concat( inputs=self._X_input, condition=self._Z_input, name="mod_z") self._output_gen = self._generator.generate_net(self._gen_input) self._output_gen_from_encoding = self._generator.generate_net( self._gen_input_from_encoding) self._generator._input_dim = z_dim assert self._output_gen.get_shape()[1:] == y_dim, ( "Generator output must have shape of y_dim. Given: {}. Expected: {}." .format(self._output_gen.get_shape(), x_dim)) with tf.name_scope("InputsAdversarial"): self._input_real = tf.concat(values=[self._Y_input, self._X_input], axis=3) self._input_fake_from_real = tf.concat( values=[self._output_gen, self._X_input], axis=3) self._input_fake_from_latent = tf.concat( values=[self._output_gen_from_encoding, self._X_input], axis=3) self._output_adv_real = self._adversarial.generate_net( self._input_real) self._output_adv_fake_from_real = self._adversarial.generate_net( self._input_fake_from_real) self._output_adv_fake_from_latent = self._adversarial.generate_net( self._input_fake_from_latent) ################# Finalize self._init_folders() self._verify_init() self._output_label_real = tf.placeholder( tf.float32, shape=self._output_adv_real.shape, name="label_real") self._output_label_fake = tf.placeholder( tf.float32, shape=self._output_adv_fake_from_real.shape, name="label_fake") if self._is_patchgan: print("PATCHGAN chosen with output: {}.".format( self._output_adv_real.shape)) def compile(self, loss, optimizer, learning_rate=None, learning_rate_enc=None, learning_rate_gen=None, learning_rate_adv=None, label_smoothing=1, lmbda_kl=0.1, lmbda_y=1, feature_matching=False, random_labeling=0): if self._is_wasserstein and loss != "wasserstein": raise ValueError( "If is_wasserstein is true in Constructor, loss needs to be wasserstein." ) if not self._is_wasserstein and loss == "wasserstein": raise ValueError( "If loss is wasserstein, is_wasserstein needs to be true in constructor." ) if np.all([ lr is None for lr in [ learning_rate, learning_rate_enc, learning_rate_gen, learning_rate_adv ] ]): raise ValueError("Need learning_rate.") if learning_rate is not None and learning_rate_enc is None: learning_rate_enc = learning_rate if learning_rate is not None and learning_rate_gen is None: learning_rate_gen = learning_rate if learning_rate is not None and learning_rate_adv is None: learning_rate_adv = learning_rate self._define_loss(loss=loss, label_smoothing=label_smoothing, lmbda_kl=lmbda_kl, lmbda_y=lmbda_y, feature_matching=feature_matching, random_labeling=random_labeling) with tf.name_scope("Optimizer"): self._enc_optimizer = optimizer(learning_rate=learning_rate_enc) self._enc_optimizer_op = self._enc_optimizer.minimize( self._enc_loss, var_list=self._get_vars("Encoder"), name="Encoder") self._gen_optimizer = optimizer(learning_rate=learning_rate_gen) self._gen_oprimizer_op = self._gen_optimizer.minimize( self._gen_loss, var_list=self._get_vars("Generator"), name="Generator") self._adv_optimizer = optimizer(learning_rate=learning_rate_adv) self._adv_optimizer_op = self._adv_optimizer.minimize( self._adv_loss, var_list=self._get_vars("Adversarial"), name="Adversarial") self._summarise() def _define_loss(self, loss, label_smoothing, lmbda_kl, lmbda_y, feature_matching, random_labeling): possible_losses = ["cross-entropy", "L2", "wasserstein", "KL"] def get_labels_one(): return tf.math.multiply(self._output_label_real, label_smoothing) def get_labels_zero(): return self._output_label_fake eps = 1e-6 self._label_smoothing = label_smoothing self._random_labeling = random_labeling ## Kullback-Leibler divergence self._KLdiv = 0.5 * (tf.square(self._mean_layer) + tf.exp(self._std_layer) - self._std_layer - 1) self._KLdiv = lmbda_kl * tf.reduce_mean(self._KLdiv) ## L1 loss self._recon_loss = lmbda_y * tf.reduce_mean( tf.abs(self._Y_input - self._output_gen)) ## Adversarial loss if loss == "cross-entropy": self._logits_real = tf.math.log(self._output_adv_real / (1 + eps - self._output_adv_real) + eps) self._logits_fake_from_real = tf.math.log( self._output_adv_fake_from_real / (1 + eps - self._output_adv_fake_from_real) + eps) self._logits_fake_from_latent = tf.math.log( self._output_adv_fake_from_latent / (1 + eps - self._output_adv_fake_from_latent) + eps) self._generator_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.ones_like(self._logits_fake_from_real), logits=self._logits_fake_from_real) + tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.ones_like(self._logits_fake_from_latent), logits=self._logits_fake_from_latent)) self._adversarial_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( labels=get_labels_one(), logits=self._logits_real) + tf.nn.sigmoid_cross_entropy_with_logits( labels=get_labels_zero(), logits=self._logits_fake_from_real) + tf.nn.sigmoid_cross_entropy_with_logits( labels=get_labels_zero(), logits=self._logits_fake_from_latent)) elif loss == "L2": self._generator_loss = tf.reduce_mean( tf.square(self._output_adv_fake_from_real - tf.ones_like(self._output_adv_fake_from_real)) + tf.square(self._output_adv_fake_from_latent - tf.ones_like(self._output_adv_fake_from_latent))) / 2 self._adversarial_loss = (tf.reduce_mean( tf.square(self._output_adv_real - get_labels_one()) + tf.square(self._output_adv_fake_from_real - get_labels_zero()) + tf.square(self._output_adv_fake_from_latent - get_labels_zero()))) / 3.0 elif loss == "wasserstein": self._generator_loss = -tf.reduce_mean( self._output_adv_fake_from_real) - tf.reduce_mean( self._output_adv_fake_from_latent) self._adversarial_loss = ( -(tf.reduce_mean(self._output_adv_real) - tf.reduce_mean(self._output_adv_fake_from_real) - tf.reduce_mean(self._output_adv_fake_from_latent)) + 10 * self._define_gradient_penalty()) elif loss == "KL": self._logits_real = tf.math.log(self._output_adv_real / (1 + eps - self._output_adv_real) + eps) self._logits_fake_from_real = tf.math.log( self._output_adv_fake_from_real / (1 + eps - self._output_adv_fake_from_real) + eps) self._logits_fake_from_latent = tf.math.log( self._output_adv_fake_from_latent / (1 + eps - self._output_adv_fake_from_latent) + eps) self._generator_loss = ( -tf.reduce_mean(self._logits_fake_from_real) - tf.reduce_mean(self._logits_fake_from_latent)) / 2 self._adversarial_loss = tf.reduce_mean( 0.5 * tf.nn.sigmoid_cross_entropy_with_logits( labels=get_labels_one(), logits=self._logits_real) + 0.25 * tf.nn.sigmoid_cross_entropy_with_logits( labels=get_labels_zero(), logits=self._logits_fake_from_real) + 0.25 * tf.nn.sigmoid_cross_entropy_with_logits( labels=get_labels_zero(), logits=self._logits_fake_from_latent)) else: raise ValueError( "Loss not implemented. Choose from {}. Given: {}.".format( possible_losses, loss)) if feature_matching: self._is_feature_matching = True otp_adv_real = self._adversarial.generate_net( self._input_real, tf_trainflag=self._is_training, return_idx=-2) otp_adv_fake = self._adversarial.generate_net( self._input_fake_from_real, tf_trainflag=self._is_training, return_idx=-2) self._generator_loss = tf.reduce_mean( tf.square(otp_adv_real - otp_adv_fake)) with tf.name_scope("Loss") as scope: self._enc_loss = self._KLdiv + self._recon_loss + self._generator_loss self._gen_loss = self._recon_loss + self._generator_loss self._adv_loss = self._adversarial_loss tf.summary.scalar("Kullback-Leibler", self._KLdiv) tf.summary.scalar("Reconstruction", self._recon_loss) tf.summary.scalar("Vanilla_Generator", self._generator_loss) tf.summary.scalar("Encoder", self._enc_loss) tf.summary.scalar("Generator", self._gen_loss) tf.summary.scalar("Adversarial", self._adv_loss) def _define_gradient_penalty(self): alpha = tf.random_uniform(shape=tf.shape(self._input_real), minval=0., maxval=1.) differences = self._input_fake_from_real - self._input_real interpolates = self._input_real + (alpha * differences) gradients = tf.gradients(self._adversarial.generate_net(interpolates), [interpolates])[0] slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients))) with tf.name_scope("Loss") as scope: self._gradient_penalty = tf.reduce_mean((slopes - 1.)**2) tf.summary.scalar("Gradient_penalty", self._gradient_penalty) return self._gradient_penalty def train(self, x_train, y_train, x_test, y_test, epochs=100, batch_size=64, adv_steps=5, gen_steps=1, log_step=3, gpu_options=None, batch_log_step=None): self._set_up_training(log_step=log_step, gpu_options=gpu_options) self._set_up_test_train_sample(x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test) self._z_test = self._generator.sample_noise(n=len(self._x_test)) nr_batches = np.floor(len(x_train) / batch_size) self.batch_size = batch_size self._prepare_monitoring() self._log_results(epoch=0, epoch_time=0) for epoch in range(epochs): adv_loss_epoch = 0 gen_loss_epoch = 0 enc_loss_epoch = 0 start = time.clock() trained_examples = 0 batch_nr = 0 while trained_examples < len(x_train): batch_train_start = time.clock() adv_loss_batch, gen_loss_batch, enc_loss_batch = self._optimize( self._trainset, adv_steps, gen_steps) trained_examples += self.batch_size adv_loss_epoch += adv_loss_batch gen_loss_epoch += gen_loss_batch enc_loss_epoch += enc_loss_batch self._total_train_time += (time.clock() - batch_train_start) if (batch_log_step is not None) and (batch_nr % batch_log_step == 0): self._count_batches += batch_log_step batch_train_time = (time.clock() - start) / 60 self._log(self._count_batches, batch_train_time) batch_nr += 1 epoch_train_time = (time.clock() - start) / 60 adv_loss_epoch = np.round(adv_loss_epoch, 2) gen_loss_epoch = np.round(gen_loss_epoch, 2) enc_loss_epoch = np.round(enc_loss_epoch, 2) print("\nEpoch {}: D: {}; G: {}; E: {}.".format( epoch, adv_loss_epoch, gen_loss_epoch, enc_loss_epoch)) if batch_log_step is None and (log_step is not None) and (epoch % log_step == 0): self._log(epoch + 1, epoch_train_time) def _optimize(self, dataset, adv_steps, gen_steps): for i in range(adv_steps): current_batch_x, current_batch_y = dataset.get_next_batch( self.batch_size) Z_noise = self._generator.sample_noise(n=len(current_batch_x)) _, adv_loss_batch = self._sess.run( [self._adv_optimizer_op, self._adv_loss], feed_dict={ self._X_input: current_batch_x, self._Y_input: current_batch_y, self._Z_input: Z_noise, self._is_training: True, self._output_label_real: self.get_random_label(is_real=True), self._output_label_fake: self.get_random_label(is_real=False) }) for i in range(gen_steps): Z_noise = self._generator.sample_noise(n=len(current_batch_x)) _, gen_loss_batch = self._sess.run( [self._gen_oprimizer_op, self._gen_loss], feed_dict={ self._X_input: current_batch_x, self._Y_input: current_batch_y, self._Z_input: Z_noise, self._is_training: True }) _, enc_loss_batch = self._sess.run( [self._enc_optimizer_op, self._enc_loss], feed_dict={ self._X_input: current_batch_x, self._Y_input: current_batch_y, self._Z_input: Z_noise, self._is_training: True }) return adv_loss_batch, gen_loss_batch, enc_loss_batch def _log_results(self, epoch, epoch_time): summary = self._sess.run(self._merged_summaries, feed_dict={ self._X_input: self._x_test, self._Y_input: self._y_test, self._Z_input: self._z_test, self._epoch_time: epoch_time, self._is_training: False, self._epoch_nr: epoch, self._output_label_real: self.get_random_label(is_real=True, size=self._nr_test), self._output_label_fake: self.get_random_label(is_real=False, size=self._nr_test) }) self._writer1.add_summary(summary, epoch) nr_test = len(self._x_test) summary = self._sess.run(self._merged_summaries, feed_dict={ self._X_input: self._trainset.get_xdata()[:nr_test], self._Z_input: self._z_test, self._Y_input: self._trainset.get_ydata()[:nr_test], self._epoch_time: epoch_time, self._is_training: False, self._epoch_nr: epoch, self._output_label_real: self.get_random_label(is_real=True, size=self._nr_test), self._output_label_fake: self.get_random_label(is_real=False, size=self._nr_test) }) self._writer2.add_summary(summary, epoch) if self._image_shape is not None: self.plot_samples(inpt_x=self._x_test[:10], inpt_y=self._y_test[:10], sess=self._sess, image_shape=self._image_shape, epoch=epoch, path="{}/GeneratedSamples/result_{}.png".format( self._folder, epoch)) self.save_model(epoch) additional_log = getattr(self, "evaluate", None) if callable(additional_log): self.evaluate(true=self._x_test, condition=self._y_test, epoch=epoch) print("Logged.") def plot_samples(self, inpt_x, inpt_y, sess, image_shape, epoch, path): outpt_xy = sess.run(self._output_gen_from_encoding, feed_dict={ self._X_input: inpt_x, self._Z_input: self._z_test[:len(inpt_x)], self._is_training: False }) image_matrix = np.array([[ x.reshape(self._x_dim[0], self._x_dim[1]), y.reshape(self._y_dim[0], self._y_dim[1]), np.zeros(shape=(self._x_dim[0], self._x_dim[1])), xy.reshape(self._y_dim[0], self._y_dim[1]) ] for x, y, xy in zip(inpt_x, inpt_y, outpt_xy)]) self._generator.build_generated_samples( image_matrix, column_titles=["True X", "True Y", "", "Gen_XY"], epoch=epoch, path=path) def _prepare_monitoring(self): self._total_train_time = 0 self._total_log_time = 0 self._count_batches = 0 self._batches = [] self._max_allowed_failed_checks = 20 self._enc_grads_and_vars = self._enc_optimizer.compute_gradients( self._enc_loss, var_list=self._get_vars("Encoder")) self._gen_grads_and_vars = self._gen_optimizer.compute_gradients( self._gen_loss, var_list=self._get_vars("Generator")) self._adv_grads_and_vars = self._adv_optimizer.compute_gradients( self._adv_loss, var_list=self._get_vars("Adversarial")) self._monitor_dict = { "Gradients": [[ self._enc_grads_and_vars, self._gen_grads_and_vars, self._adv_grads_and_vars ], ["Encoder", "Generator", "Adversarial"], [[] for i in range(9)]], "Losses": [[ self._enc_loss, self._gen_loss, self._adversarial_loss, self._generator_loss, self._recon_loss, self._KLdiv ], [ "Encoder (V+R+K)", "Generator (V+R)", "Adversarial", "Vanilla_Generator", "Reconstruction", "Kullback-Leibler" ], [[] for i in range(6)]], "Output Adversarial": [[ self._output_adv_fake_from_real, self._output_adv_fake_from_latent, self._output_adv_real ], ["Fake_from_real", "Fake_from_latent", "Real"], [[] for i in range(3)], [np.mean]] } self._check_dict = { "Dominating Discriminator": { "Tensors": [self._output_adv_real, self._output_adv_fake_from_real], "OPonTensors": [np.mean, np.mean], "Relation": [">", "<"], "Threshold": [ self._label_smoothing * 0.95, (1 - self._label_smoothing) * 1.05 ], "TensorRelation": np.logical_and }, "Generator outputs zeros": { "Tensors": [ self._output_gen_from_encoding, self._output_gen_from_encoding ], "OPonTensors": [np.max, np.min], "Relation": ["<", ">"], "Threshold": [0.05, 0.95], "TensorRelation": np.logical_or } } self._check_count = [0 for key in self._check_dict] if not os.path.exists(self._folder + "/Evaluation"): pos.mkdir(self._folder + "/Evaluation") os.mkdir(self._folder + "/Evaluation/Cells") os.mkdir(self._folder + "/Evaluation/CenterOfMassX") os.mkdir(self._folder + "/Evaluation/CenterOfMassY") os.mkdir(self._folder + "/Evaluation/Energy") os.mkdir(self._folder + "/Evaluation/MaxEnergy") os.mkdir(self._folder + "/Evaluation/StdEnergy") def evaluate(self, true, condition, epoch): print("Batch ", epoch) log_start = time.clock() self._batches.append(epoch) fake = self._sess.run(self._output_gen_from_encoding, feed_dict={ self._X_input: self._x_test, self._Z_input: self._z_test, self._is_training: False }) true = self._y_test.reshape( [-1, self._image_shape[0], self._image_shape[1]]) fake = fake.reshape([-1, self._image_shape[0], self._image_shape[1]]) build_histogram(true=true, fake=fake, function=get_energies, name="Energy", epoch=epoch, folder=self._folder) build_histogram(true=true, fake=fake, function=get_number_of_activated_cells, name="Cells", epoch=epoch, folder=self._folder, threshold=6 / 6120) build_histogram(true=true, fake=fake, function=get_max_energy, name="MaxEnergy", epoch=epoch, folder=self._folder) build_histogram(true=true, fake=fake, function=get_center_of_mass_x, name="CenterOfMassX", epoch=epoch, folder=self._folder, image_shape=self._image_shape) build_histogram(true=true, fake=fake, function=get_center_of_mass_y, name="CenterOfMassY", epoch=epoch, folder=self._folder, image_shape=self._image_shape) build_histogram(true=true, fake=fake, function=get_std_energy, name="StdEnergy", epoch=epoch, folder=self._folder) fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(18, 10)) axs = np.ravel(axs) if "Gradients" in self._monitor_dict: colors = ["green", "blue", "red"] axy_min = np.inf axy_max = -np.inf for go, gradient_ops in enumerate( self._monitor_dict["Gradients"][0]): grads = [ self._sess.run( gv[0], feed_dict={ self._X_input: self._x_test, self._Y_input: self._y_test, self._Z_input: self._z_test, self._is_training: False, self._output_label_real: self.get_random_label(is_real=True, size=self._nr_test), self._output_label_fake: self.get_random_label(is_real=False, size=self._nr_test) }) for gv in gradient_ops ] for op_idx, op in enumerate([np.mean, np.max, np.min]): self._monitor_dict["Gradients"][2][go * 3 + op_idx].append( op([op(grad) for grad in grads])) vals = self._monitor_dict["Gradients"][2][go * 3 + op_idx] if op_idx == 0: axs[0].plot( self._batches, vals, label=self._monitor_dict["Gradients"][1][go], color=colors[go]) else: axs[0].plot(self._batches, vals, linewidth=0.5, linestyle="--", color=colors[go]) upper = np.mean(vals) lower = np.mean(vals) if upper > axy_max: axy_max = upper if lower < axy_min: axy_min = lower axs[0].set_title("Gradients") axs[0].legend() axs[0].set_ylim([axy_min, axy_max]) current_batch_x, current_batch_y = self._trainset.get_next_batch( self.batch_size) Z_noise = self._generator.sample_noise(n=len(current_batch_x)) colors = [ "green", "blue", "red", "orange", "purple", "brown", "gray", "pink", "cyan", "olive" ] for k, key in enumerate(self._monitor_dict): if key == "Gradients": continue key_results = self._sess.run( self._monitor_dict[key][0], feed_dict={ self._X_input: current_batch_x, self._Y_input: current_batch_y, self._Z_input: Z_noise, self._is_training: True, self._output_label_real: self.get_random_label(is_real=True), self._output_label_fake: self.get_random_label(is_real=False) }) for kr, key_result in enumerate(key_results): try: self._monitor_dict[key][2][kr].append( self._monitor_dict[key][3][0](key_result)) except IndexError: self._monitor_dict[key][2][kr].append(key_result) axs[k].plot(self._batches, self._monitor_dict[key][2][kr], label=self._monitor_dict[key][1][kr], color=colors[kr]) axs[k].legend() axs[k].set_title(key) print("; ".join([ "{}: {}".format(name, round(float(val[-1]), 5)) for name, val in zip(self._monitor_dict[key][1], self._monitor_dict[key][2]) ])) gen_samples = self._sess.run( [self._output_gen_from_encoding], feed_dict={ self._X_input: current_batch_x, self._Z_input: Z_noise, self._is_training: False }) axs[-1].hist([np.ravel(gen_samples), np.ravel(current_batch_y)], label=["Generated", "True"]) axs[-1].set_title("Pixel distribution") axs[-1].legend() for check_idx, check_key in enumerate(self._check_dict): result_bools_of_check = [] check = self._check_dict[check_key] for tensor_idx in range(len(check["Tensors"])): tensor_ = self._sess.run(check["Tensors"][tensor_idx], feed_dict={ self._X_input: self._x_test, self._Y_input: self._y_test, self._Z_input: self._z_test, self._is_training: False }) tensor_op = check["OPonTensors"][tensor_idx](tensor_) if eval( str(tensor_op) + check["Relation"][tensor_idx] + str(check["Threshold"][tensor_idx])): result_bools_of_check.append(True) else: result_bools_of_check.append(False) if (tensor_idx > 0 and check["TensorRelation"]( *result_bools_of_check)) or (result_bools_of_check[0]): self._check_count[check_idx] += 1 if self._check_count[ check_idx] == self._max_allowed_failed_checks: raise GeneratorExit(check_key) else: self._check_count[check_idx] = 0 self._total_log_time += (time.clock() - log_start) fig.suptitle("Train {} / Log {} / Fails {}".format( np.round(self._total_train_time, 2), np.round(self._total_log_time, 2), self._check_count)) plt.savefig(self._folder + "/TrainStatistics.png") plt.close("all") def get_random_label(self, is_real, size=None): if size is None: size = self.batch_size labels_shape = [size, *self._output_adv_real.shape.as_list()[1:]] labels = np.ones(shape=labels_shape) if self._random_labeling > 0: relabel_mask = np.random.binomial(n=1, p=self._random_labeling, size=labels_shape) == 1 labels[relabel_mask] = 0 if not is_real: labels = 1 - labels return labels
class VAEGAN(GenerativeModel): def __init__(self, x_dim, z_dim, enc_architecture, gen_architecture, disc_architecture, folder="./VAEGAN"): super(VAEGAN, self).__init__( x_dim, z_dim, [enc_architecture, gen_architecture, disc_architecture], folder) self._enc_architecture = self._architectures[0] self._gen_architecture = self._architectures[1] self._disc_architecture = self._architectures[2] ################# Define architecture last_layer_mean = [ logged_dense, { "units": z_dim, "activation": tf.identity, "name": "Mean" } ] self._encoder_mean = Encoder(self._enc_architecture + [last_layer_mean], name="Encoder") last_layer_std = [ logged_dense, { "units": z_dim, "activation": tf.identity, "name": "Std" } ] self._encoder_std = Encoder(self._enc_architecture + [last_layer_std], name="Encoder") self._gen_architecture[-1][1]["name"] = "Output" self._generator = Generator(self._gen_architecture, name="Generator") self._disc_architecture.append( [tf.layers.flatten, { "name": "Flatten" }]) self._disc_architecture.append([ logged_dense, { "units": 1, "activation": tf.nn.sigmoid, "name": "Output" } ]) self._discriminator = Discriminator(self._disc_architecture, name="Discriminator") self._nets = [self._encoder_mean, self._generator, self._discriminator] ################# Connect inputs and networks self._mean_layer = self._encoder_mean.generate_net(self._X_input) self._std_layer = self._encoder_std.generate_net(self._X_input) self._output_enc_with_noise = self._mean_layer + tf.exp( 0.5 * self._std_layer) * self._Z_input self._output_gen = self._generator.generate_net( self._output_enc_with_noise) self._output_gen_from_encoding = self._generator.generate_net( self._Z_input) assert self._output_gen.get_shape()[1:] == x_dim, ( "Generator output must have shape of x_dim. Given: {}. Expected: {}." .format(self._output_gen.get_shape(), x_dim)) self._output_disc_real = self._discriminator.generate_net( self._X_input) self._output_disc_fake_from_real = self._discriminator.generate_net( self._output_gen) self._output_disc_fake_from_latent = self._discriminator.generate_net( self._output_gen_from_encoding) ################# Finalize self._init_folders() self._verify_init() def compile(self, learning_rate=0.0001, optimizer=tf.train.AdamOptimizer, label_smoothing=1, gamma=1): self._define_loss(label_smoothing=label_smoothing, gamma=gamma) with tf.name_scope("Optimizer"): enc_optimizer = optimizer(learning_rate=learning_rate) self._enc_optimizer = enc_optimizer.minimize( self._enc_loss, var_list=self._get_vars("Encoder"), name="Encoder") gen_optimizer = optimizer(learning_rate=learning_rate) self._gen_optimizer = gen_optimizer.minimize( self._gen_loss, var_list=self._get_vars("Generator"), name="Generator") disc_optimizer = optimizer(learning_rate=learning_rate) self._disc_optimizer = disc_optimizer.minimize( self._disc_loss, var_list=self._get_vars("Discriminator"), name="Discriminator") self._summarise() def _define_loss(self, label_smoothing, gamma): def get_labels_one(tensor): return tf.ones_like(tensor) * label_smoothing eps = 1e-7 ## Kullback-Leibler divergence self._KLdiv = 0.5 * (tf.square(self._mean_layer) + tf.exp(self._std_layer) - self._std_layer - 1) self._KLdiv = tf.reduce_mean(self._KLdiv) ## Feature matching loss otp_disc_real = self._discriminator.generate_net( self._X_input, tf_trainflag=self._is_training, return_idx=-2) otp_disc_fake = self._discriminator.generate_net( self._output_gen, tf_trainflag=self._is_training, return_idx=-2) self._feature_loss = tf.reduce_mean( tf.square(otp_disc_real - otp_disc_fake)) ## Discriminator loss self._logits_real = tf.math.log(self._output_disc_real / (1 + eps - self._output_disc_real) + eps) self._logits_fake_from_real = tf.math.log( self._output_disc_fake_from_real / (1 + eps - self._output_disc_fake_from_real) + eps) self._logits_fake_from_latent = tf.math.log( self._output_disc_fake_from_latent / (1 + eps - self._output_disc_fake_from_latent) + eps) self._generator_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( labels=get_labels_one(self._logits_fake_from_real), logits=self._logits_fake_from_real) + tf.nn.sigmoid_cross_entropy_with_logits( labels=get_labels_one(self._logits_fake_from_latent), logits=self._logits_fake_from_latent)) self._discriminator_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=get_labels_one( self._logits_real), logits=self._logits_real) + tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.zeros_like(self._logits_fake_from_real), logits=self._logits_fake_from_real) + tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.zeros_like(self._logits_fake_from_latent), logits=self._logits_fake_from_latent)) with tf.name_scope("Loss") as scope: self._enc_loss = self._KLdiv + self._feature_loss self._gen_loss = self._feature_loss + self._generator_loss self._disc_loss = self._discriminator_loss tf.summary.scalar("Encoder", self._enc_loss) tf.summary.scalar("Generator", self._gen_loss) tf.summary.scalar("Discriminator", self._disc_loss) def train(self, x_train, x_test, epochs=100, batch_size=64, disc_steps=5, gen_steps=1, log_step=3): self._set_up_training(log_step=log_step) self._set_up_test_train_sample(x_train, x_test) for epoch in range(epochs): batch_nr = 0 disc_loss_epoch = 0 gen_loss_epoch = 0 enc_loss_epoch = 0 start = time.clock() trained_examples = 0 while trained_examples < len(x_train): disc_loss_batch, gen_loss_batch, enc_loss_batch = self._optimize( self._trainset, batch_size, disc_steps, gen_steps) trained_examples += batch_size disc_loss_epoch += disc_loss_batch gen_loss_epoch += gen_loss_batch enc_loss_epoch += enc_loss_batch epoch_train_time = (time.clock() - start) / 60 disc_loss_epoch = np.round(disc_loss_epoch, 2) gen_loss_epoch = np.round(gen_loss_epoch, 2) enc_loss_epoch = np.round(enc_loss_epoch, 2) print("Epoch {}: D: {}; G: {}; E: {}.".format( epoch, disc_loss_epoch, gen_loss_epoch, enc_loss_epoch)) if log_step is not None: self._log(epoch, epoch_train_time) def _optimize(self, dataset, batch_size, disc_steps, gen_steps): for i in range(disc_steps): current_batch_x = dataset.get_next_batch(batch_size) Z_noise = self._generator.sample_noise(n=len(current_batch_x)) _, disc_loss_batch = self._sess.run( [self._disc_optimizer, self._disc_loss], feed_dict={ self._X_input: current_batch_x, self._Z_input: Z_noise }) for i in range(gen_steps): Z_noise = self._generator.sample_noise(n=len(current_batch_x)) _, gen_loss_batch = self._sess.run( [self._gen_optimizer, self._gen_loss], feed_dict={ self._X_input: current_batch_x, self._Z_input: Z_noise }) _, enc_loss_batch = self._sess.run( [self._enc_optimizer, self._enc_loss], feed_dict={ self._X_input: current_batch_x, self._Z_input: Z_noise }) return disc_loss_batch, gen_loss_batch, enc_loss_batch
class CGAN(ConditionalGenerativeModel): def __init__( self, x_dim, y_dim, z_dim, gen_architecture, adversarial_architecture, folder="./CGAN", append_y_at_every_layer=None, is_patchgan=False, is_wasserstein=False, aux_architecture=None, ): architectures = [gen_architecture, adversarial_architecture] self._is_cycle_consistent = False if aux_architecture is not None: architectures.append(aux_architecture) self._is_cycle_consistent = True super(CGAN, self).__init__(x_dim=x_dim, y_dim=y_dim, z_dim=z_dim, architectures=architectures, folder=folder, append_y_at_every_layer=append_y_at_every_layer) self._gen_architecture = self._architectures[0] self._adversarial_architecture = self._architectures[1] self._is_patchgan = is_patchgan self._is_wasserstein = is_wasserstein self._is_feature_matching = False ################# Define architecture if self._is_patchgan: f_xy = self._adversarial_architecture[-1][-1]["filters"] assert f_xy == 1, "If is PatchGAN, last layer of adversarial_XY needs 1 filter. Given: {}.".format( f_xy) a_xy = self._adversarial_architecture[-1][-1]["activation"] if self._is_wasserstein: assert a_xy == tf.identity, "If is PatchGAN, last layer of adversarial needs tf.identity. Given: {}.".format( a_xy) else: assert a_xy == tf.nn.sigmoid, "If is PatchGAN, last layer of adversarial needs tf.nn.sigmoid. Given: {}.".format( a_xy) else: self._adversarial_architecture.append( [tf.layers.flatten, { "name": "Flatten" }]) if self._is_wasserstein: self._adversarial_architecture.append([ logged_dense, { "units": 1, "activation": tf.identity, "name": "Output" } ]) else: self._adversarial_architecture.append([ logged_dense, { "units": 1, "activation": tf.nn.sigmoid, "name": "Output" } ]) self._gen_architecture[-1][1]["name"] = "Output" self._generator = ConditionalGenerator(self._gen_architecture, name="Generator") self._adversarial = Critic(self._adversarial_architecture, name="Adversarial") self._nets = [self._generator, self._adversarial] ################# Connect inputs and networks self._output_gen = self._generator.generate_net( self._mod_Z_input, append_elements_at_every_layer=self._append_at_every_layer, tf_trainflag=self._is_training) with tf.name_scope("InputsAdversarial"): if len(self._x_dim) == 1: self._input_real = tf.concat( axis=1, values=[self._X_input, self._Y_input], name="real") self._input_fake = tf.concat( axis=1, values=[self._output_gen, self._Y_input], name="fake") else: self._input_real = image_condition_concat( inputs=self._X_input, condition=self._Y_input, name="real") self._input_fake = image_condition_concat( inputs=self._output_gen, condition=self._Y_input, name="fake") self._output_adversarial_real = self._adversarial.generate_net( self._input_real, tf_trainflag=self._is_training) self._output_adversarial_fake = self._adversarial.generate_net( self._input_fake, tf_trainflag=self._is_training) assert self._output_gen.get_shape()[1:] == x_dim, ( "Output of generator is {}, but x_dim is {}.".format( self._output_gen.get_shape(), x_dim)) ################# Auxiliary network for cycle consistency if self._is_cycle_consistent: self._auxiliary = Encoder(self._architectures[2], name="Auxiliary") self._output_auxiliary = self._auxiliary.generate_net( self._output_gen, tf_trainflag=self._is_training) assert self._output_auxiliary.get_shape().as_list( ) == self._mod_Z_input.get_shape().as_list(), ( "Wrong shape for auxiliary vs. mod Z: {} vs {}.".format( self._output_auxiliary.get_shape(), self._mod_Z_input.get_shape())) self._nets.append(self._auxiliary) ################# Finalize self._init_folders() self._verify_init() if self._is_patchgan: print("PATCHGAN chosen with output: {}.".format( self._output_adversarial_real.shape)) def compile(self, loss, logged_images=None, logged_labels=None, learning_rate=0.0005, learning_rate_gen=None, learning_rate_adversarial=None, optimizer=tf.train.RMSPropOptimizer, feature_matching=False, label_smoothing=1): if self._is_wasserstein and loss != "wasserstein": raise ValueError( "If is_wasserstein is true in Constructor, loss needs to be wasserstein." ) if not self._is_wasserstein and loss == "wasserstein": raise ValueError( "If loss is wasserstein, is_wasserstein needs to be true in constructor." ) if learning_rate_gen is None: learning_rate_gen = learning_rate if learning_rate_adversarial is None: learning_rate_adversarial = learning_rate self._define_loss(loss, feature_matching, label_smoothing) with tf.name_scope("Optimizer"): gen_optimizer = optimizer(learning_rate=learning_rate_gen) self._gen_optimizer = gen_optimizer.minimize( self._gen_loss, var_list=self._get_vars("Generator"), name="Generator") adversarial_optimizer = optimizer( learning_rate=learning_rate_adversarial) self._adversarial_optimizer = adversarial_optimizer.minimize( self._adversarial_loss, var_list=self._get_vars("Adversarial"), name="Adversarial") if self._is_cycle_consistent: aux_optimizer = optimizer(learning_rate=learning_rate_gen) self._aux_optimizer = aux_optimizer.minimize( self._aux_loss, var_list=self._get_vars(scope="Generator") + self._get_vars(scope="Auxiliary"), name="Auxiliary") self._gen_grads_and_vars = gen_optimizer.compute_gradients( self._gen_loss) self._adversarial_grads_and_vars = adversarial_optimizer.compute_gradients( self._adversarial_loss) self._summarise(logged_images=logged_images, logged_labels=logged_labels) def _define_loss(self, loss, feature_matching, label_smoothing): possible_losses = ["cross-entropy", "L1", "L2", "wasserstein", "KL"] def get_labels_one(tensor): return tf.ones_like(tensor) * label_smoothing eps = 1e-7 if loss == "cross-entropy": self._logits_real = tf.math.log( self._output_adversarial_real / (1 + eps - self._output_adversarial_real) + eps) self._logits_fake = tf.math.log( self._output_adversarial_fake / (1 + eps - self._output_adversarial_fake) + eps) self._gen_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( labels=get_labels_one(self._logits_fake), logits=self._logits_fake)) self._adversarial_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( labels=get_labels_one(self._logits_real), logits=self._logits_real) + tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.zeros_like(self._logits_fake), logits=self._logits_fake)) elif loss == "L1": self._gen_loss = tf.reduce_mean( tf.abs(self._output_adversarial_fake - get_labels_one(self._output_adversarial_fake))) self._adversarial_loss = (tf.reduce_mean( tf.abs(self._output_adversarial_real - get_labels_one(self._output_adversarial_real)) + tf.abs(self._output_adversarial_fake))) / 2.0 elif loss == "L2": self._gen_loss = tf.reduce_mean( tf.square(self._output_adversarial_fake - get_labels_one(self._output_adversarial_fake))) self._adversarial_loss = (tf.reduce_mean( tf.square(self._output_adversarial_real - get_labels_one(self._output_adversarial_real)) + tf.square(self._output_adversarial_fake))) / 2.0 elif loss == "wasserstein": self._gen_loss = -tf.reduce_mean(self._output_adversarial_fake) self._adversarial_loss = ( -(tf.reduce_mean(self._output_adversarial_real) - tf.reduce_mean(self._output_adversarial_fake)) + 10 * self._define_gradient_penalty()) elif loss == "KL": self._logits_real = tf.math.log( self._output_adversarial_real / (1 + eps - self._output_adversarial_real) + eps) self._logits_fake = tf.math.log( self._output_adversarial_fake / (1 + eps - self._output_adversarial_fake) + eps) self._gen_loss = -tf.reduce_mean(self._logits_fake) self._adversarial_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( labels=get_labels_one(self._logits_real), logits=self._logits_real) + tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.zeros_like(self._logits_fake), logits=self._logits_fake)) else: raise ValueError( "Loss not implemented. Choose from {}. Given: {}.".format( possible_losses, loss)) if feature_matching: self._is_feature_matching = True otp_adv_real = self._adversarial.generate_net( self._input_real, tf_trainflag=self._is_training, return_idx=-2) otp_adv_fake = self._adversarial.generate_net( self._input_fake, tf_trainflag=self._is_training, return_idx=-2) self._gen_loss = tf.reduce_mean( tf.square(otp_adv_real - otp_adv_fake)) if self._is_cycle_consistent: self._aux_loss = tf.reduce_mean( tf.abs(self._mod_Z_input - self._output_auxiliary)) self._gen_loss += self._aux_loss with tf.name_scope("Loss") as scope: tf.summary.scalar("Generator_Loss", self._gen_loss) tf.summary.scalar("Adversarial_Loss", self._adversarial_loss) if self._is_cycle_consistent: tf.summary.scalar("Auxiliary_Loss", self._aux_loss) def _define_gradient_penalty(self): alpha = tf.random_uniform(shape=tf.shape(self._input_real), minval=0., maxval=1.) differences = self._input_fake - self._input_real interpolates = self._input_real + (alpha * differences) gradients = tf.gradients(self._adversarial.generate_net(interpolates), [interpolates])[0] slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients))) with tf.name_scope("Loss") as scope: self._gradient_penalty = tf.reduce_mean((slopes - 1.)**2) tf.summary.scalar("Gradient_penalty", self._gradient_penalty) return self._gradient_penalty def train(self, x_train, y_train, x_test=None, y_test=None, epochs=100, batch_size=64, gen_steps=1, adversarial_steps=5, log_step=3, batch_log_step=None, steps=None, gpu_options=None): if steps is not None: gen_steps = 1 adversarial_steps = steps self._set_up_training(log_step=log_step, gpu_options=gpu_options) self._set_up_test_train_sample(x_train, y_train, x_test, y_test) self._log_results(epoch=0, epoch_time=0) nr_batches = np.floor(len(x_train) / batch_size) self._dominating_adversarial = 0 self._gen_out_zero = 0 for epoch in range(epochs): batch_nr = 0 adversarial_loss_epoch = 0 gen_loss_epoch = 0 aux_loss_epoch = 0 start = time.clock() trained_examples = 0 ii = 0 while trained_examples < len(x_train): adversarial_loss_batch, gen_loss_batch, aux_loss_batch = self._optimize( self._trainset, batch_size, adversarial_steps, gen_steps) trained_examples += batch_size if np.isnan(adversarial_loss_batch) or np.isnan( gen_loss_batch): print("adversarialLoss / GenLoss: ", adversarial_loss_batch, gen_loss_batch) oar, oaf = self._sess.run( [ self._output_adversarial_real, self._output_adversarial_fake ], feed_dict={ self._X_input: self.current_batch_x, self._Y_input: self.current_batch_y, self._Z_input: self._Z_noise, self._is_training: True }) print(oar) print(oaf) print(np.max(oar)) print(np.max(oaf)) # self._check_tf_variables(ii, nr_batches) raise GeneratorExit("Nan found.") if (batch_log_step is not None) and (ii % batch_log_step == 0): batch_train_time = (time.clock() - start) / 60 self._log(int(epoch * nr_batches + ii), batch_train_time) adversarial_loss_epoch += adversarial_loss_batch gen_loss_epoch += gen_loss_batch aux_loss_epoch += aux_loss_batch ii += 1 epoch_train_time = (time.clock() - start) / 60 adversarial_loss_epoch = np.round(adversarial_loss_epoch, 2) gen_loss_epoch = np.round(gen_loss_epoch, 2) print("Epoch {}: Adversarial: {}.".format(epoch + 1, adversarial_loss_epoch)) print("\t\t\tGenerator: {}.".format(gen_loss_epoch)) print("\t\t\tEncoder: {}.".format(aux_loss_epoch)) if self._log_step is not None: self._log(epoch + 1, epoch_train_time) # self._check_tf_variables(epoch, epochs) def _optimize(self, dataset, batch_size, adversarial_steps, gen_steps): for i in range(adversarial_steps): current_batch_x, current_batch_y = dataset.get_next_batch( batch_size) # self.current_batch_x, self.current_batch_y = current_batch_x, current_batch_y self._Z_noise = self.sample_noise(n=len(current_batch_x)) _, adversarial_loss_batch = self._sess.run( [self._adversarial_optimizer, self._adversarial_loss], feed_dict={ self._X_input: current_batch_x, self._Y_input: current_batch_y, self._Z_input: self._Z_noise, self._is_training: True }) aux_loss_batch = 0 for _ in range(gen_steps): Z_noise = self._generator.sample_noise(n=len(current_batch_x)) if not self._is_feature_matching: _, gen_loss_batch = self._sess.run( [self._gen_optimizer, self._gen_loss], feed_dict={ self._Z_input: Z_noise, self._Y_input: current_batch_y, self._is_training: True }) else: _, gen_loss_batch = self._sess.run( [self._gen_optimizer, self._gen_loss], feed_dict={ self._X_input: current_batch_x, self._Y_input: current_batch_y, self._Z_input: self._Z_noise, self._is_training: True }) if self._is_cycle_consistent: _, aux_loss_batch = self._sess.run( [self._aux_optimizer, self._aux_loss], feed_dict={ self._Z_input: Z_noise, self._Y_input: current_batch_y, self._is_training: True }) return adversarial_loss_batch, gen_loss_batch, aux_loss_batch def predict(self, inpt_x, inpt_y): inpt = self._sess.run(self._input_real, feed_dict={ self._X_input: inpt_x, self._Y_input: inpt_y, self._is_training: True }) return self._adversarial.predict(inpt, self._sess) def _check_tf_variables(self, batch_nr, nr_batches): Z_noise = self._generator.sample_noise(n=len(self._x_test)) gen_grads = [ self._sess.run(gen_gv[0], feed_dict={ self._X_input: self._x_test, self._Y_input: self._y_test, self._Z_input: Z_noise, self._is_training: False }) for gen_gv in self._gen_grads_and_vars ] adversarial_grads = [ self._sess.run(adversarial_gv[0], feed_dict={ self._X_input: self._x_test, self._Y_input: self._y_test, self._Z_input: Z_noise, self._is_training: False }) for adversarial_gv in self._adversarial_grads_and_vars ] gen_grads_maxis = [np.max(gv) for gv in gen_grads] gen_grads_means = [np.mean(gv) for gv in gen_grads] gen_grads_minis = [np.min(gv) for gv in gen_grads] adversarial_grads_maxis = [np.max(dv) for dv in adversarial_grads] adversarial_grads_means = [np.mean(dv) for dv in adversarial_grads] adversarial_grads_minis = [np.min(dv) for dv in adversarial_grads] real_logits, fake_logits, gen_out = self._sess.run( [ self._output_adversarial_real, self._output_adversarial_fake, self._output_gen ], feed_dict={ self._X_input: self._x_test, self._Y_input: self._y_test, self._Z_input: Z_noise, self._is_training: False }) real_logits = np.mean(real_logits) fake_logits = np.mean(fake_logits) gen_varsis = np.array([ x.eval(session=self._sess) for x in self._generator.get_network_params() ]) adversarial_varsis = np.array([ x.eval(session=self._sess) for x in self._adversarial.get_network_params() ]) gen_maxis = np.array([np.max(x) for x in gen_varsis]) adversarial_maxis = np.array([np.max(x) for x in adversarial_varsis]) gen_means = np.array([np.mean(x) for x in gen_varsis]) adversarial_means = np.array([np.mean(x) for x in adversarial_varsis]) gen_minis = np.array([np.min(x) for x in gen_varsis]) adversarial_minis = np.array([np.min(x) for x in adversarial_varsis]) print(batch_nr, "/", nr_batches, ":") print("adversarialReal / adversarialFake: ", real_logits, fake_logits) print("GenWeight Max / Mean / Min: ", np.max(gen_maxis), np.mean(gen_means), np.min(gen_minis)) print("GenGrads Max / Mean / Min: ", np.max(gen_grads_maxis), np.mean(gen_grads_means), np.min(gen_grads_minis)) print("adversarialWeight Max / Mean / Min: ", np.max(adversarial_maxis), np.mean(adversarial_means), np.min(adversarial_minis)) print("adversarialGrads Max / Mean / Min: ", np.max(adversarial_grads_maxis), np.mean(adversarial_grads_means), np.min(adversarial_grads_minis)) print("GenOut Max / Mean / Min: ", np.max(gen_out), np.mean(gen_out), np.min(gen_out)) print("\n") if real_logits > 0.99 and fake_logits < 0.01: self._dominating_adversarial += 1 if self._dominating_adversarial == 5: raise GeneratorExit("Dominating adversarialriminator!") else: self._dominating_adversarial = 0 print(np.max(gen_out)) print(np.max(gen_out) < 0.05) if np.max(gen_out) < 0.05: self._gen_out_zero += 1 print(self._gen_out_zero) if self._gen_out_zero == 50: raise GeneratorExit("Generator outputs zeros") else: self._gen_out_zero = 0 print(self._gen_out_zero)