Exemple #1
0
    def test_loss_on_layer(self):
        class MyLayer(layers.Layer):
            def call(self, inputs):
                self.add_loss(math_ops.reduce_sum(inputs))
                return inputs

        inputs = Input((3, ))
        layer = MyLayer()
        outputs = layer(inputs)
        model = Model(inputs, outputs)
        self.assertEqual(len(model.losses), 1)
        model.compile('sgd',
                      'mse',
                      run_eagerly=testing_utils.should_run_eagerly())
        loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
        self.assertEqual(loss, 2 * 3)
Exemple #2
0
class BaseKerasModel(BaseModel):
    model = None
    tensorboard = None
    train_names = ['train_loss', 'train_mse', 'train_mae']
    val_names = ['val_loss', 'val_mse', 'val_mae']
    counter = 0
    inputs = None
    hidden_layer = None
    outputs = None

    def __init__(self,
                 use_default_dense=True,
                 activation='relu',
                 kernel_regularizer=tf.keras.regularizers.l1(0.001)):
        super().__init__()
        if use_default_dense:
            self.activation = activation
            self.kernel_regularizer = kernel_regularizer

    def create_input_layer(self, input_placeholder: BaseInputFormatter):
        """Creates keras model"""
        self.inputs = tf.keras.layers.InputLayer(
            input_shape=input_placeholder.get_input_state_dimension())
        return self.inputs

    def create_hidden_layers(self, input_layer=None):
        if input_layer is None:
            input_layer = self.inputs
        hidden_layer = tf.keras.layers.Dropout(0.3)(input_layer)
        hidden_layer = tf.keras.layers.Dense(
            128,
            kernel_regularizer=self.kernel_regularizer,
            activation=self.activation)(hidden_layer)
        hidden_layer = tf.keras.layers.Dropout(0.4)(hidden_layer)
        hidden_layer = tf.keras.layers.Dense(
            64,
            kernel_regularizer=self.kernel_regularizer,
            activation=self.activation)(hidden_layer)
        hidden_layer = tf.keras.layers.Dropout(0.3)(hidden_layer)
        hidden_layer = tf.keras.layers.Dense(
            32,
            kernel_regularizer=self.kernel_regularizer,
            activation=self.activation)(hidden_layer)
        hidden_layer = tf.keras.layers.Dropout(0.1)(hidden_layer)
        self.hidden_layer = hidden_layer
        return self.hidden_layer

    def create_output_layer(self,
                            output_formatter: BaseOutputFormatter,
                            hidden_layer=None):
        # sigmoid/tanh all you want on self.model
        if hidden_layer is None:
            hidden_layer = self.hidden_layer
        self.outputs = tf.keras.layers.Dense(
            output_formatter.get_model_output_dimension()[0],
            activation='tanh')(hidden_layer)
        self.model = Model(inputs=self.inputs, outputs=self.outputs)
        return self.outputs

    def write_log(self, callback, names, logs, batch_no, eval=False):
        for name, value in zip(names, logs):
            summary = tf.Summary()
            summary_value = summary.value.add()
            summary_value.simple_value = value
            tag_name = name
            if eval:
                tag_name = 'eval_' + tag_name
            summary_value.tag = tag_name
            callback.writer.add_summary(summary, batch_no)
            callback.writer.flush()

    def finalize_model(self, logname=str(int(random() * 1000))):

        loss, loss_weights = self.create_loss()
        self.model.compile(tf.keras.optimizers.Nadam(lr=0.001),
                           loss=loss,
                           loss_weights=loss_weights,
                           metrics=[
                               tf.keras.metrics.mean_absolute_error,
                               tf.keras.metrics.binary_accuracy
                           ])
        log_name = './logs/' + logname
        self.logger.info("log_name: " + log_name)
        self.tensorboard = tf.keras.callbacks.TensorBoard(
            log_dir=log_name,
            histogram_freq=1,
            write_images=False,
            batch_size=1000,
        )
        self.tensorboard.set_model(self.model)
        self.logger.info("Model has been finalized")

    def fit(self, x, y, batch_size=1):
        if self.counter % 200 == 0:
            logs = self.model.evaluate(x, y, batch_size=batch_size, verbose=1)
            self.write_log(self.tensorboard,
                           self.model.metrics_names,
                           logs,
                           self.counter,
                           eval=True)
            print('step:', self.counter)
        else:
            logs = self.model.train_on_batch(x, y)
            self.write_log(self.tensorboard, self.model.metrics_names, logs,
                           self.counter)
        self.counter += 1

    def predict(self, arr):
        return self.model.predict(arr)

    def save(self, file_path):
        self.model.save_weights(filepath=file_path, overwrite=True)

    def load(self, file_path):
        path = os.path.abspath(file_path)
        self.model.load_weights(filepath=os.path.abspath(file_path))

    def create_loss(self):
        return 'mean_absolute_error', None
