コード例 #1
0
    def load_dataset(self,
                     dataset_name="CT",
                     domain_A_folder="output8",
                     domain_B_folder="output5_x_128"):
        self.dataset_name = dataset_name

        if self.dataset_name == "MNIST":
            # Configure MNIST and MNIST-M data loader
            self.data_loader = DataLoader(img_res=(self.img_rows,
                                                   self.img_cols))
        elif self.dataset_name == "CT":
            bodys_filepath_A = "/home/lulin/na4/src/output/{}/train/bodys.npy".format(
                domain_A_folder)
            masks_filepath_A = "/home/lulin/na4/src/output/{}/train/liver_masks.npy".format(
                domain_A_folder)
            self.Dataset_A = MyDataset(
                paths=[bodys_filepath_A, masks_filepath_A],
                batch_size=self.batch_size,
                augment=False,
                seed=17,
                domain="A")

            bodys_filepath_B = "/home/lulin/na4/src/output/{}/train/bodys.npy".format(
                domain_B_folder)
            masks_filepath_B = "/home/lulin/na4/src/output/{}/train/liver_masks.npy".format(
                domain_B_folder)
            self.Dataset_B = MyDataset(
                paths=[bodys_filepath_B, masks_filepath_B],
                batch_size=self.batch_size,
                augment=False,
                seed=17,
                domain="B")
        else:
            pass
