Beispiel #1
0
class AutoGAN:
    def __init__(self, width, height, train=None, test=None, training=True):

        self.width = width
        self.height = height

        self.training = training

        if train is not None:
            (self.positive, self.negative) = train

        if test is not None:
            (self.positive_test, self.negative_test) = test

        self.kernel_size = (3, 3)
        self.model = None

        self.model_created = False

        self.ae = Autoencoder(width,
                              height,
                              train,
                              test,
                              training=self.training)
        self.d_net = None
        print("gan initiated")

    def discriminator(self, input, training=True):

        conv1 = layers.Conv2D(64, (3, 3),
                              padding='same',
                              strides=self.kernel_size,
                              name='discriminator_input')(input)
        bn1 = layers.BatchNormalization(-1)(conv1)
        leaky1 = layers.LeakyReLU()(bn1)
        dropout1 = layers.Dropout(rate=0.2)(leaky1, training=training)

        #pool1 = layers.MaxPooling2D(2,2)(leaky1)
        conv2 = layers.Conv2D(32, (3, 3),
                              padding='same',
                              strides=self.kernel_size)(dropout1)
        bn2 = layers.BatchNormalization(-1)(conv2)
        leaky2 = layers.LeakyReLU()(bn2)
        dropout2 = layers.Dropout(rate=0.2)(leaky2, training=training)

        #pool2 = layers.MaxPooling2D(2,2)(leaky2)
        conv3 = layers.Conv2D(16, (3, 3),
                              padding='same',
                              strides=self.kernel_size)(dropout2)
        bn3 = layers.BatchNormalization(-1)(conv3)
        leaky3 = layers.LeakyReLU()(bn3)
        dropout3 = layers.Dropout(rate=0.2)(leaky3, training=training)

        #pool3 = layers.MaxPooling2D(2,2)(leaky3)
        conv4 = layers.Conv2D(4, (3, 3),
                              padding='same',
                              strides=self.kernel_size)(dropout3)
        bn4 = layers.BatchNormalization(-1)(conv4)
        leaky4 = layers.LeakyReLU()(bn4)
        dropout4 = layers.Dropout(rate=0.2)(leaky4, training=training)

        flatten = layers.Flatten()(dropout4)
        output = layers.Dense(1, name='discriminator_output')(flatten)
        output_activation = layers.LeakyReLU()(output)

        print("discriminator layers created")
        return output_activation

    def create_model(self):

        self.model_created = True

        self.ae.create_model()
        self.ae_model = self.ae.get_model()
        self.ae_model.summary()

        input_layer = tf.keras.Input(shape=(self.width, self.height, 3))

        self.d_net = self.discriminator(input_layer, training=self.training)

        #compile discriminator model
        self.d_model = keras.Model(inputs=input_layer,
                                   outputs=[self.d_net],
                                   name="discriminator")
        self.d_model.compile(optimizer='adam',
                             loss='binary_crossentropy',
                             metrics=['accuracy'])
        self.d_model.summary()

        #compile autogan model
        self.d_model.trainable = False

        self.ae_out = self.ae_model(input_layer)
        self.disc_out = self.d_model(self.ae_out)

        print("ae_out ", self.ae_out)
        print("disc_out", self.disc_out)

        self.autogan_model = keras.Model(inputs=input_layer,
                                         outputs=[self.ae_out, self.disc_out],
                                         name="autogan")

        optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002)
        self.autogan_model.compile(optimizer=optimizer,
                                   loss={
                                       "autoencoder":
                                       self.ae_loss,
                                       "discriminator":
                                       keras.losses.binary_crossentropy
                                   },
                                   loss_weights=[5, 1],
                                   metrics=['accuracy'])
        self.autogan_model.summary()
        print("autogan created")

    def autogan_loss(self, y_true, y_pred):

        ae_in, label = y_true
        ae_out, disc_out = y_pred

        mse = keras.losses.MSE(ae_in, ae_out)
        bce = keras.losses.binary_crossentropy(label, disc_out)

        print("Calculating autoGAN loss: MSE: %f, BCE: %f" % (mse, bce))
        return mse + bce

    def ae_loss(self, y_true, y_pred):

        loss = keras.losses.mse(y_true, y_pred)
        loss = tf.math.reduce_mean(loss)

        return loss

    def train(self, epochs, batch_size):

        current_time = datetime.datetime.now().strftime(
            "%Y%m%d-%H%M%S")  #copied from tensorflow documentation
        summary_writer = tf.summary.create_file_writer("../logs/" +
                                                       current_time)

        batches_per_epoch = self.negative.shape[0] // batch_size

        loss = 0.0
        ae_loss = 0.0
        disc_loss = 0.0
        disc_acc = 0.0
        ae_acc = 0.0
        for epoch_num in range(0, epochs + 1):

            shuffle(self.negative)
            shuffle(self.positive)

            if epoch_num == 100:
                self.autogan_model.compile(optimizer='adam',
                                           loss={
                                               "autoencoder":
                                               self.ae_loss,
                                               "discriminator":
                                               keras.losses.binary_crossentropy
                                           },
                                           loss_weights=[5, 1],
                                           metrics=['accuracy'])

            for batch in range(0, batches_per_epoch):

                curr_batch_start = batch * batch_size

                negative_batch = self.negative[curr_batch_start:(
                    curr_batch_start + batch_size)]

                # every n epochs train whole net
                if epoch_num % 2 == 0:

                    print("train all weights")
                    labels = np.ones(shape=batch_size)

                    metrics = self.autogan_model.train_on_batch(
                        negative_batch, {
                            "autoencoder": negative_batch,
                            "discriminator": labels
                        },
                        return_dict=True)
                    ae_loss = metrics['autoencoder_loss']
                    disc_loss = metrics['discriminator_loss']
                    disc_acc = metrics['discriminator_accuracy']
                    ae_acc = metrics['autoencoder_accuracy']

                    print("autogan metrics ", metrics)
                    loss = metrics['loss']

                # else train discriminator
                else:

                    print("train discriminator only")

                    positive_batch = self.positive[curr_batch_start:(
                        curr_batch_start + batch_size)]

                    ae_pred_batch = self.ae_model.predict(negative_batch)
                    normalize_tanh(ae_pred_batch)

                    subbatch_size = batch_size // 3
                    positive_batch = positive_batch[:subbatch_size]
                    negative_batch = negative_batch[:subbatch_size]
                    ae_pred_batch = ae_pred_batch[:subbatch_size]

                    mixed_batch = np.concatenate(
                        (negative_batch, positive_batch, ae_pred_batch))
                    labels = np.concatenate((np.zeros(shape=(subbatch_size)),
                                             np.ones(shape=(subbatch_size)),
                                             np.zeros(shape=subbatch_size)))
                    shuffle_with_labels(mixed_batch, labels)

                    loss, accuracy = self.d_model.train_on_batch(
                        mixed_batch, labels)

                    disc_loss = loss
                    disc_acc = accuracy

                step = epoch_num * batches_per_epoch + batch

                with summary_writer.as_default():
                    tf.summary.scalar("accuracy/autoencoder",
                                      ae_acc,
                                      step=step)
                    tf.summary.scalar("loss/autoencoder", ae_loss, step=step)
                    tf.summary.scalar("accuracy/discriminator",
                                      disc_acc,
                                      step=step)
                    tf.summary.scalar("loss/discriminator",
                                      disc_loss,
                                      step=step)

                print("epoch: %d, batch: %d loss: %f" %
                      (epoch_num, batch, loss))

            # with summary_writer.as_default():

            #     step = epoch_num*batches_per_epoch + batch
            #     tf.summary.scalar("discriminator accuracy", disc_acc, step=step)
            #     tf.summary.scalar("autoencoder accuracy", ae_acc, step=step)
            #     tf.summary.scalar("discriminator loss", disc_loss, step=step)
            #     tf.summary.scalar("autoencoder loss", ae_loss, step=step)
            if epoch_num % 2 == 0:

                self.autogan_model.save_weights(
                    filepath="../checkpoints_autogan/epoch" + str(epoch_num))
                with summary_writer.as_default():

                    inp = self.negative[0:20]
                    out = self.ae_model.predict(inp)
                    normalize_pos(inp)
                    normalize_pos(out)

                    tf.summary.image("input/ae",
                                     inp,
                                     step=epoch_num,
                                     max_outputs=20)
                    tf.summary.image("output/ae",
                                     out,
                                     step=epoch_num,
                                     max_outputs=20)
                #print("VALIDATE DISCRIMINATOR")
                #visualize.test_discriminator("../checkpoints_autogan/epoch" + str(epoch_num), self.negative_test, self.positive_test)

    def load_trained(self, filename):

        self.create_model()
        self.autogan_model.load_weights(filename)

    def get_discriminator(self):

        return self.d_net

    def get_discriminator_model(self):

        return self.d_model

    def get_autogan_model(self):

        if not self.model_created:
            self.create_model()
        return self.autogan_model