Пример #1
0
def create_model(optimizer, target_size=(224, 224, 3)):
    """
    モデルを作って返す。
    諸々の値がハードコーディングされている。きたない。
    
    keras applicationsで定義されたモデルを特徴抽出器として利用するモデル。
    最後にconcat, fully connectedに渡して結果を計算する。
    """
    base_model = InceptionV3(include_top=False, weights='imagenet')
    shared_layers = Network(base_model.inputs,
                            base_model.output,
                            name='shared_layers')
    shared_layers.trainable = False

    input_x1 = Input(shape=target_size, name='input_x1')
    x1 = shared_layers(input_x1)
    x1 = GlobalAveragePooling2D(name='x1_gap')(x1)
    x1 = BatchNormalization()(x1)

    input_x2 = Input(shape=target_size, name='input_x2')
    x2 = shared_layers(input_x2)
    x2 = GlobalAveragePooling2D(name='x2_gap')(x2)
    x2 = BatchNormalization()(x2)

    input_x3 = Input(shape=target_size, name='input_x3')
    x3 = shared_layers(input_x3)
    x3 = GlobalAveragePooling2D(name='x3_gap')(x3)
    x3 = BatchNormalization()(x3)

    x = Concatenate(name='concat_triple_image')([x1, x2, x3])
    x = Dense(1024, activation='relu')(x)
    #     x = Dense(256, activation='relu', kernel_regularizer=regularizers.l2(0.001))(x)
    x = Dropout(0.5)(x)
    predictions = Dense(2, activation='softmax')(x)

    model = Model([input_x1, input_x2, input_x3], predictions)

    model.compile(optimizer=optimizer,
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])

    return model
Пример #2
0
    def build_discriminators(self, filters=64):
        """
        Build the discriminator network according to description in the paper.

        :param optimizer: Keras optimizer to use for network
        :param int filters: How many filters to use in first conv layer
        :return: the compiled model
        """
        def conv2d_block(input, filters, strides=1, bn=True):
            d = Conv2D(filters, kernel_size=3, strides=strides,
                       padding='same')(input)
            d = LeakyReLU(alpha=0.2)(d)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            return d

        # Input high resolution image
        img = Input(shape=self.shape_hr)
        x = conv2d_block(img, filters, bn=False)
        x = conv2d_block(x, filters, strides=2)
        x = conv2d_block(x, filters * 2)
        x = conv2d_block(x, filters * 2, strides=2)
        x = conv2d_block(x, filters * 4)
        x = conv2d_block(x, filters * 4, strides=2)
        x = conv2d_block(x, filters * 8)
        x = conv2d_block(x, filters * 8, strides=2)
        x = Dense(filters * 16)(x)
        x = LeakyReLU(alpha=0.2)(x)
        x = Dense(1, activation='sigmoid')(x)

        # Create model and compile
        model = Model(inputs=img, outputs=x)

        # Build "frozen discriminator"
        frozen_discriminator = Network(inputs=img,
                                       outputs=x,
                                       name='frozen_discriminator')
        frozen_discriminator.trainable = False

        return model, frozen_discriminator