コード例 #2
0
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """

    from pro_gan_pytorch.PRO_GAN import ProGAN

    print(args.config)
    config = get_config(args.config)
    print("Current Configuration:", config)

    # create the dataset for training
    dataset = dl.FoldersDistributedDataset(
        data_dir=config.images_dir,
        transform=dl.get_transform(config.img_dims)
    )

    print("total examples in training: ", len(dataset))

    pro_gan = ProGAN(
        depth=config.depth,
        latent_size=config.latent_size,
        learning_rate=config.learning_rate,
        beta_1=config.beta_1,
        beta_2=config.beta_2,
        eps=config.eps,
        drift=config.drift,
        n_critic=config.n_critic,
        use_eql=config.use_eql,
        loss=config.loss_function,
        use_ema=config.use_ema,
        ema_decay=config.ema_decay,
        device=device
    )

    if args.generator_file is not None:
        print("Loading generator from:", args.generator_file)
        pro_gan.gen.load_state_dict(th.load(args.generator_file))

    if args.discriminator_file is not None:
        print("Loading discriminator from:", args.discriminator_file)
        pro_gan.dis.load_state_dict(th.load(args.discriminator_file))

    # train all the networks
    train_networks(
        pro_gan=pro_gan,
        dataset=dataset,
        epochs=config.epochs,
        fade_in_percentage=config.fade_in_percentage,
        start_depth=args.start_depth,
        batch_sizes=config.batch_sizes,
        num_workers=config.num_workers,
        feedback_factor=config.feedback_factor,
        log_dir=config.log_dir,
        sample_dir=config.sample_dir,
        checkpoint_factor=config.checkpoint_factor,
        save_dir=config.save_dir,
    )
def main():
    configs = json.load(open('config.json', 'r'))
    model_dir = configs['data']['model_dir']
    train_dir = configs['data']['train_dir']

    if model_dir is not None:
        model_dir = pathlib.Path(model_dir)
    if train_dir is None:
        print('Please provide training directory!')
    else:
        train_dir = pathlib.Path(train_dir)

    data = DataLoader(nlp, configs)
    train_texts, train_labels, val_texts, val_labels = data.read_data(
        configs, train_dir)

    print("Parsing texts...")

    train_docs = list(nlp.pipe(train_texts))
    val_docs = list(nlp.pipe(val_texts))
    if configs['training']['by_sentence']:
        train_docs, train_labels = data.get_labelled_sentences(
            train_docs, train_labels)
        val_docs, val_labels = data.get_labelled_sentences(
            val_docs, val_labels)

    train_vec = data.get_vectors(train_docs)
    val_vec = data.get_vectors(val_docs)
    predictions = []
    model = Model(nlp, configs, predictions, val_vec)

    model.train_model(train_vec, train_labels, val_vec, val_labels)

    predictions = np.array(predictions)

    ensemble_prediction = model.model_evaluation(val_labels)

    val_labels = np.argmax(val_labels, axis=1)

    print('We got ', np.sum(ensemble_prediction != val_labels), 'out of ',
          val_labels.shape[0], 'misclassified texts')
    print('Here is the list of misclassified texts:\n')

    val_texts = np.array(val_texts).reshape(-1)

    print(val_texts[np.array(np.where(ensemble_prediction != val_labels))][:])
コード例 #4
0
def main(epochs=15,
         save_weights_path="./Weights/mnist_weights.hdf5",
         mode="train",
         num_classes=NUM_CLASSES,
         useCNN=False):
    # x_train, y_train, x_test, y_test,input_shape = preprocessing(X_train, Y_train, X_test, Y_test, useCNN=useCNN)
    dirname = "/".join(save_weights_path.split("/")[:-1])
    if not os.path.exists(dirname):
        os.makedirs(dirname)

    img_rows = 32
    img_cols = 32
    data_loader = DataLoader(img_res=(img_rows, img_cols))
    input_shape = (32, 32, 3)

    if mode == "train":
        model = NN_model(input_shape, num_classes, useCNN=useCNN)
        checkpointer = ModelCheckpoint(filepath=save_weights_path,
                                       verbose=1,
                                       save_best_only=True,
                                       save_weights_only=True,
                                       monitor='val_acc')
        model.fit(data_loader.mnist_X,
                  keras.utils.to_categorical(data_loader.mnist_y, 10),
                  epochs=epochs,
                  shuffle=True,
                  validation_split=0.05,
                  batch_size=BATCH_SIZE,
                  callbacks=[checkpointer])
        # model.save_weights(save_weights_path)
        print("All done.")
    elif mode == "test":
        model = NN_model(input_shape, num_classes, useCNN=useCNN)
        model.load_weights(save_weights_path, by_name=True)
        score = model.evaluate(
            data_loader.mnistm_X,
            keras.utils.to_categorical(data_loader.mnistm_y, 10))
        print("Accuracy on test set: {}".format(score[1] * 100))
        print("All done.")
    else:
        raise ValuerError("'mode' should be 'train' or 'test'.")
コード例 #5
0
                   learning_rate=config.learning_rate,
                   beta_1=config.beta_1,
                   beta_2=config.beta_2,
                   eps=config.eps,
                   drift=config.drift,
                   n_critic=config.n_critic,
                   device=device)

c_pro_gan.gen.load_state_dict(
    th.load("training_runs\\11\\saved_models\\GAN_GEN_4.pth"))

###################################################################################
#load my embedding and conditional augmentor

dataset = dl.Face2TextDataset(pro_pick_file=config.processed_text_file,
                              img_dir=config.images_dir,
                              img_transform=dl.get_transform(config.img_dims),
                              captions_len=config.captions_length)

text_encoder = Encoder(embedding_size=config.embedding_size,
                       vocab_size=dataset.vocab_size,
                       hidden_size=config.hidden_size,
                       num_layers=config.num_layers,
                       device=device)
text_encoder.load_state_dict(
    th.load("training_runs\\11\\saved_models\\Encoder_3_20.pth"))

condition_augmenter = ConditionAugmentor(input_size=config.hidden_size,
                                         latent_size=config.ca_out_size,
                                         device=device)
condition_augmenter.load_state_dict(
    th.load("training_runs\\11\\saved_models\\Condition_Augmentor_3_20.pth"))
コード例 #6
0
class PixelDA(object):
    """
	Paradigm of GAN (keras implementation)

	1. Construct D
		1a) Compile D
	2. Construct G
	3. Set D.trainable = False
	4. Stack G and D, to construct GAN 
		 4a) Compile GAN
	
	Approved by fchollet: "the process you describe is in fact correct."

	See issue #4674 keras: https://github.com/keras-team/keras/issues/4674
	"""
    def __init__(self):
        # Input shape
        self.img_rows = 32
        self.img_cols = 32
        self.channels = 3
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.num_classes = 10
        self.noise_size = (10, )

        # Calculate output shape of D (PatchGAN)
        patch = int(self.img_rows / 2**4)
        self.disc_patch = (patch, patch, 1)

        # Number of residual blocks in the generator
        self.residual_blocks = 3  #6

    def build_all_model(self):
        # Loss weights
        lambda_adv = 10
        lambda_clf = 1
        # optimizer = Adam(0.0002, 0.5)
        # optimizer = SGD(lr=0.0001)
        optimizer = RMSprop(lr=1e-5)

        # Number of filters in first layer of discriminator and classifier
        self.df = 64
        self.cf = 64

        # Build and compile the discriminators
        self.discriminator = self.build_discriminator()
        self.discriminator.name = "Discriminator"
        self.discriminator.compile(loss='mse',
                                   optimizer=optimizer,
                                   metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()
        self.generator.name = "Generator"
        # Build the task (classification) network
        self.clf = self.build_classifier()
        self.clf.name = "Classifier"
        # Input images from both domains
        img_A = Input(shape=self.img_shape, name='source_image')

        # Input noise
        noise = Input(shape=self.noise_size, name='noise_input')

        # Translate images from domain A to domain B
        fake_B = self.generator([img_A, noise])

        # Classify the translated image
        class_pred = self.clf(fake_B)

        # For the combined model we will only train the generator and classifier
        self.discriminator.trainable = False

        # Discriminator determines validity of translated images
        valid = self.discriminator(fake_B)

        self.combined = Model([img_A, noise], [valid, class_pred])
        self.combined.compile(loss=['mse', 'categorical_crossentropy'],
                              loss_weights=[lambda_adv, lambda_clf],
                              optimizer=optimizer,
                              metrics=['accuracy'])

    def load_dataset(self):
        # Configure MNIST and MNIST-M data loader
        self.data_loader = DataLoader(img_res=(self.img_rows, self.img_cols))

    def build_generator(self):
        """Resnet Generator"""
        def residual_block(layer_input):
            """Residual block described in paper"""
            d = Conv2D(64, kernel_size=3, strides=1,
                       padding='same')(layer_input)
            d = BatchNormalization(momentum=0.8)(d)  # TODO 6/5/2018
            d = Activation('relu')(d)
            d = Conv2D(64, kernel_size=3, strides=1, padding='same')(d)
            d = BatchNormalization(momentum=0.8)(d)  # TODO 6/5/2018
            d = Add()([d, layer_input])
            return d

        # Image input
        img = Input(shape=self.img_shape, name='image_input')

        ## Noise input
        noise = Input(shape=self.noise_size, name='noise_input')
        noise_layer = Dense(1024, activation="relu")(noise)
        noise_layer = Reshape((self.img_rows, self.img_cols, 1))(noise_layer)
        conditioned_img = keras.layers.concatenate([img, noise_layer])
        # keras.layers.concatenate

        # l1 = Conv2D(64, kernel_size=3, padding='same', activation='relu')(img)
        l1 = Conv2D(64,
                    kernel_size=3,
                    strides=1,
                    padding='same',
                    activation='relu')(conditioned_img)

        # Propogate signal through residual blocks
        r = residual_block(l1)
        for _ in range(self.residual_blocks - 1):
            r = residual_block(r)

        output_img = Conv2D(self.channels,
                            kernel_size=3,
                            strides=1,
                            padding='same',
                            activation='tanh')(r)

        return Model([img, noise], output_img)

    def build_discriminator(self):
        def d_layer(layer_input, filters, f_size=4, normalization=True):
            """Discriminator layer"""
            d = Conv2D(filters, kernel_size=f_size, strides=2,
                       padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if normalization:
                d = InstanceNormalization()(d)
            return d

        img = Input(shape=self.img_shape)

        d1 = d_layer(img, self.df, normalization=False)
        d2 = d_layer(d1, self.df * 2, normalization=False)
        d3 = d_layer(d2, self.df * 4, normalization=False)
        d4 = d_layer(d3, self.df * 8, normalization=False)
        d5 = d_layer(d4, self.df * 16, normalization=False)

        # validity = Conv2D(1, kernel_size=4, strides=2, padding='same')(d5)
        validity = Dense(1, activation='sigmoid')(Flatten()(d5))

        return Model(img, validity)

    def build_classifier(self):
        def clf_layer(layer_input, filters, f_size=4, normalization=True):
            """Classifier layer"""
            d = Conv2D(filters, kernel_size=f_size, strides=2,
                       padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if normalization:
                d = InstanceNormalization()(d)
            return d

        img = Input(shape=self.img_shape, name='image_input')

        c1 = clf_layer(img, self.cf, normalization=False)
        c2 = clf_layer(c1, self.cf * 2)
        c3 = clf_layer(c2, self.cf * 4)
        c4 = clf_layer(c3, self.cf * 8)
        c5 = clf_layer(c4, self.cf * 8)

        class_pred = Dense(self.num_classes,
                           activation='softmax')(Flatten()(c5))

        return Model(img, class_pred)

    def load_pretrained_weights(self,
                                weights_path="../Weights/all_weights.h5"):
        print("Loading pretrained weights from path: {} ...".format(
            weights_path))

        self.combined.load_weights(weights_path, by_name=True)
        print("+ Done.")

    def summary(self):
        print("=" * 50)
        print("Discriminator summary:")
        self.discriminator.summary()
        print("=" * 50)
        print("Generator summary:")
        self.generator.summary()
        print("=" * 50)
        print("Classifier summary:")
        self.clf.summary()
        print("=" * 50)
        print("All summary:")
        self.combined.summary()

    def train(self,
              epochs,
              batch_size=32,
              sample_interval=50,
              save_sample2dir="../samples/exp0",
              save_weights_path='../Weights/all_weights.h5',
              save_model=False):
        dirpath = "/".join(save_weights_path.split("/")[:-1])
        if not os.path.exists(dirpath):
            os.makedirs(dirpath)

        # half_batch = batch_size #int(batch_size / 2) ### TODO
        half_batch = int(batch_size / 2)

        # Classification accuracy on 100 last batches of domain B
        test_accs = []
        print("=" * 50)
        print("Discriminator summary:")
        self.discriminator.summary()
        print("=" * 50)
        print("Generator summary:")
        self.generator.summary()
        print("=" * 50)
        print("Classifier summary:")
        self.clf.summary()
        print("=" * 50)
        print("All summary:")
        self.combined.summary()

        ## Monitor to save model weights Lu
        best_test_cls_acc = 0
        second_best_cls_acc = -1
        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------
            # n_sample = half_batch # imgs_A.shape[0]

            imgs_A, _ = self.data_loader.load_data(domain="A",
                                                   batch_size=half_batch)
            imgs_B, _ = self.data_loader.load_data(domain="B",
                                                   batch_size=half_batch)

            noise_prior = np.random.normal(
                0, 1, (half_batch, self.noise_size[0]))  # TODO
            # noise_prior = np.random.rand(half_batch, self.noise_size[0]) # TODO 6/5/2018

            # Translate images from domain A to domain B
            fake_B = self.generator.predict([imgs_A, noise_prior])

            valid = np.ones((half_batch, 1))
            fake = np.zeros((half_batch, 1))
            # fake = -valid # TODO 6/5/2018 NEW
            D_train_label = np.vstack([valid, fake])  # 6/5/2018 NEW
            D_train_images = np.vstack([imgs_B, fake_B])  # 6/5/2018 NEW

            # Train the discriminators (original images = real / translated = Fake)
            # d_loss_real = self.discriminator.train_on_batch(imgs_B, valid)
            # d_loss_fake = self.discriminator.train_on_batch(fake_B, fake)
            # d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
            d_loss = self.discriminator.train_on_batch(D_train_images,
                                                       D_train_label)

            # --------------------------------
            #  Train Generator and Classifier
            # --------------------------------

            # Sample a batch of images from both domains
            imgs_A, labels_A = self.data_loader.load_data(
                domain="A", batch_size=batch_size)
            imgs_B, labels_B = self.data_loader.load_data(
                domain="B", batch_size=batch_size)

            # One-hot encoding of labels
            labels_A = to_categorical(labels_A, num_classes=self.num_classes)

            # The generators want the discriminators to label the translated images as real
            valid = np.ones((batch_size, 1))

            #
            noise_prior = np.random.normal(
                0, 1, (batch_size, self.noise_size[0]))  # TODO
            # noise_prior = np.random.rand(batch_size, self.noise_size[0]) # TODO 6/5/2018

            # Train the generator and classifier
            g_loss = self.combined.train_on_batch([imgs_A, noise_prior],
                                                  [valid, labels_A])

            #-----------------------
            # Evaluation (domain B)
            #-----------------------

            pred_B = self.clf.predict(imgs_B)
            test_acc = np.mean(np.argmax(pred_B, axis=1) == labels_B)

            # Add accuracy to list of last 100 accuracy measurements
            test_accs.append(test_acc)
            if len(test_accs) > 100:
                test_accs.pop(0)

            # Plot the progress
            # print ( "%d : [D - loss: %.5f, acc: %3d%%], [G - loss: %.5f], [clf - loss: %.5f, acc: %3d%%, test_acc: %3d%% (%3d%%)]" % \
            # 								(epoch, d_loss[0], 100*float(d_loss[1]),
            # 								g_loss[1], g_loss[2], 100*float(g_loss[-1]),
            # 								100*float(test_acc), 100*float(np.mean(test_accs))))

            if epoch % 10 == 0:

                d_train_acc = 100 * float(d_loss[1])

                gen_loss = g_loss[1]

                clf_train_acc = 100 * float(g_loss[-1])
                clf_train_loss = g_loss[2]

                current_test_acc = 100 * float(test_acc)
                test_mean_acc = 100 * float(np.mean(test_accs))

                if test_mean_acc > best_test_cls_acc:
                    second_best_cls_acc = best_test_cls_acc
                    best_test_cls_acc = test_mean_acc

                    if save_model:
                        self.combined.save(save_weights_path)
                    else:
                        self.combined.save_weights(save_weights_path)
                    print(
                        "{} : [D - loss: {:.5f}, acc: {:.2f}%], [G - loss: {:.5f}], [clf - loss: {:.5f}, acc: {:.2f}%, test_acc: {:.2f}% ({:.2f}%)] (latest)"
                        .format(epoch, d_loss[0], d_train_acc, gen_loss,
                                clf_train_loss, clf_train_acc,
                                current_test_acc, test_mean_acc))
                elif test_mean_acc > second_best_cls_acc:
                    second_best_cls_acc = test_mean_acc

                    if save_model:
                        self.combined.save(save_weights_path)
                    else:
                        self.combined.save_weights(save_weights_path[:-3] +
                                                   "_bis.h5")
                    print(
                        "{} : [D - loss: {:.5f}, acc: {:.2f}%], [G - loss: {:.5f}], [clf - loss: {:.5f}, acc: {:.2f}%, test_acc: {:.2f}% ({:.2f}%)] (before latest)"
                        .format(epoch, d_loss[0], d_train_acc, gen_loss,
                                clf_train_loss, clf_train_acc,
                                current_test_acc, test_mean_acc))

                else:

                    print(
                        "{} : [D - loss: {:.5f}, acc: {:.2f}%], [G - loss: {:.5f}], [clf - loss: {:.5f}, acc: {:.2f}%, test_acc: {:.2f}% ({:.2f}%)]"
                        .format(epoch, d_loss[0], d_train_acc, gen_loss,
                                clf_train_loss, clf_train_acc,
                                current_test_acc, test_mean_acc))

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.sample_images(epoch, save2dir=save_sample2dir)

    def sample_images(self, epoch, save2dir="../samples"):
        if not os.path.exists(save2dir):
            os.makedirs(save2dir)

        r, c = 2, 10

        imgs_A, _ = self.data_loader.load_data(domain="A", batch_size=c)

        n_sample = imgs_A.shape[0]
        noise_prior = np.random.normal(0, 1,
                                       (n_sample, self.noise_size[0]))  # TODO
        # noise_prior = np.random.rand(n_sample, self.noise_size[0]) # TODO 6/5/2018

        # Translate images to the other domain
        fake_B = self.generator.predict([imgs_A, noise_prior])

        gen_imgs = np.concatenate([imgs_A, fake_B])

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        #titles = ['Original', 'Translated']
        fig, axs = plt.subplots(r, c, figsize=(20, 4))
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i, j].imshow(gen_imgs[cnt])
                #axs[i, j].set_title(titles[i])
                axs[i, j].axis('off')
                cnt += 1
        fig.savefig(os.path.join(save2dir, "{}.png".format(epoch)))
        plt.close()

    def deploy_transform(self,
                         save2file="../domain_adapted/generated.npy",
                         stop_after=None):
        dirpath = "/".join(save2file.split("/")[:-1])
        if not os.path.exists(dirpath):
            os.makedirs(dirpath)

        dirname = "/".join(save2file.split("/")[:-1])

        if stop_after is not None:
            predict_steps = int(stop_after / 32)
        else:
            predict_steps = stop_after

        noise_vec = np.random.normal(0, 1, self.noise_size[0])
        assert 1 == 2
        # np.random.rand(n_sample, self.noise_size[0]) # TODO 6/5/2018

        print("Performing Pixel-level domain adaptation on original images...")
        adaptaed_images = self.generator.predict(
            [
                self.data_loader.mnist_X[:32 * predict_steps],
                np.tile(noise_vec, (32 * predict_steps, 1))
            ],
            batch_size=32)  #, steps=predict_steps
        # self.data_loader.mnistm_X[:stop_after]
        print("+ Done.")
        print("Saving transformed images to file {}".format(save2file))
        np.save(save2file, adaptaed_images)

        noise_vec_filepath = os.path.join(dirname, "noise_vectors.npy")
        print(
            "Saving random noise (seed) to file {}".format(noise_vec_filepath))
        np.save(noise_vec_filepath, noise_vec)

        print("+ All done.")

    def deploy_debug(self,
                     save2file="../domain_adapted/debug.npy",
                     sample_size=9,
                     seed=0):
        dirpath = "/".join(save2file.split("/")[:-1])
        if not os.path.exists(dirpath):
            os.makedirs(dirpath)

        dirname = "/".join(save2file.split("/")[:-1])

        np.random.seed(seed=seed)

        noise_vec = np.random.normal(0, 1, (sample_size, self.noise_size[0]))
        assert 1 == 2
        # np.random.rand(n_sample, self.noise_size[0]) # TODO 6/5/2018

        print("Performing Pixel-level domain adaptation on original images...")
        collections = []
        for i in range(sample_size):
            adaptaed_images = self.generator.predict([
                self.data_loader.mnist_X[:15],
                np.tile(noise_vec[i], (15, 1))
            ],
                                                     batch_size=15)

            collections.append(adaptaed_images)
        print("+ Done.")

        print("Saving transformed images to file {}".format(save2file))
        np.save(save2file, np.stack(collections))
        print("+ All done.")

    def deploy_classification(self, batch_size=32):
        print("Predicting ... ")
        pred_B = self.clf.predict(self.data_loader.mnistm_X,
                                  batch_size=batch_size)
        print("+ Done.")
        N_samples = len(pred_B)
        precision = (np.argmax(pred_B, axis=1) == self.data_loader.mnistm_y)
        Moy = np.mean(precision)
        Std = np.std(precision)

        lower_bound = Moy - 2.576 * Std / np.sqrt(N_samples)
        upper_bound = Moy + 2.576 * Std / np.sqrt(N_samples)
        print("=" * 50)
        print("Unsupervised MNIST-M classification accuracy : {}".format(Moy))
        print("Confidence interval (99%) [{}, {}]".format(
            lower_bound, upper_bound))
        print("Length of confidence interval 99%: {}".format(upper_bound -
                                                             lower_bound))
        print("=" * 50)
        print("+ All done.")
コード例 #7
0
 def load_dataset(self):
     # Configure MNIST and MNIST-M data loader
     self.data_loader = DataLoader(img_res=(self.img_rows, self.img_cols))
コード例 #8
0
def train_networks(pro_gan, dataset, epochs,
                   fade_in_percentage, batch_sizes,
                   start_depth, num_workers, feedback_factor,
                   log_dir, sample_dir, checkpoint_factor,
                   save_dir):

    assert pro_gan.depth == len(batch_sizes), "batch_sizes not compatible with depth"

    print("Starting the training process ... ")
    for current_depth in range(start_depth, pro_gan.depth):

        print("\n\nCurrently working on Depth: ", current_depth)
        current_res = np.power(2, current_depth + 2)
        print("Current resolution: %d x %d" % (current_res, current_res))

        data = dl.get_data_loader(dataset, batch_sizes[current_depth], num_workers)
        ticker = 1

        for epoch in range(1, epochs[current_depth] + 1):
            start = timeit.default_timer()  # record time at the start of epoch

            print("\nEpoch: %d" % epoch)
            total_batches = len(iter(data))

            fader_point = int((fade_in_percentage[current_depth] / 100)
                              * epochs[current_depth] * total_batches)

            for (i, batch) in enumerate(data, 1):
                # calculate the alpha for fading in the layers
                alpha = ticker / fader_point if ticker <= fader_point else 1

                # extract current batch of data for training
                images = batch.to(device)

                gan_input = th.randn(images.shape[0], pro_gan.gen.latent_size).to(pro_gan.device)

                # optimize the discriminator:
                dis_loss = pro_gan.optimize_discriminator(gan_input, images,
                                                          current_depth, alpha)

                # optimize the generator:
                gan_input = th.randn(images.shape[0], pro_gan.gen.latent_size).to(pro_gan.device)
                gen_loss = pro_gan.optimize_generator(gan_input, current_depth, alpha)

                # provide a loss feedback
                if i % int(total_batches / feedback_factor) == 0 or i == 1:
                    print("batch: %d  d_loss: %f  g_loss: %f" % (i, dis_loss, gen_loss))

                    # also write the losses to the log file:
                    log_file = os.path.join(log_dir, "loss_" + str(current_depth) + ".log")
                    with open(log_file, "a") as log:
                        log.write(str(dis_loss) + "\t" + str(gen_loss) + "\n")

                    # create a grid of samples and save it
                    gen_img_file = os.path.join(sample_dir, "gen_" + str(current_depth) +
                                                "_" + str(epoch) + "_" +
                                                str(i) + ".png")
                    create_grid(
                        samples=pro_gan.gen(
                            gan_input,
                            current_depth,
                            alpha
                        ),
                        scale_factor=int(np.power(2, pro_gan.depth - current_depth - 1)),
                        img_file=gen_img_file,
                        width=int(np.sqrt(batch_sizes[current_depth])),
                    )

                # increment the alpha ticker
                ticker += 1

            stop = timeit.default_timer()
            print("Time taken for epoch: %.3f secs" % (stop - start))

            if epoch % checkpoint_factor == 0 or epoch == 0:
                gen_save_file = os.path.join(save_dir, "GAN_GEN_" + str(current_depth) + ".pth")
                dis_save_file = os.path.join(save_dir, "GAN_DIS_" + str(current_depth) + ".pth")

                th.save(pro_gan.gen.state_dict(), gen_save_file, pickle)
                th.save(pro_gan.dis.state_dict(), dis_save_file, pickle)

    print("Training completed ...")
コード例 #9
0
def train_networks(encoder,
                   ca,
                   c_pro_gan,
                   dataset,
                   validation_dataset,
                   epochs,
                   encoder_optim,
                   ca_optim,
                   fade_in_percentage,
                   batch_sizes,
                   start_depth,
                   num_workers,
                   feedback_factor,
                   log_dir,
                   sample_dir,
                   checkpoint_factor,
                   save_dir,
                   comment,
                   use_matching_aware_dis=True):
    # required only for type checking
    from networks.TextEncoder import PretrainedEncoder
    # Writer will output to ./runs/ directory by default
    writer = SummaryWriter(comment="_{}_{}".format(batch_sizes[0], comment))

    # input assertions
    assert c_pro_gan.depth == len(
        batch_sizes), "batch_sizes not compatible with depth"
    assert c_pro_gan.depth == len(
        epochs), "epochs_sizes not compatible with depth"
    assert c_pro_gan.depth == len(
        fade_in_percentage), "fip_sizes not compatible with depth"

    # put all the Networks in training mode:
    ca.train()
    c_pro_gan.gen.train()
    c_pro_gan.dis.train()

    if not isinstance(encoder, PretrainedEncoder):
        encoder.train()

    print("Starting the training process ... ")

    # create fixed_input for debugging
    temp_data = dl.get_data_loader(dataset,
                                   batch_sizes[start_depth],
                                   num_workers=num_workers)
    fixed_captions, fixed_real_images = iter(temp_data).next()
    fixed_embeddings = encoder(fixed_captions)
    fixed_embeddings = th.from_numpy(fixed_embeddings).to(device)  # shape 4096

    fixed_c_not_hats, _, _ = ca(fixed_embeddings)  # shape 1, 256

    fixed_noise = th.randn(len(fixed_captions), c_pro_gan.latent_size -
                           fixed_c_not_hats.shape[-1]).to(
                               device)  # shape batch_size, 256

    fixed_gan_input = th.cat((fixed_c_not_hats, fixed_noise), dim=-1)

    # save the fixed_images once:
    fixed_save_dir = os.path.join(sample_dir, "__Real_Info")
    os.makedirs(fixed_save_dir, exist_ok=True)
    create_grid(
        fixed_real_images,
        None,  # scale factor is not required here
        os.path.join(fixed_save_dir, "real_samples.png"),
        real_imgs=True)
    create_descriptions_file(os.path.join(fixed_save_dir, "real_captions.txt"),
                             fixed_captions, dataset)

    # create a global time counter
    global_time = time.time()

    # delete temp data loader:
    del temp_data
    for current_depth in range(start_depth, c_pro_gan.depth):

        print("\n\nCurrently working on Depth: ", current_depth)
        current_res = np.power(2, current_depth + 2)
        print("Current resolution: %d x %d" % (current_res, current_res))

        data = dl.get_data_loader(dataset, batch_sizes[current_depth],
                                  num_workers)

        ticker = 1

        gen_losses = []
        dis_losses = []
        kl_losses = []
        val_gen_losses = []
        val_dis_losses = []
        val_kl_losses = []

        for epoch in range(1, epochs[current_depth] + 1):
            start = timeit.default_timer()  # record time at the start of epoch

            print("\nEpoch: %d" % epoch)
            total_batches = len(iter(data))
            fader_point = int((fade_in_percentage[current_depth] / 100) *
                              epochs[current_depth] * total_batches)

            for (i, batch) in enumerate(data, 1):
                # calculate the alpha for fading in the layers
                alpha = ticker / fader_point if ticker <= fader_point else 1

                # extract current batch of data for training
                captions, images = batch
                if encoder_optim is not None:
                    captions = captions.to(device)

                images = images.to(device)

                # perform text_work:
                embeddings = th.from_numpy(encoder(captions)).to(device)
                if encoder_optim is None:
                    # detach the LSTM from backpropagation
                    embeddings = embeddings.detach()
                c_not_hats, mus, sigmas = ca(embeddings)

                z = th.randn(len(captions), c_pro_gan.latent_size -
                             c_not_hats.shape[-1]).to(device)

                gan_input = th.cat((c_not_hats, z), dim=-1)

                # optimize the discriminator:
                dis_loss = c_pro_gan.optimize_discriminator(
                    gan_input, images, embeddings.detach(), current_depth,
                    alpha, use_matching_aware_dis)

                dis_losses.append(dis_loss)
                writer.add_scalar(
                    f"Batch/Discriminator_Loss/{current_depth}/{epoch}",
                    dis_loss, i)

                # optimize the generator:
                z = th.randn(
                    captions.shape[0]
                    if isinstance(captions, th.Tensor) else len(captions),
                    c_pro_gan.latent_size - c_not_hats.shape[-1]).to(device)

                gan_input = th.cat((c_not_hats, z), dim=-1)

                if encoder_optim is not None:
                    encoder_optim.zero_grad()

                ca_optim.zero_grad()
                gen_loss = c_pro_gan.optimize_generator(
                    gan_input, embeddings, current_depth, alpha)
                gen_losses.append(gen_loss)
                writer.add_scalar(
                    f"Batch/Generator_Loss/{current_depth}/{epoch}", gen_loss,
                    i)
                # once the optimize_generator is called, it also sends gradients
                # to the Conditioning Augmenter and the TextEncoder. Hence the
                # zero_grad statements prior to the optimize_generator call
                # now perform optimization on those two as well
                # obtain the loss (KL divergence from ca_optim)
                kl_loss = th.mean(0.5 * th.sum(
                    (mus**2) + (sigmas**2) - th.log((sigmas**2)) - 1, dim=1))
                writer.add_scalar(f"Batch/KL_Loss/{current_depth}/{epoch}",
                                  kl_loss.item(), i)
                kl_losses.append(kl_loss.item())
                kl_loss.backward()
                ca_optim.step()
                if encoder_optim is not None:
                    encoder_optim.step()

                writer.add_image(
                    f"Batch/{current_depth}/{epoch}",
                    create_grid(
                        samples=c_pro_gan.gen(fixed_gan_input, current_depth,
                                              alpha),
                        scale_factor=int(
                            np.power(2, c_pro_gan.depth - current_depth - 1)),
                        img_file=None,  # if none we get the image grid returned
                    ),
                    i)
                # add an evaluation loop
                if i % 100 == 0:

                    v_temp_data = dl.get_data_loader(validation_dataset,
                                                     batch_sizes[start_depth],
                                                     num_workers=num_workers)
                    v_fixed_captions, v_fixed_real_images = iter(
                        v_temp_data).next()
                    v_fixed_embeddings = encoder(v_fixed_captions)
                    v_fixed_embeddings = th.from_numpy(v_fixed_embeddings).to(
                        device)  # shape 4096

                    v_fixed_c_not_hats, _, _ = ca(
                        v_fixed_embeddings)  # shape 1, 256

                    v_fixed_noise = th.randn(
                        len(v_fixed_captions), c_pro_gan.latent_size -
                        v_fixed_c_not_hats.shape[-1]).to(
                            device)  # shape batch_size, 256

                    v_fixed_gan_input = th.cat(
                        (v_fixed_c_not_hats, v_fixed_noise), dim=-1)

                    v_dis_loss = c_pro_gan.optimize_discriminator(
                        v_fixed_gan_input,
                        images,
                        embeddings.detach(),
                        current_depth,
                        alpha,
                        use_matching_aware_dis,
                        trainable=False)
                    v_gen_loss = c_pro_gan.optimize_generator(
                        v_fixed_gan_input,
                        embeddings,
                        current_depth,
                        alpha,
                        trainable=False)

                    val_dis_losses.append(v_dis_loss)
                    val_gen_losses.append(v_dis_loss)

                    writer.add_scalar(
                        f"Batch/Val/Discriminator_Loss/{current_depth}/{epoch}",
                        v_dis_loss, i)
                    writer.add_scalar(
                        f"Batch/Val/Generator_Loss/{current_depth}/{epoch}",
                        v_gen_loss, i)
                    writer.add_text(
                        f"Batch/Val/Captions/{current_depth}/{epoch}",
                        str(v_fixed_captions), i)
                    elapsed = time.time() - global_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))
                    print(
                        "Validation [%s]  batch: %d  d_loss: %f  g_loss: %f  kl_los: %f"
                        % (elapsed, i, v_dis_loss, v_gen_loss, kl_loss.item()))

                    # also write the losses to the log file:
                    os.makedirs(log_dir, exist_ok=True)
                    log_file = os.path.join(
                        log_dir, "val_loss_" + str(current_depth) + ".log")
                    with open(log_file, "a") as log:
                        log.write(
                            str(v_dis_loss) + "\t" + str(v_gen_loss) + "\t" +
                            str(kl_loss.item()) + "\n")

                    writer.add_image(
                        f'Batch/Val/{current_depth}/{epoch}',
                        create_grid(
                            samples=c_pro_gan.gen(v_fixed_gan_input,
                                                  current_depth, alpha),
                            scale_factor=int(
                                np.power(2,
                                         c_pro_gan.depth - current_depth - 1)),
                            img_file=
                            None,  # if none we get the image grid returned
                        ),
                        i)
                # provide a loss feedback
                if i % int(total_batches + 1 / feedback_factor) == 0 or i == 1:
                    elapsed = time.time() - global_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))
                    print(
                        "Elapsed [%s]  batch: %d  d_loss: %f  g_loss: %f  kl_los: %f"
                        % (elapsed, i, dis_loss, gen_loss, kl_loss.item()))

                    # also write the losses to the log file:
                    os.makedirs(log_dir, exist_ok=True)
                    log_file = os.path.join(
                        log_dir, "loss_" + str(current_depth) + ".log")
                    with open(log_file, "a") as log:
                        log.write(
                            str(dis_loss) + "\t" + str(gen_loss) + "\t" +
                            str(kl_loss.item()) + "\n")

                    # create a grid of samples and save it
                    gen_img_file = os.path.join(
                        sample_dir, "gen_" + str(current_depth) + "_" +
                        str(epoch) + "_" + str(i) + ".png")

                    create_grid(
                        samples=c_pro_gan.gen(fixed_gan_input, current_depth,
                                              alpha),
                        scale_factor=int(
                            np.power(2, c_pro_gan.depth - current_depth - 1)),
                        img_file=gen_img_file,
                    )

                # increment the ticker:
                ticker += 1
            writer.add_scalar(f"Epoch/Generator_Loss/{current_depth}",
                              np.mean(gen_losses), epoch)
            writer.add_scalar(f"Epoch/Discriminator_Loss/{current_depth}",
                              np.mean(dis_losses), epoch)
            writer.add_scalar(f"Epoch/KL_Loss/{current_depth}",
                              np.mean(kl_losses), epoch)

            writer.add_scalar(f"Epoch/Val/Generator_Loss/{current_depth}",
                              np.mean(val_gen_losses), epoch)
            writer.add_scalar(f"Epoch/Val/Discriminator_Loss/{current_depth}",
                              np.mean(val_dis_losses), epoch)
            writer.add_image(
                f'Epoch/{current_depth}',
                create_grid(
                    samples=c_pro_gan.gen(fixed_gan_input, current_depth,
                                          alpha),
                    scale_factor=int(
                        np.power(2, c_pro_gan.depth - current_depth - 1)),
                    img_file=None,  # if none we get the image grid returned
                ),
                epoch)
            writer.close()
            stop = timeit.default_timer()
            print("Time taken for epoch: %.3f secs" % (stop - start))

            if epoch % checkpoint_factor == 0 or epoch == 0:
                # save the Model
                encoder_save_file = os.path.join(
                    save_dir, "Encoder_" + str(current_depth) + ".pth")
                ca_save_file = os.path.join(
                    save_dir,
                    "Condition_Augmentor_" + str(current_depth) + ".pth")
                gen_save_file = os.path.join(
                    save_dir, "GAN_GEN_" + str(current_depth) + ".pth")
                dis_save_file = os.path.join(
                    save_dir, "GAN_DIS_" + str(current_depth) + ".pth")

                os.makedirs(save_dir, exist_ok=True)

                if encoder_optim is not None:
                    th.save(encoder.state_dict(), encoder_save_file, pickle)
                th.save(ca.state_dict(), ca_save_file, pickle)
                th.save(c_pro_gan.gen.state_dict(), gen_save_file, pickle)
                th.save(c_pro_gan.dis.state_dict(), dis_save_file, pickle)

    print("Training completed ...")
コード例 #10
0
def train_networks(encoder,
                   ca,
                   c_pro_gan,
                   dataset,
                   epochs,
                   encoder_optim,
                   ca_optim,
                   fade_in_percentage,
                   batch_sizes,
                   start_depth,
                   num_workers,
                   feedback_factor,
                   log_dir,
                   sample_dir,
                   checkpoint_factor,
                   save_dir,
                   use_matching_aware_dis=True):
    assert c_pro_gan.depth == len(
        batch_sizes), "batch_sizes not compatible with depth"
    assert c_pro_gan.depth == len(
        epochs), "epochs_sizes not compatible with depth"
    assert c_pro_gan.depth == len(
        fade_in_percentage), "fip_sizes not compatible with depth"

    print("Starting the training process ... ")
    for current_depth in range(start_depth, c_pro_gan.depth):

        print("\n\nCurrently working on Depth: ", current_depth)
        current_res = np.power(2, current_depth + 2)
        print("Current resolution: %d x %d" % (current_res, current_res))

        data = dl.get_data_loader(dataset, batch_sizes[current_depth],
                                  num_workers)

        ticker = 1

        for epoch in range(1, epochs[current_depth] + 1):
            start = timeit.default_timer()  # record time at the start of epoch

            print("\nEpoch: %d" % epoch)
            total_batches = len(iter(data))
            fader_point = int((fade_in_percentage[current_depth] / 100) *
                              epochs[current_depth] * total_batches)

            for (i, batch) in enumerate(data, 1):
                # calculate the alpha for fading in the layers
                alpha = ticker / fader_point if ticker <= fader_point else 1

                # extract current batch of data for training
                captions, images = batch

                if encoder_optim is not None:
                    captions = captions.to(device)

                images = images.to(device)

                # perform text_work:
                embeddings = encoder(captions)
                if not isinstance(embeddings, th.Tensor):
                    embeddings = th.tensor(embeddings).to(device)
                c_not_hats, mus, sigmas = ca(embeddings)

                z = th.randn(
                    captions.shape[0]
                    if isinstance(captions, th.Tensor) else len(captions),
                    c_pro_gan.latent_size - c_not_hats.shape[-1]).to(device)

                gan_input = th.cat((c_not_hats, z), dim=-1)

                # optimize the discriminator:
                dis_loss = c_pro_gan.optimize_discriminator(
                    gan_input, images, embeddings, current_depth, alpha,
                    use_matching_aware_dis)

                # optimize the generator:
                z = th.randn(
                    captions.shape[0]
                    if isinstance(captions, th.Tensor) else len(captions),
                    c_pro_gan.latent_size - c_not_hats.shape[-1]).to(device)

                gan_input = th.cat((c_not_hats, z), dim=-1)

                if encoder_optim is not None:
                    encoder_optim.zero_grad()

                ca_optim.zero_grad()
                gen_loss = c_pro_gan.optimize_generator(
                    gan_input, embeddings, current_depth, alpha)

                # once the optimize_generator is called, it also sends gradients
                # to the Conditioning Augmenter and the TextEncoder. Hence the
                # zero_grad statements prior to the optimize_generator call
                # now perform optimization on those two as well
                # obtain the loss (KL divergence from ca_optim)
                kl_loss = th.mean(0.5 * th.sum(
                    (mus**2) + (sigmas**2) - th.log((sigmas**2)) - 1, dim=1))
                kl_loss.backward()
                ca_optim.step()
                if encoder_optim is not None:
                    encoder_optim.step()

                # provide a loss feedback
                if i % int(total_batches / feedback_factor) == 0 or i == 1:
                    print("batch: %d  d_loss: %f  g_loss: %f  kl_los: %f" %
                          (i, dis_loss, gen_loss, kl_loss.item()))

                    # also write the losses to the log file:
                    log_file = os.path.join(
                        log_dir, "loss_" + str(current_depth) + ".log")
                    with open(log_file, "a") as log:
                        log.write(
                            str(dis_loss) + "\t" + str(gen_loss) + "\t" +
                            str(kl_loss.item()) + "\n")

                    # create a grid of samples and save it
                    gen_img_file = os.path.join(
                        sample_dir, "gen_" + str(current_depth) + "_" +
                        str(epoch) + "_" + str(i) + ".png")
                    orig_img_file = os.path.join(
                        sample_dir, "orig_" + str(current_depth) + "_" +
                        str(epoch) + "_" + str(i) + ".png")
                    description_file = os.path.join(
                        sample_dir, "desc_" + str(current_depth) + "_" +
                        str(epoch) + "_" + str(i) + ".txt")
                    create_grid(
                        samples=c_pro_gan.gen(gan_input, current_depth, alpha),
                        scale_factor=int(
                            np.power(2, c_pro_gan.depth - current_depth - 1)),
                        img_file=gen_img_file,
                        width=int(np.sqrt(batch_sizes[current_depth])),
                    )

                    create_grid(samples=images,
                                scale_factor=int(
                                    np.power(
                                        2,
                                        c_pro_gan.depth - current_depth - 1)),
                                img_file=orig_img_file,
                                width=int(np.sqrt(batch_sizes[current_depth])),
                                real_imgs=True)

                    create_descriptions_file(description_file, captions,
                                             dataset)

                # increment the ticker:
                ticker += 1

            stop = timeit.default_timer()
            print("Time taken for epoch: %.3f secs" % (stop - start))

            if epoch % checkpoint_factor == 0 or epoch == 0:
                # save the Model
                encoder_save_file = os.path.join(
                    save_dir, "Encoder_" + str(current_depth) + ".pth")
                ca_save_file = os.path.join(
                    save_dir,
                    "Condition_Augmentor_" + str(current_depth) + ".pth")
                gen_save_file = os.path.join(
                    save_dir, "GAN_GEN_" + str(current_depth) + ".pth")
                dis_save_file = os.path.join(
                    save_dir, "GAN_DIS_" + str(current_depth) + ".pth")

                if encoder_optim is not None:
                    th.save(encoder.state_dict(), encoder_save_file, pickle)
                th.save(ca.state_dict(), ca_save_file, pickle)
                th.save(c_pro_gan.gen.state_dict(), gen_save_file, pickle)
                th.save(c_pro_gan.dis.state_dict(), dis_save_file, pickle)

    print("Training completed ...")
コード例 #11
0
class PixelDA(_DLalgo):
    """
	Paradigm of GAN (keras implementation)

	1. Construct D
		1a) Compile D
	2. Construct G
	3. Set D.trainable = False
	4. Stack G and D, to construct GAN (combined model)
		 4a) Compile GAN
	
	Approved by fchollet: "the process you describe is in fact correct."

	See issue #4674 keras: https://github.com/keras-team/keras/issues/4674
	"""
    def __init__(self,
                 noise_size=(100, ),
                 use_PatchGAN=False,
                 use_Wasserstein=True,
                 batch_size=64,
                 **kwargs):
        # Input shape
        self.dataset_name = "CT"  #"MNIST" # "CT"

        if self.dataset_name == "MNIST":
            self.img_shape = (32, 32, 3)
            self.num_classes = 10
        elif self.dataset_name == "CT":
            self.img_shape = (128, 128, 1)
        else:
            raise ValueError(
                "Only support two datasets for now. ('CT', 'MNIST')")

        self.img_rows, self.img_cols, self.channels = self.img_shape

        self.noise_size = noise_size  #(100,)
        self.batch_size = batch_size
        # Loss weights
        self.lambda_adv = 5  #before Exp5:10 # Exp1: 20 #17 MNIST-M
        self.lambda_seg = 1
        # Number of filters in first layer of discriminator and Segmenter
        self.df = 64
        self.sf = 64

        self.normalize_G = False
        self.normalize_D = False
        self.normalize_S = False

        # Number of residual blocks in the generator
        self.residual_blocks = 12  #before Exp5: 6 #17 # 6 # NEW TODO 14/5/2018
        self.use_PatchGAN = use_PatchGAN  #False
        self.use_Wasserstein = use_Wasserstein
        self.use_He_initialization = False
        self.my_initializer = lambda: he_normal(
        ) if self.use_He_initialization else "glorot_uniform"  # TODO

        if self.use_PatchGAN:
            # Calculate output shape of D (PatchGAN)
            patch = int(self.img_rows / 2**4)
            self.disc_patch = (patch, patch, 1)

        if self.use_Wasserstein:
            self.critic_steps = 5  #5 #7 #10
        else:
            self.critic_steps = 1

        self.GRADIENT_PENALTY_WEIGHT = 10  #10#5 #10 As the paper

        ##### Set up the other attributes
        for key in kwargs:
            setattr(self, key, kwargs[key])

    def build_all_model(self):

        # optimizer = Adam(0.0002, 0.5)
        optimizer = Adam(0.0001, beta_1=0.5, beta_2=0.9)  # Exp4
        # optimizer = Adam(0.0001, beta_1=0.0, beta_2=0.9) # Exp3 of CT2XperCT

        # Build and compile the discriminators
        self.discriminator = self.build_discriminator()
        self.discriminator.name = "Discriminator"

        img_A = Input(shape=self.img_shape, name='source_image')  # real A
        img_B = Input(shape=self.img_shape, name='target_image')  # real B
        fake_img = Input(shape=self.img_shape, name="fake_image")  # fake B

        # We also need to generate weighted-averages of real and generated samples, to use for the gradient norm penalty.
        avg_img = RandomWeightedAverage()([img_B, fake_img])

        real_img_rating = self.discriminator(img_B)
        fake_img_rating = self.discriminator(fake_img)
        avg_img_output = self.discriminator(avg_img)

        # The gradient penalty loss function requires the input averaged samples to get gradients. However,
        # Keras loss functions can only have two arguments, y_true and y_pred. We get around this by making a partial()
        # of the function with the averaged samples here.
        partial_gp_loss = partial(
            gradient_penalty_loss,
            averaged_samples=avg_img,
            gradient_penalty_weight=self.GRADIENT_PENALTY_WEIGHT)
        partial_gp_loss.__name__ = 'gradient_penalty'  # Functions need names or Keras will throw an error

        if self.use_Wasserstein:
            self.discriminator_model = Model(
                inputs=[img_B, fake_img],  #, avg_img
                # loss_weights=[1,1,1], # useless, since we have multiply the penalization by GRADIENT_PENALTY_WEIGHT=10
                outputs=[real_img_rating, fake_img_rating, avg_img_output])
            self.discriminator_model.compile(
                loss=[wasserstein_loss, wasserstein_loss, partial_gp_loss],
                optimizer=optimizer,
                metrics=[my_critic_acc])
        else:
            self.discriminator.compile(loss='mse',
                                       optimizer=optimizer,
                                       metrics=['accuracy'])

        # For the combined model we will only train the generator and Segmenter
        self.discriminator.trainable = False

        # Build the generator
        self.generator = self.build_generator()
        self.generator.name = "Generator"
        # Build the task (segmentation) network
        self.seg = self.build_segmenter()
        self.seg.name = "Segmenter"
        # Input images from both domains

        # Input noise
        noise = Input(shape=self.noise_size, name='noise_input')

        # Translate images from domain A to domain B
        fake_B = self.generator([img_A, noise])

        # Segment the translated image
        mask_pred = self.seg(fake_B)

        # Discriminator determines validity of translated images
        valid = self.discriminator(fake_B)  # fake_B_rating
        if self.use_Wasserstein:
            self.combined = Model(inputs=[img_A, noise],
                                  outputs=[valid, mask_pred])
            self.combined.compile(
                optimizer=optimizer,
                loss=[wasserstein_loss, dice_coef_loss],
                loss_weights=[self.lambda_adv, self.lambda_seg],
                metrics=['accuracy'])
        else:
            self.combined = Model([img_A, noise], [valid, mask_pred])
            self.combined.compile(
                loss=['mse', dice_coef_loss],
                loss_weights=[self.lambda_adv, self.lambda_seg],
                optimizer=optimizer,
                metrics=['accuracy'])

    def load_dataset(self,
                     dataset_name="CT",
                     domain_A_folder="output8",
                     domain_B_folder="output5_x_128"):
        self.dataset_name = dataset_name

        if self.dataset_name == "MNIST":
            # Configure MNIST and MNIST-M data loader
            self.data_loader = DataLoader(img_res=(self.img_rows,
                                                   self.img_cols))
        elif self.dataset_name == "CT":
            bodys_filepath_A = "/home/lulin/na4/src/output/{}/train/bodys.npy".format(
                domain_A_folder)
            masks_filepath_A = "/home/lulin/na4/src/output/{}/train/liver_masks.npy".format(
                domain_A_folder)
            self.Dataset_A = MyDataset(
                paths=[bodys_filepath_A, masks_filepath_A],
                batch_size=self.batch_size,
                augment=False,
                seed=17,
                domain="A")

            bodys_filepath_B = "/home/lulin/na4/src/output/{}/train/bodys.npy".format(
                domain_B_folder)
            masks_filepath_B = "/home/lulin/na4/src/output/{}/train/liver_masks.npy".format(
                domain_B_folder)
            self.Dataset_B = MyDataset(
                paths=[bodys_filepath_B, masks_filepath_B],
                batch_size=self.batch_size,
                augment=False,
                seed=17,
                domain="B")
        else:
            pass

    def build_generator(self):
        """Resnet Generator"""
        def residual_block(layer_input, normalization=self.normalize_G):
            """Residual block described in paper"""
            d = Conv2D(64,
                       kernel_size=3,
                       strides=1,
                       padding='same',
                       kernel_initializer=self.my_initializer())(layer_input)
            if normalization:
                d = InstanceNormalization()(d)
                # d = BatchNormalization(momentum=0.8)(d) #  6/5/2018
            d = Activation('relu')(d)
            d = Conv2D(64, kernel_size=3, strides=1, padding='same')(d)
            if normalization:
                d = InstanceNormalization()(d)
                # d = BatchNormalization(momentum=0.8)(d) #  6/5/2018
            d = Add()([d, layer_input])
            return d

        # Image input
        img = Input(shape=self.img_shape, name='image_input')

        ## Noise input
        noise = Input(shape=self.noise_size, name='noise_input')
        noise_layer = Dense(self.img_rows * self.img_cols,
                            activation="relu",
                            kernel_initializer=self.my_initializer())(noise)
        noise_layer = Reshape((self.img_rows, self.img_cols, 1))(noise_layer)
        conditioned_img = keras.layers.concatenate([img, noise_layer])
        # keras.layers.concatenate

        # l1 = Conv2D(64, kernel_size=3, padding='same', activation='relu')(img)
        l1 = Conv2D(64,
                    kernel_size=3,
                    padding='same',
                    activation='relu',
                    kernel_initializer=self.my_initializer())(conditioned_img)

        # Propogate signal through residual blocks
        r = residual_block(l1)
        for _ in range(self.residual_blocks - 1):
            r = residual_block(r)

        output_img = Conv2D(self.channels,
                            kernel_size=3,
                            padding='same',
                            activation='tanh')(r)

        return Model([img, noise], output_img)

    def build_discriminator(self):
        def d_layer(layer_input,
                    filters,
                    f_size=4,
                    normalization=self.normalize_D):
            """Discriminator layer"""
            d = Conv2D(filters, kernel_size=f_size, strides=2,
                       padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if normalization:
                d = InstanceNormalization()(d)
            return d

        img = Input(shape=self.img_shape, name="image")

        d1 = d_layer(img, self.df, normalization=False)
        d2 = d_layer(d1, self.df * 2, normalization=self.normalize_D)
        d3 = d_layer(d2, self.df * 4, normalization=self.normalize_D)
        d4 = d_layer(d3, self.df * 8, normalization=self.normalize_D)

        if self.use_PatchGAN:  # NEW 7/5/2018
            validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)
        else:
            if self.use_Wasserstein:  # NEW 8/5/2018
                validity = Dense(1, activation=None)(
                    Flatten()(d4))  # he_normal ??
            else:
                validity = Dense(1, activation='sigmoid')(Flatten()(d4))

        return Model(img, validity)

    def build_segmenter(self):
        """Segmenter layer"""
        model = UNet(self.img_shape,
                     depth=3,
                     dropout=0.5,
                     start_ch=32,
                     upconv=False,
                     batchnorm=self.normalize_S)

        return model

    def load_pretrained_weights(self,
                                weights_path="../Weights/all_weights.h5"):
        print("Loading pretrained weights from path: {} ...".format(
            weights_path))

        self.combined.load_weights(weights_path, by_name=True)
        print("+ Done.")

    def summary(self):
        print("=" * 50)
        print("Discriminator summary:")
        self.discriminator.summary()
        print("=" * 50)
        print("Generator summary:")
        self.generator.summary()
        print("=" * 50)
        print("Segmenter summary:")
        self.seg.summary()

        if self.use_Wasserstein:
            print("=" * 50)
            print("Discriminator model summary:")
            self.discriminator_model.summary()
        print("=" * 50)
        print("Combined model summary:")
        self.combined.summary()

    def write_tensorboard_graph(self,
                                to_dir="../logs",
                                save_png2dir="../Model_graph"):
        if not os.path.exists(save_png2dir):
            os.makedirs(save_png2dir)
        tensorboard = keras.callbacks.TensorBoard(log_dir=to_dir,
                                                  histogram_freq=0,
                                                  write_graph=True,
                                                  write_images=False)
        # tensorboard.set_model(self.combined)
        tensorboard.set_model(self.discriminator_model)
        try:
            plot_model(self.combined,
                       to_file=os.path.join(save_png2dir,
                                            "Combined_model.png"))
            plot_model(self.discriminator_model,
                       to_file=os.path.join(save_png2dir,
                                            "Discriminator_model.png"))
        except:
            pass

    def train(self,
              iterations,
              sample_interval=50,
              save_sample2dir="../samples/exp0",
              save_weights_path='../Weights/all_weights.h5',
              save_model=False,
              time_monitor=True):
        dirpath = "/".join(save_weights_path.split("/")[:-1])
        if not os.path.exists(dirpath):
            os.makedirs(dirpath)
        self.save_config(save2path=os.path.join(dirpath, "config.dill"),
                         verbose=True)

        # segmentation accuracy on 100 last batches of domain B
        test_accs = []

        ## Monitor to save model weights Lu
        best_test_cls_acc = 0
        second_best_cls_acc = -1

        st = time()
        elapsed_time = 0
        for iteration in range(iterations):
            if time_monitor and (iteration % 10 == 0) and (iteration > 0):
                et = time()
                elapsed_time = et - st
                st = et

            # ---------------------
            #  Train Discriminator
            # ---------------------
            # n_sample = half_batch # imgs_A.shape[0]

            for _ in range(self.critic_steps):

                if self.dataset_name == "MNIST":
                    imgs_A, _ = self.data_loader.load_data(
                        domain="A", batch_size=self.batch_size)
                    imgs_B, _ = self.data_loader.load_data(
                        domain="B", batch_size=self.batch_size)
                elif self.dataset_name == "CT":
                    imgs_A, _ = self.Dataset_A.next()
                    imgs_B, _ = self.Dataset_B.next()

                noise_prior = np.random.normal(
                    0, 1, (self.batch_size, self.noise_size[0]))
                # noise_prior = np.random.rand(half_batch, self.noise_size[0]) #  6/5/2018

                # Translate images from domain A to domain B
                fake_B = self.generator.predict([imgs_A, noise_prior])
                if self.use_PatchGAN:
                    valid = np.ones((self.batch_size, ) + self.disc_patch)
                    fake = np.zeros((self.batch_size, ) + self.disc_patch)
                else:
                    if self.use_Wasserstein:
                        valid = np.ones((self.batch_size, 1))
                        fake = -valid  #np.ones((half_batch, 1)) # = - valid ?
                        dummy_y = np.zeros((self.batch_size, 1))  # NEW
                    else:
                        valid = np.ones((self.batch_size, 1))
                        fake = np.zeros((self.batch_size, 1))

                # Train the discriminators (original images = real / translated = Fake)
                # d_loss_real = self.discriminator.train_on_batch(imgs_B, valid)
                # d_loss_fake = self.discriminator.train_on_batch(fake_B, fake)
                # d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
                if self.use_Wasserstein:
                    d_loss = self.discriminator_model.train_on_batch(
                        [imgs_B, fake_B], [valid, fake, dummy_y])
                    # d_loss = self.discriminator.train_on_batch(D_train_images, D_train_label, dummy_y)
                else:
                    D_train_images = np.vstack([imgs_B,
                                                fake_B])  # 6/5/2018 NEW
                    D_train_label = np.vstack([valid, fake])  # 6/5/2018 NEW
                    d_loss = self.discriminator.train_on_batch(
                        D_train_images, D_train_label)

            # --------------------------------
            #  Train Generator and Segmenter
            # --------------------------------
            # Sample a batch of images from both domains

            if self.dataset_name == "MNIST":
                imgs_A, _, masks_A = self.data_loader.load_data(
                    domain="A", batch_size=self.batch_size, return_mask=True)
                imgs_B, _, masks_B = self.data_loader.load_data(
                    domain="B", batch_size=self.batch_size, return_mask=True)

            elif self.dataset_name == "CT":
                imgs_A, masks_A = self.Dataset_A.next()
                imgs_B, masks_B = self.Dataset_B.next()
            else:
                pass

            # One-hot encoding of labels
            # labels_A = to_categorical(labels_A, num_classes=self.num_classes)

            # The generators want the discriminators to label the translated images as real
            if self.use_PatchGAN:
                valid = np.ones((self.batch_size, ) + self.disc_patch)
            else:
                valid = np.ones((self.batch_size, 1))

            #
            noise_prior = np.random.normal(
                0, 1, (self.batch_size, self.noise_size[0]))
            # noise_prior = np.random.rand(batch_size, self.noise_size[0]) #  6/5/2018

            # Train the generator and Segmenter
            g_loss = self.combined.train_on_batch([imgs_A, noise_prior],
                                                  [valid, masks_A])

            #-----------------------
            # Evaluation (domain B)
            #-----------------------

            pred_B = self.seg.predict(imgs_B)
            # test_acc = np.mean(np.argmax(pred_B, axis=1) == labels_B)

            _, test_acc = dice_predict(masks_B, pred_B)

            # Add accuracy to list of last 100 accuracy measurements
            test_accs.append(test_acc)
            if len(test_accs) > 100:
                test_accs.pop(0)

            if iteration % 10 == 0:

                if self.use_Wasserstein:
                    d_real_acc = 100 * float(d_loss[4])
                    d_fake_acc = 100 * float(d_loss[5])
                    d_train_acc = 100 * (float(d_loss[4]) +
                                         float(d_loss[5])) / 2
                else:
                    d_train_acc = 100 * float(d_loss[1])

                gen_loss = g_loss[1]

                clf_train_acc = 100 * float(g_loss[-1])
                clf_train_loss = g_loss[2]

                current_test_acc = 100 * float(test_acc)
                test_mean_acc = 100 * float(np.mean(test_accs))

                g_loss.append(current_test_acc)
                g_loss.append(test_mean_acc)

                with open(os.path.join(dirpath, "D_Losses.csv"),
                          "ab") as csv_file:
                    np.savetxt(csv_file,
                               np.array(d_loss).reshape(1, -1),
                               delimiter=",")
                with open(os.path.join(dirpath, "G_Losses.csv"),
                          "ab") as csv_file:
                    np.savetxt(csv_file,
                               np.array(g_loss).reshape(1, -1),
                               delimiter=",")

                message = "{} : [D - loss: {:.5f}, GP_loss: {:.5f}, (+) acc: {:.2f}%, (-) acc: {:.2f}%, acc: {:.2f}%], [G - loss: {:.5f}], [seg - loss: {:.5f}, acc: {:.2f}%, test_dice: {:.2f}% ({:.2f}%)]".format(
                    iteration, d_loss[0], d_loss[3], d_real_acc, d_fake_acc,
                    d_train_acc, gen_loss, clf_train_loss, clf_train_acc,
                    current_test_acc, test_mean_acc)

                if test_mean_acc > best_test_cls_acc:
                    second_best_cls_acc = best_test_cls_acc
                    best_test_cls_acc = test_mean_acc

                    if save_model:
                        self.combined.save(save_weights_path)
                    else:
                        self.combined.save_weights(save_weights_path)
                    message += "  (best)"

                elif test_mean_acc > second_best_cls_acc:
                    second_best_cls_acc = test_mean_acc

                    if save_model:
                        self.combined.save(save_weights_path)
                    else:
                        self.combined.save_weights(save_weights_path[:-3] +
                                                   "_bis.h5")
                    message += "  (second best)"

                else:
                    pass
                if time_monitor:
                    message += "... {:.2f}s.".format(elapsed_time)
                print(message)

            # If at save interval => save generated image samples
            if iteration % sample_interval == 0:
                self.sample_images(iteration, save2dir=save_sample2dir)

        #### NEW 24/5/2018
        self.combined.save_weights(save_weights_path[:-3] + "_final.h5")

    def sample_images(self, iterations, save2dir="../samples"):
        if not os.path.exists(save2dir):
            os.makedirs(save2dir)

        if self.dataset_name == "MNIST":
            r, c = 5, 10
            imgs_A, _ = self.data_loader.load_data(domain="A", batch_size=c)
        elif self.dataset_name == "CT":
            r, c = 2, 5
            assert r == 2
            imgs_A, masks_A = self.Dataset_A.next()
            imgs_A = imgs_A[:c]
            masks_A = masks_A[:c]
            masks_A = np.squeeze(masks_A)
            # raise ValueError("Not implemented error.")
        else:
            pass

        n_sample = imgs_A.shape[0]  # == c

        gen_imgs = imgs_A
        for i in range(r - 1):
            noise_prior = np.random.normal(
                0, 1, (n_sample, self.noise_size[0]))  # TODO
            # noise_prior = np.random.normal(0,3, (n_sample, self.noise_size[0])) # TODO # 16/5/2018
            # noise_prior = np.random.rand(n_sample, self.noise_size[0]) # TODO 6/5/2018

            # Translate images to the other domain
            fake_B = self.generator.predict([imgs_A, noise_prior])
            # print(fake_B.shape)
            gen_imgs = np.concatenate([gen_imgs, fake_B])

        if self.dataset_name == "MNIST":
            # Rescale images from (-1, 1) to (0, 1)
            gen_imgs = 0.5 * gen_imgs + 0.5
        elif self.dataset_name == "CT":
            gen_imgs = np.squeeze(gen_imgs)
        # print(gen_imgs.shape)
        #titles = ['Original', 'Translated']

        # TODO
        r = 4
        fig, axs = plt.subplots(r, c, figsize=(3 * c, 3 * r))

        cnt = 0
        for i in range(2):  # replace r by 2
            for j in range(c):
                axs[i, j].imshow(gen_imgs[cnt], cmap="gray")
                #axs[i, j].set_title(titles[i])
                axs[i, j].axis('off')
                cnt += 1
        for j in range(c):
            ############ TODO  ############
            # visualize image with adaptive histogram
            axs[2, j].imshow(apply_adapt_hist()(gen_imgs[j + c * 1]),
                             cmap="gray")
            axs[2, j].axis('off')
            # mask image with ground truth mask
            axs[3, j].imshow(apply_adapt_hist()(gen_imgs[j + c * 1]),
                             cmap="gray")
            axs[3, j].imshow(masks_A[j],
                             aspect="equal",
                             cmap="Blues",
                             alpha=0.4)
            axs[3, j].axis('off')

        fig.savefig(os.path.join(save2dir, "{}.png".format(iterations)))
        plt.close()

    def train_segmenter(self,
                        iterations,
                        batch_size=32,
                        noise_range=5,
                        save_weights_path=None):
        raise ValueError("Not modified yet.")
        if save_weights_path is not None:
            dirpath = "/".join(save_weights_path.split("/")[:-1])
            if not os.path.exists(dirpath):
                os.makedirs(dirpath)
        optimizer = Adam(0.000001, beta_1=0.0, beta_2=0.9)

        # Input noise
        noise = Input(shape=self.noise_size, name='noise_input_seg')
        img_A = Input(shape=self.img_shape, name='source_image_seg')
        # Translate images from domain A to domain B
        fake_B = self.generator([img_A, noise])

        # Segment the translated image
        mask_pred = self.seg(fake_B)

        self.generator.trainable = False

        self.segmentation_model = Model(inputs=[img_A, noise],
                                        outputs=[mask_pred])
        self.segmentation_model.compile(loss=dice_coef_loss,
                                        optimizer=optimizer,
                                        metrics=["acc", dice_coef])

        self.segmentation_model.name = "U-net (freeze Generator)"
        self.segmentation_model.summary()

        best_test_dice = 0.0
        second_best_test_dice = -1.0
        collections = []
        for e in range(iterations):
            noise = (2 * np.random.random(
                (batch_size, self.noise_size[0])) - 1) * noise_range

            images_A, _, mask_A = self.data_loader.load_data(
                domain="A", batch_size=batch_size, return_mask=True)

            s_loss = self.segmentation_model.train_on_batch([images_A, noise],
                                                            mask_A)

            if e % 100 == 0:
                images_B, _, mask_B = self.data_loader.load_data(
                    domain="B", batch_size=batch_size, return_mask=True)

                pred_mask_B = self.seg.predict(images_B)
                _, current_test_dice = dice_predict(mask_B, pred_mask_B)

                if len(collections) >= 100:
                    collections.pop(0)
                collections.append(current_test_dice)
                mean_dice = np.mean(collections)
                message = "{} dice loss: {:.3f}; acc: {:.5f}; mean dice (test): {:.3f}".format(
                    e, 100 * s_loss[0], s_loss[1], 100 * mean_dice)

                if mean_dice > best_test_dice:
                    best_test_dice = mean_dice
                    message += "  (best)"
                    if save_weights_path is not None:
                        self.segmentation_model.save_weights(save_weights_path)

                elif mean_dice > second_best_test_dice:
                    second_best_test_dice = mean_dice
                    message += "  (second best)"
                    if save_weights_path is not None:
                        self.segmentation_model.save_weights(
                            save_weights_path[:-3] + "_bis.h5")
                else:
                    pass
                print(message)

        return

    def deploy_transform(self,
                         save2file="../domain_adapted/generated.npy",
                         stop_after=None):
        raise ValueError("Not modified yet.")
        dirpath = "/".join(save2file.split("/")[:-1])
        if not os.path.exists(dirpath):
            os.makedirs(dirpath)

        dirname = "/".join(save2file.split("/")[:-1])

        if stop_after is not None:
            predict_steps = int(stop_after / 32)
        else:
            predict_steps = stop_after

        noise_vec = np.random.normal(0, 1, self.noise_size[0])
        assert 1 == 2
        # np.random.rand(n_sample, self.noise_size[0]) # TODO 6/5/2018

        print("Performing Pixel-level domain adaptation on original images...")
        adaptaed_images = self.generator.predict(
            [
                self.data_loader.mnist_X[:32 * predict_steps],
                np.tile(noise_vec, (32 * predict_steps, 1))
            ],
            batch_size=32)  #, steps=predict_steps
        # self.data_loader.mnistm_X[:stop_after]
        print("+ Done.")
        print("Saving transformed images to file {}".format(save2file))
        np.save(save2file, adaptaed_images)

        noise_vec_filepath = os.path.join(dirname, "noise_vectors.npy")
        print(
            "Saving random noise (seed) to file {}".format(noise_vec_filepath))
        np.save(noise_vec_filepath, noise_vec)

        print("+ All done.")

    def deploy_debug(self,
                     save2file="../domain_adapted/debug.npy",
                     sample_size=100,
                     noise_number=128,
                     use_sobol=False,
                     use_linear=False,
                     use_sphere=False,
                     use_uniform_linear=False,
                     use_zeros=False,
                     seed=17):
        raise ValueError("Not modified yet.")
        dirpath = "/".join(save2file.split("/")[:-1])
        if not os.path.exists(dirpath):
            os.makedirs(dirpath)

        dirname = "/".join(save2file.split("/")[:-1])

        np.random.seed(seed=seed)

        # np.random.rand(n_sample, self.noise_size[0]) # TODO 6/5/2018

        print("Performing Pixel-level domain adaptation on original images...")
        # noise_vec = np.random.normal(0,1, (sample_size, self.noise_size[0]))
        # collections = []
        # for i in range(sample_size):
        # 	adaptaed_images = self.generator.predict([self.data_loader.mnist_X[:15], np.tile(noise_vec[i], (15,1))], batch_size=15)

        # 	collections.append(adaptaed_images)
        collections = []
        imgs_A, labels_A = self.data_loader.load_data(domain="A",
                                                      batch_size=sample_size)

        for i in tqdm(range(sample_size)):
            if use_sobol:
                noise_vec = 5 * (2 * i4_sobol_generate(
                    self.noise_size[0], noise_number, i * noise_number).T - 1)
            elif use_linear:
                tangents = 3.0 * (2 * np.random.random((noise_number, 1)) - 1)
                noise_vec = np.ones(
                    (noise_number, self.noise_size[0])) * tangents
            elif use_sphere:
                noise_vec = 2 * (np.random.random(
                    (noise_number, self.noise_size[0])) - 1)
                norm_vec = np.linalg.norm(noise_vec, axis=-1)
                noise_vec = noise_vec / norm_vec[:, np.newaxis]
            elif use_uniform_linear:
                tangents = 10.0 * np.linspace(-1, 1, noise_number)[:,
                                                                   np.newaxis]
                noise_vec = np.ones(
                    (noise_number, self.noise_size[0])) * tangents
            elif use_zeros:

                noise_vec = np.zeros((noise_number, self.noise_size[0]))

            else:
                noise_vec = np.random.normal(
                    0, 3, (noise_number, self.noise_size[0]))
            adaptaed_images = self.generator.predict(
                [np.tile(imgs_A[i], (noise_number, 1, 1, 1)), noise_vec],
                batch_size=32)
            collections.append(adaptaed_images)

        print("+ Done.")

        print("Saving transformed images to file {}".format(save2file))
        np.save(save2file, np.stack(collections))
        print("+ All done.")

    def deploy_segmentation(self, batch_size=32):
        print("Predicting ... ")
        if self.dataset_name == "MNIST":
            pred_B = self.seg.predict(self.data_loader.mnistm_X,
                                      batch_size=batch_size)
            precision = (np.argmax(pred_B,
                                   axis=1) == self.data_loader.mnistm_y)
            Moy = np.mean(precision)
            Std = np.std(precision)
        elif self.dataset_name == "CT":
            pred_B = self.seg.predict(self.Dataset_B.X_train,
                                      batch_size=batch_size)
            gt_B = self.Dataset_B.Y_train
            dice_all, dice_mean = dice_predict(gt_B, pred_B)
            Moy = dice_mean
            Std = np.std(dice_all)
        print("+ Done.")

        N_samples = len(pred_B)

        lower_bound = Moy - 2.576 * Std / np.sqrt(N_samples)
        upper_bound = Moy + 2.576 * Std / np.sqrt(N_samples)
        print("=" * 50)
        print("Unsupervised MNIST-M segmentation accuracy : {}".format(Moy))
        print("Confidence interval (99%) [{}, {}]".format(
            lower_bound, upper_bound))
        print("Length of confidence interval 99%: {}".format(upper_bound -
                                                             lower_bound))
        print("=" * 50)
        print("+ All done.")

    def deploy_demo_only(self,
                         save2file="../domain_adapted/WGAN_GP/Exp4/demo.npy",
                         sample_size=25,
                         noise_number=512,
                         linspace_size=10.0):
        raise ValueError("Not modified yet.")
        collections = []
        imgs_A, labels_A = self.data_loader.load_data(domain="A",
                                                      batch_size=sample_size)

        tangents = linspace_size * np.linspace(-1, 1, noise_number)[:,
                                                                    np.newaxis]
        noise_vec = np.ones((noise_number, self.noise_size[0])) * tangents

        for i in tqdm(range(noise_number)):
            adaptaed_images = self.generator.predict(
                [imgs_A, np.tile(noise_vec[i], (sample_size, 1))],
                batch_size=sample_size)
            collections.append(adaptaed_images)
        print("+ Done.")

        print("Saving transformed images to file {}".format(save2file))
        np.save(save2file, np.stack(collections))
        print("+ All done.")

    def deploy_cherry_pick(
            self,
            save2file="../domain_adapted/WGAN_GP/Exp4/demo_cherry_picked.png",
            sample_size=25,
            noise_number=25,
            linspace_size=5.0):
        raise ValueError("Not modified yet.")
        collections = []
        imgs_A, labels_A = self.data_loader.load_data(domain="A",
                                                      batch_size=sample_size)
        assert noise_number == sample_size

        tangents = linspace_size * np.linspace(-1, 1, noise_number)[:,
                                                                    np.newaxis]
        noise_vec = np.ones((noise_number, self.noise_size[0])) * tangents

        np.random.shuffle(noise_vec)  # shuffle background color !

        adaptaed_images = self.generator.predict([imgs_A, noise_vec],
                                                 batch_size=sample_size)
        adaptaed_images = (adaptaed_images + 1) / 2
        print("+ Done.")

        print("Saving transformed images to file {}".format(save2file))
        r = 5
        c = 5
        fig, axs = plt.subplots(r, c, figsize=(5 * c, 5 * r))
        for j in range(r):
            for i in range(c):
                axs[j, i].imshow(adaptaed_images[c * j + i])
                axs[j, i].axis('off')
        plt.savefig(save2file)
        plt.close()
        print("+ All done.")
コード例 #12
0
class PixelDA(object):
    """
	Paradigm of GAN (keras implementation)

	1. Construct D
		1a) Compile D
	2. Construct G
	3. Set D.trainable = False
	4. Stack G and D, to construct GAN (combined model)
		 4a) Compile GAN
	
	Approved by fchollet: "the process you describe is in fact correct."

	See issue #4674 keras: https://github.com/keras-team/keras/issues/4674
	"""
    def __init__(self,
                 noise_size=100,
                 use_PatchGAN=False,
                 use_Wasserstein=True):
        # Input shape
        self.img_rows = 32
        self.img_cols = 32
        self.channels = 3
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.num_classes = 10
        self.noise_size = (noise_size, )  #(100,)

        # Calculate output shape of D (PatchGAN)
        patch = int(self.img_rows / 2**4)
        self.disc_patch = (patch, patch, 1)

        # Number of residual blocks in the generator
        self.residual_blocks = 6
        self.use_PatchGAN = use_PatchGAN  #False
        self.use_Wasserstein = use_Wasserstein

        if self.use_Wasserstein:
            self.critic_steps = 10
        else:
            self.critic_steps = 1

        self.GRADIENT_PENALTY_WEIGHT = 10  # As the paper

    def build_all_model(self, batch_size=32):
        self.batch_size = batch_size
        # Loss weights
        lambda_adv = 7
        lambda_clf = 1
        # optimizer = Adam(0.0002, 0.5)
        optimizer = Adam(0.0001, beta_1=0.5, beta_2=0.9)
        # optimizer = SGD(lr=0.0001)
        # optimizer = RMSprop(lr=1e-5)

        # Number of filters in first layer of discriminator and classifier
        self.df = 128  # NEW TODO #64 11/5/2018
        self.cf = 64

        # Build and compile the discriminators
        self.discriminator = self.build_discriminator()
        self.discriminator.name = "Discriminator"

        img_A = Input(shape=self.img_shape, name='source_image')  # real A
        img_B = Input(shape=self.img_shape, name='target_image')  # real B
        fake_img = Input(shape=self.img_shape, name="fake_image")  # fake B

        # We also need to generate weighted-averages of real and generated samples, to use for the gradient norm penalty.
        avg_img = RandomWeightedAverage(batch_size=self.batch_size)(
            [img_B, fake_img])

        real_img_rating = self.discriminator(img_B)  # TODO img_A
        fake_img_rating = self.discriminator(fake_img)
        avg_img_output = self.discriminator(avg_img)

        # The gradient penalty loss function requires the input averaged samples to get gradients. However,
        # Keras loss functions can only have two arguments, y_true and y_pred. We get around this by making a partial()
        # of the function with the averaged samples here.
        partial_gp_loss = partial(
            gradient_penalty_loss,
            averaged_samples=avg_img,
            gradient_penalty_weight=self.GRADIENT_PENALTY_WEIGHT)
        partial_gp_loss.__name__ = 'gradient_penalty'  # Functions need names or Keras will throw an error

        if self.use_Wasserstein:
            self.discriminator_model = Model(
                inputs=[img_B, fake_img],  #, avg_img
                # loss_weights=[1,1,1], # useless, since we have multiply the penalization by GRADIENT_PENALTY_WEIGHT=10
                outputs=[real_img_rating, fake_img_rating, avg_img_output])
            self.discriminator_model.compile(
                loss=[wasserstein_loss, wasserstein_loss, partial_gp_loss],
                optimizer=optimizer,
                metrics=[my_critic_acc])
        else:
            self.discriminator.compile(loss='mse',
                                       optimizer=optimizer,
                                       metrics=['accuracy'])

        # For the combined model we will only train the generator and classifier
        self.discriminator.trainable = False

        # Build the generator
        self.generator = self.build_generator()
        self.generator.name = "Generator"
        # Build the task (classification) network
        self.clf = self.build_classifier()
        self.clf.name = "Classifier"
        # Input images from both domains

        # Input noise
        noise = Input(shape=self.noise_size, name='noise_input')

        # Translate images from domain A to domain B
        fake_B = self.generator([img_A, noise])

        # Classify the translated image
        class_pred = self.clf(fake_B)

        # Discriminator determines validity of translated images
        valid = self.discriminator(fake_B)  # fake_B_rating
        if self.use_Wasserstein:
            self.combined = Model(inputs=[img_A, noise],
                                  outputs=[valid, class_pred])
            self.combined.compile(
                optimizer=optimizer,
                loss=[wasserstein_loss, 'categorical_crossentropy'],
                loss_weights=[lambda_adv, lambda_clf],  # TODO NEW 
                metrics=['accuracy'])
        else:
            self.combined = Model([img_A, noise], [valid, class_pred])
            self.combined.compile(loss=['mse', 'categorical_crossentropy'],
                                  loss_weights=[lambda_adv, lambda_clf],
                                  optimizer=optimizer,
                                  metrics=['accuracy'])

    def load_dataset(self):
        # Configure MNIST and MNIST-M data loader
        self.data_loader = DataLoader(img_res=(self.img_rows, self.img_cols))

    def build_generator(self):
        """Resnet Generator"""
        def residual_block(layer_input):
            """Residual block described in paper"""
            d = Conv2D(64, kernel_size=3, strides=1,
                       padding='same')(layer_input)
            d = BatchNormalization(momentum=0.8)(d)  # TODO 6/5/2018
            d = Activation('relu')(d)
            d = Conv2D(64, kernel_size=3, strides=1, padding='same')(d)
            d = BatchNormalization(momentum=0.8)(d)  # TODO 6/5/2018
            d = Add()([d, layer_input])
            return d

        # Image input
        img = Input(shape=self.img_shape, name='image_input')

        ## Noise input
        noise = Input(shape=self.noise_size, name='noise_input')
        noise_layer = Dense(1024, activation="relu")(noise)
        noise_layer = Reshape((self.img_rows, self.img_cols, 1))(noise_layer)
        conditioned_img = keras.layers.concatenate([img, noise_layer])
        # keras.layers.concatenate

        # l1 = Conv2D(64, kernel_size=3, padding='same', activation='relu')(img)
        l1 = Conv2D(64, kernel_size=3, padding='same',
                    activation='relu')(conditioned_img)

        # Propogate signal through residual blocks
        r = residual_block(l1)
        for _ in range(self.residual_blocks - 1):
            r = residual_block(r)

        output_img = Conv2D(self.channels,
                            kernel_size=3,
                            padding='same',
                            activation='tanh')(r)

        return Model([img, noise], output_img)

    # def build_discriminator(self):

    # 	model = Sequential()
    # 	model.add(Input(shape=self.img_shape, name='image'))
    # 	def d_layer(model, filters, f_size=4, normalization=True):
    # 		"""Discriminator layer"""
    # 		model.add(Conv2D(filters, kernel_size=f_size, strides=2, padding='same'))
    # 		model.add(LeakyReLU(alpha=0.2))

    # 		if normalization:
    # 			model.add(InstanceNormalization())
    # 		return model

    # 	model = d_layer(model, self.df, normalization=False)
    # 	model = d_layer(model, self.df*2)
    # 	model = d_layer(model, self.df*4)
    # 	model = d_layer(model, self.df*8)

    # 	if self.use_PatchGAN: # NEW 7/5/2018
    # 		model.add(Conv2D(1, kernel_size=4, strides=1, padding='same'))
    # 	else:
    # 		if self.use_Wasserstein: # NEW 8/5/2018
    # 			model.add(Flatten())
    # 			model.add(Dense(1, kernel_initializer='he_normal')) # he_normal ?? TODO
    # 		else:
    # 			model.add(Flatten())
    # 			model.add(Dense(1, activation='sigmoid'))
    # 	return model
    def build_discriminator(self):
        def d_layer(layer_input, filters, f_size=4, normalization=True):
            """Discriminator layer"""
            d = Conv2D(filters, kernel_size=f_size, strides=2,
                       padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if normalization:
                d = InstanceNormalization()(d)
            return d

        img = Input(shape=self.img_shape, name="image")

        d1 = d_layer(img, self.df, normalization=False)
        d2 = d_layer(d1, self.df * 2, normalization=True)
        d3 = d_layer(d2, self.df * 4, normalization=True)
        d4 = d_layer(d3, self.df * 8, normalization=True)

        if self.use_PatchGAN:  # NEW 7/5/2018
            validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)
        else:
            if self.use_Wasserstein:  # NEW 8/5/2018
                validity = Dense(1, activation=None)(
                    Flatten()(d4))  # he_normal ?? TODO
            else:
                validity = Dense(1, activation='sigmoid')(Flatten()(d4))

        return Model(img, validity)

    def build_classifier(self):
        def clf_layer(layer_input, filters, f_size=4, normalization=True):
            """Classifier layer"""
            d = Conv2D(filters, kernel_size=f_size, strides=2,
                       padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if normalization:
                d = InstanceNormalization()(d)
            return d

        img = Input(shape=self.img_shape, name='image_input')

        c1 = clf_layer(img, self.cf, normalization=False)
        c2 = clf_layer(c1, self.cf * 2)
        c3 = clf_layer(c2, self.cf * 4)
        c4 = clf_layer(c3, self.cf * 8)
        c5 = clf_layer(c4, self.cf * 8)

        class_pred = Dense(self.num_classes,
                           activation='softmax')(Flatten()(c5))

        return Model(img, class_pred)

    def load_pretrained_weights(self,
                                weights_path="../Weights/all_weights.h5"):
        print("Loading pretrained weights from path: {} ...".format(
            weights_path))

        self.combined.load_weights(weights_path, by_name=True)
        print("+ Done.")

    def summary(self):
        print("=" * 50)
        print("Discriminator summary:")
        self.discriminator.summary()
        print("=" * 50)
        print("Generator summary:")
        self.generator.summary()
        print("=" * 50)
        print("Classifier summary:")
        self.clf.summary()

        if self.use_Wasserstein:
            print("=" * 50)
            print("Discriminator model summary:")
            self.discriminator_model.summary()
        print("=" * 50)
        print("Combined model summary:")
        self.combined.summary()

    def write_tensorboard_graph(self,
                                to_dir="../logs",
                                save_png2dir="../Model_graph"):
        if not os.path.exists(save_png2dir):
            os.makedirs(save_png2dir)
        tensorboard = keras.callbacks.TensorBoard(log_dir=to_dir,
                                                  histogram_freq=0,
                                                  write_graph=True,
                                                  write_images=False)
        # tensorboard.set_model(self.combined)
        tensorboard.set_model(self.discriminator_model)
        try:
            plot_model(self.combined,
                       to_file=os.path.join(save_png2dir,
                                            "Combined_model.png"))
            plot_model(self.discriminator_model,
                       to_file=os.path.join(save_png2dir,
                                            "Discriminator_model.png"))
        except:
            pass

    def train(self,
              epochs,
              batch_size=32,
              sample_interval=50,
              save_sample2dir="../samples/exp0",
              save_weights_path='../Weights/all_weights.h5',
              save_model=False):
        dirpath = "/".join(save_weights_path.split("/")[:-1])
        if not os.path.exists(dirpath):
            os.makedirs(dirpath)

        half_batch = batch_size  #int(batch_size / 2) ### TODO
        # half_batch = int(batch_size / 2)

        # Classification accuracy on 100 last batches of domain B
        test_accs = []

        ## Monitor to save model weights Lu
        best_test_cls_acc = 0
        second_best_cls_acc = -1
        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------
            # n_sample = half_batch # imgs_A.shape[0]

            for _ in range(self.critic_steps):

                imgs_A, _ = self.data_loader.load_data(domain="A",
                                                       batch_size=half_batch)
                imgs_B, _ = self.data_loader.load_data(domain="B",
                                                       batch_size=half_batch)

                noise_prior = np.random.normal(
                    0, 1, (half_batch, self.noise_size[0]))  # TODO
                # noise_prior = np.random.rand(half_batch, self.noise_size[0]) # TODO 6/5/2018

                # Translate images from domain A to domain B
                fake_B = self.generator.predict([imgs_A, noise_prior])
                if self.use_PatchGAN:
                    valid = np.ones((half_batch, ) + self.disc_patch)
                    fake = np.zeros((half_batch, ) + self.disc_patch)
                else:
                    if self.use_Wasserstein:
                        valid = np.ones((half_batch, 1))
                        fake = -valid  #np.ones((half_batch, 1)) # = - valid ? TODO
                        dummy_y = np.zeros((batch_size, 1))  # NEW
                    else:
                        valid = np.ones((half_batch, 1))
                        fake = np.zeros((half_batch, 1))
                # fake = -valid # TODO 6/5/2018 NEW

                # Train the discriminators (original images = real / translated = Fake)
                # d_loss_real = self.discriminator.train_on_batch(imgs_B, valid)
                # d_loss_fake = self.discriminator.train_on_batch(fake_B, fake)
                # d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
                if self.use_Wasserstein:
                    d_loss = self.discriminator_model.train_on_batch(
                        [imgs_B, fake_B], [valid, fake, dummy_y])
                    # d_loss = self.discriminator.train_on_batch(D_train_images, D_train_label, dummy_y)
                else:
                    D_train_images = np.vstack([imgs_B,
                                                fake_B])  # 6/5/2018 NEW
                    D_train_label = np.vstack([valid, fake])  # 6/5/2018 NEW
                    d_loss = self.discriminator.train_on_batch(
                        D_train_images, D_train_label)

            # --------------------------------
            #  Train Generator and Classifier
            # --------------------------------

            # Sample a batch of images from both domains
            imgs_A, labels_A = self.data_loader.load_data(
                domain="A", batch_size=batch_size)
            imgs_B, labels_B = self.data_loader.load_data(
                domain="B", batch_size=batch_size)

            # One-hot encoding of labels
            labels_A = to_categorical(labels_A, num_classes=self.num_classes)

            # The generators want the discriminators to label the translated images as real
            if self.use_PatchGAN:
                valid = np.ones((batch_size, ) + self.disc_patch)
            else:
                valid = np.ones((batch_size, 1))

            #
            noise_prior = np.random.normal(
                0, 1, (batch_size, self.noise_size[0]))  # TODO
            # noise_prior = np.random.rand(batch_size, self.noise_size[0]) # TODO 6/5/2018

            # Train the generator and classifier
            g_loss = self.combined.train_on_batch([imgs_A, noise_prior],
                                                  [valid, labels_A])

            #-----------------------
            # Evaluation (domain B)
            #-----------------------

            pred_B = self.clf.predict(imgs_B)
            test_acc = np.mean(np.argmax(pred_B, axis=1) == labels_B)

            # Add accuracy to list of last 100 accuracy measurements
            test_accs.append(test_acc)
            if len(test_accs) > 100:
                test_accs.pop(0)

            # Plot the progress
            # print ( "%d : [D - loss: %.5f, acc: %3d%%], [G - loss: %.5f], [clf - loss: %.5f, acc: %3d%%, test_acc: %3d%% (%3d%%)]" % \
            # 								(epoch, d_loss[0], 100*float(d_loss[1]),
            # 								g_loss[1], g_loss[2], 100*float(g_loss[-1]),
            # 								100*float(test_acc), 100*float(np.mean(test_accs))))

            if epoch % 10 == 0:
                with open(os.path.join(dirpath, "D_Losses.csv"),
                          "ab") as csv_file:
                    np.savetxt(csv_file,
                               np.array(d_loss).reshape(1, -1),
                               delimiter=",")
                with open(os.path.join(dirpath, "G_Losses.csv"),
                          "ab") as csv_file:
                    np.savetxt(csv_file,
                               np.array(g_loss).reshape(1, -1),
                               delimiter=",")

                if self.use_Wasserstein:
                    d_train_acc = 100 * (float(d_loss[4]) +
                                         float(d_loss[5])) / 2
                else:
                    d_train_acc = 100 * float(d_loss[1])

                gen_loss = g_loss[1]

                clf_train_acc = 100 * float(g_loss[-1])
                clf_train_loss = g_loss[2]

                current_test_acc = 100 * float(test_acc)
                test_mean_acc = 100 * float(np.mean(test_accs))

                if test_mean_acc > best_test_cls_acc:
                    second_best_cls_acc = best_test_cls_acc
                    best_test_cls_acc = test_mean_acc

                    if save_model:
                        self.combined.save(save_weights_path)
                    else:
                        self.combined.save_weights(save_weights_path)
                    print(
                        "{} : [D - loss: {:.5f}, acc: {:.2f}%], [G - loss: {:.5f}], [clf - loss: {:.5f}, acc: {:.2f}%, test_acc: {:.2f}% ({:.2f}%)] (latest)"
                        .format(epoch, d_loss[0], d_train_acc, gen_loss,
                                clf_train_loss, clf_train_acc,
                                current_test_acc, test_mean_acc))
                elif test_mean_acc > second_best_cls_acc:
                    second_best_cls_acc = test_mean_acc

                    if save_model:
                        self.combined.save(save_weights_path)
                    else:
                        self.combined.save_weights(save_weights_path[:-3] +
                                                   "_bis.h5")

                    print(
                        "{} : [D - loss: {:.5f}, acc: {:.2f}%], [G - loss: {:.5f}], [clf - loss: {:.5f}, acc: {:.2f}%, test_acc: {:.2f}% ({:.2f}%)] (before latest)"
                        .format(epoch, d_loss[0], d_train_acc, gen_loss,
                                clf_train_loss, clf_train_acc,
                                current_test_acc, test_mean_acc))

                else:

                    print(
                        "{} : [D - loss: {:.5f}, acc: {:.2f}%], [G - loss: {:.5f}], [clf - loss: {:.5f}, acc: {:.2f}%, test_acc: {:.2f}% ({:.2f}%)]"
                        .format(epoch, d_loss[0], d_train_acc, gen_loss,
                                clf_train_loss, clf_train_acc,
                                current_test_acc, test_mean_acc))

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.sample_images(epoch, save2dir=save_sample2dir)

    def sample_images(self, epoch, save2dir="../samples"):
        if not os.path.exists(save2dir):
            os.makedirs(save2dir)

        r, c = 5, 10

        imgs_A, _ = self.data_loader.load_data(domain="A", batch_size=c)

        n_sample = imgs_A.shape[0]

        gen_imgs = imgs_A
        for i in range(r - 1):
            noise_prior = np.random.normal(
                0, 1, (n_sample, self.noise_size[0]))  # TODO
            # noise_prior = np.random.rand(n_sample, self.noise_size[0]) # TODO 6/5/2018

            # Translate images to the other domain
            fake_B = self.generator.predict([imgs_A, noise_prior])
            gen_imgs = np.concatenate([gen_imgs, fake_B])

        # Rescale images to 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        #titles = ['Original', 'Translated']
        fig, axs = plt.subplots(r, c, figsize=(2 * c, 2 * r))

        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i, j].imshow(gen_imgs[cnt])
                #axs[i, j].set_title(titles[i])
                axs[i, j].axis('off')
                cnt += 1
        fig.savefig(os.path.join(save2dir, "{}.png".format(epoch)))
        plt.close()

    def deploy_transform(self,
                         save2file="../domain_adapted/generated.npy",
                         stop_after=None):
        dirpath = "/".join(save2file.split("/")[:-1])
        if not os.path.exists(dirpath):
            os.makedirs(dirpath)

        dirname = "/".join(save2file.split("/")[:-1])

        if stop_after is not None:
            predict_steps = int(stop_after / 32)
        else:
            predict_steps = stop_after

        noise_vec = np.random.normal(0, 1, self.noise_size[0])
        assert 1 == 2
        # np.random.rand(n_sample, self.noise_size[0]) # TODO 6/5/2018

        print("Performing Pixel-level domain adaptation on original images...")
        adaptaed_images = self.generator.predict(
            [
                self.data_loader.mnist_X[:32 * predict_steps],
                np.tile(noise_vec, (32 * predict_steps, 1))
            ],
            batch_size=32)  #, steps=predict_steps
        # self.data_loader.mnistm_X[:stop_after]
        print("+ Done.")
        print("Saving transformed images to file {}".format(save2file))
        np.save(save2file, adaptaed_images)

        noise_vec_filepath = os.path.join(dirname, "noise_vectors.npy")
        print(
            "Saving random noise (seed) to file {}".format(noise_vec_filepath))
        np.save(noise_vec_filepath, noise_vec)

        print("+ All done.")

    def deploy_debug(self,
                     save2file="../domain_adapted/debug.npy",
                     sample_size=100,
                     noise_number=128,
                     seed=17):
        dirpath = "/".join(save2file.split("/")[:-1])
        if not os.path.exists(dirpath):
            os.makedirs(dirpath)

        dirname = "/".join(save2file.split("/")[:-1])

        np.random.seed(seed=seed)

        # np.random.rand(n_sample, self.noise_size[0]) # TODO 6/5/2018

        print("Performing Pixel-level domain adaptation on original images...")
        # noise_vec = np.random.normal(0,1, (sample_size, self.noise_size[0]))
        # collections = []
        # for i in range(sample_size):
        # 	adaptaed_images = self.generator.predict([self.data_loader.mnist_X[:15], np.tile(noise_vec[i], (15,1))], batch_size=15)

        # 	collections.append(adaptaed_images)
        collections = []
        imgs_A, labels_A = self.data_loader.load_data(domain="A",
                                                      batch_size=sample_size)

        for i in range(sample_size):
            noise_vec = np.random.normal(0, 3,
                                         (noise_number, self.noise_size[0]))
            adaptaed_images = self.generator.predict(
                [np.tile(imgs_A[i], (noise_number, 1, 1, 1)), noise_vec],
                batch_size=32)
            collections.append(adaptaed_images)

        print("+ Done.")

        print("Saving transformed images to file {}".format(save2file))
        np.save(save2file, np.stack(collections))
        print("+ All done.")

    def deploy_classification(self, batch_size=32):
        print("Predicting ... ")
        pred_B = self.clf.predict(self.data_loader.mnistm_X,
                                  batch_size=batch_size)
        print("+ Done.")
        N_samples = len(pred_B)
        precision = (np.argmax(pred_B, axis=1) == self.data_loader.mnistm_y)
        Moy = np.mean(precision)
        Std = np.std(precision)

        lower_bound = Moy - 2.576 * Std / np.sqrt(N_samples)
        upper_bound = Moy + 2.576 * Std / np.sqrt(N_samples)
        print("=" * 50)
        print("Unsupervised MNIST-M classification accuracy : {}".format(Moy))
        print("Confidence interval (99%) [{}, {}]".format(
            lower_bound, upper_bound))
        print("Length of confidence interval 99%: {}".format(upper_bound -
                                                             lower_bound))
        print("=" * 50)
        print("+ All done.")
コード例 #13
0
ファイル: run.py プロジェクト: East-early/git-first
    get_res = np.array(pd.read_csv('data/toPredict_noLabel.csv'))
    cur_road_index = 0
    cur_road = get_res[cur_road_index, 1]
    row_test_data = np.array(pd.read_csv('data/toPredict_train_TTI.csv'))
    row_train_data = np.array(pd.read_csv('data/train_TTI.csv'))

    for i in range(12):
        print("Predicting {0}th road, Num: {1} ".format(i+1, cur_road))
        test = row_test_data[row_test_data[:, 0] == cur_road]
        test = add_time(test)

        train = row_train_data[row_train_data[:, 0] == cur_road]
        train = add_time(train)

        data = DataLoader(train, test)
        x, y = data.get_train_data(
            seq_len=configs['data']['sequence_length'],
            normalise=configs['data']['normalise']
        )
        x_test = data.get_test_data(
            seq_len=configs['data']['sequence_length'],
            normalise=configs['data']['normalise']
        )
        cur_road_index += 21
        cur_road = get_res[cur_road_index, 1]
        totalPrediction.append(run_lstm(data, x_test))
    totalPrediction = np.array(totalPrediction)
    print(totalPrediction.shape)
    l, r = np.hsplit(totalPrediction, [21])
    res = l.reshape(1, -1).squeeze().tolist()
コード例 #14
0
class PixelDA(_DLalgo):
	"""
	Paradigm of GAN (keras implementation)

	1. Construct D
		1a) Compile D
	2. Construct G
	3. Set D.trainable = False
	4. Stack G and D, to construct GAN (combined model)
		 4a) Compile GAN
	
	Approved by fchollet: "the process you describe is in fact correct."

	See issue #4674 keras: https://github.com/keras-team/keras/issues/4674
	"""
	def __init__(self, noise_size=(100,), 
		use_PatchGAN=False, 
		use_Wasserstein=True, 
		batch_size=64,
		**kwargs):
		# Input shape
		self.dataset_name = "MNIST" # "CT"

		if self.dataset_name == "MNIST":
			self.img_shape = (32, 32, 3)
			self.num_classes = 10
		elif self.dataset_name == "CT":
			self.img_shape = (64, 64, 1)
		else:
			assert 1==2, "Only support two datasets for now. ('CT', 'MNIST')"

		self.img_rows, self.img_cols, self.channels = self.img_shape

		self.noise_size = noise_size #(100,)
		self.batch_size = batch_size
		# Loss weights
		self.lambda_adv = 7#10 # 7
		self.lambda_clf = 1
		# Number of filters in first layer of discriminator and classifier
		self.df = 64 # NEW TODO #64 11/5/2018
		self.cf = 64

		self.normalize_G = False
		self.normalize_D = False
		self.normalize_C = True
		self.shift_label = False # TODO NEW 31/5/2018
		# Number of residual blocks in the generator
		self.residual_blocks = 17 # 6 # NEW TODO 14/5/2018
		self.use_PatchGAN = use_PatchGAN #False
		self.use_Wasserstein = use_Wasserstein
		if self.use_PatchGAN:
			# Calculate output shape of D (PatchGAN)
			patch = int(self.img_rows / 2**4)
			self.disc_patch = (patch, patch, 1)

		if self.use_Wasserstein:
			self.critic_steps = 5#5 #7 #10
		else:
			self.critic_steps = 1
		
		self.GRADIENT_PENALTY_WEIGHT = 10#10#5 #10 As the paper


		##### Set up the other attributes
		for key in kwargs:
			setattr(self, key, kwargs[key])

	def freeze_layers_kernel(self, model):
		valid_name = ["dense", "conv2d"]
		layers_names = list(map(lambda layer:layer.name, model.layers))
		num_layers = len(layers_names)
		valid_index = list(map(lambda layer_name:(layer_name.split('_')[0] in valid_name), layers_names))
		valid_index = np.arange(num_layers)[valid_index]

		for i in valid_index:
			## Remove 'kernel' from trainable weights before compile model !
			model.layers[i].trainable_weights = model.layers[i].trainable_weights[1:]

	def build_all_model(self):

		# optimizer = Adam(0.0002, 0.5)
		optimizer = Adam(0.0001, beta_1=0.5, beta_2=0.9)
		# optimizer = SGD(lr=0.0001)
		# optimizer = RMSprop(lr=1e-5)

		

		# Build and compile the discriminators
		self.discriminator = self.build_discriminator()
		self.discriminator.name = "Discriminator"


		img_A = Input(shape=self.img_shape, name='source_image') # real A
		img_B = Input(shape=self.img_shape, name='target_image') # real B
		fake_img = Input(shape=self.img_shape, name="fake_image") # fake B

		# We also need to generate weighted-averages of real and generated samples, to use for the gradient norm penalty.
		avg_img = RandomWeightedAverage()([img_B, fake_img])
		

		real_img_rating = self.discriminator(img_B) 
		fake_img_rating = self.discriminator(fake_img)
		avg_img_output = self.discriminator(avg_img)

		# The gradient penalty loss function requires the input averaged samples to get gradients. However,
		# Keras loss functions can only have two arguments, y_true and y_pred. We get around this by making a partial()
		# of the function with the averaged samples here.
		partial_gp_loss = partial(gradient_penalty_loss,
					  averaged_samples=avg_img,
					  gradient_penalty_weight=self.GRADIENT_PENALTY_WEIGHT)
		partial_gp_loss.__name__ = 'gradient_penalty'  # Functions need names or Keras will throw an error

		if self.use_Wasserstein:

			### One time experimence: Freeze all Discriminator's 'kernel'
			# self.freeze_layers_kernel(self.discriminator) # TODO

			self.discriminator_model = Model(inputs=[img_B, fake_img],  #, avg_img
											# loss_weights=[1,1,1], # useless, since we have multiply the penalization by GRADIENT_PENALTY_WEIGHT=10
											outputs=[real_img_rating, fake_img_rating, avg_img_output])
			self.discriminator_model.compile(loss=[wasserstein_loss, wasserstein_loss, partial_gp_loss],
				optimizer=optimizer,
				metrics=[my_critic_acc])
			# print(len(self.discriminator_model._collected_trainable_weights))
		else:
			self.discriminator.compile(loss='mse',
				optimizer=optimizer,
				metrics=['accuracy'])



		# For the combined model we will only train the generator and classifier
		self.discriminator.trainable = False

		# Build the generator
		self.generator = self.build_generator()
		self.generator.name = "Generator"
		# Build the task (classification) network
		self.clf = self.build_classifier()
		self.clf.name = "Classifier" 
		# Input images from both domains

		
		# Input noise
		noise = Input(shape=self.noise_size, name='noise_input')

		# Translate images from domain A to domain B
		fake_B = self.generator([img_A, noise])

		# Classify the translated image
		class_pred = self.clf(fake_B)

		
		# Discriminator determines validity of translated images
		valid = self.discriminator(fake_B) # fake_B_rating
		if self.use_Wasserstein:
			self.combined = Model(inputs=[img_A, noise], outputs=[valid, class_pred])
			self.combined.compile(optimizer=optimizer, 
									loss=[wasserstein_loss, 'categorical_crossentropy'],
									loss_weights=[self.lambda_adv, self.lambda_clf], 
									metrics=['accuracy'])
		else:
			self.combined = Model([img_A, noise], [valid, class_pred])
			self.combined.compile(loss=['mse', 'categorical_crossentropy'],
										loss_weights=[self.lambda_adv, self.lambda_clf],
										optimizer=optimizer,
										metrics=['accuracy'])



	def load_dataset(self):
		# Configure MNIST and MNIST-M data loader
		self.data_loader = DataLoader(img_res=(self.img_rows, self.img_cols))

	def build_generator(self):
		"""Resnet Generator"""

		def residual_block(layer_input, normalization=self.normalize_G):
			"""Residual block described in paper"""
			d = Conv2D(64, kernel_size=3, strides=1, padding='same')(layer_input)
			if normalization:
				d = InstanceNormalization()(d)
				# d = BatchNormalization(momentum=0.8)(d) # TODO 6/5/2018
			d = Activation('relu')(d)
			d = Conv2D(64, kernel_size=3, strides=1, padding='same')(d)
			if normalization:
				d = InstanceNormalization()(d)
				# d = BatchNormalization(momentum=0.8)(d) # TODO 6/5/2018
			d = Add()([d, layer_input])
			return d

		# Image input
		img = Input(shape=self.img_shape, name='image_input')

		## Noise input
		noise = Input(shape=self.noise_size, name='noise_input')
		noise_layer = Dense(1024, activation="relu")(noise)
		noise_layer = Reshape((self.img_rows,self.img_cols, 1))(noise_layer)
		conditioned_img = keras.layers.concatenate([img, noise_layer])
		# keras.layers.concatenate

		# l1 = Conv2D(64, kernel_size=3, padding='same', activation='relu')(img)
		l1 = Conv2D(64, kernel_size=3, padding='same', activation='relu')(conditioned_img)
		

		# Propogate signal through residual blocks
		r = residual_block(l1)
		for _ in range(self.residual_blocks - 1):
			r = residual_block(r)

		output_img = Conv2D(self.channels, kernel_size=3, padding='same', activation='tanh')(r)

		return Model([img, noise], output_img)


	def build_discriminator(self):

		def d_layer(layer_input, filters, f_size=4, normalization=self.normalize_D):
			"""Discriminator layer"""
			d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
			d = LeakyReLU(alpha=0.2)(d)
			if normalization:
				d = InstanceNormalization()(d)
			return d

		img = Input(shape=self.img_shape, name="image")

		d1 = d_layer(img, self.df, normalization=False)
		d2 = d_layer(d1, self.df*2, normalization=self.normalize_D)
		d3 = d_layer(d2, self.df*4, normalization=self.normalize_D)
		d4 = d_layer(d3, self.df*8, normalization=self.normalize_D)

		if self.use_PatchGAN: # NEW 7/5/2018
			validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)
		else:
			if self.use_Wasserstein: # NEW 8/5/2018
				validity = Dense(1, activation=None)(Flatten()(d4)) # he_normal ?? TODO
			else:
				validity = Dense(1, activation='sigmoid')(Flatten()(d4))
			

		return Model(img, validity)

	def build_classifier(self):

		def clf_layer(layer_input, filters, f_size=4, normalization=self.normalize_C):
			"""Classifier layer"""
			d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
			d = LeakyReLU(alpha=0.2)(d)
			if normalization:
				d = InstanceNormalization()(d)
			return d

		img = Input(shape=self.img_shape, name='image_input')

		c1 = clf_layer(img, self.cf, normalization=False)
		c2 = clf_layer(c1, self.cf*2)
		c3 = clf_layer(c2, self.cf*4)
		c4 = clf_layer(c3, self.cf*8)
		c5 = clf_layer(c4, self.cf*8)

		class_pred = Dense(self.num_classes, activation='softmax')(Flatten()(c5))

		return Model(img, class_pred)
	def load_pretrained_weights(self, weights_path="../Weights/all_weights.h5", only_cls=False):
		print("Loading pretrained weights from path: {} ...".format(weights_path))
		if only_cls:
			# self.clf.load_weights(weights_path, by_name=True) # Don't work !
			# See: https://github.com/keras-team/keras/issues/5348
			# Solution:
			# Build a model, load (all) weights, save sub model weigths as np.array, kill(clear session) model
			# than finally, build new model, set sub model weights from pre-saved np.array !
			self.combined.load_weights(weights_path, by_name=True)
			clf_weights = self.clf.get_weights()
			K.clear_session()
			self.build_all_model()
			self.clf.set_weights(clf_weights)

		else:
			self.combined.load_weights(weights_path, by_name=True)
		print("+ Done.")
	def summary(self):
		print("="*50)
		print("Discriminator summary:")
		self.discriminator.summary()
		print("="*50)
		print("Generator summary:")
		self.generator.summary()
		print("="*50)
		print("Classifier summary:")
		self.clf.summary()
		
		if self.use_Wasserstein:
			print("="*50)
			print("Discriminator model summary:")
			self.discriminator_model.summary()
		print("="*50)
		print("Combined model summary:")
		self.combined.summary()

	def write_tensorboard_graph(self, to_dir="../logs", save_png2dir="../Model_graph"):
		if not os.path.exists(save_png2dir):
			os.makedirs(save_png2dir)
		tensorboard = keras.callbacks.TensorBoard(log_dir=to_dir, histogram_freq=0, write_graph=True, write_images=False)
		# tensorboard.set_model(self.combined)
		tensorboard.set_model(self.discriminator_model)
		try:
			plot_model(self.combined, to_file=os.path.join(save_png2dir, "Combined_model.png"))
			plot_model(self.discriminator_model, to_file=os.path.join(save_png2dir, "Discriminator_model.png"))
		except:
			pass
		
	def train(self, epochs, sample_interval=50, save_sample2dir="../samples/exp0", save_weights_path='../Weights/all_weights.h5', save_model=False):
		dirpath = "/".join(save_weights_path.split("/")[:-1])
		if not os.path.exists(dirpath):
			os.makedirs(dirpath)
		self.save_config(save2path=os.path.join(dirpath, "config.dill"), verbose=True)

		#half_batch = batch_size #int(batch_size / 2) ### TODO
		# half_batch = int(batch_size / 2)
		
		# Classification accuracy on 100 last batches of domain B
		test_accs = []


		## Monitor to save model weights Lu
		best_test_cls_acc = 0
		second_best_cls_acc = -1
		for epoch in range(epochs):

			# ---------------------
			#  Train Discriminator
			# ---------------------
			# n_sample = half_batch # imgs_A.shape[0]
			
			for _ in range(self.critic_steps):

				imgs_A, _ = self.data_loader.load_data(domain="A", batch_size=self.batch_size)
				imgs_B, _ = self.data_loader.load_data(domain="B", batch_size=self.batch_size)
				
				
				noise_prior = np.random.normal(0,1, (self.batch_size, self.noise_size[0])) 
				# noise_prior = np.random.rand(half_batch, self.noise_size[0]) # TODO 6/5/2018
				
				# Translate images from domain A to domain B
				fake_B = self.generator.predict([imgs_A, noise_prior])
				if self.use_PatchGAN:
					valid = np.ones((self.batch_size,) + self.disc_patch)
					fake = np.zeros((self.batch_size,) + self.disc_patch)
				else:
					if self.use_Wasserstein:
						valid = np.ones((self.batch_size, 1))
						fake = - valid #np.ones((half_batch, 1)) # = - valid ? TODO
						dummy_y = np.zeros((self.batch_size, 1)) # NEW
					else:
						valid = np.ones((self.batch_size, 1))
						fake = np.zeros((self.batch_size, 1))
				
				
				

				# Train the discriminators (original images = real / translated = Fake)
				# d_loss_real = self.discriminator.train_on_batch(imgs_B, valid)
				# d_loss_fake = self.discriminator.train_on_batch(fake_B, fake)
				# d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
				if self.use_Wasserstein:
					d_loss = self.discriminator_model.train_on_batch([imgs_B, fake_B], [valid, fake, dummy_y])
					# d_loss = self.discriminator.train_on_batch(D_train_images, D_train_label, dummy_y)
				else:
					D_train_images = np.vstack([imgs_B, fake_B]) # 6/5/2018 NEW
					D_train_label = np.vstack([valid, fake]) # 6/5/2018 NEW
					d_loss = self.discriminator.train_on_batch(D_train_images, D_train_label)


			# --------------------------------
			#  Train Generator and Classifier
			# --------------------------------

			# Sample a batch of images from both domains
			imgs_A, labels_A = self.data_loader.load_data(domain="A", batch_size=self.batch_size)
			imgs_B, labels_B = self.data_loader.load_data(domain="B", batch_size=self.batch_size)
			if self.shift_label:
				labels_A = np.mod(labels_A+1, 10) #  == (labels_A+1) % 10

			# One-hot encoding of labels
			labels_A = to_categorical(labels_A, num_classes=self.num_classes)

			# The generators want the discriminators to label the translated images as real
			if self.use_PatchGAN:
				valid = np.ones((self.batch_size,) + self.disc_patch)
			else:
				valid = np.ones((self.batch_size, 1))

			#
			noise_prior = np.random.normal(0,1, (self.batch_size, self.noise_size[0])) 
			# noise_prior = np.random.rand(batch_size, self.noise_size[0]) # TODO 6/5/2018

			# Train the generator and classifier
			g_loss = self.combined.train_on_batch([imgs_A, noise_prior], [valid, labels_A])

			#-----------------------
			# Evaluation (domain B)
			#-----------------------

			pred_B = self.clf.predict(imgs_B)
			test_acc = np.mean(np.argmax(pred_B, axis=1) == labels_B)

			# Add accuracy to list of last 100 accuracy measurements
			test_accs.append(test_acc)
			if len(test_accs) > 100:
				test_accs.pop(0)


			# Plot the progress
			# print ( "%d : [D - loss: %.5f, acc: %3d%%], [G - loss: %.5f], [clf - loss: %.5f, acc: %3d%%, test_acc: %3d%% (%3d%%)]" % \
			# 								(epoch, d_loss[0], 100*float(d_loss[1]),
			# 								g_loss[1], g_loss[2], 100*float(g_loss[-1]),
			# 								100*float(test_acc), 100*float(np.mean(test_accs))))
			
			if epoch % 10 == 0:
				

				if self.use_Wasserstein:
					d_real_acc = 100*float(d_loss[4])
					d_fake_acc = 100*float(d_loss[5])
					d_train_acc = 100*(float(d_loss[4])+float(d_loss[5]))/2
				else:
					d_train_acc = 100*float(d_loss[1])
				
				gen_loss = g_loss[1]

				clf_train_acc = 100*float(g_loss[-1])
				clf_train_loss = g_loss[2]

				current_test_acc = 100*float(test_acc)
				test_mean_acc = 100*float(np.mean(test_accs))

				
				g_loss.append(current_test_acc)
				g_loss.append(test_mean_acc)

				with open(os.path.join(dirpath, "D_Losses.csv"), "ab") as csv_file:
					np.savetxt(csv_file, np.array(d_loss).reshape(1,-1), delimiter=",")
				with open(os.path.join(dirpath, "G_Losses.csv"), "ab") as csv_file:
					np.savetxt(csv_file, np.array(g_loss).reshape(1,-1), delimiter=",")

				message = "{} : [D - loss: {:.5f}, GP_loss: {:.5f}, (+) acc: {:.2f}%, (-) acc: {:.2f}%, acc: {:.2f}%], [G - loss: {:.5f}], [clf - loss: {:.5f}, acc: {:.2f}%, test_dice: {:.2f}% ({:.2f}%)]".format(epoch, d_loss[0], d_loss[3], d_real_acc, d_fake_acc, d_train_acc, gen_loss, clf_train_loss, clf_train_acc, current_test_acc, test_mean_acc)

				if test_mean_acc > best_test_cls_acc:
					second_best_cls_acc = best_test_cls_acc
					best_test_cls_acc = test_mean_acc
					
					if save_model:
						self.combined.save(save_weights_path)
					else:
						self.combined.save_weights(save_weights_path)
					message += "  (best)"
					 
				elif test_mean_acc > second_best_cls_acc:
					second_best_cls_acc = test_mean_acc
					
					if save_model:
						self.combined.save(save_weights_path)
					else:
						self.combined.save_weights(save_weights_path[:-3]+"_bis.h5")
					message += "  (second best)"

				else:
					pass
				print(message)


			# If at save interval => save generated image samples
			if epoch % sample_interval == 0:
				self.sample_images(epoch, save2dir=save_sample2dir)
			
				

	def sample_images(self, epoch, save2dir="../samples"):
		if not os.path.exists(save2dir):
			os.makedirs(save2dir)

		r, c = 5, 10

		imgs_A, _ = self.data_loader.load_data(domain="A", batch_size=c)

		n_sample = imgs_A.shape[0]

		gen_imgs = imgs_A
		for i in range(r-1):
			noise_prior = np.random.normal(0,1, (n_sample, self.noise_size[0])) # TODO
			# noise_prior = np.random.normal(0,3, (n_sample, self.noise_size[0])) # TODO # 16/5/2018
			# noise_prior = np.random.rand(n_sample, self.noise_size[0]) # TODO 6/5/2018

			# Translate images to the other domain
			fake_B = self.generator.predict([imgs_A, noise_prior])
			gen_imgs = np.concatenate([gen_imgs, fake_B])

		# Rescale images to 0 - 1
		gen_imgs = 0.5 * gen_imgs + 0.5

		#titles = ['Original', 'Translated']
		fig, axs = plt.subplots(r, c, figsize=(2*c, 2*r))

		cnt = 0
		for i in range(r):
			for j in range(c):
				axs[i,j].imshow(gen_imgs[cnt])
				#axs[i, j].set_title(titles[i])
				axs[i,j].axis('off')
				cnt += 1
		fig.savefig(os.path.join(save2dir, "{}.png".format(epoch)))
		plt.close()

	def train_segmenter(self, iterations, batch_size=32, noise_range=5, save_weights_path=None):
		if save_weights_path is not None:
			dirpath = "/".join(save_weights_path.split("/")[:-1])
			if not os.path.exists(dirpath):
				os.makedirs(dirpath)
		optimizer = Adam(0.000001, beta_1=0.0, beta_2=0.9)
		
		# Input noise
		noise = Input(shape=self.noise_size, name='noise_input_seg')
		img_A = Input(shape=self.img_shape, name='source_image_seg')
		# Translate images from domain A to domain B
		fake_B = self.generator([img_A, noise])

		# Segment the translated image
		mask_pred = self.seg(fake_B)

		self.generator.trainable = False

		self.segmentation_model = Model(inputs=[img_A, noise], outputs=[mask_pred])
		self.segmentation_model.compile(loss=dice_coef_loss, optimizer=optimizer, metrics=["acc", dice_coef])

		self.segmentation_model.name = "U-net (freeze Generator)"
		self.segmentation_model.summary()

		best_test_dice = 0.0
		second_best_test_dice = -1.0
		collections = []
		for e in range(iterations):
			noise = (2*np.random.random((batch_size, self.noise_size[0]))-1)*noise_range
			
			images_A, _, mask_A = self.data_loader.load_data(domain="A", batch_size=batch_size, return_mask=True)

			s_loss = self.segmentation_model.train_on_batch([images_A, noise], mask_A)


			if e%100 == 0:
				images_B, _, mask_B = self.data_loader.load_data(domain="B", batch_size=batch_size, return_mask=True)

				pred_mask_B = self.seg.predict(images_B)
				_, current_test_dice = dice_predict(mask_B, pred_mask_B)
				
				if len(collections)>=100:
					collections.pop(0)
				collections.append(current_test_dice)
				mean_dice = np.mean(collections)
				message = "{} dice loss: {:.3f}; acc: {:.5f}; mean dice (test): {:.3f}".format(e, 100*s_loss[0], s_loss[1], 100*mean_dice)

				if mean_dice>best_test_dice:
					best_test_dice = mean_dice
					message += "  (best)"
					if save_weights_path is not None:
						self.segmentation_model.save_weights(save_weights_path)

				elif mean_dice> second_best_test_dice:
					second_best_test_dice = mean_dice
					message += "  (second best)"
					if save_weights_path is not None:
						self.segmentation_model.save_weights(save_weights_path[:-3]+"_bis.h5")
				else:
					pass
				print(message)
				



		return


	def deploy_transform(self, save2file="../domain_adapted/generated.npy", noise_range=5.0, separate_cls=True, stop_after=None, seed=None):
		dirpath = "/".join(save2file.split("/")[:-1])
		if not os.path.exists(dirpath):
			os.makedirs(dirpath)

		dirname = "/".join(save2file.split("/")[:-1])
		
		# noise_vec = np.random.normal(0,1, self.noise_size[0])
		
		# np.random.rand(n_sample, self.noise_size[0]) # TODO 6/5/2018
		print("Performing Pixel-level domain adaptation on original images...")
		if separate_cls:
			for mnist_cls in tqdm(range(10)):

				imgs_A = self.data_loader.mnist_X[self.data_loader.mnist_y==mnist_cls]
				if stop_after is not None:
					num_samples = int(stop_after)
				else:
					num_samples = len(imgs_A)
				np.random.seed(seed)
				noise_vec = (2*np.random.random((num_samples, self.noise_size[0]))-1)*noise_range
				adaptaed_images = self.generator.predict([imgs_A[:num_samples], noise_vec], batch_size=32)
				filename_by_cls = save2file[:-4]+"_{}.npy".format(mnist_cls)
				np.save(filename_by_cls, adaptaed_images)
		else:

			imgs_A = self.data_loader.mnist_X
			if stop_after is not None:
					num_samples = int(stop_after)
			else:
				num_samples = len(imgs_A)

			np.random.seed(seed)
			noise_vec = (2*np.random.random((num_samples, self.noise_size[0]))-1)*noise_range
			adaptaed_images = self.generator.predict([imgs_A[:num_samples], noise_vec], batch_size=32) 
			print("Saving transformed images to file {}".format(save2file))
			np.save(save2file, adaptaed_images)
		
		# np.save(save2file, adaptaed_images)
		print("+ All done.")

	def deploy_debug(self, save2file="../domain_adapted/debug.npy", sample_size=100, noise_number=128, 
		use_sobol=False, 
		use_linear=False, 
		use_sphere=False, 
		use_uniform_linear=False, 
		use_zeros=False,
		seed = 17):
		dirpath = "/".join(save2file.split("/")[:-1])
		if not os.path.exists(dirpath):
			os.makedirs(dirpath)

		dirname = "/".join(save2file.split("/")[:-1])

		np.random.seed(seed=seed)
		
		# np.random.rand(n_sample, self.noise_size[0]) # TODO 6/5/2018

		print("Performing Pixel-level domain adaptation on original images...")
		# noise_vec = np.random.normal(0,1, (sample_size, self.noise_size[0]))
		# collections = []
		# for i in range(sample_size):
		# 	adaptaed_images = self.generator.predict([self.data_loader.mnist_X[:15], np.tile(noise_vec[i], (15,1))], batch_size=15)

		# 	collections.append(adaptaed_images)
		collections = []
		# imgs_A, _ = self.data_loader.load_data(domain="A", batch_size=sample_size)
		imgs_A = np.load("/home/lulin/na4/src/output/output13_32x32/test/liver_masks.npy")[100:100+sample_size]
		imgs_A = np.repeat(imgs_A, 3, axis=-1)

		for i in tqdm(range(sample_size)):
			if use_sobol:
				noise_vec = 5*(2*i4_sobol_generate(self.noise_size[0], noise_number, i*noise_number).T-1)
			elif use_linear:
				tangents = 3.0*(2*np.random.random((noise_number, 1))-1)
				noise_vec = np.ones((noise_number, self.noise_size[0]))*tangents
			elif use_sphere:
				noise_vec = 2*(np.random.random((noise_number, self.noise_size[0]))-1)
				norm_vec = np.linalg.norm(noise_vec, axis=-1)
				noise_vec = noise_vec/ norm_vec[:, np.newaxis]
			elif use_uniform_linear:
				tangents = 10.0*np.linspace(-1,1,noise_number)[:, np.newaxis]
				noise_vec = np.ones((noise_number, self.noise_size[0]))*tangents
			elif use_zeros:

				noise_vec = np.zeros((noise_number, self.noise_size[0]))
				
			else:
				noise_vec = np.random.normal(0,3, (noise_number, self.noise_size[0]))
			adaptaed_images = self.generator.predict([np.tile(imgs_A[i], (noise_number,1,1,1)), noise_vec], batch_size=32)
			collections.append(adaptaed_images)
		
		print("+ Done.")

		print("Saving transformed images to file {}".format(save2file))
		np.save(save2file, np.stack(collections))
		print("+ All done.")
	
	def deploy_classification(self, batch_size=32):
		print("Predicting ... ")
		pred_B = self.clf.predict(self.data_loader.mnistm_X, batch_size=batch_size)
		print("+ Done.")
		N_samples = len(pred_B)
		precision = (np.argmax(pred_B, axis=1) == self.data_loader.mnistm_y)
		Moy = np.mean(precision)
		Std = np.std(precision)

		lower_bound = Moy - 2.576*Std/np.sqrt(N_samples) 
		upper_bound = Moy + 2.576*Std/np.sqrt(N_samples)
		print("="*50)
		print("Unsupervised MNIST-M classification accuracy : {}".format(Moy))
		print("Confidence interval (99%) [{}, {}]".format(lower_bound, upper_bound))
		print("Length of confidence interval 99%: {}".format(upper_bound-lower_bound))
		print("="*50)
		print("+ All done.")

	def deploy_demo_only(self, save2file="../domain_adapted/WGAN_GP/Exp4/demo.npy", sample_size=25, noise_number=512, linspace_size=10.0):
		collections = []
		imgs_A, labels_A = self.data_loader.load_data(domain="A", batch_size=sample_size)


		tangents = linspace_size*np.linspace(-1,1,noise_number)[:, np.newaxis]
		noise_vec = np.ones((noise_number, self.noise_size[0]))*tangents

		for i in tqdm(range(noise_number)):
			adaptaed_images = self.generator.predict([imgs_A, np.tile(noise_vec[i],(sample_size, 1))], batch_size=sample_size)
			collections.append(adaptaed_images)
		print("+ Done.")

		print("Saving transformed images to file {}".format(save2file))
		np.save(save2file, np.stack(collections))
		print("+ All done.")

	def deploy_cherry_pick(self, save2file="../domain_adapted/WGAN_GP/Exp4/demo_cherry_picked.png", sample_size=25, noise_number=25, linspace_size=5.0):
		collections = []
		imgs_A, labels_A = self.data_loader.load_data(domain="A", batch_size=sample_size)
		assert noise_number == sample_size

		tangents = linspace_size*np.linspace(-1,1,noise_number)[:, np.newaxis]
		noise_vec = np.ones((noise_number, self.noise_size[0]))*tangents
		
		np.random.shuffle(noise_vec) # shuffle background color !

		
		adaptaed_images = self.generator.predict([imgs_A, noise_vec], batch_size=sample_size)
		adaptaed_images = (adaptaed_images+1)/2
		print("+ Done.")

		print("Saving transformed images to file {}".format(save2file))
		r = 5
		c = 5	
		fig, axs = plt.subplots(r, c, figsize=(5*c, 5*r))
		for j in range(r):
			for i in range(c):
				axs[j,i].imshow(adaptaed_images[c*j+i])
				axs[j,i].axis('off')
		plt.savefig(save2file)
		plt.close()
		print("+ All done.")
コード例 #15
0
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """

    from networks.TextEncoder import Encoder
    from networks.ConditionAugmentation import ConditionAugmentor
    from networks.C_PRO_GAN import ProGAN

    print(args.config)
    config = get_config(args.config)
    print("Current Configuration:", config)

    # create the dataset for training
    dataset = dl.Face2TextDataset(pro_pick_file=config.processed_text_file,
                                  img_dir=config.images_dir,
                                  img_transform=dl.get_transform(
                                      config.img_dims),
                                  captions_len=config.captions_length)

    # create the networks
    text_encoder = Encoder(embedding_size=config.embedding_size,
                           vocab_size=dataset.vocab_size,
                           hidden_size=config.hidden_size,
                           num_layers=config.num_layers,
                           device=device)

    if args.encoder_file is not None:
        print("Loading encoder from:", args.encoder_file)
        text_encoder.load_state_dict(th.load(args.encoder_file))

    condition_augmenter = ConditionAugmentor(input_size=config.hidden_size,
                                             latent_size=config.ca_out_size,
                                             device=device)

    if args.ca_file is not None:
        print("Loading conditioning augmenter from:", args.ca_file)
        condition_augmenter.load_state_dict(th.load(args.ca_file))

    c_pro_gan = ProGAN(embedding_size=config.hidden_size,
                       depth=config.depth,
                       latent_size=config.latent_size,
                       learning_rate=config.learning_rate,
                       beta_1=config.beta_1,
                       beta_2=config.beta_2,
                       eps=config.eps,
                       drift=config.drift,
                       n_critic=config.n_critic,
                       device=device)

    if args.generator_file is not None:
        print("Loading generator from:", args.generator_file)
        c_pro_gan.gen.load_state_dict(th.load(args.generator_file))

    if args.discriminator_file is not None:
        print("Loading discriminator from:", args.discriminator_file)
        c_pro_gan.dis.load_state_dict(th.load(args.discriminator_file))

    # create the optimizers for Encoder and Condition Augmenter separately
    encoder_optim = th.optim.Adam(text_encoder.parameters(),
                                  lr=config.learning_rate,
                                  betas=(config.beta_1, config.beta_2),
                                  eps=config.eps)

    ca_optim = th.optim.Adam(condition_augmenter.parameters(),
                             lr=config.learning_rate,
                             betas=(config.beta_1, config.beta_2),
                             eps=config.eps)

    # train all the networks
    train_networks(
        encoder=text_encoder,
        ca=condition_augmenter,
        c_pro_gan=c_pro_gan,
        dataset=dataset,
        encoder_optim=encoder_optim,
        ca_optim=ca_optim,
        epochs=config.epochs,
        fade_in_percentage=config.fade_in_percentage,
        start_depth=args.start_depth,
        batch_sizes=config.batch_sizes,
        num_workers=config.num_workers,
        feedback_factor=config.feedback_factor,
        log_dir=config.log_dir,
        sample_dir=config.sample_dir,
        checkpoint_factor=config.checkpoint_factor,
        save_dir=config.save_dir,
        use_matching_aware_dis=config.use_matching_aware_discriminator)