Exemple #3
0
class RetroCycleGAN:
    def __init__(self, save_index="0", save_folder="./", generator_size=32,
                 discriminator_size=64, word_vector_dimensions=300,
                 discriminator_lr=0.0001, generator_lr=0.0001,
                 lambda_cycle=1, lambda_id_weight=0.01, one_way_mm=True,
                 cycle_mm=True,
                 cycle_dis=True,
                 id_loss=True,
                 cycle_mm_w=2,
                 cycle_loss=True):
        self.cycle_mm = cycle_mm
        self.cycle_dis = cycle_dis
        self.cycle_mae = cycle_loss
        self.id_loss = id_loss
        self.one_way_mm = one_way_mm
        self.cycle_mm_w = cycle_mm_w if self.cycle_mm else 0
        self.save_folder = save_folder

        # Input shape
        self.word_vector_dimensions = word_vector_dimensions
        self.embeddings_dimensionality = (self.word_vector_dimensions,)  # , self.channels)
        self.save_index = save_index

        # Number of filters in the first layer of G and D
        self.gf = generator_size
        self.df = discriminator_size

        # Loss weights
        self.lambda_cycle = lambda_cycle  if self.cycle_mae else 0# Cycle-consistency loss
        self.lambda_id = lambda_id_weight if self.id_loss else 0  # Identity loss

        d_lr = discriminator_lr
        self.d_lr = d_lr
        g_lr = generator_lr
        self.g_lr = g_lr
        # cv = clip_value
        # cn = cn
        self.d_A = self.build_discriminator(name="word_vector_discriminator")
        self.d_B = self.build_discriminator(name="retrofitted_word_vector_discriminator")
        self.d_ABBA = self.build_c_discriminator(name="cycle_cond_discriminator_unfit")
        self.d_BAAB = self.build_c_discriminator(name="cycle_cond_discriminator_fit")
        # Best combo sofar SGD, gaussian, dropout,5,0.5 mml(0,5,.5),3x1024gen, 2x1024, no normalization

        # return Adam(lr,amsgrad=True,decay=1e-8)

        # -------------------------
        # Construct Computational
        #   Graph of Generators
        # -------------------------

        # Build the generators
        self.g_AB = self.build_generator(name="to_retro_generator")
        # for layer in self.g_AB.layers:
        #     a = layer.get_weights()
        # print(a)

        # self.d_A.summary()
        # self.g_AB.summary()
        # plot_model(self.g_AB, show_shapes=True)
        self.g_BA = self.build_generator(name="from_retro_generator")

        # self.d_B.summary()
        # self.g_BA.summary()
        # Input images from both domains
        unfit_wv = Input(shape=self.embeddings_dimensionality, name="plain_word_vector")
        fit_wv = Input(shape=self.embeddings_dimensionality, name="retrofitted_word_vector")
        #

        # Translate images to the other domain
        fake_B = self.g_AB(unfit_wv)
        fake_A = self.g_BA(fit_wv)
        # Translate images back to original domain
        reconstr_A = self.g_BA(fake_B)
        reconstr_B = self.g_AB(fake_A)

        print("Building recon model")
        # self.reconstr = Model(inputs=[unfit_wv,fit_wv],outputs=[reconstr_A,reconstr_B])
        print("Done")
        # Identity mapping of images
        unfit_wv_id = self.g_BA(unfit_wv)
        fit_wv_id = self.g_AB(fit_wv)

        # For the combined model we will only train the generators
        # Discriminators determines validity of translated images
        valid_A = self.d_A(fake_A)
        valid_B = self.d_B(fake_B)

        # Combined model trains generators to fool discriminators
        self.d_A.trainable = False
        self.d_B.trainable = False
        # self.d_ABBA.trainable = False
        # self.d_BAAB.trainable = False

        self.combined = Model(inputs=[unfit_wv, fit_wv],  # Model that does A->B->A (left), B->A->B (right)
                              outputs=[valid_A, valid_B,  # for the bce calculation
                                       reconstr_A, reconstr_B,  # for the mae calculation
                                       reconstr_A, reconstr_B,  # for the max margin calculation
                                       unfit_wv_id, fit_wv_id,
                                       # dAc_r, dBc_r,  # for the conditional discriminator margin calculation
                                       # dAc_fake, dBc_fake  # for the conditional discriminator margin calculation
                                       ],  # for the id loss calculation
                              name="combinedmodel")

        log_path = './logs'
        callback = keras.callbacks.TensorBoard(log_dir=log_path)
        callback.set_model(self.combined)
        self.combined_callback = callback

    def compile_all(self, optimizer="sgd"):

        def max_margin_loss(y_true, y_pred):
            cost = 0
            sim_neg = 25
            sim_margin = 1
            for i in range(0, sim_neg):
                new_true = tf.random.shuffle(y_true)
                normalize_a = tf.nn.l2_normalize(y_true)
                normalize_b = tf.nn.l2_normalize(y_pred)
                normalize_c = tf.nn.l2_normalize(new_true)
                minimize = tf.reduce_sum(tf.multiply(normalize_a, normalize_b))
                maximize = tf.reduce_sum(tf.multiply(normalize_a, normalize_c))
                mg = sim_margin - minimize + maximize
                # print(mg)
                cost += tf.keras.backend.clip(mg, 0, 1000)
            return cost / (sim_neg * 1.0)

        def create_opt(lr=0.1):
            if optimizer == "adam":
                opt = tf.optimizers.Adam(lr=lr, epsilon=1e-10)
                return opt
            else:
                raise KeyError("coULD NOT FIND THE OPTIMIZER")
        # self.d_A.trainable = True
        # self.d_B.trainable = True

        self.d_A.compile(loss='binary_crossentropy',
                         optimizer=create_opt(self.d_lr),
                         metrics=['accuracy'])
        self.d_ABBA.compile(loss='binary_crossentropy',
                            optimizer=create_opt(self.d_lr),
                            metrics=['accuracy'])
        self.d_BAAB.compile(loss='binary_crossentropy',
                            optimizer=create_opt(self.d_lr),
                            metrics=['accuracy'])

        self.d_B.compile(loss='binary_crossentropy',
                         optimizer=create_opt(self.d_lr),
                         metrics=['accuracy'])
        # self.d_A.trainable = False
        # self.d_B.trainable = False

        self.g_AB.compile(loss=max_margin_loss,
                          optimizer=create_opt(self.g_lr),
                          )
        self.g_BA.compile(loss=max_margin_loss,
                          optimizer=create_opt(self.g_lr),
                          )

        self.combined.compile(loss=['binary_crossentropy', 'binary_crossentropy',
                                    'mae', 'mae',
                                    max_margin_loss, max_margin_loss,
                                    'mae', 'mae',
                                    ],
                              loss_weights=[1, 1,
                                            self.lambda_cycle * 1, self.lambda_cycle * 1,
                                            self.cycle_mm_w, self.cycle_mm_w,
                                            self.lambda_id, self.lambda_id,
                                            # self.lambda_cycle * 1, self.lambda_cycle * 1,
                                            # self.lambda_cycle * 1, self.lambda_cycle * 1
                                            ],
                              optimizer=create_opt(self.g_lr))
        # self.combined.summary()
        self.g_AB.summary()
        self.d_A.summary()
        self.combined.summary()

    def build_generator(self, name, hidden_dim=2048):
        """U-Net Generator"""

        def dense(layer_input, hidden_dim, normalization=True, dropout=True, dropout_percentage=0.2):
            d = Dense(hidden_dim, activation="relu")(layer_input)
            if normalization:
                d = BatchNormalization()(d)
            if dropout:
                d = Dropout(dropout_percentage)(d)
            return d

        # Image input
        inpt = Input(shape=self.embeddings_dimensionality)
        encoder = dense(inpt, hidden_dim, normalization=False, dropout=True, dropout_percentage=0.2)

        decoder = dense(encoder, hidden_dim, normalization=False, dropout=True, dropout_percentage=0.2)  # +encoder
        output = Dense(self.word_vector_dimensions)(decoder)
        return Model(inpt, output, name=name)

    def build_discriminator(self, name, hidden_dim=2048):

        def d_layer(layer_input, hidden_dim, normalization=True, dropout=True, dropout_percentage=0.3):
            """Discriminator layer"""
            d = Dense(hidden_dim, activation="relu")(layer_input)
            if normalization:
                d = BatchNormalization()(d)
            if dropout:
                d = Dropout(dropout_percentage)(d)
            return d

        inpt = Input(shape=self.embeddings_dimensionality)
        d1 = d_layer(inpt, hidden_dim, normalization=False, dropout=True, dropout_percentage=0.3)
        d1 = d_layer(d1, hidden_dim, normalization=True, dropout=True, dropout_percentage=0.3)
        validity = Dense(1, activation="sigmoid", dtype='float32')(d1)
        return Model(inpt, validity, name=name)

    def build_c_discriminator(self, name, hidden_dim=2048):

        def d_layer(layer_input, hidden_dim, normalization=True, dropout=True, dropout_percentage=0.3):
            """Discriminator layer"""
            d = Dense(hidden_dim, activation="relu")(layer_input)
            if normalization:
                d = BatchNormalization()(d)
            if dropout:
                d = Dropout(dropout_percentage)(d)
            return d

        inpt = Input(shape=600)
        d1 = d_layer(inpt, hidden_dim, normalization=False, dropout=True, dropout_percentage=0.3)
        d1 = d_layer(d1, hidden_dim, normalization=True, dropout=True, dropout_percentage=0.3)
        validity = Dense(1, activation="sigmoid", dtype='float32')(d1)
        return Model(inpt, validity, name=name)

    def load_weights(self, preface="", folder=None):
        if folder is None:
            folder = self.save_folder
        try:
            self.g_AB.reset_states()
            self.g_BA.reset_states()
            self.combined.reset_states()
            self.d_B.reset_states()
            self.d_A.reset_states()
            self.d_A.load_weights(os.path.join(folder, preface + "fromretrodis.h5"))
            self.d_B.load_weights(os.path.join(folder, preface + "toretrodis.h5"))
            self.g_AB.load_weights(os.path.join(folder, preface + "toretrogen.h5"))
            self.g_BA.load_weights(os.path.join(folder, preface + "fromretrogen.h5"))
            self.combined.load_weights(os.path.join(folder, preface + "combined_model.h5"))

        except Exception as e:
            print(e)

    def train(self, epochs, dataset, save_folder, name, batch_size=1, cache=False, epochs_per_checkpoint=4,
              dis_train_amount=3):
        wandb.init(project="retrogan", dir=save_folder)
        wandb.run.name = name
        # wandb.watch(self.g_AB,criterion="simlex")
        wandb.run.save()
        self.name = name
        start_time = datetime.datetime.now()
        res = []
        X_train, Y_train = tools.load_all_words_dataset_final(dataset["original"], dataset["retrofitted"],
                                                              save_folder=save_folder, cache=cache)
        print("Shapes of training data:",
              X_train.shape,
              Y_train.shape)
        print(X_train)
        print(Y_train)
        print("*" * 100)

        def load_batch(batch_size=32, always_random=False):
            def _int_load():
                iterable = list(Y_train.index)
                shuffle(iterable)
                batches = []
                print("Prefetching batches")
                for ndx in tqdm(range(0, len(iterable), batch_size)):
                    try:
                        ixs = iterable[ndx:min(ndx + batch_size, len(iterable))]
                        if always_random:
                            ixs = list(np.array(iterable)[random.sample(range(0, len(iterable)), batch_size)])
                        imgs_A = X_train.loc[ixs]
                        imgs_B = Y_train.loc[ixs]
                        if np.isnan(imgs_A).any().any() or np.isnan(imgs_B).any().any():  # np.isnan(imgs_B).any():
                            # print(ixs)
                            continue

                        batches.append((imgs_A, imgs_B))
                    except Exception as e:
                        print("Skipping batch")
                        # print(e)
                return batches

            batches = _int_load()

            print("Beginning iteration")
            for i in tqdm(range(0, len(batches)), ncols=30):
                imgs_A, imgs_B = batches[i]
                yield np.array(imgs_A.values, dtype=np.float32), np.array(imgs_B.values, dtype=np.float32)

        # def load_random_batch(batch_size=32, batch_amount=1000000):
        #     iterable = list(Y_train.index)
        #     # shuffle(iterable)
        #     ixs = list(np.array(iterable)[random.sample(range(0, len(iterable)), batch_size)])
        #     imgs_A = X_train.loc[ixs]
        #     imgs_B = Y_train.loc[ixs]
        #     def test_nan(a,b):
        #         return np.isnan(a).any().any() or np.isnan(b).any().any()
        #     while True:
        #         if(test_nan(imgs_A,imgs_B)):
        #             ixs = list(np.array(iterable)[random.sample(range(0, len(iterable)), batch_size)])
        #             imgs_A = X_train.loc[ixs]
        #             imgs_B = Y_train.loc[ixs]
        #         else:
        #             break
        #     return imgs_A, imgs_B
        #
        # def exp_decay(epoch):
        #     initial_lrate = 0.1
        #     k = 0.1
        #     lrate = initial_lrate * math.exp(-k * epoch)
        #     return lrate

        # noise = np.random.normal(size=(1, dimensionality), scale=0.001)
        # noise = np.tile(noise,(batch_size,1))
        dis_train_amount = dis_train_amount

        self.compile_all("adam")

        # ds = tf.data.Dataset.from_generator(load_batch,(tf.float32,tf.float32),args=(batch_size,))
        # ds = ds.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)

        def train_(training_epochs, always_random=False):
            global_step = 0
            for epoch in range(training_epochs):
                # noise = np.random.normal(size=(batch_size, dimensionality), scale=0.01)
                for batch_i, (imgs_A, imgs_B) in enumerate(load_batch(batch_size, always_random=always_random)):
                    global_step += 1
                    # for batch_i, (imgs_A, imgs_B) in enumerate(ds):
                    # try:
                    # if epoch % 2 == 0:
                    #     # print("Adding noise")
                    #     imgs_A = np.add(noise[0:imgs_A.shape[0], :], imgs_A)
                    #     imgs_B = np.add(noise[0:imgs_B.shape[0], :], imgs_B)
                    # imgs_A = tf.cast(imgs_A, tf.float32)
                    # imgs_B = tf.cast(imgs_B, tf.float32)

                    fake_B = self.g_AB.predict(imgs_A)
                    fake_A = self.g_BA.predict(imgs_B)
                    fake_ABBA = self.g_BA.predict(fake_B)
                    fake_BAAB = self.g_AB.predict(fake_A)
                    # Train the discriminators (original images = real / translated = Fake)
                    dA_loss = None
                    dB_loss = None
                    valid = np.ones((imgs_A.shape[0],))  # *noisy_entries_num,) )
                    fake = np.zeros((imgs_A.shape[0],))  # *noisy_entries_num,) )
                    # self.d_A.trainable = True
                    # self.d_B.trainable = True

                    for _ in range(int(dis_train_amount)):
                        # da = self.d_A.evaluate(imgs_A)
                        dA_loss_real = self.d_A.train_on_batch(imgs_A, valid)
                        # daf = self.d_A(fake_A)
                        dA_loss_fake = self.d_A.train_on_batch(fake_A, fake)
                        if dA_loss is None:
                            dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)
                        else:
                            dA_loss += 0.5 * np.add(dA_loss_real, dA_loss_fake)
                        dB_loss_real = self.d_B.train_on_batch(imgs_B, valid)
                        dB_loss_fake = self.d_B.train_on_batch(fake_B, fake)
                        if dB_loss is None:
                            dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)
                        else:
                            dB_loss += 0.5 * np.add(dB_loss_real, dB_loss_fake)
                    d_loss = (1.0 / dis_train_amount) * 0.5 * np.add(dA_loss, dB_loss)
                    # self.d_A.trainable = False
                    # self.d_B.trainable = False

                    def CycleCondLoss(d_ground, d_approx):
                        l = tf.math.log(d_ground) + tf.math.log(1 - d_approx)
                        return -1 * tf.reduce_mean(l)

                    # train cycle discriminators
                    d_cycle_dis = 0
                    g_cycle_dis = 0
                    if self.cycle_dis:
                        with tf.GradientTape() as tape:
                            dA = self.d_ABBA(tf.concat([fake_B, imgs_A], 1))
                            dA_r = self.d_ABBA(tf.concat([fake_B, fake_ABBA], 1))
                            la = CycleCondLoss(dA, dA_r)
                            tga = tape.gradient(la, self.d_ABBA.trainable_variables)
                            self.d_ABBA.optimizer.apply_gradients(zip(tga, self.d_ABBA.trainable_variables))
                            d_cycle_dis += la

                        with tf.GradientTape() as tape:
                            dB = self.d_BAAB(tf.concat([fake_A, imgs_B], 1))
                            dB_r = self.d_BAAB(tf.concat([fake_A, fake_BAAB], 1))
                            lb = CycleCondLoss(dB, dB_r)
                            tgb = tape.gradient(lb, self.d_BAAB.trainable_variables)
                            self.d_BAAB.optimizer.apply_gradients(zip(tgb, self.d_BAAB.trainable_variables))
                            d_cycle_dis += lb
                        with tf.GradientTape() as tape:
                            fake_B = self.g_AB(imgs_A)
                            fake_A = self.g_BA(imgs_B)
                            fake_ABBA = self.g_BA(fake_B)
                            fake_BAAB = self.g_AB(fake_A)
                            dB = self.d_BAAB(tf.concat([fake_A, imgs_B], 1))
                            dB_r = self.d_BAAB(tf.concat([fake_A, fake_BAAB], 1))

                            dA = self.d_ABBA(tf.concat([fake_B, imgs_A], 1))
                            dA_r = self.d_ABBA(tf.concat([fake_B, fake_ABBA], 1))
                            la = CycleCondLoss(dA, dA_r)
                            lb = CycleCondLoss(dB, dB_r)

                            tga = tape.gradient((la + lb) / 2.0, self.combined.trainable_variables)
                            self.combined.optimizer.apply_gradients(zip(tga, self.combined.trainable_variables))
                            g_cycle_dis += (la + lb) / 2.0

                    # Calculate the max margin loss for A->B, B->A
                    mm_b_loss = 0
                    mm_a_loss = 0
                    if self.one_way_mm:
                        mm_a_loss = self.g_AB.train_on_batch(imgs_A, imgs_B)
                        mm_b_loss = self.g_BA.train_on_batch(imgs_B, imgs_A)
                    # Calculate the cycle A->B->A, B->A->B with max margin, and mae
                    # Train cycle dis

                    g_loss = self.combined.train_on_batch([imgs_A, imgs_B],
                                                          [valid, valid,
                                                           imgs_A, imgs_B,
                                                           imgs_A, imgs_B,
                                                           imgs_A, imgs_B,
                                                           # valid,valid,
                                                           # valid,valid
                                                           ])

                    def named_logs(model, logs):
                        result = {}
                        for l in zip(model.metrics_names, logs):
                            result[l[0]] = l[1]
                        return result

                    r = named_logs(self.combined, g_loss)
                    r.update({
                        'mma': mm_a_loss,
                        'mmb': mm_b_loss,
                    })
                    elapsed_time = datetime.datetime.now() - start_time
                    if batch_i % 50 == 0 and batch_i != 0:
                        print(
                            "\n[Epoch %d/%d] [Batch %d] [D loss: %f, acc: %3d%%] "
                            "[G loss: %05f, adv: %05f, recon: %05f, recon_mm: %05f,id: %05f][mma:%05f,mmb:%05f]time: %s " \
                            % (epoch, training_epochs,
                               batch_i,
                               d_loss[0], 100 * d_loss[1],
                               g_loss[0],
                               np.mean(g_loss[1:3]),
                               np.mean(g_loss[3:5]),
                               np.mean(g_loss[5:7]),
                               np.mean(g_loss[7:8]),
                               mm_a_loss,
                               mm_b_loss,
                               elapsed_time))

                        scalars = {
                            "epoch": epoch,
                            # "batch": batch_i,
                            "global_step": global_step,
                            "discriminator_loss": d_loss[0],
                            "discriminator_acc": d_loss[1],
                            "combined_loss": g_loss[0]+g_cycle_dis+d_cycle_dis,
                            "loss": g_loss[0] + d_loss[0],
                            "cycle_da": g_loss[1],
                            "cycle_db": g_loss[2],
                            "cycle_dis": d_cycle_dis,
                            "cycle_gen_condis":g_cycle_dis,
                            "MM_ABBA_CYCLE": g_loss[5],
                            "MM_BAAB_CYCLE": g_loss[6],
                            "abba_mae": g_loss[3],
                            "baab_mae": g_loss[4],
                            "idloss_ab": g_loss[7],
                            "idloss_ba": g_loss[8],
                            "mm_ab_loss": mm_a_loss,
                            "mm_ba_loss": mm_b_loss,
                        }
                        wandb.log(scalars, step=global_step)

                        # wandbcb.on_batch_end(batch_i, r)
                        # wandb.log({"batch_num":batch_i,"epoch_num":epoch})
                        # self.combined_callback.on_batch_end(batch_i, r)

                print("\n")
                sl, sv,c = self.test(dataset)
                if epoch % epochs_per_checkpoint == 0 and epoch != 0:
                    self.save_model(name="checkpoint")

                res.append((sl, sv, c))
                wandb.log({"simlex": sl, "simverb": sv, "card":c,"epoch": epoch})

                # self.combined_callback.on_epoch_end(epoch, {"simlex": sl, "simverb": sv})
                # wandbcb.on_epoch_end(epoch, {"simlex": sl, "simverb": sv})

                print(res)
                print("\n")

        print("Actual training")
        train_(epochs)
        print("Final performance")
        sl, sv,c = self.test(dataset)
        res.append((sl, sv,c))

        self.save_model(name="final")
        return res

    def test(self, dataset, simlex="testing/SimLex-999.txt", simverb="testing/SimVerb-3500.txt",card="testing/card660.tsv",
             fasttext="fasttext_model/cc.en.300.bin",
             prefix="en_"):
        sl = tools.test_sem(self.g_AB, dataset, dataset_location=simlex,
                            fast_text_location=fasttext, prefix=prefix,pt=False)[0]
        sv = tools.test_sem(self.g_AB, dataset, dataset_location=simverb,
                            fast_text_location=fasttext, prefix=prefix,pt=False)[0]
        c = tools.test_sem(self.g_AB, dataset, dataset_location=card,
                            fast_text_location=fasttext, prefix=prefix,pt=False)[0]
        return sl, sv,c

    def save_model(self, name=""):
        self.d_A.save(os.path.join(self.save_folder, name + "fromretrodis.h5"), include_optimizer=False)
        self.d_B.save(os.path.join(self.save_folder, name + "toretrodis.h5"), include_optimizer=False)
        self.g_AB.save(os.path.join(self.save_folder, name + "toretrogen.h5"), include_optimizer=False)
        self.g_BA.save(os.path.join(self.save_folder, name + "fromretrogen.h5"), include_optimizer=False)
        self.combined.save(os.path.join(self.save_folder, name + "combined_model.h5"), include_optimizer=False)