Пример #3
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        create_data_file(config)

    data_file_opened = open_data_file(config["data_file"])

    seg_loss_func = getattr(fetal_net.metrics, config['loss'])
    dis_loss_func = getattr(fetal_net.metrics, config['dis_loss'])

    # instantiate new model
    seg_model_func = getattr(fetal_net.model, config['model_name'])
    gen_model = seg_model_func(
        input_shape=config["input_shape"],
        initial_learning_rate=config["initial_learning_rate"],
        **{
            'dropout_rate':
            config['dropout_rate'],
            'loss_function':
            seg_loss_func,
            'mask_shape':
            None if config["weight_mask"] is None else config["input_shape"],
            'old_model_path':
            config['old_model']
        })

    dis_model_func = getattr(fetal_net.model, config['dis_model_name'])
    dis_model = dis_model_func(
        input_shape=[config["input_shape"][0] + config["n_labels"]] +
        config["input_shape"][1:],
        initial_learning_rate=config["initial_learning_rate"],
        **{
            'dropout_rate': config['dropout_rate'],
            'loss_function': dis_loss_func
        })

    if not overwrite \
            and len(glob.glob(config["model_file"] + 'g_*.h5')) > 0:
        # dis_model_path = get_last_model_path(config["model_file"] + 'dis_')
        gen_model_path = get_last_model_path(config["model_file"] + 'g_')
        # print('Loading dis model from: {}'.format(dis_model_path))
        print('Loading gen model from: {}'.format(gen_model_path))
        # dis_model = load_old_model(dis_model_path)
        # gen_model = load_old_model(gen_model_path)
        # dis_model.load_weights(dis_model_path)
        gen_model.load_weights(gen_model_path)

    gen_model.summary()
    dis_model.summary()

    # Build "frozen discriminator"
    frozen_dis_model = Network(dis_model.inputs,
                               dis_model.outputs,
                               name='frozen_discriminator')
    frozen_dis_model.trainable = False

    inputs_real = Input(shape=config["input_shape"])
    inputs_fake = Input(shape=config["input_shape"])
    segs_real = Activation(None, name='seg_real')(gen_model(inputs_real))
    segs_fake = Activation(None, name='seg_fake')(gen_model(inputs_fake))
    valid = Activation(None, name='dis')(frozen_dis_model(
        Concatenate(axis=1)([segs_fake, inputs_fake])))
    combined_model = Model(inputs=[inputs_real, inputs_fake],
                           outputs=[segs_real, valid])
    combined_model.compile(loss=[seg_loss_func, 'binary_crossentropy'],
                           loss_weights=[1, config["gd_loss_ratio"]],
                           optimizer=Adam(config["initial_learning_rate"]))
    combined_model.summary()

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        test_keys_file=config["test_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=(*config["patch_shape"], config["patch_depth"]),
        validation_batch_size=config["validation_batch_size"],
        augment=config["augment"],
        skip_blank_train=config["skip_blank_train"],
        skip_blank_val=config["skip_blank_val"],
        truth_index=config["truth_index"],
        truth_size=config["truth_size"],
        prev_truth_index=config["prev_truth_index"],
        prev_truth_size=config["prev_truth_size"],
        truth_downsample=config["truth_downsample"],
        truth_crop=config["truth_crop"],
        patches_per_epoch=config["patches_per_epoch"],
        categorical=config["categorical"],
        is3d=config["3D"],
        drop_easy_patches_train=config["drop_easy_patches_train"],
        drop_easy_patches_val=config["drop_easy_patches_val"])

    # get training and testing generators
    _, semi_generator, _, _ = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        test_keys_file=config["test_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=(*config["patch_shape"], config["patch_depth"]),
        validation_batch_size=config["validation_batch_size"],
        val_augment=config["augment"],
        skip_blank_train=config["skip_blank_train"],
        skip_blank_val=config["skip_blank_val"],
        truth_index=config["truth_index"],
        truth_size=config["truth_size"],
        prev_truth_index=config["prev_truth_index"],
        prev_truth_size=config["prev_truth_size"],
        truth_downsample=config["truth_downsample"],
        truth_crop=config["truth_crop"],
        patches_per_epoch=config["patches_per_epoch"],
        categorical=config["categorical"],
        is3d=config["3D"],
        drop_easy_patches_train=config["drop_easy_patches_train"],
        drop_easy_patches_val=config["drop_easy_patches_val"])

    # start training
    scheduler = Scheduler(config["dis_steps"],
                          config["gen_steps"],
                          init_lr=config["initial_learning_rate"],
                          lr_patience=config["patience"],
                          lr_decay=config["learning_rate_drop"])

    best_loss = np.inf
    for epoch in range(config["n_epochs"]):
        postfix = {'g': None, 'd': None}  # , 'val_g': None, 'val_d': None}
        with tqdm(range(n_train_steps // config["gen_steps"]),
                  dynamic_ncols=True,
                  postfix={
                      'gen': None,
                      'dis': None,
                      'val_gen': None,
                      'val_dis': None,
                      None: None
                  }) as pbar:
            for n_round in pbar:
                # train D
                outputs = np.zeros(dis_model.metrics_names.__len__())
                for i in range(scheduler.get_dsteps()):
                    real_patches, real_segs = next(train_generator)
                    semi_patches, _ = next(semi_generator)
                    d_x_batch, d_y_batch = input2discriminator(
                        real_patches, real_segs, semi_patches,
                        gen_model.predict(semi_patches,
                                          batch_size=config["batch_size"]),
                        dis_model.output_shape)
                    outputs += dis_model.train_on_batch(d_x_batch, d_y_batch)
                if scheduler.get_dsteps():
                    outputs /= scheduler.get_dsteps()
                    postfix['d'] = build_dsc(dis_model.metrics_names, outputs)
                    pbar.set_postfix(**postfix)

                # train G (freeze discriminator)
                outputs = np.zeros(combined_model.metrics_names.__len__())
                for i in range(scheduler.get_gsteps()):
                    real_patches, real_segs = next(train_generator)
                    semi_patches, _ = next(validation_generator)
                    g_x_batch, g_y_batch = input2gan(real_patches, real_segs,
                                                     semi_patches,
                                                     dis_model.output_shape)
                    outputs += combined_model.train_on_batch(
                        g_x_batch, g_y_batch)
                outputs /= scheduler.get_gsteps()

                postfix['g'] = build_dsc(combined_model.metrics_names, outputs)
                pbar.set_postfix(**postfix)

            # evaluate on validation set
            dis_metrics = np.zeros(dis_model.metrics_names.__len__(),
                                   dtype=float)
            gen_metrics = np.zeros(gen_model.metrics_names.__len__(),
                                   dtype=float)
            evaluation_rounds = n_validation_steps
            for n_round in range(evaluation_rounds):  # rounds_for_evaluation:
                val_patches, val_segs = next(validation_generator)

                # D
                if scheduler.get_dsteps() > 0:
                    d_x_test, d_y_test = input2discriminator(
                        val_patches, val_segs, val_patches,
                        gen_model.predict(
                            val_patches,
                            batch_size=config["validation_batch_size"]),
                        dis_model.output_shape)
                    dis_metrics += dis_model.evaluate(
                        d_x_test,
                        d_y_test,
                        batch_size=config["validation_batch_size"],
                        verbose=0)

                # G
                # gen_x_test, gen_y_test = input2gan(val_patches, val_segs, dis_model.output_shape)
                gen_metrics += gen_model.evaluate(
                    val_patches,
                    val_segs,
                    batch_size=config["validation_batch_size"],
                    verbose=0)

            dis_metrics /= float(evaluation_rounds)
            gen_metrics /= float(evaluation_rounds)
            # save the model and weights with the best validation loss
            if gen_metrics[0] < best_loss:
                best_loss = gen_metrics[0]
                print('Saving Model...')
                with open(
                        os.path.join(
                            config["base_dir"],
                            "g_{}_{:.3f}.json".format(epoch, gen_metrics[0])),
                        'w') as f:
                    f.write(gen_model.to_json())
                gen_model.save_weights(
                    os.path.join(
                        config["base_dir"],
                        "g_{}_{:.3f}.h5".format(epoch, gen_metrics[0])))

            postfix['val_d'] = build_dsc(dis_model.metrics_names, dis_metrics)
            postfix['val_g'] = build_dsc(gen_model.metrics_names, gen_metrics)
            # pbar.set_postfix(**postfix)
            print('val_d: ' + postfix['val_d'], end=' | ')
            print('val_g: ' + postfix['val_g'])
            # pbar.refresh()

            # update step sizes, learning rates
            scheduler.update_steps(epoch, gen_metrics[0])
            K.set_value(dis_model.optimizer.lr, scheduler.get_lr())
            K.set_value(combined_model.optimizer.lr, scheduler.get_lr())

    data_file_opened.close()
Пример #4
0
        image_shape, dropout_rate)

    discriminator = Model(discriminator_input,
                          discriminator_output,
                          name='discriminator')

    discriminator.compile(optimizer,
                          loss='binary_crossentropy',
                          metrics=['accuracy'])

    assert (len(discriminator._collected_trainable_weights) > 0)

    frozen_discriminator = Network(discriminator_input,
                                   discriminator_output,
                                   name='frozen_discriminator')
    frozen_discriminator.trainable = False

    generator_input, generator_output = build_generator()
    generator = Model(generator_input, generator_output, name='generator')

    adversarial_model = Model(generator_input,
                              frozen_discriminator(generator_output),
                              name='adversarial_model')

    adversarial_model.compile(optimizer, loss='binary_crossentropy')

    assert (len(adversarial_model._collected_trainable_weights) == len(
        generator.trainable_weights))
    batch_size = 64

    config = configparser.ConfigParser()