コード例 #16
0
def homepage_result():
    caption = request.form["des"]
    current_depth = 4

    from networks.TextEncoder import Encoder
    from networks.ConditionAugmentation import ConditionAugmentor
    from networks.C_PRO_GAN import ProGAN

    # define the device for the training script
    device = th.device("cuda" if th.cuda.is_available() else "cpu")

    ############################################################################
    # load my generator.

    def get_config(conf_file):
        """
        parse and load the provided configuration
        :param conf_file: configuration file
        :return: conf => parsed configuration
        """
        from easydict import EasyDict as edict

        with open(conf_file, "r") as file_descriptor:
            data = yaml.load(file_descriptor)

        # convert the data into an easyDictionary
        return edict(data)

    config = get_config("configs\\11.conf")

    c_pro_gan = ProGAN(embedding_size=config.hidden_size,
                       depth=config.depth,
                       latent_size=config.latent_size,
                       learning_rate=config.learning_rate,
                       beta_1=config.beta_1,
                       beta_2=config.beta_2,
                       eps=config.eps,
                       drift=config.drift,
                       n_critic=config.n_critic,
                       device=device)

    c_pro_gan.gen.load_state_dict(
        th.load("training_runs\\11\\saved_models\\GAN_GEN_3_20.pth"))

    ###################################################################################
    # load my embedding and conditional augmentor

    dataset = dl.Face2TextDataset(pro_pick_file=config.processed_text_file,
                                  img_dir=config.images_dir,
                                  img_transform=dl.get_transform(
                                      config.img_dims),
                                  captions_len=config.captions_length)

    text_encoder = Encoder(embedding_size=config.embedding_size,
                           vocab_size=dataset.vocab_size,
                           hidden_size=config.hidden_size,
                           num_layers=config.num_layers,
                           device=device)
    text_encoder.load_state_dict(
        th.load("training_runs\\11\\saved_models\\Encoder_3_20.pth"))

    condition_augmenter = ConditionAugmentor(input_size=config.hidden_size,
                                             latent_size=config.ca_out_size,
                                             device=device)
    condition_augmenter.load_state_dict(
        th.load(
            "training_runs\\11\\saved_models\\Condition_Augmentor_3_20.pth"))

    ###################################################################################
    # #ask for text description/caption

    # caption to text encoding
    #caption = input('Enter your desired description : ')
    seq = []
    for word in caption.split():
        seq.append(dataset.rev_vocab[word])
    for i in range(len(seq), 100):
        seq.append(0)

    seq = th.LongTensor(seq)
    seq = seq.cuda()
    print(type(seq))
    print('\nInput : ', caption)

    list_seq = [seq for i in range(16)]
    print(len(list_seq))
    list_seq = th.stack(list_seq)
    list_seq = list_seq.cuda()

    embeddings = text_encoder(list_seq)

    c_not_hats, mus, sigmas = condition_augmenter(embeddings)

    z = th.randn(list_seq.shape[0],
                 c_pro_gan.latent_size - c_not_hats.shape[-1]).to(device)

    gan_input = th.cat((c_not_hats, z), dim=-1)

    alpha = 0.007352941176470588

    samples = c_pro_gan.gen(gan_input, current_depth, alpha)

    from torchvision.utils import save_image
    from torch.nn.functional import upsample
    # from train_network import create_grid

    img_file = "static\\" + caption + '.png'
    samples = (samples / 2) + 0.5
    if int(np.power(2, c_pro_gan.depth - current_depth - 1)) > 1:
        samples = upsample(samples, scale_factor=current_depth)

    # save image to the disk, the resulting image is <caption>.png
    save_image(samples, img_file, nrow=int(np.sqrt(20)))

    ###################################################################################
    # #output the image.

    result = "\\static\\" + caption + ".png"
    return render_template("main.html",
                           result_img=result,
                           result_caption=caption)
