예제 #1
0
class CVAEGAN(Image2ImageGenerativeModel):
    def __init__(self,
                 x_dim,
                 y_dim,
                 z_dim,
                 enc_architecture,
                 gen_architecture,
                 adversarial_architecture,
                 folder="./CVAEGAN",
                 is_patchgan=False,
                 is_wasserstein=False):
        super(CVAEGAN, self).__init__(
            x_dim, y_dim,
            [enc_architecture, gen_architecture, adversarial_architecture],
            folder)

        self._z_dim = z_dim
        with tf.name_scope("Inputs"):
            self._Z_input = tf.placeholder(tf.float32,
                                           shape=[None, z_dim],
                                           name="z")

        self._enc_architecture = self._architectures[0]
        self._gen_architecture = self._architectures[1]
        self._adv_architecture = self._architectures[2]

        self._is_patchgan = is_patchgan
        self._is_wasserstein = is_wasserstein
        self._is_feature_matching = False

        ################# Define architecture
        if self._is_patchgan:
            f_xy = self._adv_architecture[-1][-1]["filters"]
            assert f_xy == 1, "If is PatchGAN, last layer of adversarial_XY needs 1 filter. Given: {}.".format(
                f_xy)

            a_xy = self._adv_architecture[-1][-1]["activation"]
            if self._is_wasserstein:
                assert a_xy == tf.identity, "If is PatchGAN, last layer of adversarial needs tf.identity. Given: {}.".format(
                    a_xy)
            else:
                assert a_xy == tf.nn.sigmoid, "If is PatchGAN, last layer of adversarial needs tf.nn.sigmoid. Given: {}.".format(
                    a_xy)
        else:
            self._adv_architecture.append(
                [tf.layers.flatten, {
                    "name": "Flatten"
                }])
            if self._is_wasserstein:
                self._adv_architecture.append([
                    logged_dense, {
                        "units": 1,
                        "activation": tf.identity,
                        "name": "Output"
                    }
                ])
            else:
                self._adv_architecture.append([
                    logged_dense, {
                        "units": 1,
                        "activation": tf.nn.sigmoid,
                        "name": "Output"
                    }
                ])

        last_layers_mean = [[tf.layers.flatten, {
            "name": "flatten"
        }],
                            [
                                logged_dense, {
                                    "units": z_dim,
                                    "activation": tf.identity,
                                    "name": "Mean"
                                }
                            ]]
        self._encoder_mean = Encoder(self._enc_architecture + last_layers_mean,
                                     name="Encoder")
        last_layers_std = [[tf.layers.flatten, {
            "name": "flatten"
        }],
                           [
                               logged_dense, {
                                   "units": z_dim,
                                   "activation": tf.identity,
                                   "name": "Std"
                               }
                           ]]
        self._encoder_std = Encoder(self._enc_architecture + last_layers_std,
                                    name="Encoder")

        self._gen_architecture[-1][1]["name"] = "Output"
        self._generator = Generator(self._gen_architecture, name="Generator")

        self._adversarial = Discriminator(self._adv_architecture,
                                          name="Adversarial")

        self._nets = [self._encoder_mean, self._generator, self._adversarial]

        ################# Connect inputs and networks
        self._mean_layer = self._encoder_mean.generate_net(self._Y_input)
        self._std_layer = self._encoder_std.generate_net(self._Y_input)

        self._output_enc_with_noise = self._mean_layer + tf.exp(
            0.5 * self._std_layer) * self._Z_input
        with tf.name_scope("Inputs"):
            self._gen_input = image_condition_concat(
                inputs=self._X_input,
                condition=self._output_enc_with_noise,
                name="mod_z_real")
            self._gen_input_from_encoding = image_condition_concat(
                inputs=self._X_input, condition=self._Z_input, name="mod_z")
        self._output_gen = self._generator.generate_net(self._gen_input)
        self._output_gen_from_encoding = self._generator.generate_net(
            self._gen_input_from_encoding)
        self._generator._input_dim = z_dim

        assert self._output_gen.get_shape()[1:] == y_dim, (
            "Generator output must have shape of y_dim. Given: {}. Expected: {}."
            .format(self._output_gen.get_shape(), x_dim))

        with tf.name_scope("InputsAdversarial"):
            self._input_real = tf.concat(values=[self._Y_input, self._X_input],
                                         axis=3)
            self._input_fake_from_real = tf.concat(
                values=[self._output_gen, self._X_input], axis=3)
            self._input_fake_from_latent = tf.concat(
                values=[self._output_gen_from_encoding, self._X_input], axis=3)

        self._output_adv_real = self._adversarial.generate_net(
            self._input_real)
        self._output_adv_fake_from_real = self._adversarial.generate_net(
            self._input_fake_from_real)
        self._output_adv_fake_from_latent = self._adversarial.generate_net(
            self._input_fake_from_latent)

        ################# Finalize
        self._init_folders()
        self._verify_init()

        self._output_label_real = tf.placeholder(
            tf.float32, shape=self._output_adv_real.shape, name="label_real")
        self._output_label_fake = tf.placeholder(
            tf.float32,
            shape=self._output_adv_fake_from_real.shape,
            name="label_fake")

        if self._is_patchgan:
            print("PATCHGAN chosen with output: {}.".format(
                self._output_adv_real.shape))

    def compile(self,
                loss,
                optimizer,
                learning_rate=None,
                learning_rate_enc=None,
                learning_rate_gen=None,
                learning_rate_adv=None,
                label_smoothing=1,
                lmbda_kl=0.1,
                lmbda_y=1,
                feature_matching=False,
                random_labeling=0):

        if self._is_wasserstein and loss != "wasserstein":
            raise ValueError(
                "If is_wasserstein is true in Constructor, loss needs to be wasserstein."
            )
        if not self._is_wasserstein and loss == "wasserstein":
            raise ValueError(
                "If loss is wasserstein, is_wasserstein needs to be true in constructor."
            )

        if np.all([
                lr is None for lr in [
                    learning_rate, learning_rate_enc, learning_rate_gen,
                    learning_rate_adv
                ]
        ]):
            raise ValueError("Need learning_rate.")
        if learning_rate is not None and learning_rate_enc is None:
            learning_rate_enc = learning_rate
        if learning_rate is not None and learning_rate_gen is None:
            learning_rate_gen = learning_rate
        if learning_rate is not None and learning_rate_adv is None:
            learning_rate_adv = learning_rate

        self._define_loss(loss=loss,
                          label_smoothing=label_smoothing,
                          lmbda_kl=lmbda_kl,
                          lmbda_y=lmbda_y,
                          feature_matching=feature_matching,
                          random_labeling=random_labeling)
        with tf.name_scope("Optimizer"):
            self._enc_optimizer = optimizer(learning_rate=learning_rate_enc)
            self._enc_optimizer_op = self._enc_optimizer.minimize(
                self._enc_loss,
                var_list=self._get_vars("Encoder"),
                name="Encoder")
            self._gen_optimizer = optimizer(learning_rate=learning_rate_gen)
            self._gen_oprimizer_op = self._gen_optimizer.minimize(
                self._gen_loss,
                var_list=self._get_vars("Generator"),
                name="Generator")
            self._adv_optimizer = optimizer(learning_rate=learning_rate_adv)
            self._adv_optimizer_op = self._adv_optimizer.minimize(
                self._adv_loss,
                var_list=self._get_vars("Adversarial"),
                name="Adversarial")
        self._summarise()

    def _define_loss(self, loss, label_smoothing, lmbda_kl, lmbda_y,
                     feature_matching, random_labeling):
        possible_losses = ["cross-entropy", "L2", "wasserstein", "KL"]

        def get_labels_one():
            return tf.math.multiply(self._output_label_real, label_smoothing)

        def get_labels_zero():
            return self._output_label_fake

        eps = 1e-6
        self._label_smoothing = label_smoothing
        self._random_labeling = random_labeling
        ## Kullback-Leibler divergence
        self._KLdiv = 0.5 * (tf.square(self._mean_layer) +
                             tf.exp(self._std_layer) - self._std_layer - 1)
        self._KLdiv = lmbda_kl * tf.reduce_mean(self._KLdiv)

        ## L1 loss
        self._recon_loss = lmbda_y * tf.reduce_mean(
            tf.abs(self._Y_input - self._output_gen))

        ## Adversarial loss
        if loss == "cross-entropy":
            self._logits_real = tf.math.log(self._output_adv_real /
                                            (1 + eps - self._output_adv_real) +
                                            eps)
            self._logits_fake_from_real = tf.math.log(
                self._output_adv_fake_from_real /
                (1 + eps - self._output_adv_fake_from_real) + eps)
            self._logits_fake_from_latent = tf.math.log(
                self._output_adv_fake_from_latent /
                (1 + eps - self._output_adv_fake_from_latent) + eps)
            self._generator_loss = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=tf.ones_like(self._logits_fake_from_real),
                    logits=self._logits_fake_from_real) +
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=tf.ones_like(self._logits_fake_from_latent),
                    logits=self._logits_fake_from_latent))
            self._adversarial_loss = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=get_labels_one(), logits=self._logits_real) +
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=get_labels_zero(),
                    logits=self._logits_fake_from_real) +
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=get_labels_zero(),
                    logits=self._logits_fake_from_latent))
        elif loss == "L2":
            self._generator_loss = tf.reduce_mean(
                tf.square(self._output_adv_fake_from_real -
                          tf.ones_like(self._output_adv_fake_from_real)) +
                tf.square(self._output_adv_fake_from_latent -
                          tf.ones_like(self._output_adv_fake_from_latent))) / 2
            self._adversarial_loss = (tf.reduce_mean(
                tf.square(self._output_adv_real - get_labels_one()) +
                tf.square(self._output_adv_fake_from_real -
                          get_labels_zero()) +
                tf.square(self._output_adv_fake_from_latent -
                          get_labels_zero()))) / 3.0
        elif loss == "wasserstein":
            self._generator_loss = -tf.reduce_mean(
                self._output_adv_fake_from_real) - tf.reduce_mean(
                    self._output_adv_fake_from_latent)
            self._adversarial_loss = (
                -(tf.reduce_mean(self._output_adv_real) -
                  tf.reduce_mean(self._output_adv_fake_from_real) -
                  tf.reduce_mean(self._output_adv_fake_from_latent)) +
                10 * self._define_gradient_penalty())
        elif loss == "KL":
            self._logits_real = tf.math.log(self._output_adv_real /
                                            (1 + eps - self._output_adv_real) +
                                            eps)
            self._logits_fake_from_real = tf.math.log(
                self._output_adv_fake_from_real /
                (1 + eps - self._output_adv_fake_from_real) + eps)
            self._logits_fake_from_latent = tf.math.log(
                self._output_adv_fake_from_latent /
                (1 + eps - self._output_adv_fake_from_latent) + eps)
            self._generator_loss = (
                -tf.reduce_mean(self._logits_fake_from_real) -
                tf.reduce_mean(self._logits_fake_from_latent)) / 2
            self._adversarial_loss = tf.reduce_mean(
                0.5 * tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=get_labels_one(), logits=self._logits_real) +
                0.25 * tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=get_labels_zero(),
                    logits=self._logits_fake_from_real) +
                0.25 * tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=get_labels_zero(),
                    logits=self._logits_fake_from_latent))
        else:
            raise ValueError(
                "Loss not implemented. Choose from {}. Given: {}.".format(
                    possible_losses, loss))

        if feature_matching:
            self._is_feature_matching = True
            otp_adv_real = self._adversarial.generate_net(
                self._input_real,
                tf_trainflag=self._is_training,
                return_idx=-2)
            otp_adv_fake = self._adversarial.generate_net(
                self._input_fake_from_real,
                tf_trainflag=self._is_training,
                return_idx=-2)
            self._generator_loss = tf.reduce_mean(
                tf.square(otp_adv_real - otp_adv_fake))

        with tf.name_scope("Loss") as scope:

            self._enc_loss = self._KLdiv + self._recon_loss + self._generator_loss
            self._gen_loss = self._recon_loss + self._generator_loss
            self._adv_loss = self._adversarial_loss

            tf.summary.scalar("Kullback-Leibler", self._KLdiv)
            tf.summary.scalar("Reconstruction", self._recon_loss)
            tf.summary.scalar("Vanilla_Generator", self._generator_loss)

            tf.summary.scalar("Encoder", self._enc_loss)
            tf.summary.scalar("Generator", self._gen_loss)
            tf.summary.scalar("Adversarial", self._adv_loss)

    def _define_gradient_penalty(self):
        alpha = tf.random_uniform(shape=tf.shape(self._input_real),
                                  minval=0.,
                                  maxval=1.)
        differences = self._input_fake_from_real - self._input_real
        interpolates = self._input_real + (alpha * differences)
        gradients = tf.gradients(self._adversarial.generate_net(interpolates),
                                 [interpolates])[0]
        slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients)))
        with tf.name_scope("Loss") as scope:
            self._gradient_penalty = tf.reduce_mean((slopes - 1.)**2)
            tf.summary.scalar("Gradient_penalty", self._gradient_penalty)
        return self._gradient_penalty

    def train(self,
              x_train,
              y_train,
              x_test,
              y_test,
              epochs=100,
              batch_size=64,
              adv_steps=5,
              gen_steps=1,
              log_step=3,
              gpu_options=None,
              batch_log_step=None):
        self._set_up_training(log_step=log_step, gpu_options=gpu_options)
        self._set_up_test_train_sample(x_train=x_train,
                                       y_train=y_train,
                                       x_test=x_test,
                                       y_test=y_test)
        self._z_test = self._generator.sample_noise(n=len(self._x_test))
        nr_batches = np.floor(len(x_train) / batch_size)
        self.batch_size = batch_size
        self._prepare_monitoring()
        self._log_results(epoch=0, epoch_time=0)

        for epoch in range(epochs):
            adv_loss_epoch = 0
            gen_loss_epoch = 0
            enc_loss_epoch = 0
            start = time.clock()
            trained_examples = 0
            batch_nr = 0

            while trained_examples < len(x_train):
                batch_train_start = time.clock()
                adv_loss_batch, gen_loss_batch, enc_loss_batch = self._optimize(
                    self._trainset, adv_steps, gen_steps)
                trained_examples += self.batch_size
                adv_loss_epoch += adv_loss_batch
                gen_loss_epoch += gen_loss_batch
                enc_loss_epoch += enc_loss_batch
                self._total_train_time += (time.clock() - batch_train_start)

                if (batch_log_step is not None) and (batch_nr % batch_log_step
                                                     == 0):
                    self._count_batches += batch_log_step
                    batch_train_time = (time.clock() - start) / 60
                    self._log(self._count_batches, batch_train_time)
                batch_nr += 1

            epoch_train_time = (time.clock() - start) / 60
            adv_loss_epoch = np.round(adv_loss_epoch, 2)
            gen_loss_epoch = np.round(gen_loss_epoch, 2)
            enc_loss_epoch = np.round(enc_loss_epoch, 2)

            print("\nEpoch {}: D: {}; G: {}; E: {}.".format(
                epoch, adv_loss_epoch, gen_loss_epoch, enc_loss_epoch))

            if batch_log_step is None and (log_step
                                           is not None) and (epoch % log_step
                                                             == 0):
                self._log(epoch + 1, epoch_train_time)

    def _optimize(self, dataset, adv_steps, gen_steps):
        for i in range(adv_steps):
            current_batch_x, current_batch_y = dataset.get_next_batch(
                self.batch_size)
            Z_noise = self._generator.sample_noise(n=len(current_batch_x))
            _, adv_loss_batch = self._sess.run(
                [self._adv_optimizer_op, self._adv_loss],
                feed_dict={
                    self._X_input: current_batch_x,
                    self._Y_input: current_batch_y,
                    self._Z_input: Z_noise,
                    self._is_training: True,
                    self._output_label_real:
                    self.get_random_label(is_real=True),
                    self._output_label_fake:
                    self.get_random_label(is_real=False)
                })

        for i in range(gen_steps):
            Z_noise = self._generator.sample_noise(n=len(current_batch_x))
            _, gen_loss_batch = self._sess.run(
                [self._gen_oprimizer_op, self._gen_loss],
                feed_dict={
                    self._X_input: current_batch_x,
                    self._Y_input: current_batch_y,
                    self._Z_input: Z_noise,
                    self._is_training: True
                })
            _, enc_loss_batch = self._sess.run(
                [self._enc_optimizer_op, self._enc_loss],
                feed_dict={
                    self._X_input: current_batch_x,
                    self._Y_input: current_batch_y,
                    self._Z_input: Z_noise,
                    self._is_training: True
                })

        return adv_loss_batch, gen_loss_batch, enc_loss_batch

    def _log_results(self, epoch, epoch_time):
        summary = self._sess.run(self._merged_summaries,
                                 feed_dict={
                                     self._X_input:
                                     self._x_test,
                                     self._Y_input:
                                     self._y_test,
                                     self._Z_input:
                                     self._z_test,
                                     self._epoch_time:
                                     epoch_time,
                                     self._is_training:
                                     False,
                                     self._epoch_nr:
                                     epoch,
                                     self._output_label_real:
                                     self.get_random_label(is_real=True,
                                                           size=self._nr_test),
                                     self._output_label_fake:
                                     self.get_random_label(is_real=False,
                                                           size=self._nr_test)
                                 })
        self._writer1.add_summary(summary, epoch)
        nr_test = len(self._x_test)
        summary = self._sess.run(self._merged_summaries,
                                 feed_dict={
                                     self._X_input:
                                     self._trainset.get_xdata()[:nr_test],
                                     self._Z_input:
                                     self._z_test,
                                     self._Y_input:
                                     self._trainset.get_ydata()[:nr_test],
                                     self._epoch_time:
                                     epoch_time,
                                     self._is_training:
                                     False,
                                     self._epoch_nr:
                                     epoch,
                                     self._output_label_real:
                                     self.get_random_label(is_real=True,
                                                           size=self._nr_test),
                                     self._output_label_fake:
                                     self.get_random_label(is_real=False,
                                                           size=self._nr_test)
                                 })
        self._writer2.add_summary(summary, epoch)
        if self._image_shape is not None:
            self.plot_samples(inpt_x=self._x_test[:10],
                              inpt_y=self._y_test[:10],
                              sess=self._sess,
                              image_shape=self._image_shape,
                              epoch=epoch,
                              path="{}/GeneratedSamples/result_{}.png".format(
                                  self._folder, epoch))
        self.save_model(epoch)
        additional_log = getattr(self, "evaluate", None)
        if callable(additional_log):
            self.evaluate(true=self._x_test,
                          condition=self._y_test,
                          epoch=epoch)
        print("Logged.")

    def plot_samples(self, inpt_x, inpt_y, sess, image_shape, epoch, path):
        outpt_xy = sess.run(self._output_gen_from_encoding,
                            feed_dict={
                                self._X_input: inpt_x,
                                self._Z_input: self._z_test[:len(inpt_x)],
                                self._is_training: False
                            })

        image_matrix = np.array([[
            x.reshape(self._x_dim[0], self._x_dim[1]),
            y.reshape(self._y_dim[0], self._y_dim[1]),
            np.zeros(shape=(self._x_dim[0], self._x_dim[1])),
            xy.reshape(self._y_dim[0], self._y_dim[1])
        ] for x, y, xy in zip(inpt_x, inpt_y, outpt_xy)])
        self._generator.build_generated_samples(
            image_matrix,
            column_titles=["True X", "True Y", "", "Gen_XY"],
            epoch=epoch,
            path=path)

    def _prepare_monitoring(self):
        self._total_train_time = 0
        self._total_log_time = 0
        self._count_batches = 0
        self._batches = []

        self._max_allowed_failed_checks = 20
        self._enc_grads_and_vars = self._enc_optimizer.compute_gradients(
            self._enc_loss, var_list=self._get_vars("Encoder"))
        self._gen_grads_and_vars = self._gen_optimizer.compute_gradients(
            self._gen_loss, var_list=self._get_vars("Generator"))
        self._adv_grads_and_vars = self._adv_optimizer.compute_gradients(
            self._adv_loss, var_list=self._get_vars("Adversarial"))

        self._monitor_dict = {
            "Gradients": [[
                self._enc_grads_and_vars, self._gen_grads_and_vars,
                self._adv_grads_and_vars
            ], ["Encoder", "Generator", "Adversarial"],
                          [[] for i in range(9)]],
            "Losses":
            [[
                self._enc_loss, self._gen_loss, self._adversarial_loss,
                self._generator_loss, self._recon_loss, self._KLdiv
            ],
             [
                 "Encoder (V+R+K)", "Generator (V+R)", "Adversarial",
                 "Vanilla_Generator", "Reconstruction", "Kullback-Leibler"
             ], [[] for i in range(6)]],
            "Output Adversarial": [[
                self._output_adv_fake_from_real,
                self._output_adv_fake_from_latent, self._output_adv_real
            ], ["Fake_from_real", "Fake_from_latent", "Real"],
                                   [[] for i in range(3)], [np.mean]]
        }

        self._check_dict = {
            "Dominating Discriminator": {
                "Tensors":
                [self._output_adv_real, self._output_adv_fake_from_real],
                "OPonTensors": [np.mean, np.mean],
                "Relation": [">", "<"],
                "Threshold": [
                    self._label_smoothing * 0.95,
                    (1 - self._label_smoothing) * 1.05
                ],
                "TensorRelation":
                np.logical_and
            },
            "Generator outputs zeros": {
                "Tensors": [
                    self._output_gen_from_encoding,
                    self._output_gen_from_encoding
                ],
                "OPonTensors": [np.max, np.min],
                "Relation": ["<", ">"],
                "Threshold": [0.05, 0.95],
                "TensorRelation":
                np.logical_or
            }
        }
        self._check_count = [0 for key in self._check_dict]

        if not os.path.exists(self._folder + "/Evaluation"):
            pos.mkdir(self._folder + "/Evaluation")
        os.mkdir(self._folder + "/Evaluation/Cells")
        os.mkdir(self._folder + "/Evaluation/CenterOfMassX")
        os.mkdir(self._folder + "/Evaluation/CenterOfMassY")
        os.mkdir(self._folder + "/Evaluation/Energy")
        os.mkdir(self._folder + "/Evaluation/MaxEnergy")
        os.mkdir(self._folder + "/Evaluation/StdEnergy")

    def evaluate(self, true, condition, epoch):
        print("Batch ", epoch)
        log_start = time.clock()
        self._batches.append(epoch)

        fake = self._sess.run(self._output_gen_from_encoding,
                              feed_dict={
                                  self._X_input: self._x_test,
                                  self._Z_input: self._z_test,
                                  self._is_training: False
                              })
        true = self._y_test.reshape(
            [-1, self._image_shape[0], self._image_shape[1]])
        fake = fake.reshape([-1, self._image_shape[0], self._image_shape[1]])
        build_histogram(true=true,
                        fake=fake,
                        function=get_energies,
                        name="Energy",
                        epoch=epoch,
                        folder=self._folder)
        build_histogram(true=true,
                        fake=fake,
                        function=get_number_of_activated_cells,
                        name="Cells",
                        epoch=epoch,
                        folder=self._folder,
                        threshold=6 / 6120)
        build_histogram(true=true,
                        fake=fake,
                        function=get_max_energy,
                        name="MaxEnergy",
                        epoch=epoch,
                        folder=self._folder)
        build_histogram(true=true,
                        fake=fake,
                        function=get_center_of_mass_x,
                        name="CenterOfMassX",
                        epoch=epoch,
                        folder=self._folder,
                        image_shape=self._image_shape)
        build_histogram(true=true,
                        fake=fake,
                        function=get_center_of_mass_y,
                        name="CenterOfMassY",
                        epoch=epoch,
                        folder=self._folder,
                        image_shape=self._image_shape)
        build_histogram(true=true,
                        fake=fake,
                        function=get_std_energy,
                        name="StdEnergy",
                        epoch=epoch,
                        folder=self._folder)

        fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(18, 10))
        axs = np.ravel(axs)
        if "Gradients" in self._monitor_dict:
            colors = ["green", "blue", "red"]
            axy_min = np.inf
            axy_max = -np.inf
            for go, gradient_ops in enumerate(
                    self._monitor_dict["Gradients"][0]):
                grads = [
                    self._sess.run(
                        gv[0],
                        feed_dict={
                            self._X_input:
                            self._x_test,
                            self._Y_input:
                            self._y_test,
                            self._Z_input:
                            self._z_test,
                            self._is_training:
                            False,
                            self._output_label_real:
                            self.get_random_label(is_real=True,
                                                  size=self._nr_test),
                            self._output_label_fake:
                            self.get_random_label(is_real=False,
                                                  size=self._nr_test)
                        }) for gv in gradient_ops
                ]

                for op_idx, op in enumerate([np.mean, np.max, np.min]):
                    self._monitor_dict["Gradients"][2][go * 3 + op_idx].append(
                        op([op(grad) for grad in grads]))
                    vals = self._monitor_dict["Gradients"][2][go * 3 + op_idx]
                    if op_idx == 0:
                        axs[0].plot(
                            self._batches,
                            vals,
                            label=self._monitor_dict["Gradients"][1][go],
                            color=colors[go])
                    else:
                        axs[0].plot(self._batches,
                                    vals,
                                    linewidth=0.5,
                                    linestyle="--",
                                    color=colors[go])
                        upper = np.mean(vals)
                        lower = np.mean(vals)
                        if upper > axy_max:
                            axy_max = upper
                        if lower < axy_min:
                            axy_min = lower
        axs[0].set_title("Gradients")
        axs[0].legend()
        axs[0].set_ylim([axy_min, axy_max])

        current_batch_x, current_batch_y = self._trainset.get_next_batch(
            self.batch_size)
        Z_noise = self._generator.sample_noise(n=len(current_batch_x))

        colors = [
            "green", "blue", "red", "orange", "purple", "brown", "gray",
            "pink", "cyan", "olive"
        ]
        for k, key in enumerate(self._monitor_dict):
            if key == "Gradients":
                continue
            key_results = self._sess.run(
                self._monitor_dict[key][0],
                feed_dict={
                    self._X_input: current_batch_x,
                    self._Y_input: current_batch_y,
                    self._Z_input: Z_noise,
                    self._is_training: True,
                    self._output_label_real:
                    self.get_random_label(is_real=True),
                    self._output_label_fake:
                    self.get_random_label(is_real=False)
                })
            for kr, key_result in enumerate(key_results):
                try:
                    self._monitor_dict[key][2][kr].append(
                        self._monitor_dict[key][3][0](key_result))
                except IndexError:
                    self._monitor_dict[key][2][kr].append(key_result)
                axs[k].plot(self._batches,
                            self._monitor_dict[key][2][kr],
                            label=self._monitor_dict[key][1][kr],
                            color=colors[kr])
            axs[k].legend()
            axs[k].set_title(key)
            print("; ".join([
                "{}: {}".format(name, round(float(val[-1]), 5)) for name, val
                in zip(self._monitor_dict[key][1], self._monitor_dict[key][2])
            ]))

        gen_samples = self._sess.run(
            [self._output_gen_from_encoding],
            feed_dict={
                self._X_input: current_batch_x,
                self._Z_input: Z_noise,
                self._is_training: False
            })
        axs[-1].hist([np.ravel(gen_samples),
                      np.ravel(current_batch_y)],
                     label=["Generated", "True"])
        axs[-1].set_title("Pixel distribution")
        axs[-1].legend()

        for check_idx, check_key in enumerate(self._check_dict):
            result_bools_of_check = []
            check = self._check_dict[check_key]
            for tensor_idx in range(len(check["Tensors"])):
                tensor_ = self._sess.run(check["Tensors"][tensor_idx],
                                         feed_dict={
                                             self._X_input: self._x_test,
                                             self._Y_input: self._y_test,
                                             self._Z_input: self._z_test,
                                             self._is_training: False
                                         })
                tensor_op = check["OPonTensors"][tensor_idx](tensor_)
                if eval(
                        str(tensor_op) + check["Relation"][tensor_idx] +
                        str(check["Threshold"][tensor_idx])):
                    result_bools_of_check.append(True)
                else:
                    result_bools_of_check.append(False)
            if (tensor_idx > 0 and check["TensorRelation"](
                    *result_bools_of_check)) or (result_bools_of_check[0]):
                self._check_count[check_idx] += 1
                if self._check_count[
                        check_idx] == self._max_allowed_failed_checks:
                    raise GeneratorExit(check_key)
            else:
                self._check_count[check_idx] = 0

        self._total_log_time += (time.clock() - log_start)
        fig.suptitle("Train {} / Log {} / Fails {}".format(
            np.round(self._total_train_time, 2),
            np.round(self._total_log_time, 2), self._check_count))

        plt.savefig(self._folder + "/TrainStatistics.png")
        plt.close("all")

    def get_random_label(self, is_real, size=None):
        if size is None:
            size = self.batch_size
        labels_shape = [size, *self._output_adv_real.shape.as_list()[1:]]
        labels = np.ones(shape=labels_shape)
        if self._random_labeling > 0:
            relabel_mask = np.random.binomial(n=1,
                                              p=self._random_labeling,
                                              size=labels_shape) == 1
            labels[relabel_mask] = 0
        if not is_real:
            labels = 1 - labels
        return labels