Пример #5
0
def GAN_Main(result_dir="output", data_dir="data"):

    alpha = 0.0004
    input_shape = (256, 256, 3)
    local_shape = (128, 128, 3)

    batchSize = 4
    Epochs = 150
    l1 = int(Epochs * 0.18)
    l2 = int(Epochs * 0.02)

    train_datagen = Data(data_dir, input_shape[:2], local_shape[:2])

    Gen = Generative(input_shape)
    Dis = Discriminative(input_shape, local_shape)

    optimizer = Adadelta()
    #######
    orgVal = Input(shape=input_shape)
    mask = Input(shape=(input_shape[0], input_shape[1], 1))

    imgContent = Lambda(lambda x: x[0] * (1 - x[1]),
                        output_shape=input_shape)([orgVal, mask])
    mimic = Gen(imgContent)
    completion = Lambda(lambda x: x[0] * x[2] + x[1] * (1 - x[2]),
                        output_shape=input_shape)([mimic, orgVal, mask])
    Gen_container = Network([orgVal, mask], completion)
    Gen_out = Gen_container([orgVal, mask])
    Gen_model = Model([orgVal, mask], Gen_out)
    Gen_model.compile(loss='mse', optimizer=optimizer)

    inputLayer = Input(shape=(4, ), dtype='int32')
    Dis_container = Network([orgVal, inputLayer], Dis([orgVal, inputLayer]))
    Dis_model = Model([orgVal, inputLayer], Dis_container([orgVal,
                                                           inputLayer]))
    Dis_model.compile(loss='binary_crossentropy', optimizer=optimizer)

    Dis_container.trainable = False
    totalModel = Model([orgVal, mask, inputLayer],
                       [Gen_out, Dis_container([Gen_out, inputLayer])])
    totalModel.compile(loss=['mse', 'binary_crossentropy'],
                       loss_weights=[1.0, alpha],
                       optimizer=optimizer)

    for n in range(Epochs):
        progress = generic_utils.Progbar(len(train_datagen))
        for inputs, points, masks in train_datagen.flow(batchSize):
            Gen_image = Gen_model.predict([inputs, masks])
            real = np.ones((batchSize, 1))
            unreal = np.zeros((batchSize, 1))

            generatorLoss = 0.0
            discriminatorLoss = 0.0

            if n < l1:
                generatorLoss = Gen_model.train_on_batch([inputs, masks],
                                                         inputs)
            else:
                discriminatorLoss_real = Dis_model.train_on_batch(
                    [inputs, points], real)
                discriminatorLoss_unreal = Dis_model.train_on_batch(
                    [Gen_image, points], unreal)
                discriminatorLoss = 0.5 * np.add(discriminatorLoss_real,
                                                 discriminatorLoss_unreal)
                if n >= l1 + l2:
                    generatorLoss = totalModel.train_on_batch(
                        [inputs, masks, points], [inputs, real])
                    generatorLoss = generatorLoss[0] + alpha * generatorLoss[1]
            progress.add(inputs.shape[0])
        imgs = min(5, batchSize)
        Display, Axis = mplt.subplots(imgs, 3)

        Axis[0, 0].set_title('Input Image')
        Axis[0, 1].set_title('Output Image')
        Axis[0, 2].set_title('Original Image')

        for i in range(imgs):
            Axis[i, 0].imshow(inputs[i] * (1 - masks[i]))
            Axis[i, 0].axis('off')
            Axis[i, 1].imshow(Gen_image[i])
            Axis[i, 1].axis('off')
            Axis[i, 2].imshow(inputs[i])
            Axis[i, 2].axis('off')

        Display.savefig(os.path.join(result_dir, "Batch_%d.png" % n))
        mplt.close()