Exemple #4
0
(train_x1, train_x2, train_fov, train_y, test_x1, test_x2, test_fov, test_y,
 images) = get_dataset()
train_batch_size = 64

# train
sum_logs = []
for batch in range(50000001):
    idx = np.random.randint(0, len(train_x1), train_batch_size)
    images_idx_x1 = train_x1[idx]
    images_idx_x2 = train_x2[idx]
    images_x1 = images[images_idx_x1] / 255.
    images_x2 = images[images_idx_x2] / 255.
    images_fov = train_fov[idx]
    result = train_y[idx]

    logs = model.train_on_batch(x=[images_x1, images_x2, images_fov], y=result)
    sum_logs.append(logs)

    if batch % 200 == 0 and batch > 0:
        # check model on the validation data
        valid_idx = np.random.randint(0, len(test_x1), train_batch_size)
        valid_images_idx_x1 = test_x1[valid_idx]
        valid_images_idx_x2 = test_x2[valid_idx]
        valid_images_x1 = images[valid_images_idx_x1] / 255.
        valid_images_x2 = images[valid_images_idx_x2] / 255.
        valid_images_fov = train_fov[valid_idx]
        valid_result = test_y[valid_idx]

        v_loss = model.test_on_batch(
            x=[valid_images_x1, valid_images_x2, valid_images_fov],
            y=valid_result)
