def create_model(optimizer, target_size=(224, 224, 3)): """ モデルを作って返す。 諸々の値がハードコーディングされている。きたない。 keras applicationsで定義されたモデルを特徴抽出器として利用するモデル。 最後にconcat, fully connectedに渡して結果を計算する。 """ base_model = InceptionV3(include_top=False, weights='imagenet') shared_layers = Network(base_model.inputs, base_model.output, name='shared_layers') shared_layers.trainable = False input_x1 = Input(shape=target_size, name='input_x1') x1 = shared_layers(input_x1) x1 = GlobalAveragePooling2D(name='x1_gap')(x1) x1 = BatchNormalization()(x1) input_x2 = Input(shape=target_size, name='input_x2') x2 = shared_layers(input_x2) x2 = GlobalAveragePooling2D(name='x2_gap')(x2) x2 = BatchNormalization()(x2) input_x3 = Input(shape=target_size, name='input_x3') x3 = shared_layers(input_x3) x3 = GlobalAveragePooling2D(name='x3_gap')(x3) x3 = BatchNormalization()(x3) x = Concatenate(name='concat_triple_image')([x1, x2, x3]) x = Dense(1024, activation='relu')(x) # x = Dense(256, activation='relu', kernel_regularizer=regularizers.l2(0.001))(x) x = Dropout(0.5)(x) predictions = Dense(2, activation='softmax')(x) model = Model([input_x1, input_x2, input_x3], predictions) model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy']) return model
def build_discriminators(self, filters=64): """ Build the discriminator network according to description in the paper. :param optimizer: Keras optimizer to use for network :param int filters: How many filters to use in first conv layer :return: the compiled model """ def conv2d_block(input, filters, strides=1, bn=True): d = Conv2D(filters, kernel_size=3, strides=strides, padding='same')(input) d = LeakyReLU(alpha=0.2)(d) if bn: d = BatchNormalization(momentum=0.8)(d) return d # Input high resolution image img = Input(shape=self.shape_hr) x = conv2d_block(img, filters, bn=False) x = conv2d_block(x, filters, strides=2) x = conv2d_block(x, filters * 2) x = conv2d_block(x, filters * 2, strides=2) x = conv2d_block(x, filters * 4) x = conv2d_block(x, filters * 4, strides=2) x = conv2d_block(x, filters * 8) x = conv2d_block(x, filters * 8, strides=2) x = Dense(filters * 16)(x) x = LeakyReLU(alpha=0.2)(x) x = Dense(1, activation='sigmoid')(x) # Create model and compile model = Model(inputs=img, outputs=x) # Build "frozen discriminator" frozen_discriminator = Network(inputs=img, outputs=x, name='frozen_discriminator') frozen_discriminator.trainable = False return model, frozen_discriminator
def main(overwrite=False): # convert input images into an hdf5 file if overwrite or not os.path.exists(config["data_file"]): create_data_file(config) data_file_opened = open_data_file(config["data_file"]) seg_loss_func = getattr(fetal_net.metrics, config['loss']) dis_loss_func = getattr(fetal_net.metrics, config['dis_loss']) # instantiate new model seg_model_func = getattr(fetal_net.model, config['model_name']) gen_model = seg_model_func( input_shape=config["input_shape"], initial_learning_rate=config["initial_learning_rate"], **{ 'dropout_rate': config['dropout_rate'], 'loss_function': seg_loss_func, 'mask_shape': None if config["weight_mask"] is None else config["input_shape"], 'old_model_path': config['old_model'] }) dis_model_func = getattr(fetal_net.model, config['dis_model_name']) dis_model = dis_model_func( input_shape=[config["input_shape"][0] + config["n_labels"]] + config["input_shape"][1:], initial_learning_rate=config["initial_learning_rate"], **{ 'dropout_rate': config['dropout_rate'], 'loss_function': dis_loss_func }) if not overwrite \ and len(glob.glob(config["model_file"] + 'g_*.h5')) > 0: # dis_model_path = get_last_model_path(config["model_file"] + 'dis_') gen_model_path = get_last_model_path(config["model_file"] + 'g_') # print('Loading dis model from: {}'.format(dis_model_path)) print('Loading gen model from: {}'.format(gen_model_path)) # dis_model = load_old_model(dis_model_path) # gen_model = load_old_model(gen_model_path) # dis_model.load_weights(dis_model_path) gen_model.load_weights(gen_model_path) gen_model.summary() dis_model.summary() # Build "frozen discriminator" frozen_dis_model = Network(dis_model.inputs, dis_model.outputs, name='frozen_discriminator') frozen_dis_model.trainable = False inputs_real = Input(shape=config["input_shape"]) inputs_fake = Input(shape=config["input_shape"]) segs_real = Activation(None, name='seg_real')(gen_model(inputs_real)) segs_fake = Activation(None, name='seg_fake')(gen_model(inputs_fake)) valid = Activation(None, name='dis')(frozen_dis_model( Concatenate(axis=1)([segs_fake, inputs_fake]))) combined_model = Model(inputs=[inputs_real, inputs_fake], outputs=[segs_real, valid]) combined_model.compile(loss=[seg_loss_func, 'binary_crossentropy'], loss_weights=[1, config["gd_loss_ratio"]], optimizer=Adam(config["initial_learning_rate"])) combined_model.summary() # get training and testing generators train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators( data_file_opened, batch_size=config["batch_size"], data_split=config["validation_split"], overwrite=overwrite, validation_keys_file=config["validation_file"], training_keys_file=config["training_file"], test_keys_file=config["test_file"], n_labels=config["n_labels"], labels=config["labels"], patch_shape=(*config["patch_shape"], config["patch_depth"]), validation_batch_size=config["validation_batch_size"], augment=config["augment"], skip_blank_train=config["skip_blank_train"], skip_blank_val=config["skip_blank_val"], truth_index=config["truth_index"], truth_size=config["truth_size"], prev_truth_index=config["prev_truth_index"], prev_truth_size=config["prev_truth_size"], truth_downsample=config["truth_downsample"], truth_crop=config["truth_crop"], patches_per_epoch=config["patches_per_epoch"], categorical=config["categorical"], is3d=config["3D"], drop_easy_patches_train=config["drop_easy_patches_train"], drop_easy_patches_val=config["drop_easy_patches_val"]) # get training and testing generators _, semi_generator, _, _ = get_training_and_validation_generators( data_file_opened, batch_size=config["batch_size"], data_split=config["validation_split"], overwrite=overwrite, validation_keys_file=config["validation_file"], training_keys_file=config["training_file"], test_keys_file=config["test_file"], n_labels=config["n_labels"], labels=config["labels"], patch_shape=(*config["patch_shape"], config["patch_depth"]), validation_batch_size=config["validation_batch_size"], val_augment=config["augment"], skip_blank_train=config["skip_blank_train"], skip_blank_val=config["skip_blank_val"], truth_index=config["truth_index"], truth_size=config["truth_size"], prev_truth_index=config["prev_truth_index"], prev_truth_size=config["prev_truth_size"], truth_downsample=config["truth_downsample"], truth_crop=config["truth_crop"], patches_per_epoch=config["patches_per_epoch"], categorical=config["categorical"], is3d=config["3D"], drop_easy_patches_train=config["drop_easy_patches_train"], drop_easy_patches_val=config["drop_easy_patches_val"]) # start training scheduler = Scheduler(config["dis_steps"], config["gen_steps"], init_lr=config["initial_learning_rate"], lr_patience=config["patience"], lr_decay=config["learning_rate_drop"]) best_loss = np.inf for epoch in range(config["n_epochs"]): postfix = {'g': None, 'd': None} # , 'val_g': None, 'val_d': None} with tqdm(range(n_train_steps // config["gen_steps"]), dynamic_ncols=True, postfix={ 'gen': None, 'dis': None, 'val_gen': None, 'val_dis': None, None: None }) as pbar: for n_round in pbar: # train D outputs = np.zeros(dis_model.metrics_names.__len__()) for i in range(scheduler.get_dsteps()): real_patches, real_segs = next(train_generator) semi_patches, _ = next(semi_generator) d_x_batch, d_y_batch = input2discriminator( real_patches, real_segs, semi_patches, gen_model.predict(semi_patches, batch_size=config["batch_size"]), dis_model.output_shape) outputs += dis_model.train_on_batch(d_x_batch, d_y_batch) if scheduler.get_dsteps(): outputs /= scheduler.get_dsteps() postfix['d'] = build_dsc(dis_model.metrics_names, outputs) pbar.set_postfix(**postfix) # train G (freeze discriminator) outputs = np.zeros(combined_model.metrics_names.__len__()) for i in range(scheduler.get_gsteps()): real_patches, real_segs = next(train_generator) semi_patches, _ = next(validation_generator) g_x_batch, g_y_batch = input2gan(real_patches, real_segs, semi_patches, dis_model.output_shape) outputs += combined_model.train_on_batch( g_x_batch, g_y_batch) outputs /= scheduler.get_gsteps() postfix['g'] = build_dsc(combined_model.metrics_names, outputs) pbar.set_postfix(**postfix) # evaluate on validation set dis_metrics = np.zeros(dis_model.metrics_names.__len__(), dtype=float) gen_metrics = np.zeros(gen_model.metrics_names.__len__(), dtype=float) evaluation_rounds = n_validation_steps for n_round in range(evaluation_rounds): # rounds_for_evaluation: val_patches, val_segs = next(validation_generator) # D if scheduler.get_dsteps() > 0: d_x_test, d_y_test = input2discriminator( val_patches, val_segs, val_patches, gen_model.predict( val_patches, batch_size=config["validation_batch_size"]), dis_model.output_shape) dis_metrics += dis_model.evaluate( d_x_test, d_y_test, batch_size=config["validation_batch_size"], verbose=0) # G # gen_x_test, gen_y_test = input2gan(val_patches, val_segs, dis_model.output_shape) gen_metrics += gen_model.evaluate( val_patches, val_segs, batch_size=config["validation_batch_size"], verbose=0) dis_metrics /= float(evaluation_rounds) gen_metrics /= float(evaluation_rounds) # save the model and weights with the best validation loss if gen_metrics[0] < best_loss: best_loss = gen_metrics[0] print('Saving Model...') with open( os.path.join( config["base_dir"], "g_{}_{:.3f}.json".format(epoch, gen_metrics[0])), 'w') as f: f.write(gen_model.to_json()) gen_model.save_weights( os.path.join( config["base_dir"], "g_{}_{:.3f}.h5".format(epoch, gen_metrics[0]))) postfix['val_d'] = build_dsc(dis_model.metrics_names, dis_metrics) postfix['val_g'] = build_dsc(gen_model.metrics_names, gen_metrics) # pbar.set_postfix(**postfix) print('val_d: ' + postfix['val_d'], end=' | ') print('val_g: ' + postfix['val_g']) # pbar.refresh() # update step sizes, learning rates scheduler.update_steps(epoch, gen_metrics[0]) K.set_value(dis_model.optimizer.lr, scheduler.get_lr()) K.set_value(combined_model.optimizer.lr, scheduler.get_lr()) data_file_opened.close()
image_shape, dropout_rate) discriminator = Model(discriminator_input, discriminator_output, name='discriminator') discriminator.compile(optimizer, loss='binary_crossentropy', metrics=['accuracy']) assert (len(discriminator._collected_trainable_weights) > 0) frozen_discriminator = Network(discriminator_input, discriminator_output, name='frozen_discriminator') frozen_discriminator.trainable = False generator_input, generator_output = build_generator() generator = Model(generator_input, generator_output, name='generator') adversarial_model = Model(generator_input, frozen_discriminator(generator_output), name='adversarial_model') adversarial_model.compile(optimizer, loss='binary_crossentropy') assert (len(adversarial_model._collected_trainable_weights) == len( generator.trainable_weights)) batch_size = 64 config = configparser.ConfigParser()
def GAN_Main(result_dir="output", data_dir="data"): alpha = 0.0004 input_shape = (256, 256, 3) local_shape = (128, 128, 3) batchSize = 4 Epochs = 150 l1 = int(Epochs * 0.18) l2 = int(Epochs * 0.02) train_datagen = Data(data_dir, input_shape[:2], local_shape[:2]) Gen = Generative(input_shape) Dis = Discriminative(input_shape, local_shape) optimizer = Adadelta() ####### orgVal = Input(shape=input_shape) mask = Input(shape=(input_shape[0], input_shape[1], 1)) imgContent = Lambda(lambda x: x[0] * (1 - x[1]), output_shape=input_shape)([orgVal, mask]) mimic = Gen(imgContent) completion = Lambda(lambda x: x[0] * x[2] + x[1] * (1 - x[2]), output_shape=input_shape)([mimic, orgVal, mask]) Gen_container = Network([orgVal, mask], completion) Gen_out = Gen_container([orgVal, mask]) Gen_model = Model([orgVal, mask], Gen_out) Gen_model.compile(loss='mse', optimizer=optimizer) inputLayer = Input(shape=(4, ), dtype='int32') Dis_container = Network([orgVal, inputLayer], Dis([orgVal, inputLayer])) Dis_model = Model([orgVal, inputLayer], Dis_container([orgVal, inputLayer])) Dis_model.compile(loss='binary_crossentropy', optimizer=optimizer) Dis_container.trainable = False totalModel = Model([orgVal, mask, inputLayer], [Gen_out, Dis_container([Gen_out, inputLayer])]) totalModel.compile(loss=['mse', 'binary_crossentropy'], loss_weights=[1.0, alpha], optimizer=optimizer) for n in range(Epochs): progress = generic_utils.Progbar(len(train_datagen)) for inputs, points, masks in train_datagen.flow(batchSize): Gen_image = Gen_model.predict([inputs, masks]) real = np.ones((batchSize, 1)) unreal = np.zeros((batchSize, 1)) generatorLoss = 0.0 discriminatorLoss = 0.0 if n < l1: generatorLoss = Gen_model.train_on_batch([inputs, masks], inputs) else: discriminatorLoss_real = Dis_model.train_on_batch( [inputs, points], real) discriminatorLoss_unreal = Dis_model.train_on_batch( [Gen_image, points], unreal) discriminatorLoss = 0.5 * np.add(discriminatorLoss_real, discriminatorLoss_unreal) if n >= l1 + l2: generatorLoss = totalModel.train_on_batch( [inputs, masks, points], [inputs, real]) generatorLoss = generatorLoss[0] + alpha * generatorLoss[1] progress.add(inputs.shape[0]) imgs = min(5, batchSize) Display, Axis = mplt.subplots(imgs, 3) Axis[0, 0].set_title('Input Image') Axis[0, 1].set_title('Output Image') Axis[0, 2].set_title('Original Image') for i in range(imgs): Axis[i, 0].imshow(inputs[i] * (1 - masks[i])) Axis[i, 0].axis('off') Axis[i, 1].imshow(Gen_image[i]) Axis[i, 1].axis('off') Axis[i, 2].imshow(inputs[i]) Axis[i, 2].axis('off') Display.savefig(os.path.join(result_dir, "Batch_%d.png" % n)) mplt.close() #Trained Model Files... Gen.save(os.path.join(result_dir, "generator.h5")) Dis.save(os.path.join(result_dir, "discriminator.h5"))
def __init__(self, dataset_path: str, num_of_upscales: int, gen_mod_name: str, disc_mod_name: str, training_progress_save_path: str, dataset_augmentation_settings: Union[AugmentationSettings, None] = None, generator_optimizer: Optimizer = Adam(0.0001, 0.9), discriminator_optimizer: Optimizer = Adam(0.0001, 0.9), gen_loss="mae", disc_loss="binary_crossentropy", feature_loss="mae", gen_loss_weight: float = 1.0, disc_loss_weight: float = 0.003, feature_loss_weights: Union[list, float, None] = None, feature_extractor_layers: Union[list, None] = None, generator_lr_decay_interval: Union[int, None] = None, discriminator_lr_decay_interval: Union[int, None] = None, generator_lr_decay_factor: Union[float, None] = None, discriminator_lr_decay_factor: Union[float, None] = None, generator_min_lr: Union[float, None] = None, discriminator_min_lr: Union[float, None] = None, discriminator_label_noise: Union[float, None] = None, discriminator_label_noise_decay: Union[float, None] = None, discriminator_label_noise_min: Union[float, None] = 0.001, batch_size: int = 4, buffered_batches: int = 20, generator_weights: Union[str, None] = None, discriminator_weights: Union[str, None] = None, load_from_checkpoint: bool = False, custom_hr_test_images_paths: Union[list, None] = None, check_dataset: bool = True, num_of_loading_workers: int = 8): # Save params to inner variables self.__disc_mod_name = disc_mod_name self.__gen_mod_name = gen_mod_name self.__num_of_upscales = num_of_upscales assert self.__num_of_upscales >= 0, Fore.RED + "Invalid number of upscales" + Fore.RESET self.__discriminator_label_noise = discriminator_label_noise self.__discriminator_label_noise_decay = discriminator_label_noise_decay self.__discriminator_label_noise_min = discriminator_label_noise_min if self.__discriminator_label_noise_min is None: self.__discriminator_label_noise_min = 0 self.__batch_size = batch_size assert self.__batch_size > 0, Fore.RED + "Invalid batch size" + Fore.RESET self.__episode_counter = 0 # Insert empty lists if feature extractor settings are empty if feature_extractor_layers is None: feature_extractor_layers = [] if feature_loss_weights is None: feature_loss_weights = [] # If feature_loss_weights is float then create list of the weights from it if isinstance(feature_loss_weights, float) and len(feature_extractor_layers) > 0: feature_loss_weights = [ feature_loss_weights / len(feature_extractor_layers) ] * len(feature_extractor_layers) assert len(feature_extractor_layers) == len( feature_loss_weights ), Fore.RED + "Number of extractor layers and feature loss weights must match!" + Fore.RESET # Create array of input image paths self.__train_data = get_paths_of_files_from_path(dataset_path, only_files=True) assert self.__train_data, Fore.RED + "Training dataset is not loaded" + Fore.RESET # Load one image to get shape of it self.__target_image_shape = cv.imread(self.__train_data[0]).shape # Check image size validity if self.__target_image_shape[0] < 4 or self.__target_image_shape[1] < 4: raise Exception("Images too small, min size (4, 4)") # Starting image size calculate self.__start_image_shape = count_upscaling_start_size( self.__target_image_shape, self.__num_of_upscales) # Check validity of whole datasets if check_dataset: self.__validate_dataset() # Initialize training data folder and logging self.__training_progress_save_path = training_progress_save_path self.__training_progress_save_path = os.path.join( self.__training_progress_save_path, f"{self.__gen_mod_name}__{self.__disc_mod_name}__{self.__start_image_shape}_to_{self.__target_image_shape}" ) self.__tensorboard = TensorBoardCustom( log_dir=os.path.join(self.__training_progress_save_path, "logs")) self.__stat_logger = StatLogger(self.__tensorboard) # Define static vars self.kernel_initializer = RandomNormal(stddev=0.02) self.__custom_loading_failed = False self.__custom_test_images = True if custom_hr_test_images_paths else False if custom_hr_test_images_paths: self.__progress_test_images_paths = custom_hr_test_images_paths for idx, image_path in enumerate( self.__progress_test_images_paths): if not os.path.exists(image_path): self.__custom_loading_failed = True self.__progress_test_images_paths[idx] = random.choice( self.__train_data) else: self.__progress_test_images_paths = [ random.choice(self.__train_data) ] # Create batchmaker and start it self.__batch_maker = BatchMaker( self.__train_data, self.__batch_size, buffered_batches=buffered_batches, secondary_size=self.__start_image_shape, num_of_loading_workers=num_of_loading_workers, augmentation_settings=dataset_augmentation_settings) # Create LR Schedulers for both "Optimizer" self.__gen_lr_scheduler = LearningRateScheduler( start_lr=float(K.get_value(generator_optimizer.lr)), lr_decay_factor=generator_lr_decay_factor, lr_decay_interval=generator_lr_decay_interval, min_lr=generator_min_lr) self.__disc_lr_scheduler = LearningRateScheduler( start_lr=float(K.get_value(discriminator_optimizer.lr)), lr_decay_factor=discriminator_lr_decay_factor, lr_decay_interval=discriminator_lr_decay_interval, min_lr=discriminator_min_lr) ##################################### ### Create discriminator ### ##################################### self.__discriminator = self.__build_discriminator(disc_mod_name) self.__discriminator.compile(loss=disc_loss, optimizer=discriminator_optimizer) ##################################### ### Create generator ### ##################################### self.__generator = self.__build_generator(gen_mod_name) if self.__generator.output_shape[1:] != self.__target_image_shape: raise Exception( f"Invalid image input size for this generator model\nGenerator shape: {self.__generator.output_shape[1:]}, Target shape: {self.__target_image_shape}" ) self.__generator.compile(loss=gen_loss, optimizer=generator_optimizer, metrics=[PSNR_Y, PSNR, SSIM]) ##################################### ### Create vgg network ### ##################################### self.__vgg = create_feature_extractor(self.__target_image_shape, feature_extractor_layers) ##################################### ### Create combined generator ### ##################################### small_image_input_generator = Input(shape=self.__start_image_shape, name="small_image_input") # Images upscaled by generator gen_images = self.__generator(small_image_input_generator) # Discriminator takes images and determinates validity frozen_discriminator = Network(self.__discriminator.inputs, self.__discriminator.outputs, name="frozen_discriminator") frozen_discriminator.trainable = False validity = frozen_discriminator(gen_images) # Extracts features from generated images generated_features = self.__vgg(preprocess_vgg(gen_images)) # Combine models # Train generator to fool discriminator self.__combined_generator_model = Model( inputs=small_image_input_generator, outputs=[gen_images, validity] + [*generated_features], name="srgan") self.__combined_generator_model.compile( loss=[gen_loss, disc_loss] + ([feature_loss] * len(generated_features)), loss_weights=[gen_loss_weight, disc_loss_weight] + feature_loss_weights, optimizer=generator_optimizer, metrics={"generator": [PSNR_Y, PSNR, SSIM]}) # Print all summaries print("\nDiscriminator Summary:") self.__discriminator.summary() print("\nGenerator Summary:") self.__generator.summary() # Load checkpoint self.__initiated = False if load_from_checkpoint: self.__load_checkpoint() # Load weights from param and override checkpoint weights if generator_weights: self.__generator.load_weights(generator_weights) if discriminator_weights: self.__discriminator.load_weights(discriminator_weights) # Set LR self.__gen_lr_scheduler.set_lr(self.__combined_generator_model, self.__episode_counter) self.__disc_lr_scheduler.set_lr(self.__discriminator, self.__episode_counter)
def __init__(self, dataset_path: str, gen_mod_name: str, disc_mod_name: str, latent_dim: int, training_progress_save_path: str, testing_dataset_path: str = None, generator_optimizer: Optimizer = Adam(0.0002, 0.5), discriminator_optimizer: Optimizer = Adam(0.0002, 0.5), discriminator_label_noise: float = None, discriminator_label_noise_decay: float = None, discriminator_label_noise_min: float = 0.001, batch_size: int = 32, buffered_batches: int = 20, generator_weights: Union[str, None] = None, discriminator_weights: Union[str, None] = None, start_episode: int = 0, load_from_checkpoint: bool = False, check_dataset: bool = True, num_of_loading_workers: int = 8): self.disc_mod_name = disc_mod_name self.gen_mod_name = gen_mod_name self.generator_optimizer = generator_optimizer self.latent_dim = latent_dim assert self.latent_dim > 0, Fore.RED + "Invalid latent dim" + Fore.RESET self.batch_size = batch_size assert self.batch_size > 0, Fore.RED + "Invalid batch size" + Fore.RESET self.discriminator_label_noise = discriminator_label_noise self.discriminator_label_noise_decay = discriminator_label_noise_decay self.discriminator_label_noise_min = discriminator_label_noise_min self.progress_image_dim = (16, 9) if start_episode < 0: start_episode = 0 self.episode_counter = start_episode # Initialize training data folder and logging self.training_progress_save_path = training_progress_save_path self.training_progress_save_path = os.path.join( self.training_progress_save_path, f"{self.gen_mod_name}__{self.disc_mod_name}") self.tensorboard = TensorBoardCustom( log_dir=os.path.join(self.training_progress_save_path, "logs")) # Create array of input image paths self.train_data = get_paths_of_files_from_path(dataset_path, only_files=True) assert self.train_data, Fore.RED + "Training dataset is not loaded" + Fore.RESET self.testing_data = None if testing_dataset_path: self.testing_data = get_paths_of_files_from_path( testing_dataset_path) assert self.testing_data, Fore.RED + "Testing dataset is not loaded" + Fore.RESET # Load one image to get shape of it tmp_image = cv.imread(self.train_data[0]) self.image_shape = tmp_image.shape self.image_channels = self.image_shape[2] # Check image size validity if self.image_shape[0] < 4 or self.image_shape[1] < 4: raise Exception("Images too small, min size (4, 4)") # Check validity of whole datasets if check_dataset: self.validate_dataset() # Define static vars if os.path.exists( f"{self.training_progress_save_path}/static_noise.npy"): self.static_noise = np.load( f"{self.training_progress_save_path}/static_noise.npy") if self.static_noise.shape[0] != (self.progress_image_dim[0] * self.progress_image_dim[1]): print(Fore.YELLOW + "Progress image dim changed, restarting static noise!" + Fore.RESET) os.remove( f"{self.training_progress_save_path}/static_noise.npy") self.static_noise = np.random.normal( 0.0, 1.0, size=(self.progress_image_dim[0] * self.progress_image_dim[1], self.latent_dim)) else: self.static_noise = np.random.normal( 0.0, 1.0, size=(self.progress_image_dim[0] * self.progress_image_dim[1], self.latent_dim)) self.kernel_initializer = RandomNormal(stddev=0.02) # Load checkpoint self.initiated = False loaded_gen_weights_path = None loaded_disc_weights_path = None if load_from_checkpoint: loaded_gen_weights_path, loaded_disc_weights_path = self.load_checkpoint( ) # Create batchmaker and start it self.batch_maker = BatchMaker( self.train_data, self.batch_size, buffered_batches=buffered_batches, num_of_loading_workers=num_of_loading_workers) self.testing_batchmaker = None if self.testing_data: self.testing_batchmaker = BatchMaker( self.testing_data, self.batch_size, buffered_batches=buffered_batches, num_of_loading_workers=num_of_loading_workers) self.testing_batchmaker.start() ################################# ### Create discriminator ### ################################# self.discriminator = self.build_discriminator(disc_mod_name) self.discriminator.compile(loss="binary_crossentropy", optimizer=discriminator_optimizer) ################################# ### Create generator ### ################################# self.generator = self.build_generator(gen_mod_name) if self.generator.output_shape[1:] != self.image_shape: raise Exception( "Invalid image input size for this generator model") ################################# ### Create combined generator ### ################################# noise_input = Input(shape=(self.latent_dim, ), name="noise_input") gen_images = self.generator(noise_input) # Create frozen version of discriminator frozen_discriminator = Network(self.discriminator.inputs, self.discriminator.outputs, name="frozen_discriminator") frozen_discriminator.trainable = False # Discriminator takes images and determinates validity valid = frozen_discriminator(gen_images) # Combine models # Train generator to fool discriminator self.combined_generator_model = Model(noise_input, valid, name="dcgan_model") self.combined_generator_model.compile( loss="binary_crossentropy", optimizer=self.generator_optimizer) # Print all summaries print("\nDiscriminator Summary:") self.discriminator.summary() print("\nGenerator Summary:") self.generator.summary() print("\nGAN Summary") self.combined_generator_model.summary() # Load weights from checkpoint try: if loaded_gen_weights_path: self.generator.load_weights(loaded_gen_weights_path) except: print(Fore.YELLOW + "Failed to load generator weights from checkpoint" + Fore.RESET) try: if loaded_disc_weights_path: self.discriminator.load_weights(loaded_disc_weights_path) except: print(Fore.YELLOW + "Failed to load discriminator weights from checkpoint" + Fore.RESET) # Load weights from param and override checkpoint weights if generator_weights: self.generator.load_weights(generator_weights) if discriminator_weights: self.discriminator.load_weights(discriminator_weights)