#Trained Model Files...
    Gen.save(os.path.join(result_dir, "generator.h5"))
    Dis.save(os.path.join(result_dir, "discriminator.h5"))
Пример #6
0
    def __init__(self,
                 dataset_path: str,
                 num_of_upscales: int,
                 gen_mod_name: str,
                 disc_mod_name: str,
                 training_progress_save_path: str,
                 dataset_augmentation_settings: Union[AugmentationSettings,
                                                      None] = None,
                 generator_optimizer: Optimizer = Adam(0.0001, 0.9),
                 discriminator_optimizer: Optimizer = Adam(0.0001, 0.9),
                 gen_loss="mae",
                 disc_loss="binary_crossentropy",
                 feature_loss="mae",
                 gen_loss_weight: float = 1.0,
                 disc_loss_weight: float = 0.003,
                 feature_loss_weights: Union[list, float, None] = None,
                 feature_extractor_layers: Union[list, None] = None,
                 generator_lr_decay_interval: Union[int, None] = None,
                 discriminator_lr_decay_interval: Union[int, None] = None,
                 generator_lr_decay_factor: Union[float, None] = None,
                 discriminator_lr_decay_factor: Union[float, None] = None,
                 generator_min_lr: Union[float, None] = None,
                 discriminator_min_lr: Union[float, None] = None,
                 discriminator_label_noise: Union[float, None] = None,
                 discriminator_label_noise_decay: Union[float, None] = None,
                 discriminator_label_noise_min: Union[float, None] = 0.001,
                 batch_size: int = 4,
                 buffered_batches: int = 20,
                 generator_weights: Union[str, None] = None,
                 discriminator_weights: Union[str, None] = None,
                 load_from_checkpoint: bool = False,
                 custom_hr_test_images_paths: Union[list, None] = None,
                 check_dataset: bool = True,
                 num_of_loading_workers: int = 8):

        # Save params to inner variables
        self.__disc_mod_name = disc_mod_name
        self.__gen_mod_name = gen_mod_name
        self.__num_of_upscales = num_of_upscales
        assert self.__num_of_upscales >= 0, Fore.RED + "Invalid number of upscales" + Fore.RESET

        self.__discriminator_label_noise = discriminator_label_noise
        self.__discriminator_label_noise_decay = discriminator_label_noise_decay
        self.__discriminator_label_noise_min = discriminator_label_noise_min
        if self.__discriminator_label_noise_min is None:
            self.__discriminator_label_noise_min = 0

        self.__batch_size = batch_size
        assert self.__batch_size > 0, Fore.RED + "Invalid batch size" + Fore.RESET

        self.__episode_counter = 0

        # Insert empty lists if feature extractor settings are empty
        if feature_extractor_layers is None:
            feature_extractor_layers = []

        if feature_loss_weights is None:
            feature_loss_weights = []

        # If feature_loss_weights is float then create list of the weights from it
        if isinstance(feature_loss_weights,
                      float) and len(feature_extractor_layers) > 0:
            feature_loss_weights = [
                feature_loss_weights / len(feature_extractor_layers)
            ] * len(feature_extractor_layers)

        assert len(feature_extractor_layers) == len(
            feature_loss_weights
        ), Fore.RED + "Number of extractor layers and feature loss weights must match!" + Fore.RESET

        # Create array of input image paths
        self.__train_data = get_paths_of_files_from_path(dataset_path,
                                                         only_files=True)
        assert self.__train_data, Fore.RED + "Training dataset is not loaded" + Fore.RESET

        # Load one image to get shape of it
        self.__target_image_shape = cv.imread(self.__train_data[0]).shape

        # Check image size validity
        if self.__target_image_shape[0] < 4 or self.__target_image_shape[1] < 4:
            raise Exception("Images too small, min size (4, 4)")

        # Starting image size calculate
        self.__start_image_shape = count_upscaling_start_size(
            self.__target_image_shape, self.__num_of_upscales)

        # Check validity of whole datasets
        if check_dataset:
            self.__validate_dataset()

        # Initialize training data folder and logging
        self.__training_progress_save_path = training_progress_save_path
        self.__training_progress_save_path = os.path.join(
            self.__training_progress_save_path,
            f"{self.__gen_mod_name}__{self.__disc_mod_name}__{self.__start_image_shape}_to_{self.__target_image_shape}"
        )
        self.__tensorboard = TensorBoardCustom(
            log_dir=os.path.join(self.__training_progress_save_path, "logs"))
        self.__stat_logger = StatLogger(self.__tensorboard)

        # Define static vars
        self.kernel_initializer = RandomNormal(stddev=0.02)
        self.__custom_loading_failed = False
        self.__custom_test_images = True if custom_hr_test_images_paths else False
        if custom_hr_test_images_paths:
            self.__progress_test_images_paths = custom_hr_test_images_paths
            for idx, image_path in enumerate(
                    self.__progress_test_images_paths):
                if not os.path.exists(image_path):
                    self.__custom_loading_failed = True
                    self.__progress_test_images_paths[idx] = random.choice(
                        self.__train_data)
        else:
            self.__progress_test_images_paths = [
                random.choice(self.__train_data)
            ]

        # Create batchmaker and start it
        self.__batch_maker = BatchMaker(
            self.__train_data,
            self.__batch_size,
            buffered_batches=buffered_batches,
            secondary_size=self.__start_image_shape,
            num_of_loading_workers=num_of_loading_workers,
            augmentation_settings=dataset_augmentation_settings)

        # Create LR Schedulers for both "Optimizer"
        self.__gen_lr_scheduler = LearningRateScheduler(
            start_lr=float(K.get_value(generator_optimizer.lr)),
            lr_decay_factor=generator_lr_decay_factor,
            lr_decay_interval=generator_lr_decay_interval,
            min_lr=generator_min_lr)
        self.__disc_lr_scheduler = LearningRateScheduler(
            start_lr=float(K.get_value(discriminator_optimizer.lr)),
            lr_decay_factor=discriminator_lr_decay_factor,
            lr_decay_interval=discriminator_lr_decay_interval,
            min_lr=discriminator_min_lr)

        #####################################
        ###      Create discriminator     ###
        #####################################
        self.__discriminator = self.__build_discriminator(disc_mod_name)
        self.__discriminator.compile(loss=disc_loss,
                                     optimizer=discriminator_optimizer)

        #####################################
        ###       Create generator        ###
        #####################################
        self.__generator = self.__build_generator(gen_mod_name)
        if self.__generator.output_shape[1:] != self.__target_image_shape:
            raise Exception(
                f"Invalid image input size for this generator model\nGenerator shape: {self.__generator.output_shape[1:]}, Target shape: {self.__target_image_shape}"
            )
        self.__generator.compile(loss=gen_loss,
                                 optimizer=generator_optimizer,
                                 metrics=[PSNR_Y, PSNR, SSIM])

        #####################################
        ###      Create vgg network       ###
        #####################################
        self.__vgg = create_feature_extractor(self.__target_image_shape,
                                              feature_extractor_layers)

        #####################################
        ###   Create combined generator   ###
        #####################################
        small_image_input_generator = Input(shape=self.__start_image_shape,
                                            name="small_image_input")

        # Images upscaled by generator
        gen_images = self.__generator(small_image_input_generator)

        # Discriminator takes images and determinates validity
        frozen_discriminator = Network(self.__discriminator.inputs,
                                       self.__discriminator.outputs,
                                       name="frozen_discriminator")
        frozen_discriminator.trainable = False

        validity = frozen_discriminator(gen_images)

        # Extracts features from generated images
        generated_features = self.__vgg(preprocess_vgg(gen_images))

        # Combine models
        # Train generator to fool discriminator
        self.__combined_generator_model = Model(
            inputs=small_image_input_generator,
            outputs=[gen_images, validity] + [*generated_features],
            name="srgan")
        self.__combined_generator_model.compile(
            loss=[gen_loss, disc_loss] +
            ([feature_loss] * len(generated_features)),
            loss_weights=[gen_loss_weight, disc_loss_weight] +
            feature_loss_weights,
            optimizer=generator_optimizer,
            metrics={"generator": [PSNR_Y, PSNR, SSIM]})

        # Print all summaries
        print("\nDiscriminator Summary:")
        self.__discriminator.summary()
        print("\nGenerator Summary:")
        self.__generator.summary()

        # Load checkpoint
        self.__initiated = False
        if load_from_checkpoint: self.__load_checkpoint()

        # Load weights from param and override checkpoint weights
        if generator_weights: self.__generator.load_weights(generator_weights)
        if discriminator_weights:
            self.__discriminator.load_weights(discriminator_weights)

        # Set LR
        self.__gen_lr_scheduler.set_lr(self.__combined_generator_model,
                                       self.__episode_counter)
        self.__disc_lr_scheduler.set_lr(self.__discriminator,
                                        self.__episode_counter)