Exemple #5
0
class FMatrixGanModel:
    """
    Defines the complete model with generator, regressor and discriminator.
    This includes the low level training and prediction methods for this model, like the GAN training.
    """
    def __init__(self, params, model_folder, img_size):
        """
        Inits the model.

        :param params: Hyperparameters
        :param model_folder: Folder path, in which all results and temporary data of the model is stored.
        :param img_size: (image_width, image_height), defining the size of the input images.
        """
        if not isinstance(params, Params):
            params = Params(params)
        self.params = params

        self.model_folder = model_folder

        # inputs
        input_shape = (img_size[0], img_size[1], 1)
        img_A, img_B = Input(shape=input_shape), Input(shape=input_shape)

        # --- build models
        discriminator_model, frozen_discriminator_model = build_discriminator_models(
            img_size, params)

        generator_with_regressor_model, generator_model, generator_with_output_model, regressor_model = \
            build_generator_with_regressor_models(img_size, params)

        # --- models
        self.discriminator = discriminator_model
        self.regressor = regressor_model
        self.generator = generator_model
        self.generator_with_output = generator_with_output_model
        self.generator_with_regressor = generator_with_regressor_model

        # model: GAN without regressor and without output
        fake_B = generator_model(img_A)
        gan_out = frozen_discriminator_model(fake_B)
        self.gan = Model(inputs=img_A, outputs=gan_out)

        # model: GAN with regressor
        if params['use_images']:
            fake_B, fmatrix = generator_with_regressor_model([img_A, img_B])
            gan_out = frozen_discriminator_model(fake_B)
            self.gan_with_regressor = Model(inputs=[img_A, img_B],
                                            outputs=[gan_out, fmatrix])
        else:
            fake_B, fmatrix = generator_with_regressor_model(img_A)
            gan_out = frozen_discriminator_model(fake_B)
            self.gan_with_regressor = Model(inputs=img_A,
                                            outputs=[gan_out, fmatrix])

        # --- compile models
        self.discriminator.compile(loss='binary_crossentropy',
                                   optimizer=Adam(lr=params['lr_D'],
                                                  beta_1=0.9,
                                                  beta_2=0.999,
                                                  epsilon=1e-08),
                                   metrics=['accuracy'])
        self.regressor.compile(loss='mean_squared_error',
                               optimizer=Adam(lr=params['lr_R'],
                                              beta_1=0.9,
                                              beta_2=0.999,
                                              epsilon=1e-08))

        # generators do not need to be compiled as they are compiled within the GANs

        if params['freeze_discriminator']:
            frozen_discriminator_model.trainable = False

        self.gan.compile(loss='binary_crossentropy',
                         optimizer=Adam(lr=params['lr_G'],
                                        beta_1=0.9,
                                        beta_2=0.999,
                                        epsilon=1e-08))

        loss_weights = params['generator_loss_weights']
        assert len(loss_weights) == 2
        self.gan_with_regressor.compile(
            loss=['binary_crossentropy', 'mean_squared_error'],
            loss_weights=loss_weights,
            optimizer=Adam(lr=params['lr_G'],
                           beta_1=0.9,
                           beta_2=0.999,
                           epsilon=1e-08))

        self.__models_with_weights = [
            self.generator, self.regressor, self.discriminator
        ]

    def generate_img(self, img_A):
        """
        Generates an image from img_A using the generator and its current weights.
        :param img_A: Input image to the generator. Dimension: (img_width, img_height)
        :return: The generated image
        """
        img_A = _img_to_img_batch(img_A)
        img_B = self.generator.predict(img_A)
        return img_B[0]  # only 1 sample in batch

    def generate_regression_input(self, img_pair):
        """
        Generates the regression input for the given image pair using the current weights.

        :param img_pair: Input image pair. Dimension: (img_width, img_height, 2)
        :return: The regression input which can be passed into the regressor.
            This is a list of inputs which may include the image pair, the bottleneck and the derived feature layers.
        """
        img_A, img_B = _img_pair_to_img_batches(img_pair)
        generator_output, *regression_input = self.generator_with_output.predict(
            img_A)
        # for each of the elements in regression input only select the first sample (there are only 1 samples)
        # because batches are returned
        regression_input = [batch[0] for batch in regression_input]
        if self.params['use_images']:
            regression_input.append(img_pair)

        return regression_input

    def regress_from_img(self, img_pair):
        """
        Regresses the fundamental matrix from the given image pair using the current weights.

        :param img_pair: Input image pair. Dimension: (img_width, img_height, 2)
        :return: Fundamental matrix. Dimension: (3, 3)
        """
        img_A, img_B = _img_pair_to_img_batches(img_pair)

        if self.params['use_images']:
            # as the regressor also uses images, imgA and imgB are needed
            gen_img, F = self.generator_with_regressor.predict([img_A, img_B])
        else:
            # the regressor uses no images, only input from generator and the generator only needs imgA
            gen_img, F = self.generator_with_regressor.predict(img_A)
        return F[0]  # only 1 sample in batch

    def regress_from_regression_input(self, regression_input):
        """
        Regresses the fundamental matrix from the given regression inputs using the current weights (of the regressor)

        :param regression_input: Regression input which was generated using generate_regression_input.
        :return: Fundamental matrix. Dimension: (3, 3)
        """
        regression_input = _regression_input_to_batch(regression_input)
        F_batch = self.regress_from_regression_input_batch(regression_input)
        return F_batch[0]  # only 1 sample in batch

    def regress_from_regression_input_batch(self, regression_input_batch):
        """
        Regresses the fundamental matrices for a batch of regression inputs using the current weights (of the regressor)

        :param regression_input_batch: Batch of regression inputs which were generated using generate_regression_input.
        :return: Batch of fundamental matrices. Dimension: (None, 3, 3)
        """
        F_batch = self.regressor.predict(regression_input_batch)

        return F_batch

    def train_GAN(self,
                  img_pair,
                  epochs,
                  discr_iterations,
                  plot_interval=None,
                  img_path_prefix=None,
                  check_img_mse=False,
                  verbose=0):
        """
        Trains the GAN for the given image pair.

        :param img_pair: Image pair to train the GAN.
        :param epochs: Number of training epochs.
        :param discr_iterations: Number of discriminator iterations in each epoch.
        :param plot_interval: How often is the generated image plot.
            - None: no plotting
            - positive integer: plot every nth epoch
            - -1: plot only after the last epoch.
        :param img_path_prefix: Prefix for plotted image files. If None: image is only plotted but not saved.
        :param check_img_mse: bool - Check and store the image mean squared error of the generated image in the history.
        :param verbose: Verbosity level: 0 to 2
        :return: History of the GAN training, dictionary of lists
        """
        return self.__do_train_gan_for_sample(
            img_pair,
            epochs=epochs,
            discr_iterations=discr_iterations,
            plot_interval=plot_interval,
            img_path_prefix=img_path_prefix,
            check_img_mse=check_img_mse,
            verbose=verbose)

    def train_GAN_and_regressor(self,
                                img_pair,
                                F_true,
                                epochs,
                                discr_iterations,
                                plot_interval=None,
                                img_path_prefix=None,
                                check_img_mse=False,
                                verbose=0):
        """
        Train the GAN and the regressor together using a combined loss for generator and regressor.

        :param img_pair: Image pair to train.
        :param F_true: Ground truth fundamental matrix.
        :param epochs: Number of training epochs.
        :param discr_iterations: Number of discriminator iterations in each epoch.
        :param plot_interval: plot_interval: How often is the generated image plot.
            - None: no plotting
            - positive integer: plot every nth epoch
            - -1: plot only after the last epoch.
        :param img_path_prefix: Prefix for plotted image files. If None: image is only plotted but not saved.
        :param check_img_mse: bool - Check and store the image mean squared error of the generated image in the history.
        :param verbose: Verbosity level: 0 to 2
        :return: History of the training, dictionary of lists
        """
        return self.__do_train_gan_for_sample(
            img_pair,
            F_true=F_true,
            epochs=epochs,
            discr_iterations=discr_iterations,
            plot_interval=plot_interval,
            img_path_prefix=img_path_prefix,
            check_img_mse=check_img_mse,
            verbose=verbose)

    def train_regressor(self, regression_input, F_true):
        """
        Trains the regressor for the given regression input and F_true.

        The regressor is only trained for one epoch on that single sample.

        :param regression_input: Regression input to train for.
        :param F_true: Ground truth fundamental matrix.
        :return: History of the training, dictionary of lists
        """
        return self.train_regressor_batch(
            _regression_input_to_batch(regression_input),
            _F_to_F_batch(F_true))

    def train_regressor_batch(self, regression_input_batch, F_true_batch):
        """
        Trains the regressor for the given regression input and F_true batch.

        The regressor is only trained for one epoch on that batch.

        :param regression_input_batch: Batch of regression inputs to train for.
        :param F_true_batch: Batch of Ground truth fundamental matrices.
        :return: History of the training, dictionary of lists
        """
        loss = self.regressor.train_on_batch(regression_input_batch,
                                             F_true_batch)
        return loss

    def update_regressor_lr(self, update_fn):
        """
        Updates the current learning rate of the regressor using the given update function.
        The update function gets the old lr as input and should return the new lr.
        This new lr is then set as the regressor lr.

        :param update_fn: Function applied to compute new lr: update_fn(old_lr: float) -> new_lr: float
        :return: New learning rate which was set.
        """
        old_lr = float(K.get_value(self.regressor.optimizer.lr))
        new_lr = update_fn(old_lr)
        K.set_value(self.regressor.optimizer.lr, new_lr)
        return new_lr

    def save_weights(self, file_prefix=None):
        """
        Save all model weights.

        Multiple files will be stored, as this model has multiple sub models.
        All files are stored within the model folder.

        :param file_prefix: If defined, use this as prefix for the weight file names.
            If None: store temporary weights.
        """
        if file_prefix is None:
            file_prefix = TMP_WEIGHTS_FILE_PREFIX
        for i, model in enumerate(self.__models_with_weights):
            model.save_weights(self.model_folder + '/' + file_prefix +
                               ('_%d.h5' % i))

    def load_weights(self, file_prefix=None, remove=False):
        """
        Loads all model weights.

        All files are loaded from within the model folder.

        :param file_prefix: file_prefix: If defined, use this as prefix for the weight file names.
            If None: load temporary weights.
        :param remove: If True, remove the loaded weight files.
        """
        if file_prefix is None:
            file_prefix = TMP_WEIGHTS_FILE_PREFIX
        for i, model in enumerate(self.__models_with_weights):
            file = self.model_folder + '/' + file_prefix + ('_%d.h5' % i)
            model.load_weights(file)
            if remove:
                os.remove(file)

    def plot_models(self, file_prefix):
        """
        Plots all sub models
        :param file_prefix:
        """
        print('Plotting model with file prefix %s' % file_prefix)
        plot_model(self.generator,
                   to_file=file_prefix + 'generator.png',
                   show_shapes=True)
        plot_model(self.generator_with_output,
                   to_file=file_prefix + 'generator_with_output.png',
                   show_shapes=True)
        plot_model(self.generator_with_regressor,
                   to_file=file_prefix + 'generator_with_regressor.png',
                   show_shapes=True)
        plot_model(self.discriminator,
                   to_file=file_prefix + '_discriminator.png',
                   show_shapes=True)
        plot_model(self.regressor,
                   to_file=file_prefix + 'regressor.png',
                   show_shapes=True)
        plot_model(self.gan, to_file=file_prefix + 'gan.png', show_shapes=True)
        plot_model(self.gan_with_regressor,
                   to_file=file_prefix + 'gan_with_regressor.png',
                   show_shapes=True)

    # verbose=0 -> no logging
    # verbose=1 -> only show current epoch
    # verbose=2 -> show epoch results and details
    # plot_interval: None -> disabled, >0 every i epochs, -1 only at the end of all epochs
    def __do_train_gan_for_sample(self,
                                  img_pair_sample,
                                  F_true=None,
                                  epochs=1,
                                  discr_iterations=1,
                                  plot_interval=None,
                                  img_path_prefix=None,
                                  check_img_mse=False,
                                  verbose=0):
        img_A, img_B = _img_pair_to_img_batches(img_pair_sample)
        if F_true is not None:
            F_true = _F_to_F_batch(F_true)

        if plot_interval is not None and img_path_prefix is not None:
            # Save original images for later debugging
            _save_imgs(img_A, img_B, img_path_prefix + 'img_A.png',
                       img_path_prefix + 'img_B.png')

        valid = np.array([1])
        fake = np.array([0])

        generator_loss_history = []
        generator_gen_loss_history = []
        generator_F_history = []
        discriminator_history = []
        discriminator_real_history = []
        discriminator_fake_history = []
        img_mse_history = []

        fake_B = self.generator.predict(img_A)  # generate fake B for 1st epoch

        for epoch in range(1, epochs + 1):
            if verbose == 1:
                print(('--> GAN epoch %d/%d' % (epoch, epochs)).ljust(100),
                      end='\r')
            elif verbose > 1:
                print('--> GAN epoch %d/%d' % (epoch, epochs))

            # --- train discriminator
            if verbose >= 2:
                print('-----> Train D...'.ljust(100), end='\r')
            discr_input = np.concatenate([img_B, fake_B])
            discr_target = np.concatenate([valid, fake])
            for it in range(1, discr_iterations + 1):
                discriminator_loss_real, discriminator_loss_fake = self.discriminator.train_on_batch(
                    discr_input, discr_target)
                discriminator_loss = (discriminator_loss_real +
                                      discriminator_loss_fake) / 2
                if verbose >= 2:
                    print(
                        '-----> D iteration %d/%d [loss: %f, real_loss: %f, fake_loss: %f]'
                        .ljust(100) %
                        (it, discr_iterations, discriminator_loss,
                         discriminator_loss_real, discriminator_loss_fake),
                        end='\r')
            discriminator_history.append(discriminator_loss)
            discriminator_real_history.append(discriminator_loss_real)
            discriminator_fake_history.append(discriminator_loss_fake)

            # --- train generator
            if verbose >= 2:
                print('-----> Train G...'.ljust(100), end='\r')
            if F_true is None:
                generator_loss = self.gan.train_on_batch(img_A, valid)
                generator_loss_history.append(generator_loss)
                if verbose == 1:
                    print(
                        ('--> GAN epoch %d/%d [D - loss: %f] [G - loss: %f]' %
                         (epoch, epochs, discriminator_loss,
                          generator_loss)).ljust(100),
                        end='\r')
                elif verbose > 1:
                    print(('---> [D - loss: %f] [G - loss: %f]' %
                           (discriminator_loss, generator_loss)).ljust(100))
            else:
                if self.params['use_images']:
                    loss, generator_loss, fmatrix_loss = self.gan_with_regressor.train_on_batch(
                        [img_A, img_B], [valid, F_true])
                else:
                    loss, generator_loss, fmatrix_loss = self.gan_with_regressor.train_on_batch(
                        img_A, [valid, F_true])
                generator_loss_history.append(loss)
                generator_gen_loss_history.append(generator_loss)
                generator_F_history.append(fmatrix_loss)
                if verbose == 1:
                    print((
                        '--> GAN epoch %d/%d [D - loss: %f] [G - loss: %f, gen_loss: %f, F_loss: %f]'
                        % (epoch, epochs, discriminator_loss, loss,
                           generator_loss, fmatrix_loss)).ljust(100),
                          end='\r')
                elif verbose > 1:
                    print((
                        '--->  [D - loss: %f] [G - loss: %f, gen_loss: %f, F_loss: %f]'
                        % (discriminator_loss, loss, generator_loss,
                           fmatrix_loss)).ljust(100))

            # Generate for next epoch and for results checking (so that the img has not to be generated twice)
            fake_B = self.generator.predict(img_A)

            if check_img_mse:
                img_mse = _calc_image_mse(img_A, img_B, fake_B)
                img_mse_history.append(img_mse)
                if verbose >= 2:
                    print('---> [image_mse: %f]' % img_mse)

            if plot_interval is not None and plot_interval != -1 and epoch % plot_interval == 0:
                if img_path_prefix is not None:
                    img_path = img_path_prefix + ('generated_B_%04d.png' %
                                                  epoch)
                else:
                    img_path = None
                _plot_img(img_A, img_B, fake_B, img_path)

        if plot_interval == -1:
            if img_path_prefix is not None:
                img_path = img_path_prefix + 'generated_B.png'
            else:
                img_path = None
            _plot_img(img_A, img_B, fake_B, img_path)

        if F_true is None:
            return {
                'discriminator_loss': discriminator_history,
                'discriminator_loss_real': discriminator_real_history,
                'discriminator_loss_fake': discriminator_fake_history,
                'generator_loss': generator_loss_history,
                'img_mse': img_mse_history
            }
        else:
            return {
                'discriminator_loss': discriminator_history,
                'discriminator_loss_real': discriminator_real_history,
                'discriminator_loss_fake': discriminator_fake_history,
                'generator_loss': generator_loss_history,
                'generator_loss_gen': generator_gen_loss_history,
                'generator_F_loss': generator_F_history,
                'img_mse': img_mse_history
            }