コード例 #17
0
from __future__ import print_function, division
import scipy

import datetime
import matplotlib.pyplot as plt
import sys
from data_processing import DataLoader
import numpy as np
import os

# Configure MNIST and MNIST-M data loader
data_loader = DataLoader(img_res=(32, 32))

mnist, _ = data_loader.load_data(domain="A", batch_size=25)
mnistm, _ = data_loader.load_data(domain="B", batch_size=25)

r, c = 5, 5

for img_i, imgs in enumerate([mnist, mnistm]):

    #titles = ['Original', 'Translated']
    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i, j].imshow(imgs[cnt])
            #axs[i, j].set_title(titles[i])
            axs[i, j].axis('off')
            cnt += 1
    fig.savefig("%d.png" % (img_i))
    plt.close()
コード例 #18
0
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """

    from networks.TextEncoder import Encoder
    from networks.ConditionAugmentation import ConditionAugmentor
    from networks.PRO_GAN import ConditionalProGAN

    print(args.config)
    config = get_config(args.config)
    print("Current Configuration:", config)

    # create the dataset for training
    if config.use_pretrained_encoder:
        dataset = dl.RawTextFace2TextDataset(
            annots_file=config.annotations_file,
            img_dir=config.images_dir,
            img_transform=dl.get_transform(config.img_dims))
        from networks.TextEncoder import PretrainedEncoder
        # create a new session object for the pretrained encoder:
        sess_config = tf.ConfigProto(device_count={"GPU": 0})
        session = tf.Session(config=sess_config)
        text_encoder = PretrainedEncoder(
            session=session,
            module_dir=config.pretrained_encoder_dir,
            download=config.download_pretrained_encoder)
        encoder_optim = None
    else:
        dataset = dl.Face2TextDataset(pro_pick_file=config.processed_text_file,
                                      img_dir=config.images_dir,
                                      img_transform=dl.get_transform(
                                          config.img_dims),
                                      captions_len=config.captions_length)
        text_encoder = Encoder(embedding_size=config.embedding_size,
                               vocab_size=dataset.vocab_size,
                               hidden_size=config.hidden_size,
                               num_layers=config.num_layers,
                               device=device)
        encoder_optim = th.optim.Adam(text_encoder.parameters(),
                                      lr=config.learning_rate,
                                      betas=(config.beta_1, config.beta_2),
                                      eps=config.eps)

    # create the networks

    if args.encoder_file is not None:
        print("Loading encoder from:", args.encoder_file)
        text_encoder.load_state_dict(th.load(args.encoder_file))

    condition_augmenter = ConditionAugmentor(input_size=config.hidden_size,
                                             latent_size=config.ca_out_size,
                                             use_eql=config.use_eql,
                                             device=device)

    if args.ca_file is not None:
        print("Loading conditioning augmenter from:", args.ca_file)
        condition_augmenter.load_state_dict(th.load(args.ca_file))

    c_pro_gan = ConditionalProGAN(
        embedding_size=config.hidden_size,
        depth=config.depth,
        latent_size=config.latent_size,
        compressed_latent_size=config.compressed_latent_size,
        learning_rate=config.learning_rate,
        beta_1=config.beta_1,
        beta_2=config.beta_2,
        eps=config.eps,
        drift=config.drift,
        n_critic=config.n_critic,
        use_eql=config.use_eql,
        loss=config.loss_function,
        use_ema=config.use_ema,
        ema_decay=config.ema_decay,
        device=device)

    if args.generator_file is not None:
        print("Loading generator from:", args.generator_file)
        c_pro_gan.gen.load_state_dict(th.load(args.generator_file))

    if args.discriminator_file is not None:
        print("Loading discriminator from:", args.discriminator_file)
        c_pro_gan.dis.load_state_dict(th.load(args.discriminator_file))

    # create the optimizer for Condition Augmenter separately
    ca_optim = th.optim.Adam(condition_augmenter.parameters(),
                             lr=config.learning_rate,
                             betas=(config.beta_1, config.beta_2),
                             eps=config.eps)

    # train all the networks
    train_networks(
        encoder=text_encoder,
        ca=condition_augmenter,
        c_pro_gan=c_pro_gan,
        dataset=dataset,
        encoder_optim=encoder_optim,
        ca_optim=ca_optim,
        epochs=config.epochs,
        fade_in_percentage=config.fade_in_percentage,
        start_depth=args.start_depth,
        batch_sizes=config.batch_sizes,
        num_workers=config.num_workers,
        feedback_factor=config.feedback_factor,
        log_dir=config.log_dir,
        sample_dir=config.sample_dir,
        checkpoint_factor=config.checkpoint_factor,
        save_dir=config.save_dir,
        use_matching_aware_dis=config.use_matching_aware_discriminator)
コード例 #19
0
def train_networks(encoder, ca, msg_gan, dataset, epochs,
                   encoder_optim, ca_optim, gen_optim, dis_optim, loss_fn, fade_in_percentage,
                   batch_sizes, start_depth, num_workers, feedback_factor,
                   log_dir, sample_dir, checkpoint_factor,
                   save_dir, use_matching_aware_dis=True):
    # required only for type checking
    from networks.TextEncoder import PretrainedEncoder
    from numpy import power

    # input assertions
    assert msg_gan.depth == len(batch_sizes), "batch_sizes not compatible with depth"
    assert msg_gan.depth == len(epochs), "epochs_sizes not compatible with depth"
    assert msg_gan.depth == len(fade_in_percentage), "fip_sizes not compatible with depth"

    # put all the Networks in training mode:
    ca.train()
    msg_gan.gen.train()
    msg_gan.dis.train()

    if not isinstance(encoder, PretrainedEncoder):
        encoder.train()

    print("Starting the training process ... ")

    # create fixed_input for debugging###################################################
    temp_data = dl.get_data_loader(dataset, batch_sizes[start_depth], num_workers=3)
    fixed_captions, fixed_real_images = iter(temp_data).next()
    fixed_embeddings = encoder(fixed_captions.to(device)).to(device)
    #fixed_embeddings = th.from_numpy(fixed_embeddings).to(device)

    fixed_c_not_hats, _, _ = ca(fixed_embeddings)

    fixed_noise = th.randn(len(fixed_captions),
                           msg_gan.latent_size - fixed_c_not_hats.shape[-1]).to(device)

    fixed_gan_input = th.cat((fixed_c_not_hats, fixed_noise), dim=-1)

    # create a global time counter
    global_time = time.time()

    # delete temp data loader:
    del temp_data
    ####################################################################################
    ####################################################################################
    for current_depth in range(start_depth, msg_gan.depth):

        print("\n\nCurrently working on Depth: ", current_depth)
        current_res = np.power(2, current_depth + 2)
        print("Current resolution: %d x %d" % (current_res, current_res))

        data = dl.get_data_loader(dataset, batch_sizes[current_depth], num_workers)

        ticker = 1

        for epoch in range(1, epochs[current_depth] + 1):
            start = timeit.default_timer()  # record time at the start of epoch

            print("\nEpoch: %d" % epoch)
            total_batches = len(iter(data))
            fader_point = int((fade_in_percentage[current_depth] / 100)
                              * epochs[current_depth] * total_batches)

            for (i, batch) in enumerate(data, 1):
                # calculate the alpha for fading in the layers
                alpha = ticker / fader_point if ticker <= fader_point else 1

                # extract current batch of data for training
                captions, images = batch
                images = images.to(device)
                extracted_batch_size = images.shape[0]
                if encoder_optim is not None:
                    captions = captions.to(device)

                #create a lst of downsampled images from the real images:
                images = [images] + [avg_pool2d(images, int(np.power(2, i)))
                                     for i in range(1, 7)]
                images = list(reversed(images))
                # perform text_work:
                embeddings = th.from_numpy(encoder(captions).cpu().detach().numpy()).to(device)
                if encoder_optim is None:
                    # detach the LSTM from backpropagation
                    embeddings = embeddings.detach()
                c_not_hats, mus, sigmas = ca(embeddings)

                z = th.randn(
                    extracted_batch_size,
                    msg_gan.latent_size - c_not_hats.shape[-1]
                ).to(device)

                gan_input = th.cat((c_not_hats, z), dim=-1)

                # optimize the discriminator:
                dis_loss = msg_gan.optimize_discriminator(dis_optim, gan_input, images,
                                                            loss_fn)
                
                # optimize the generator:
                z = th.randn(
                    captions.shape[0] if isinstance(captions, th.Tensor) else len(captions),
                    msg_gan.latent_size - c_not_hats.shape[-1]
                ).to(device)

                gan_input = th.cat((c_not_hats, z), dim=-1)

                if encoder_optim is not None:
                    encoder_optim.zero_grad()

                ca_optim.zero_grad()
                gen_loss = msg_gan.optimize_generator(gen_optim, gan_input, images,
                                                        loss_fn)
                
                # once the optimize_generator is called, it also sends gradients
                # to the Conditioning Augmenter and the TextEncoder. Hence the
                # zero_grad statements prior to the optimize_generator call
                # now perform optimization on those two as well
                # obtain the loss (KL divergence from ca_optim)
                kl_loss = th.mean(0.5 * th.sum((mus ** 2) + (sigmas ** 2)
                                               - th.log((sigmas ** 2)) - 1, dim=1))
                kl_loss.backward(retain_graph=True)
                ca_optim.step()
                if encoder_optim is not None:
                    encoder_optim.step()

                # provide a loss feedback
                if i % int(total_batches / feedback_factor) == 0 or i == 1:
                    elapsed = time.time() - global_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))
                    print("Elapsed [%s]  batch: %d  d_loss: %f  g_loss: %f  kl_los: %f"
                          % (elapsed, i, dis_loss, gen_loss, kl_loss.item()))

                    # also write the losses to the log file:
                    os.makedirs(log_dir, exist_ok=True)
                    log_file = os.path.join(log_dir, "loss_" + str(current_depth) + ".log")
                    with open(log_file, "a") as log:
                        log.write(str(dis_loss) + "\t" + str(gen_loss)
                                  + "\t" + str(kl_loss.item()) + "\n")

                    # create a grid of samples and save it
                    """gen_img_file = os.path.join(sample_dir, "gen_" + str(current_depth) +
                                                "_" + str(epoch) + "_" +
                                                str(i) + ".png")"""
                    # create a grid of samples and save it
                    reses = [str(int(np.power(2, dep))) + "_x_"
                             + str(int(np.power(2, dep)))
                             for dep in range(2, 9)]
                    
                    #print(current_depth)
                    #print(reses)
                    gen_img_files = [os.path.join(sample_dir, res, "gen_" +
                                                  str(epoch) + "_" +
                                                  str(i) + ".png")
                                     for res in reses]
                    
                    os.makedirs(sample_dir, exist_ok=True)

                    for gen_img_file in gen_img_files:
                        os.makedirs(os.path.dirname(gen_img_file), exist_ok=True)

                    dis_optim.zero_grad()
                    gen_optim.zero_grad()

                    with th.no_grad():

                        create_grid(samples=msg_gan.gen(fixed_gan_input) if not True 
                            else msg_gan.gen_shadow(fixed_gan_input),
                            img_files=gen_img_files
                            )

                # increment the ticker:
                ticker += 1

            stop = timeit.default_timer()
            print("Time taken for epoch: %.3f secs" % (stop - start))

            if epoch % checkpoint_factor == 0 or epoch == 0:
                # save the Model
                encoder_save_file = os.path.join(save_dir, "Encoder_" +
                                                 str(current_depth) + ".pth")
                ca_save_file = os.path.join(save_dir, "Condition_Augmentor_" +
                                            str(current_depth) + ".pth")
                gen_save_file = os.path.join(save_dir, "GAN_GEN_" +
                                             str(current_depth) + ".pth")
                dis_save_file = os.path.join(save_dir, "GAN_DIS_" +
                                             str(current_depth) + ".pth")

                os.makedirs(save_dir, exist_ok=True)

                if encoder_optim is not None:
                    th.save(encoder.state_dict(), encoder_save_file, pickle)
                th.save(ca.state_dict(), ca_save_file, pickle)
                th.save(msg_gan.gen.state_dict(), gen_save_file, pickle)
                th.save(msg_gan.dis.state_dict(), dis_save_file, pickle)

    print("Training completed ...")
コード例 #20
0
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """

    from networks.TextEncoder import Encoder
    from networks.ConditionAugmentation import ConditionAugmentor
    from networks.PRO_GAN import ConditionalProGAN

    #print(args.config)
    config = get_config(args.config)
    #print("Current Configuration:", config)

    print("Create dataset...")
    # create the dataset for training
    if config.use_pretrained_encoder:
        print("Using PretrainedEncoder...")
        if not os.path.exists(
                f"text_encoder_{config.tensorboard_comment}.pickle"):

            print("Creating new vocab and dataset pickle files ...")
            dataset = dl.RawTextFace2TextDataset(
                data_path=config.data_path,
                img_dir=config.images_dir,
                img_transform=dl.get_transform(config.img_dims))
            val_dataset = dl.RawTextFace2TextDataset(
                data_path=config.data_path_val,
                img_dir=config.val_images_dir,  # unnecessary
                img_transform=dl.get_transform(config.img_dims))
            from networks.TextEncoder import PretrainedEncoder
            # create a new session object for the pretrained encoder:
            text_encoder = PretrainedEncoder(
                model_file=config.pretrained_encoder_file,
                embedding_file=config.pretrained_embedding_file,
                device=device)
            encoder_optim = None
            print("Pickling dataset, val_dataset and text_encoder....")
            with open(f'dataset_{config.tensorboard_comment}.pickle',
                      'wb') as handle:
                pickle.dump(dataset, handle, protocol=pickle.HIGHEST_PROTOCOL)
            with open(f'val_dataset_{config.tensorboard_comment}.pickle',
                      'wb') as handle:
                pickle.dump(val_dataset,
                            handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
            with open(f'text_encoder_{config.tensorboard_comment}.pickle',
                      'wb') as handle:
                pickle.dump(text_encoder,
                            handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
        else:
            print("Loading dataset, val_dataset and text_encoder from file...")
            with open(f'val_dataset_{config.tensorboard_comment}.pickle',
                      'rb') as handle:
                val_dataset = pickle.load(handle)
            with open(f'dataset_{config.tensorboard_comment}.pickle',
                      'rb') as handle:
                dataset = pickle.load(handle)
            from networks.TextEncoder import PretrainedEncoder
            with open(f'text_encoder_{config.tensorboard_comment}.pickle',
                      'rb') as handle:
                text_encoder = pickle.load(handle)
            encoder_optim = None
    else:
        print("Using Face2TextDataset dataloader...")
        dataset = dl.Face2TextDataset(pro_pick_file=config.processed_text_file,
                                      img_dir=config.images_dir,
                                      img_transform=dl.get_transform(
                                          config.img_dims),
                                      captions_len=config.captions_length)
        text_encoder = Encoder(embedding_size=config.embedding_size,
                               vocab_size=dataset.vocab_size,
                               hidden_size=config.hidden_size,
                               num_layers=config.num_layers,
                               device=device)
        encoder_optim = th.optim.Adam(text_encoder.parameters(),
                                      lr=config.learning_rate,
                                      betas=(config.beta_1, config.beta_2),
                                      eps=config.eps)

    # create the networks

    if args.encoder_file is not None:
        # Note this should not be used with the pretrained encoder file
        print("Loading encoder from:", args.encoder_file)
        text_encoder.load_state_dict(th.load(args.encoder_file))

    condition_augmenter = ConditionAugmentor(input_size=config.hidden_size,
                                             latent_size=config.ca_out_size,
                                             use_eql=config.use_eql,
                                             device=device)

    if args.ca_file is not None:
        print("Loading conditioning augmenter from:", args.ca_file)
        condition_augmenter.load_state_dict(th.load(args.ca_file))
    print("Create cprogan...")
    c_pro_gan = ConditionalProGAN(
        embedding_size=config.hidden_size,
        depth=config.depth,
        latent_size=config.latent_size,
        compressed_latent_size=config.compressed_latent_size,
        learning_rate=config.learning_rate,
        beta_1=config.beta_1,
        beta_2=config.beta_2,
        eps=config.eps,
        drift=config.drift,
        n_critic=config.n_critic,
        use_eql=config.use_eql,
        loss=config.loss_function,
        use_ema=config.use_ema,
        ema_decay=config.ema_decay,
        device=device)

    #print("Generator Config:")
    print(c_pro_gan.gen)

    #print("\nDiscriminator Config:")
    #print(c_pro_gan.dis)

    if args.generator_file is not None:
        print("Loading generator from:", args.generator_file)
        c_pro_gan.gen.load_state_dict(th.load(args.generator_file))

    if args.discriminator_file is not None:
        print("Loading discriminator from:", args.discriminator_file)
        c_pro_gan.dis.load_state_dict(th.load(args.discriminator_file))

    print("Create optimizer...")
    # create the optimizer for Condition Augmenter separately
    ca_optim = th.optim.Adam(condition_augmenter.parameters(),
                             lr=config.learning_rate,
                             betas=(config.beta_1, config.beta_2),
                             eps=config.eps)

    # train all the networks
    train_networks(
        encoder=text_encoder,
        ca=condition_augmenter,
        c_pro_gan=c_pro_gan,
        dataset=dataset,
        validation_dataset=val_dataset,
        encoder_optim=encoder_optim,
        ca_optim=ca_optim,
        epochs=config.epochs,
        fade_in_percentage=config.fade_in_percentage,
        start_depth=args.start_depth,
        batch_sizes=config.batch_sizes,
        num_workers=config.num_workers,
        feedback_factor=config.feedback_factor,
        log_dir=config.log_dir,
        sample_dir=config.sample_dir,
        checkpoint_factor=config.checkpoint_factor,
        save_dir=config.save_dir,
        comment=config.tensorboard_comment,
        use_matching_aware_dis=config.use_matching_aware_discriminator)
コード例 #21
0
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """

    from networks.TextEncoder import Encoder
    from networks.ConditionAugmentation import ConditionAugmentor
    #from pro_gan_pytorch.PRO_GAN import ConditionalProGAN
    from MSG_GAN.GAN import MSG_GAN
    from MSG_GAN import Losses as lses

    print(args.config)
    config = get_config(args.config)
    print("Current Configuration:", config)

    # create the dataset for training
    if config.use_pretrained_encoder:
        dataset = dl.RawTextFace2TextDataset(
            annots_file=config.annotations_file,
            img_dir=config.images_dir,
            img_transform=dl.get_transform(config.img_dims)
        )
        from networks.TextEncoder import PretrainedEncoder
        # create a new session object for the pretrained encoder:
        text_encoder = PretrainedEncoder(
            model_file=config.pretrained_encoder_file,
            embedding_file=config.pretrained_embedding_file,
            device=device
        )
        encoder_optim = None
    else:
        dataset = dl.Face2TextDataset(
            pro_pick_file=config.processed_text_file,
            img_dir=config.images_dir,
            img_transform=dl.get_transform(config.img_dims),
            captions_len=config.captions_length
        )
        text_encoder = Encoder(
            embedding_size=config.embedding_size,
            vocab_size=dataset.vocab_size,
            hidden_size=config.hidden_size,
            num_layers=config.num_layers,
            device=device
        )
        encoder_optim = th.optim.Adam(text_encoder.parameters(),
                                      lr=config.learning_rate,
                                      betas=(config.adam_beta1, config.adam_beta2),
                                      eps=config.eps)
    msg_gan = MSG_GAN(
        depth=config.depth,
        latent_size=config.latent_size,
        use_eql=config.use_eql,
        use_ema=config.use_ema,
        ema_decay=config.ema_decay,
        device=device)

    genoptim = th.optim.Adam(msg_gan.gen.parameters(), config.g_lr,
                              [config.adam_beta1, config.adam_beta2])

    disoptim = th.optim.Adam(msg_gan.dis.parameters(), config.d_lr,
                              [config.adam_beta1, config.adam_beta2])

    loss = lses.RelativisticAverageHingeGAN

    # create the networks

    if args.encoder_file is not None:
        # Note this should not be used with the pretrained encoder file
        print("Loading encoder from:", args.encoder_file)
        text_encoder.load_state_dict(th.load(args.encoder_file))

    condition_augmenter = ConditionAugmentor(
        input_size=config.hidden_size,
        latent_size=config.ca_out_size,
        use_eql=config.use_eql,
        device=device
    )

    if args.ca_file is not None:
        print("Loading conditioning augmenter from:", args.ca_file)
        condition_augmenter.load_state_dict(th.load(args.ca_file))

    if args.generator_file is not None:
        print("Loading generator from:", args.generator_file)
        msg_gan.gen.load_state_dict(th.load(args.generator_file))

    if args.discriminator_file is not None:
        print("Loading discriminator from:", args.discriminator_file)
        msg_gan.dis.load_state_dict(th.load(args.discriminator_file))

    # create the optimizer for Condition Augmenter separately
    ca_optim = th.optim.Adam(condition_augmenter.parameters(),
                             lr=config.learning_rate,
                             betas=(config.adam_beta1, config.adam_beta2),
                             eps=config.eps)

    print("Generator Config:")
    print(msg_gan.gen)

    print("\nDiscriminator Config:")
    print(msg_gan.dis)

    # train all the networks
    train_networks(
        encoder=text_encoder,
        ca=condition_augmenter,
        msg_gan=msg_gan,
        dataset=dataset,
        encoder_optim=encoder_optim,
        ca_optim=ca_optim,
        gen_optim=genoptim,
        dis_optim=disoptim,
        loss_fn=loss(msg_gan.dis),
        epochs=config.epochs,
        fade_in_percentage=config.fade_in_percentage,
        start_depth=args.start_depth,
        batch_sizes=config.batch_sizes,
        num_workers=config.num_workers,
        feedback_factor=config.feedback_factor,
        log_dir=config.log_dir,
        sample_dir=config.sample_dir,
        checkpoint_factor=config.checkpoint_factor,
        save_dir=config.save_dir,
    )