Пример #7
0
    def __init__(self,
                 dataset_path: str,
                 gen_mod_name: str,
                 disc_mod_name: str,
                 latent_dim: int,
                 training_progress_save_path: str,
                 testing_dataset_path: str = None,
                 generator_optimizer: Optimizer = Adam(0.0002, 0.5),
                 discriminator_optimizer: Optimizer = Adam(0.0002, 0.5),
                 discriminator_label_noise: float = None,
                 discriminator_label_noise_decay: float = None,
                 discriminator_label_noise_min: float = 0.001,
                 batch_size: int = 32,
                 buffered_batches: int = 20,
                 generator_weights: Union[str, None] = None,
                 discriminator_weights: Union[str, None] = None,
                 start_episode: int = 0,
                 load_from_checkpoint: bool = False,
                 check_dataset: bool = True,
                 num_of_loading_workers: int = 8):

        self.disc_mod_name = disc_mod_name
        self.gen_mod_name = gen_mod_name
        self.generator_optimizer = generator_optimizer

        self.latent_dim = latent_dim
        assert self.latent_dim > 0, Fore.RED + "Invalid latent dim" + Fore.RESET

        self.batch_size = batch_size
        assert self.batch_size > 0, Fore.RED + "Invalid batch size" + Fore.RESET

        self.discriminator_label_noise = discriminator_label_noise
        self.discriminator_label_noise_decay = discriminator_label_noise_decay
        self.discriminator_label_noise_min = discriminator_label_noise_min

        self.progress_image_dim = (16, 9)

        if start_episode < 0: start_episode = 0
        self.episode_counter = start_episode

        # Initialize training data folder and logging
        self.training_progress_save_path = training_progress_save_path
        self.training_progress_save_path = os.path.join(
            self.training_progress_save_path,
            f"{self.gen_mod_name}__{self.disc_mod_name}")
        self.tensorboard = TensorBoardCustom(
            log_dir=os.path.join(self.training_progress_save_path, "logs"))

        # Create array of input image paths
        self.train_data = get_paths_of_files_from_path(dataset_path,
                                                       only_files=True)
        assert self.train_data, Fore.RED + "Training dataset is not loaded" + Fore.RESET

        self.testing_data = None
        if testing_dataset_path:
            self.testing_data = get_paths_of_files_from_path(
                testing_dataset_path)
            assert self.testing_data, Fore.RED + "Testing dataset is not loaded" + Fore.RESET

        # Load one image to get shape of it
        tmp_image = cv.imread(self.train_data[0])
        self.image_shape = tmp_image.shape
        self.image_channels = self.image_shape[2]

        # Check image size validity
        if self.image_shape[0] < 4 or self.image_shape[1] < 4:
            raise Exception("Images too small, min size (4, 4)")

        # Check validity of whole datasets
        if check_dataset:
            self.validate_dataset()

        # Define static vars
        if os.path.exists(
                f"{self.training_progress_save_path}/static_noise.npy"):
            self.static_noise = np.load(
                f"{self.training_progress_save_path}/static_noise.npy")
            if self.static_noise.shape[0] != (self.progress_image_dim[0] *
                                              self.progress_image_dim[1]):
                print(Fore.YELLOW +
                      "Progress image dim changed, restarting static noise!" +
                      Fore.RESET)
                os.remove(
                    f"{self.training_progress_save_path}/static_noise.npy")
                self.static_noise = np.random.normal(
                    0.0,
                    1.0,
                    size=(self.progress_image_dim[0] *
                          self.progress_image_dim[1], self.latent_dim))
        else:
            self.static_noise = np.random.normal(
                0.0,
                1.0,
                size=(self.progress_image_dim[0] * self.progress_image_dim[1],
                      self.latent_dim))
        self.kernel_initializer = RandomNormal(stddev=0.02)

        # Load checkpoint
        self.initiated = False
        loaded_gen_weights_path = None
        loaded_disc_weights_path = None
        if load_from_checkpoint:
            loaded_gen_weights_path, loaded_disc_weights_path = self.load_checkpoint(
            )

        # Create batchmaker and start it
        self.batch_maker = BatchMaker(
            self.train_data,
            self.batch_size,
            buffered_batches=buffered_batches,
            num_of_loading_workers=num_of_loading_workers)

        self.testing_batchmaker = None
        if self.testing_data:
            self.testing_batchmaker = BatchMaker(
                self.testing_data,
                self.batch_size,
                buffered_batches=buffered_batches,
                num_of_loading_workers=num_of_loading_workers)
            self.testing_batchmaker.start()

        #################################
        ###   Create discriminator    ###
        #################################
        self.discriminator = self.build_discriminator(disc_mod_name)
        self.discriminator.compile(loss="binary_crossentropy",
                                   optimizer=discriminator_optimizer)

        #################################
        ###     Create generator      ###
        #################################
        self.generator = self.build_generator(gen_mod_name)
        if self.generator.output_shape[1:] != self.image_shape:
            raise Exception(
                "Invalid image input size for this generator model")

        #################################
        ### Create combined generator ###
        #################################
        noise_input = Input(shape=(self.latent_dim, ), name="noise_input")
        gen_images = self.generator(noise_input)

        # Create frozen version of discriminator
        frozen_discriminator = Network(self.discriminator.inputs,
                                       self.discriminator.outputs,
                                       name="frozen_discriminator")
        frozen_discriminator.trainable = False

        # Discriminator takes images and determinates validity
        valid = frozen_discriminator(gen_images)

        # Combine models
        # Train generator to fool discriminator
        self.combined_generator_model = Model(noise_input,
                                              valid,
                                              name="dcgan_model")
        self.combined_generator_model.compile(
            loss="binary_crossentropy", optimizer=self.generator_optimizer)

        # Print all summaries
        print("\nDiscriminator Summary:")
        self.discriminator.summary()
        print("\nGenerator Summary:")
        self.generator.summary()
        print("\nGAN Summary")
        self.combined_generator_model.summary()

        # Load weights from checkpoint
        try:
            if loaded_gen_weights_path:
                self.generator.load_weights(loaded_gen_weights_path)
        except:
            print(Fore.YELLOW +
                  "Failed to load generator weights from checkpoint" +
                  Fore.RESET)

        try:
            if loaded_disc_weights_path:
                self.discriminator.load_weights(loaded_disc_weights_path)
        except:
            print(Fore.YELLOW +
                  "Failed to load discriminator weights from checkpoint" +
                  Fore.RESET)

        # Load weights from param and override checkpoint weights
        if generator_weights: self.generator.load_weights(generator_weights)
        if discriminator_weights:
            self.discriminator.load_weights(discriminator_weights)