Exemple #6
0
class DCGan:
    def __init__(self, input_shape, config):
        self.config = config
        self.config['model']['generator']['z_size'] = self.config['data']['z_size']
        self.input_shape = input_shape

        self._build_model()

    def _build_model(self):
        # Generator
        model_input = Input(self.config['data']['z_size'], name='gan_input')
        self.generator = DCGan._build_generator(model_input, self.input_shape, self.config['model']['generator'])

        # Discriminator
        self.discriminator = DCGan._build_discriminator(self.input_shape, self.config['model']['discriminator'])

        # GAN
        #self.gan = Sequential([self.generator, self.discriminator])
        model_output = self.discriminator(self.generator(model_input))
        self.gan = Model(model_input, model_output)

        # Compile discriminator
        # discriminator_optimizer = RMSprop(lr=self.config['training']['discriminator']['learning_rate'])
        # self.discriminator.compile(loss="binary_crossentropy", optimizer=discriminator_optimizer,
        #                            metrics=['accuracy'])
        #
        # # Compile generator
        # # taken into account only when compiling a model,
        # # so discriminator is trainable if we call its fit() method while not for the gan model
        # self.discriminator.trainable = False
        #
        # gan_optimizer = RMSprop(lr=self.config['training']['generator']['learning_rate'])
        # self.gan.compile(loss="binary_crossentropy", optimizer=gan_optimizer,
        #                  metrics=['accuracy'])
        #
        # self.discriminator.trainable = True

    # Already with a basic GAN setup we break the use of model.fit and related utilities
    # need to find a way to manage callbacks and validation

    # This still doesn't work for a problem possibly related to a bug with nested models https://github.com/keras-team/keras/issues/10074
    # Not compiling the discriminator in fact doesn't trigger the error anymore
    def train(self, train_ds, validation_ds, nb_epochs: int, log_dir, checkpoint_dir, is_tfdataset=False):
        callbacks = []

        # tensorboard
        tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
        callbacks.append(tensorboard_callback)

        # checkpoints
        if checkpoint_dir:
            cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_dir,
                                                             save_weights_only=True,
                                                             verbose=1,
                                                             period=self.config['training']['checkpoint_steps'])
            callbacks.append(cp_callback)

        # plotting callback
        plot_callback = PlotData(validation_ds, self.generator, log_dir)
        callbacks.append(plot_callback)

        # training
        batch_size = self.config['training']['batch_size']
        z_dim = self.config['data']['z_size']
        for epoch in range(nb_epochs):
            if is_tfdataset:
                for x in train_ds:
                    train_batch = x.numpy()
                    break
            else:
                idx = np.random.randint(0, train_ds.shape[0], batch_size)
                train_batch = train_ds[idx]

            self.train_discriminator(train_batch, batch_size, z_dim)
            self.train_generator(batch_size, z_dim)

            # TODO add validation step

    # Train with pure TF, because Keras doesn't work
    def _train(self, train_ds, validation_ds, nb_epochs: int, log_dir, checkpoint_dir, is_tfdataset=False,
               restore_latest_checkpoint=True):
        batch_size = self.config['training']['batch_size']
        z_dim = self.config['data']['z_size']

        noise = tf.random.normal([batch_size, z_dim])
        plot_summary_writer = tf.summary.create_file_writer(str(log_dir / 'plot'))
        train_summary_writer = tf.summary.create_file_writer(str(log_dir / 'train'))

        # optimizers
        generator_optimizer = tf.keras.optimizers.Adam(self.config['training']['generator']['learning_rate'])
        discriminator_optimizer = tf.keras.optimizers.Adam(self.config['training']['discriminator']['learning_rate'])

        # checkpoints
        if checkpoint_dir:
            checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                             discriminator_optimizer=discriminator_optimizer,
                                             generator=self.generator,
                                             discriminator=self.discriminator)
            ckpt_manager = tf.train.CheckpointManager(checkpoint, os.path.join(checkpoint_dir, "ckpt"),
                                                      max_to_keep=self.config['training']['checkpoints_to_keep'])
            if restore_latest_checkpoint and ckpt_manager.latest_checkpoint:
                print(f"Restored from {ckpt_manager.latest_checkpoint}")
            else:
                print("Initializing from scratch.")

        # train loop
        for epoch in tqdm(range(nb_epochs)):
            gen_losses = []
            disc_losses = []
            for ds_batch in train_ds:
                gen_loss, disc_loss = train_step(ds_batch, self.generator, self.discriminator,
                                                 generator_optimizer=generator_optimizer,
                                                 discriminator_optimizer=discriminator_optimizer,
                                                 batch_size=batch_size, noise_dim=z_dim)
                gen_losses.append(gen_loss)
                disc_losses.append(disc_loss)

            # Loss summary
            avg_gen_loss = np.mean(gen_losses)
            avg_disc_loss = np.mean(disc_losses)
            with train_summary_writer.as_default():
                tf.summary.scalar("Average Gen Loss", avg_gen_loss, step=epoch)
                tf.summary.scalar("Average Disc Loss", avg_disc_loss, step=epoch)

            # Plot data
            with plot_summary_writer.as_default():
                # Plot sample data
                predictions = self.generator(noise)
                tf.summary.image("Sample Generated", predictions, step=epoch)
                tf.summary.image("Sample Input", [ds_batch[np.random.randint(len(ds_batch))]], step=epoch)

            # checkpoint
            if checkpoint_dir:
                checkpoint.step.assign_add(1)
                ckpt_step = int(checkpoint.step)
                if ckpt_step % self.config['training']['checkpoint_steps'] == 0:
                    save_path = ckpt_manager.save()
                    print(f"Saved checkpoint for step {ckpt_step}: {save_path}")

    @staticmethod
    # takes an image and generates two vectors: means and standards deviations
    def _build_generator(model_input, img_shape, config):
        latent_vector = Input(config['z_size'], name='generator_input')
        init_shape = tuple([get_initial_size(d, config['num_conv_blocks'])
                            for d in img_shape[:-1]] + [config['init_filters']])

        x = Dense(np.prod(init_shape))(latent_vector)
        x = BatchNormalization()(x)
        x = LeakyReLU()(x)
        x = Reshape(init_shape)(x)

        for i in range(config['num_conv_blocks'] - 1):
            x = upscale(filters=config['init_filters'] // 2 ** i,
                        kernel_size=config['kernel_size'], strides=config['strides'],
                        upscale_method=config['upscale_method'],
                        activation='relu')(x)

        # last upscale layer
        model_output = upscale(filters=config['n_channels'],
                               kernel_size=config['kernel_size'], strides=config['strides'],
                               upscale_method=config['upscale_method'],
                               activation='tanh')(x)

        return Model(latent_vector, model_output)

    @staticmethod
    def _build_discriminator(img_shape, config):
        model_input = Input(shape=img_shape, name="discriminator_input")

        x = model_input
        for i in range(config['num_conv_blocks']):
            x = conv(filters=config['init_filters'] * (2 ** i), kernel_size=config['kernel_size'],
                     strides=config['strides'])(x)

        features = Flatten()(x)

        model_output = Dense(1, activation='sigmoid')(features)

        return Model(model_input, model_output)

    def train_discriminator(self, true_imgs, batch_size: int, z_dim: int):
        # Train on real image
        # [1,1,...,1] with real output since it is true and we want our generated examples to look like it
        self.discriminator.train_on_batch(true_imgs, np.ones((batch_size, 1)))

        # Train on generated images
        # [0,0,...,0] with generated images since they are fake
        noise = np.random.normal(0, 1, (batch_size, z_dim))
        gen_imgs = self.generator.predict(noise)
        self.discriminator.train_on_batch(gen_imgs, np.zeros((batch_size, 1)))

    def train_generator(self, batch_size: int, z_dim: int):
        # Train on noise input
        # [1,1,...,1] with generated output since we want the discriminator to believe these are real images
        noise = np.random.normal(0, 1, (batch_size, z_dim))
        self.gan.train_on_batch(noise, np.ones((batch_size, 1)))

    def setup_dataset(self, dataset):
        # prefetch lets the dataset fetch batches in the background while the model is training
        dataset = dataset.shuffle(self.config['data']['buffer_size']) \
                            .batch(self.config['training']['batch_size'], drop_remainder=True) \
                            .prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
        return dataset