예제 #2
0
class VAEGAN(GenerativeModel):
    def __init__(self,
                 x_dim,
                 z_dim,
                 enc_architecture,
                 gen_architecture,
                 disc_architecture,
                 folder="./VAEGAN"):
        super(VAEGAN, self).__init__(
            x_dim, z_dim,
            [enc_architecture, gen_architecture, disc_architecture], folder)

        self._enc_architecture = self._architectures[0]
        self._gen_architecture = self._architectures[1]
        self._disc_architecture = self._architectures[2]

        ################# Define architecture
        last_layer_mean = [
            logged_dense, {
                "units": z_dim,
                "activation": tf.identity,
                "name": "Mean"
            }
        ]
        self._encoder_mean = Encoder(self._enc_architecture +
                                     [last_layer_mean],
                                     name="Encoder")
        last_layer_std = [
            logged_dense, {
                "units": z_dim,
                "activation": tf.identity,
                "name": "Std"
            }
        ]
        self._encoder_std = Encoder(self._enc_architecture + [last_layer_std],
                                    name="Encoder")

        self._gen_architecture[-1][1]["name"] = "Output"
        self._generator = Generator(self._gen_architecture, name="Generator")

        self._disc_architecture.append(
            [tf.layers.flatten, {
                "name": "Flatten"
            }])
        self._disc_architecture.append([
            logged_dense, {
                "units": 1,
                "activation": tf.nn.sigmoid,
                "name": "Output"
            }
        ])
        self._discriminator = Discriminator(self._disc_architecture,
                                            name="Discriminator")

        self._nets = [self._encoder_mean, self._generator, self._discriminator]

        ################# Connect inputs and networks
        self._mean_layer = self._encoder_mean.generate_net(self._X_input)
        self._std_layer = self._encoder_std.generate_net(self._X_input)

        self._output_enc_with_noise = self._mean_layer + tf.exp(
            0.5 * self._std_layer) * self._Z_input

        self._output_gen = self._generator.generate_net(
            self._output_enc_with_noise)
        self._output_gen_from_encoding = self._generator.generate_net(
            self._Z_input)

        assert self._output_gen.get_shape()[1:] == x_dim, (
            "Generator output must have shape of x_dim. Given: {}. Expected: {}."
            .format(self._output_gen.get_shape(), x_dim))

        self._output_disc_real = self._discriminator.generate_net(
            self._X_input)
        self._output_disc_fake_from_real = self._discriminator.generate_net(
            self._output_gen)
        self._output_disc_fake_from_latent = self._discriminator.generate_net(
            self._output_gen_from_encoding)

        ################# Finalize
        self._init_folders()
        self._verify_init()

    def compile(self,
                learning_rate=0.0001,
                optimizer=tf.train.AdamOptimizer,
                label_smoothing=1,
                gamma=1):
        self._define_loss(label_smoothing=label_smoothing, gamma=gamma)
        with tf.name_scope("Optimizer"):
            enc_optimizer = optimizer(learning_rate=learning_rate)
            self._enc_optimizer = enc_optimizer.minimize(
                self._enc_loss,
                var_list=self._get_vars("Encoder"),
                name="Encoder")
            gen_optimizer = optimizer(learning_rate=learning_rate)
            self._gen_optimizer = gen_optimizer.minimize(
                self._gen_loss,
                var_list=self._get_vars("Generator"),
                name="Generator")
            disc_optimizer = optimizer(learning_rate=learning_rate)
            self._disc_optimizer = disc_optimizer.minimize(
                self._disc_loss,
                var_list=self._get_vars("Discriminator"),
                name="Discriminator")
        self._summarise()

    def _define_loss(self, label_smoothing, gamma):
        def get_labels_one(tensor):
            return tf.ones_like(tensor) * label_smoothing

        eps = 1e-7
        ## Kullback-Leibler divergence
        self._KLdiv = 0.5 * (tf.square(self._mean_layer) +
                             tf.exp(self._std_layer) - self._std_layer - 1)
        self._KLdiv = tf.reduce_mean(self._KLdiv)

        ## Feature matching loss
        otp_disc_real = self._discriminator.generate_net(
            self._X_input, tf_trainflag=self._is_training, return_idx=-2)
        otp_disc_fake = self._discriminator.generate_net(
            self._output_gen, tf_trainflag=self._is_training, return_idx=-2)
        self._feature_loss = tf.reduce_mean(
            tf.square(otp_disc_real - otp_disc_fake))

        ## Discriminator loss
        self._logits_real = tf.math.log(self._output_disc_real /
                                        (1 + eps - self._output_disc_real) +
                                        eps)
        self._logits_fake_from_real = tf.math.log(
            self._output_disc_fake_from_real /
            (1 + eps - self._output_disc_fake_from_real) + eps)
        self._logits_fake_from_latent = tf.math.log(
            self._output_disc_fake_from_latent /
            (1 + eps - self._output_disc_fake_from_latent) + eps)
        self._generator_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                labels=get_labels_one(self._logits_fake_from_real),
                logits=self._logits_fake_from_real) +
            tf.nn.sigmoid_cross_entropy_with_logits(
                labels=get_labels_one(self._logits_fake_from_latent),
                logits=self._logits_fake_from_latent))
        self._discriminator_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(labels=get_labels_one(
                self._logits_real),
                                                    logits=self._logits_real) +
            tf.nn.sigmoid_cross_entropy_with_logits(
                labels=tf.zeros_like(self._logits_fake_from_real),
                logits=self._logits_fake_from_real) +
            tf.nn.sigmoid_cross_entropy_with_logits(
                labels=tf.zeros_like(self._logits_fake_from_latent),
                logits=self._logits_fake_from_latent))

        with tf.name_scope("Loss") as scope:

            self._enc_loss = self._KLdiv + self._feature_loss
            self._gen_loss = self._feature_loss + self._generator_loss
            self._disc_loss = self._discriminator_loss

            tf.summary.scalar("Encoder", self._enc_loss)
            tf.summary.scalar("Generator", self._gen_loss)
            tf.summary.scalar("Discriminator", self._disc_loss)

    def train(self,
              x_train,
              x_test,
              epochs=100,
              batch_size=64,
              disc_steps=5,
              gen_steps=1,
              log_step=3):
        self._set_up_training(log_step=log_step)
        self._set_up_test_train_sample(x_train, x_test)
        for epoch in range(epochs):
            batch_nr = 0
            disc_loss_epoch = 0
            gen_loss_epoch = 0
            enc_loss_epoch = 0
            start = time.clock()
            trained_examples = 0
            while trained_examples < len(x_train):
                disc_loss_batch, gen_loss_batch, enc_loss_batch = self._optimize(
                    self._trainset, batch_size, disc_steps, gen_steps)
                trained_examples += batch_size
                disc_loss_epoch += disc_loss_batch
                gen_loss_epoch += gen_loss_batch
                enc_loss_epoch += enc_loss_batch

            epoch_train_time = (time.clock() - start) / 60
            disc_loss_epoch = np.round(disc_loss_epoch, 2)
            gen_loss_epoch = np.round(gen_loss_epoch, 2)
            enc_loss_epoch = np.round(enc_loss_epoch, 2)

            print("Epoch {}: D: {}; G: {}; E: {}.".format(
                epoch, disc_loss_epoch, gen_loss_epoch, enc_loss_epoch))

            if log_step is not None:
                self._log(epoch, epoch_train_time)

    def _optimize(self, dataset, batch_size, disc_steps, gen_steps):
        for i in range(disc_steps):
            current_batch_x = dataset.get_next_batch(batch_size)
            Z_noise = self._generator.sample_noise(n=len(current_batch_x))
            _, disc_loss_batch = self._sess.run(
                [self._disc_optimizer, self._disc_loss],
                feed_dict={
                    self._X_input: current_batch_x,
                    self._Z_input: Z_noise
                })

        for i in range(gen_steps):
            Z_noise = self._generator.sample_noise(n=len(current_batch_x))
            _, gen_loss_batch = self._sess.run(
                [self._gen_optimizer, self._gen_loss],
                feed_dict={
                    self._X_input: current_batch_x,
                    self._Z_input: Z_noise
                })
            _, enc_loss_batch = self._sess.run(
                [self._enc_optimizer, self._enc_loss],
                feed_dict={
                    self._X_input: current_batch_x,
                    self._Z_input: Z_noise
                })

        return disc_loss_batch, gen_loss_batch, enc_loss_batch