Exemple #7
0
class FusionModel:
    def __init__(self, config, load_weight_path=None, ab_loss='mse'):
        img_shape = (config.IMAGE_SIZE, config.IMAGE_SIZE)

        # Creating generator and discriminator
        optimizer = Adam(0.00002, 0.5)

        self.foreground_generator = instance_network(img_shape)

        self.fusion_discriminator = discriminator_network(img_shape)
        self.fusion_discriminator.compile(loss=wasserstein_loss_dummy,
                                          optimizer=optimizer)
        self.fusion_generator = fusion_network(img_shape, config.BATCH_SIZE)
        self.fusion_generator.compile(loss=[ab_loss, 'kld'],
                                      optimizer=optimizer)

        if load_weight_path:
            chroma_gan = load_model(load_weight_path)
            chroma_gan_layers = [layer.name for layer in chroma_gan.layers]

            print('Loading chroma GAN parameter to instance network...')
            instance_layer_names = [
                layer.name for layer in self.foreground_generator.layers
            ]
            for i, layer in enumerate(instance_layer_names):
                if layer == 'fg_model_3':
                    print('model 3 skip')
                    continue
                if len(layer) < 2:
                    continue
                if layer[:3] == 'fg_':
                    try:
                        j = chroma_gan_layers.index(layer[3:])
                        self.foreground_generator.layers[i].set_weights(
                            chroma_gan.layers[j].get_weights())
                        print(f'Successfully set weights for layer {layer}')
                    except ValueError:
                        print(f'Layer {layer} not found in chroma gan.')
                    except Exception as e:
                        print(e)

            print('Loading chroma GAN parameter to fusion network...')
            fusion_layer_names = [
                layer.name for layer in self.fusion_generator.layers
            ]
            for i, layer in enumerate(fusion_layer_names):
                if layer == 'model_3':
                    print('model 3 skip')
                    continue
                try:
                    j = chroma_gan_layers.index(layer)
                    self.fusion_generator.layers[i].set_weights(
                        chroma_gan.layers[j].get_weights())
                    print(f'Successfully set weights for layer {layer}')
                except ValueError:
                    print(f'Layer {layer} not found in chroma gan.')
                except Exception as e:
                    print(e)

        # Fg=instance prediction
        fg_img_l = Input(shape=(*img_shape, 1, MAX_INSTANCES))

        # self.foreground_generator.trainable = False
        fg_model_3, fg_conv2d_11, fg_conv2d_13, fg_conv2d_15, fg_conv2d_17 = self.foreground_generator(
            fg_img_l)

        # Fusion prediction
        fusion_img_l = Input(shape=(*img_shape, 1))
        fusion_img_real_ab = Input(shape=(*img_shape, 2))
        fg_bbox = Input(shape=(4, MAX_INSTANCES))
        fg_mask = Input(shape=(*img_shape, MAX_INSTANCES))

        self.fusion_generator.trainable = False
        fusion_img_pred_ab, fusion_class_vec = self.fusion_generator([
            fusion_img_l, fg_model_3, fg_conv2d_11, fg_conv2d_13, fg_conv2d_15,
            fg_conv2d_17, fg_bbox, fg_mask
        ])

        dis_pred_ab = self.fusion_discriminator(
            [fusion_img_pred_ab, fusion_img_l])
        dis_real_ab = self.fusion_discriminator(
            [fusion_img_real_ab, fusion_img_l])

        # Sample the gradient penalty
        img_ab_interp_samples = RandomWeightedAverage()(
            [fusion_img_pred_ab, fusion_img_real_ab])
        dis_interp_ab = self.fusion_discriminator(
            [img_ab_interp_samples, fusion_img_l])
        partial_gp_loss = partial(
            gradient_penalty_loss,
            averaged_samples=img_ab_interp_samples,
            gradient_penalty_weight=GRADIENT_PENALTY_WEIGHT)
        partial_gp_loss.__name__ = 'gradient_penalty'

        # Compile D and G as well as combined
        self.discriminator_model = Model(
            inputs=[
                fusion_img_l, fusion_img_real_ab, fg_img_l, fg_bbox, fg_mask
            ],
            outputs=[dis_real_ab, dis_pred_ab, dis_interp_ab])

        self.discriminator_model.compile(optimizer=optimizer,
                                         loss=[
                                             wasserstein_loss_dummy,
                                             wasserstein_loss_dummy,
                                             partial_gp_loss
                                         ],
                                         loss_weights=[-1.0, 1.0, 1.0])

        self.fusion_generator.trainable = True
        self.fusion_discriminator.trainable = False
        self.combined = Model(
            inputs=[fusion_img_l, fg_img_l, fg_bbox, fg_mask],
            outputs=[fusion_img_pred_ab, fusion_class_vec, dis_pred_ab])
        self.combined.compile(loss=[ab_loss, 'kld', wasserstein_loss_dummy],
                              loss_weights=[1.0, 0.003, -0.1],
                              optimizer=optimizer)

        # Monitor stuff
        self.callback = TensorBoard(config.LOG_DIR)
        self.callback.set_model(self.combined)
        self.train_names = [
            'loss', 'mse_loss', 'kullback_loss', 'wasserstein_loss'
        ]
        self.disc_names = ['disc_loss', 'disc_valid', 'disc_fake', 'disc_gp']

        self.test_loss_array = []
        self.g_loss_array = []

    def train(self,
              data: Data,
              test_data,
              log,
              config,
              skip_to_after_epoch=None):
        # Load VGG network
        VGG_modelF = applications.vgg16.VGG16(weights='imagenet',
                                              include_top=True)

        # Real, Fake and Dummy for Discriminator
        positive_y = np.ones((data.batch_size, 1), dtype=np.float32)
        negative_y = -positive_y
        dummy_y = np.zeros((data.batch_size, 1), dtype=np.float32)

        # total number of batches in one epoch
        total_batch = int(data.size / data.batch_size)
        print(f'batch_size={data.batch_size} * total_batch={total_batch}')

        save_path = lambda type, epoch: os.path.join(
            config.MODEL_DIR, f"fusion_{type}Epoch{epoch}.h5")

        if skip_to_after_epoch:
            start_epoch = skip_to_after_epoch + 1
            print(f"Loading weights from epoch {skip_to_after_epoch}")
            self.combined.load_weights(
                save_path("combined", skip_to_after_epoch))
            self.fusion_discriminator.load_weights(
                save_path("discriminator", skip_to_after_epoch))
        else:
            start_epoch = 0

        for epoch in range(start_epoch, config.NUM_EPOCHS):
            for batch in tqdm(range(total_batch)):
                train_batch = data.generate_batch()
                resized_l = train_batch.resized_images.l
                resized_ab = train_batch.resized_images.ab

                # GT vgg
                predictVGG = VGG_modelF.predict(
                    np.tile(resized_l, [1, 1, 1, 3]))

                # train generator
                g_loss = self.combined.train_on_batch([
                    resized_l, train_batch.instances.l,
                    train_batch.instances.bbox, train_batch.instances.mask
                ], [resized_ab, predictVGG, positive_y])
                # train discriminator
                d_loss = self.discriminator_model.train_on_batch([
                    resized_l, resized_ab, train_batch.instances.l,
                    train_batch.instances.bbox, train_batch.instances.mask
                ], [positive_y, negative_y, dummy_y])

                # update log files
                write_log(self.callback, self.train_names, g_loss,
                          (epoch * total_batch + batch + 1))
                write_log(self.callback, self.disc_names, d_loss,
                          (epoch * total_batch + batch + 1))

                if batch % 10 == 0:
                    print(
                        f"[Epoch {epoch}] [Batch {batch}/{total_batch}] [generator loss: {g_loss[0]:08f}] [discriminator loss: {d_loss[0]:08f}]"
                    )

            print('Saving models...')
            self.combined.save(save_path("combined", epoch))
            self.fusion_discriminator.save(save_path("discriminator", epoch))
            print('Models saved.')

            print('Sampling test images...')
            # sample images after each epoch
            self.sample_images(test_data, epoch, config)

    def sample_images(self, test_data: Data, epoch, config):
        total_batch = int(ceil(test_data.size / test_data.batch_size))
        for _ in range(total_batch):
            # load test data
            test_batch = test_data.generate_batch()

            # predict AB channels
            fg_model_3, fg_conv2d_11, fg_conv2d_13, fg_conv2d_15, fg_conv2d_17 = self.foreground_generator.predict(
                test_batch.instances.l)

            fusion_img_pred_ab, _ = self.fusion_generator.predict([
                test_batch.resized_images.l, fg_model_3, fg_conv2d_11,
                fg_conv2d_13, fg_conv2d_15, fg_conv2d_17,
                test_batch.instances.bbox, test_batch.instances.mask
            ])

            # print results
            for i in range(test_data.batch_size):
                original_full_img = test_batch.images.full[i]
                height, width, _ = original_full_img.shape
                pred_ab = cv2.resize(
                    deprocess_float2int(fusion_img_pred_ab[i]),
                    (width, height))
                reconstruct_and_save(
                    test_batch.images.l[i], pred_ab,
                    f'epoch{epoch}_{test_batch.file_names[i]}', config)