예제 #3
0
class CGAN(ConditionalGenerativeModel):
    def __init__(
        self,
        x_dim,
        y_dim,
        z_dim,
        gen_architecture,
        adversarial_architecture,
        folder="./CGAN",
        append_y_at_every_layer=None,
        is_patchgan=False,
        is_wasserstein=False,
        aux_architecture=None,
    ):
        architectures = [gen_architecture, adversarial_architecture]
        self._is_cycle_consistent = False
        if aux_architecture is not None:
            architectures.append(aux_architecture)
            self._is_cycle_consistent = True
        super(CGAN,
              self).__init__(x_dim=x_dim,
                             y_dim=y_dim,
                             z_dim=z_dim,
                             architectures=architectures,
                             folder=folder,
                             append_y_at_every_layer=append_y_at_every_layer)

        self._gen_architecture = self._architectures[0]
        self._adversarial_architecture = self._architectures[1]
        self._is_patchgan = is_patchgan
        self._is_wasserstein = is_wasserstein
        self._is_feature_matching = False

        ################# Define architecture
        if self._is_patchgan:
            f_xy = self._adversarial_architecture[-1][-1]["filters"]
            assert f_xy == 1, "If is PatchGAN, last layer of adversarial_XY needs 1 filter. Given: {}.".format(
                f_xy)

            a_xy = self._adversarial_architecture[-1][-1]["activation"]
            if self._is_wasserstein:
                assert a_xy == tf.identity, "If is PatchGAN, last layer of adversarial needs tf.identity. Given: {}.".format(
                    a_xy)
            else:
                assert a_xy == tf.nn.sigmoid, "If is PatchGAN, last layer of adversarial needs tf.nn.sigmoid. Given: {}.".format(
                    a_xy)
        else:
            self._adversarial_architecture.append(
                [tf.layers.flatten, {
                    "name": "Flatten"
                }])
            if self._is_wasserstein:
                self._adversarial_architecture.append([
                    logged_dense, {
                        "units": 1,
                        "activation": tf.identity,
                        "name": "Output"
                    }
                ])
            else:
                self._adversarial_architecture.append([
                    logged_dense, {
                        "units": 1,
                        "activation": tf.nn.sigmoid,
                        "name": "Output"
                    }
                ])
        self._gen_architecture[-1][1]["name"] = "Output"

        self._generator = ConditionalGenerator(self._gen_architecture,
                                               name="Generator")
        self._adversarial = Critic(self._adversarial_architecture,
                                   name="Adversarial")

        self._nets = [self._generator, self._adversarial]

        ################# Connect inputs and networks
        self._output_gen = self._generator.generate_net(
            self._mod_Z_input,
            append_elements_at_every_layer=self._append_at_every_layer,
            tf_trainflag=self._is_training)

        with tf.name_scope("InputsAdversarial"):
            if len(self._x_dim) == 1:
                self._input_real = tf.concat(
                    axis=1, values=[self._X_input, self._Y_input], name="real")
                self._input_fake = tf.concat(
                    axis=1,
                    values=[self._output_gen, self._Y_input],
                    name="fake")
            else:
                self._input_real = image_condition_concat(
                    inputs=self._X_input, condition=self._Y_input, name="real")
                self._input_fake = image_condition_concat(
                    inputs=self._output_gen,
                    condition=self._Y_input,
                    name="fake")

        self._output_adversarial_real = self._adversarial.generate_net(
            self._input_real, tf_trainflag=self._is_training)
        self._output_adversarial_fake = self._adversarial.generate_net(
            self._input_fake, tf_trainflag=self._is_training)

        assert self._output_gen.get_shape()[1:] == x_dim, (
            "Output of generator is {}, but x_dim is {}.".format(
                self._output_gen.get_shape(), x_dim))

        ################# Auxiliary network for cycle consistency
        if self._is_cycle_consistent:
            self._auxiliary = Encoder(self._architectures[2], name="Auxiliary")
            self._output_auxiliary = self._auxiliary.generate_net(
                self._output_gen, tf_trainflag=self._is_training)
            assert self._output_auxiliary.get_shape().as_list(
            ) == self._mod_Z_input.get_shape().as_list(), (
                "Wrong shape for auxiliary vs. mod Z: {} vs {}.".format(
                    self._output_auxiliary.get_shape(),
                    self._mod_Z_input.get_shape()))
            self._nets.append(self._auxiliary)

        ################# Finalize
        self._init_folders()
        self._verify_init()

        if self._is_patchgan:
            print("PATCHGAN chosen with output: {}.".format(
                self._output_adversarial_real.shape))

    def compile(self,
                loss,
                logged_images=None,
                logged_labels=None,
                learning_rate=0.0005,
                learning_rate_gen=None,
                learning_rate_adversarial=None,
                optimizer=tf.train.RMSPropOptimizer,
                feature_matching=False,
                label_smoothing=1):
        if self._is_wasserstein and loss != "wasserstein":
            raise ValueError(
                "If is_wasserstein is true in Constructor, loss needs to be wasserstein."
            )
        if not self._is_wasserstein and loss == "wasserstein":
            raise ValueError(
                "If loss is wasserstein, is_wasserstein needs to be true in constructor."
            )

        if learning_rate_gen is None:
            learning_rate_gen = learning_rate
        if learning_rate_adversarial is None:
            learning_rate_adversarial = learning_rate
        self._define_loss(loss, feature_matching, label_smoothing)
        with tf.name_scope("Optimizer"):
            gen_optimizer = optimizer(learning_rate=learning_rate_gen)
            self._gen_optimizer = gen_optimizer.minimize(
                self._gen_loss,
                var_list=self._get_vars("Generator"),
                name="Generator")
            adversarial_optimizer = optimizer(
                learning_rate=learning_rate_adversarial)
            self._adversarial_optimizer = adversarial_optimizer.minimize(
                self._adversarial_loss,
                var_list=self._get_vars("Adversarial"),
                name="Adversarial")

            if self._is_cycle_consistent:
                aux_optimizer = optimizer(learning_rate=learning_rate_gen)
                self._aux_optimizer = aux_optimizer.minimize(
                    self._aux_loss,
                    var_list=self._get_vars(scope="Generator") +
                    self._get_vars(scope="Auxiliary"),
                    name="Auxiliary")

            self._gen_grads_and_vars = gen_optimizer.compute_gradients(
                self._gen_loss)
            self._adversarial_grads_and_vars = adversarial_optimizer.compute_gradients(
                self._adversarial_loss)
        self._summarise(logged_images=logged_images,
                        logged_labels=logged_labels)

    def _define_loss(self, loss, feature_matching, label_smoothing):
        possible_losses = ["cross-entropy", "L1", "L2", "wasserstein", "KL"]

        def get_labels_one(tensor):
            return tf.ones_like(tensor) * label_smoothing

        eps = 1e-7
        if loss == "cross-entropy":
            self._logits_real = tf.math.log(
                self._output_adversarial_real /
                (1 + eps - self._output_adversarial_real) + eps)
            self._logits_fake = tf.math.log(
                self._output_adversarial_fake /
                (1 + eps - self._output_adversarial_fake) + eps)

            self._gen_loss = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=get_labels_one(self._logits_fake),
                    logits=self._logits_fake))
            self._adversarial_loss = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=get_labels_one(self._logits_real),
                    logits=self._logits_real) +
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=tf.zeros_like(self._logits_fake),
                    logits=self._logits_fake))

        elif loss == "L1":
            self._gen_loss = tf.reduce_mean(
                tf.abs(self._output_adversarial_fake -
                       get_labels_one(self._output_adversarial_fake)))
            self._adversarial_loss = (tf.reduce_mean(
                tf.abs(self._output_adversarial_real -
                       get_labels_one(self._output_adversarial_real)) +
                tf.abs(self._output_adversarial_fake))) / 2.0

        elif loss == "L2":
            self._gen_loss = tf.reduce_mean(
                tf.square(self._output_adversarial_fake -
                          get_labels_one(self._output_adversarial_fake)))
            self._adversarial_loss = (tf.reduce_mean(
                tf.square(self._output_adversarial_real -
                          get_labels_one(self._output_adversarial_real)) +
                tf.square(self._output_adversarial_fake))) / 2.0
        elif loss == "wasserstein":
            self._gen_loss = -tf.reduce_mean(self._output_adversarial_fake)
            self._adversarial_loss = (
                -(tf.reduce_mean(self._output_adversarial_real) -
                  tf.reduce_mean(self._output_adversarial_fake)) +
                10 * self._define_gradient_penalty())
        elif loss == "KL":
            self._logits_real = tf.math.log(
                self._output_adversarial_real /
                (1 + eps - self._output_adversarial_real) + eps)
            self._logits_fake = tf.math.log(
                self._output_adversarial_fake /
                (1 + eps - self._output_adversarial_fake) + eps)

            self._gen_loss = -tf.reduce_mean(self._logits_fake)
            self._adversarial_loss = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=get_labels_one(self._logits_real),
                    logits=self._logits_real) +
                tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=tf.zeros_like(self._logits_fake),
                    logits=self._logits_fake))
        else:
            raise ValueError(
                "Loss not implemented. Choose from {}. Given: {}.".format(
                    possible_losses, loss))

        if feature_matching:
            self._is_feature_matching = True
            otp_adv_real = self._adversarial.generate_net(
                self._input_real,
                tf_trainflag=self._is_training,
                return_idx=-2)
            otp_adv_fake = self._adversarial.generate_net(
                self._input_fake,
                tf_trainflag=self._is_training,
                return_idx=-2)
            self._gen_loss = tf.reduce_mean(
                tf.square(otp_adv_real - otp_adv_fake))

        if self._is_cycle_consistent:
            self._aux_loss = tf.reduce_mean(
                tf.abs(self._mod_Z_input - self._output_auxiliary))
            self._gen_loss += self._aux_loss

        with tf.name_scope("Loss") as scope:
            tf.summary.scalar("Generator_Loss", self._gen_loss)
            tf.summary.scalar("Adversarial_Loss", self._adversarial_loss)
            if self._is_cycle_consistent:
                tf.summary.scalar("Auxiliary_Loss", self._aux_loss)

    def _define_gradient_penalty(self):
        alpha = tf.random_uniform(shape=tf.shape(self._input_real),
                                  minval=0.,
                                  maxval=1.)
        differences = self._input_fake - self._input_real
        interpolates = self._input_real + (alpha * differences)
        gradients = tf.gradients(self._adversarial.generate_net(interpolates),
                                 [interpolates])[0]
        slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients)))
        with tf.name_scope("Loss") as scope:
            self._gradient_penalty = tf.reduce_mean((slopes - 1.)**2)
            tf.summary.scalar("Gradient_penalty", self._gradient_penalty)
        return self._gradient_penalty

    def train(self,
              x_train,
              y_train,
              x_test=None,
              y_test=None,
              epochs=100,
              batch_size=64,
              gen_steps=1,
              adversarial_steps=5,
              log_step=3,
              batch_log_step=None,
              steps=None,
              gpu_options=None):
        if steps is not None:
            gen_steps = 1
            adversarial_steps = steps
        self._set_up_training(log_step=log_step, gpu_options=gpu_options)
        self._set_up_test_train_sample(x_train, y_train, x_test, y_test)
        self._log_results(epoch=0, epoch_time=0)
        nr_batches = np.floor(len(x_train) / batch_size)

        self._dominating_adversarial = 0
        self._gen_out_zero = 0
        for epoch in range(epochs):
            batch_nr = 0
            adversarial_loss_epoch = 0
            gen_loss_epoch = 0
            aux_loss_epoch = 0
            start = time.clock()
            trained_examples = 0
            ii = 0

            while trained_examples < len(x_train):
                adversarial_loss_batch, gen_loss_batch, aux_loss_batch = self._optimize(
                    self._trainset, batch_size, adversarial_steps, gen_steps)
                trained_examples += batch_size

                if np.isnan(adversarial_loss_batch) or np.isnan(
                        gen_loss_batch):
                    print("adversarialLoss / GenLoss: ",
                          adversarial_loss_batch, gen_loss_batch)
                    oar, oaf = self._sess.run(
                        [
                            self._output_adversarial_real,
                            self._output_adversarial_fake
                        ],
                        feed_dict={
                            self._X_input: self.current_batch_x,
                            self._Y_input: self.current_batch_y,
                            self._Z_input: self._Z_noise,
                            self._is_training: True
                        })
                    print(oar)
                    print(oaf)
                    print(np.max(oar))
                    print(np.max(oaf))

                    # self._check_tf_variables(ii, nr_batches)
                    raise GeneratorExit("Nan found.")

                if (batch_log_step is not None) and (ii % batch_log_step == 0):
                    batch_train_time = (time.clock() - start) / 60
                    self._log(int(epoch * nr_batches + ii), batch_train_time)

                adversarial_loss_epoch += adversarial_loss_batch
                gen_loss_epoch += gen_loss_batch
                aux_loss_epoch += aux_loss_batch
                ii += 1

            epoch_train_time = (time.clock() - start) / 60
            adversarial_loss_epoch = np.round(adversarial_loss_epoch, 2)
            gen_loss_epoch = np.round(gen_loss_epoch, 2)

            print("Epoch {}: Adversarial: {}.".format(epoch + 1,
                                                      adversarial_loss_epoch))
            print("\t\t\tGenerator: {}.".format(gen_loss_epoch))
            print("\t\t\tEncoder: {}.".format(aux_loss_epoch))

            if self._log_step is not None:
                self._log(epoch + 1, epoch_train_time)

            # self._check_tf_variables(epoch, epochs)

    def _optimize(self, dataset, batch_size, adversarial_steps, gen_steps):
        for i in range(adversarial_steps):
            current_batch_x, current_batch_y = dataset.get_next_batch(
                batch_size)
            # self.current_batch_x, self.current_batch_y = current_batch_x, current_batch_y
            self._Z_noise = self.sample_noise(n=len(current_batch_x))
            _, adversarial_loss_batch = self._sess.run(
                [self._adversarial_optimizer, self._adversarial_loss],
                feed_dict={
                    self._X_input: current_batch_x,
                    self._Y_input: current_batch_y,
                    self._Z_input: self._Z_noise,
                    self._is_training: True
                })

        aux_loss_batch = 0
        for _ in range(gen_steps):
            Z_noise = self._generator.sample_noise(n=len(current_batch_x))
            if not self._is_feature_matching:
                _, gen_loss_batch = self._sess.run(
                    [self._gen_optimizer, self._gen_loss],
                    feed_dict={
                        self._Z_input: Z_noise,
                        self._Y_input: current_batch_y,
                        self._is_training: True
                    })
            else:
                _, gen_loss_batch = self._sess.run(
                    [self._gen_optimizer, self._gen_loss],
                    feed_dict={
                        self._X_input: current_batch_x,
                        self._Y_input: current_batch_y,
                        self._Z_input: self._Z_noise,
                        self._is_training: True
                    })
            if self._is_cycle_consistent:
                _, aux_loss_batch = self._sess.run(
                    [self._aux_optimizer, self._aux_loss],
                    feed_dict={
                        self._Z_input: Z_noise,
                        self._Y_input: current_batch_y,
                        self._is_training: True
                    })

        return adversarial_loss_batch, gen_loss_batch, aux_loss_batch

    def predict(self, inpt_x, inpt_y):
        inpt = self._sess.run(self._input_real,
                              feed_dict={
                                  self._X_input: inpt_x,
                                  self._Y_input: inpt_y,
                                  self._is_training: True
                              })
        return self._adversarial.predict(inpt, self._sess)

    def _check_tf_variables(self, batch_nr, nr_batches):
        Z_noise = self._generator.sample_noise(n=len(self._x_test))
        gen_grads = [
            self._sess.run(gen_gv[0],
                           feed_dict={
                               self._X_input: self._x_test,
                               self._Y_input: self._y_test,
                               self._Z_input: Z_noise,
                               self._is_training: False
                           }) for gen_gv in self._gen_grads_and_vars
        ]
        adversarial_grads = [
            self._sess.run(adversarial_gv[0],
                           feed_dict={
                               self._X_input: self._x_test,
                               self._Y_input: self._y_test,
                               self._Z_input: Z_noise,
                               self._is_training: False
                           })
            for adversarial_gv in self._adversarial_grads_and_vars
        ]
        gen_grads_maxis = [np.max(gv) for gv in gen_grads]
        gen_grads_means = [np.mean(gv) for gv in gen_grads]
        gen_grads_minis = [np.min(gv) for gv in gen_grads]
        adversarial_grads_maxis = [np.max(dv) for dv in adversarial_grads]
        adversarial_grads_means = [np.mean(dv) for dv in adversarial_grads]
        adversarial_grads_minis = [np.min(dv) for dv in adversarial_grads]

        real_logits, fake_logits, gen_out = self._sess.run(
            [
                self._output_adversarial_real, self._output_adversarial_fake,
                self._output_gen
            ],
            feed_dict={
                self._X_input: self._x_test,
                self._Y_input: self._y_test,
                self._Z_input: Z_noise,
                self._is_training: False
            })
        real_logits = np.mean(real_logits)
        fake_logits = np.mean(fake_logits)

        gen_varsis = np.array([
            x.eval(session=self._sess)
            for x in self._generator.get_network_params()
        ])
        adversarial_varsis = np.array([
            x.eval(session=self._sess)
            for x in self._adversarial.get_network_params()
        ])
        gen_maxis = np.array([np.max(x) for x in gen_varsis])
        adversarial_maxis = np.array([np.max(x) for x in adversarial_varsis])
        gen_means = np.array([np.mean(x) for x in gen_varsis])
        adversarial_means = np.array([np.mean(x) for x in adversarial_varsis])
        gen_minis = np.array([np.min(x) for x in gen_varsis])
        adversarial_minis = np.array([np.min(x) for x in adversarial_varsis])

        print(batch_nr, "/", nr_batches, ":")
        print("adversarialReal / adversarialFake: ", real_logits, fake_logits)
        print("GenWeight Max / Mean / Min: ", np.max(gen_maxis),
              np.mean(gen_means), np.min(gen_minis))
        print("GenGrads Max / Mean / Min: ", np.max(gen_grads_maxis),
              np.mean(gen_grads_means), np.min(gen_grads_minis))
        print("adversarialWeight Max / Mean / Min: ",
              np.max(adversarial_maxis), np.mean(adversarial_means),
              np.min(adversarial_minis))
        print("adversarialGrads Max / Mean / Min: ",
              np.max(adversarial_grads_maxis),
              np.mean(adversarial_grads_means),
              np.min(adversarial_grads_minis))
        print("GenOut Max / Mean / Min: ", np.max(gen_out), np.mean(gen_out),
              np.min(gen_out))
        print("\n")

        if real_logits > 0.99 and fake_logits < 0.01:
            self._dominating_adversarial += 1
            if self._dominating_adversarial == 5:
                raise GeneratorExit("Dominating adversarialriminator!")
        else:
            self._dominating_adversarial = 0

        print(np.max(gen_out))
        print(np.max(gen_out) < 0.05)
        if np.max(gen_out) < 0.05:
            self._gen_out_zero += 1
            print(self._gen_out_zero)
            if self._gen_out_zero == 50:
                raise GeneratorExit("Generator outputs zeros")
        else:
            self._gen_out_zero = 0
        print(self._gen_out_zero)