def __init__(self, base_path, contrast_max, percent_val_max, list_res_max, training_csv, multi_gpu, patch=64, first_discriminator_kernel=32, first_generator_kernel=16, lamb_rec=1, lamb_adv=0.001, lamb_gp=10, lr_dis_model=0.0001, lr_gen_model=0.0001, u_net_gen=False, is_conditional=False, is_residual=True,fit_mask=False,nb_classe_mask = 0,loss_name="charbonnier"): self.SegSRGAN = SegSRGAN(image_row=patch, image_column=patch, image_depth=patch, first_discriminator_kernel=first_discriminator_kernel, first_generator_kernel=first_generator_kernel, lamb_rec=lamb_rec, lamb_adv=lamb_adv, lamb_gp=lamb_gp, lr_dis_model=lr_dis_model, lr_gen_model=lr_gen_model, u_net_gen=u_net_gen, multi_gpu=multi_gpu, is_conditional=is_conditional,fit_mask=fit_mask,nb_classe_mask=nb_classe_mask,loss_name=loss_name) self.generator = self.SegSRGAN.generator() self.training_csv = training_csv self.DiscriminatorModel, self.DiscriminatorModel_multi_gpu = self.SegSRGAN.discriminator_model() self.GeneratorModel, self.GeneratorModel_multi_gpu = self.SegSRGAN.generator_model() self.base_path = base_path self.contrast_max = contrast_max self.percent_val_max = percent_val_max self.list_res_max = list_res_max self.multi_gpu = multi_gpu self.is_conditional = is_conditional self.is_residual = is_residual self.fit_mask = fit_mask self.nb_classe_mask = nb_classe_mask print("initialization completed")
def __init__(self, weights, patch1, patch2, patch3, is_conditional, u_net_gen, is_residual, first_generator_kernel, first_discriminator_kernel, resolution=0): self.patch1 = patch1 self.patch2 = patch2 self.patch3 = patch3 self.prediction = None self.SegSRGAN = SegSRGAN( first_generator_kernel=first_generator_kernel, first_discriminator_kernel=first_discriminator_kernel, u_net_gen=u_net_gen, image_row=patch1, image_column=patch2, image_depth=patch3, is_conditional=is_conditional, is_residual=is_residual) self.generator_model = self.SegSRGAN.generator_model_for_pred() self.generator_model.load_weights(weights, by_name=True) self.generator = self.SegSRGAN.generator() self.is_conditional = is_conditional self.resolution = resolution self.is_residual = is_residual self.res_tensor = np.expand_dims(np.expand_dims( np.ones([patch1, patch2, patch3]) * self.resolution, axis=0), axis=0)
class SegSrganTrain(object): def __init__(self, base_path, contrast_max, percent_val_max, list_res_max, training_csv, multi_gpu, patch=64, first_discriminator_kernel=32, first_generator_kernel=16, lamb_rec=1, lamb_adv=0.001, lamb_gp=10, lr_dis_model=0.0001, lr_gen_model=0.0001, u_net_gen=False, is_conditional=False, is_residual=True,fit_mask=False,nb_classe_mask = 0,loss_name="charbonnier"): self.SegSRGAN = SegSRGAN(image_row=patch, image_column=patch, image_depth=patch, first_discriminator_kernel=first_discriminator_kernel, first_generator_kernel=first_generator_kernel, lamb_rec=lamb_rec, lamb_adv=lamb_adv, lamb_gp=lamb_gp, lr_dis_model=lr_dis_model, lr_gen_model=lr_gen_model, u_net_gen=u_net_gen, multi_gpu=multi_gpu, is_conditional=is_conditional,fit_mask=fit_mask,nb_classe_mask=nb_classe_mask,loss_name=loss_name) self.generator = self.SegSRGAN.generator() self.training_csv = training_csv self.DiscriminatorModel, self.DiscriminatorModel_multi_gpu = self.SegSRGAN.discriminator_model() self.GeneratorModel, self.GeneratorModel_multi_gpu = self.SegSRGAN.generator_model() self.base_path = base_path self.contrast_max = contrast_max self.percent_val_max = percent_val_max self.list_res_max = list_res_max self.multi_gpu = multi_gpu self.is_conditional = is_conditional self.is_residual = is_residual self.fit_mask = fit_mask self.nb_classe_mask = nb_classe_mask print("initialization completed") #@vectorize(['float64(float64, float64, float64)'], target ="cuda") #def interp_func(epsilon, train_output, fake_images): #return epsilon * train_output + (1 - epsilon) * fake_images #@profile def train(self, snapshot_folder, dice_file, mse_file, folder_training_data, patch_size, training_epoch=200, batch_size=16, snapshot_epoch=1, initialize_epoch=1, number_of_disciminator_iteration=5, resuming=None, interp='scipy', interpolation_type='Spline',image_cropping_method="bounding_box"): """ :param patch_size: :param snapshot_folder: :param dice_file: :param mse_file: :param folder_training_data: :param training_epoch: :param batch_size: :param snapshot_epoch: :param initialize_epoch: :param number_of_disciminator_iteration: :param resuming: :param interp: interpolation type (scipy or sitk) """ # snapshot_prefix='weights/SegSRGAN_epoch' print("train begin") snapshot_prefix = os.path.join(snapshot_folder,"SegSRGAN_epoch") print("Generator metrics name :"+str(self.GeneratorModel_multi_gpu.metrics_names)) print("Disciminator metrics name :"+str(self.DiscriminatorModel_multi_gpu.metrics_names)) # boolean to print only one time 'the number of patch not in one epoch (mode batch_size)' never_print = True if os.path.exists(snapshot_folder) is False: os.makedirs(snapshot_folder) # Initialization Parameters real = -np.ones([batch_size, 1], dtype=np.float32) fake = -real dummy = np.zeros([batch_size, 1], dtype=np.float32) # Data processing # TrainingSet = ProcessingTrainingSet(self.TrainingText,batch_size, InputName='data', LabelName = 'label') data = pd.read_csv(self.training_csv) data["HR_image"] = self.base_path + data["HR_image"] data["Label_image"] = self.base_path + data["Label_image"] if self.fit_mask or (image_cropping_method=='overlapping_with_mask'): data["Mask_image"] = self.base_path + data["Mask_image"] data_train = data[data['Base'] == "Train"] data_test = data[data['Base'] == "Test"] # Resuming if initialize_epoch == 1: iteration = 0 if resuming is None: print("Training from scratch") else: print("Training from the pretrained model (names of layers must be identical): ", resuming) self.GeneratorModel.load_weights(resuming, by_name=True) elif initialize_epoch < 1: raise AssertionError('Resumming needs a positive epoch') elif training_epoch < initialize_epoch : raise AssertionError('initialize epoch need to be smaller than the total number of training epoch ') else: if resuming is None: raise AssertionError('We need pretrained weights') else: print ('TRAINING IS RESUMING') print('Continue training from : ', resuming) self.GeneratorModel.load_weights(resuming, by_name=True) iteration = 0 # patch test creation : t1 = time.time() test_contrast_list = np.linspace(1 - self.contrast_max, 1 + self.contrast_max, data_test.shape[0]) # list_res[0] = lower bound and list_res[1] = borne supp # list_res[0][0] = lower bound for the first coordinate lin_res_x = np.linspace(self.list_res_max[0][0], self.list_res_max[1][0], data_test.shape[0]) lin_res_y = np.linspace(self.list_res_max[0][1], self.list_res_max[1][1], data_test.shape[0]) lin_res_z = np.linspace(self.list_res_max[0][2], self.list_res_max[1][2], data_test.shape[0]) res_test = [(lin_res_x[i], lin_res_y[i], lin_res_z[i]) for i in range(data_test.shape[0])] test_path_save_npy, test_Path_Datas_mini_batch, test_Labels_mini_batch, test_remaining_patch = \ create_patch_from_df_hr(df=data_test, per_cent_val_max=self.percent_val_max, contrast_list=test_contrast_list, list_res=res_test, order=3, thresholdvalue=0, patch_size=patch_size, batch_size=1, # 1 to keep all data path_save_npy=os.path.join(folder_training_data,"test_mini_batch"), stride=20, is_conditional=self.is_conditional, interp =interp, interpolation_type=interpolation_type, fit_mask=self.fit_mask, image_cropping_method=image_cropping_method, nb_classe_mask = self.nb_classe_mask) t2 = time.time() print("time for making test npy :" + str(t2 - t1)) if self.fit_mask : colunms_dice = ["Dice_label_1"] colunms_dice_mask = ["Dice_mask"+str(i) for i in range(self.nb_classe_mask)] colunms_dice.extend(colunms_dice_mask) else : colunms_dice=["Dice"] df_dice = pd.DataFrame(index=np.arange(initialize_epoch, training_epoch + 1), columns=colunms_dice) df_MSE = pd.DataFrame(index=np.arange(initialize_epoch, training_epoch + 1), columns=["MSE"]) #Edited BY ME train_contrast_list = np.random.uniform(1 - self.contrast_max, 1 + self.contrast_max, data_train.shape[0]) res_train = [(np.random.uniform(self.list_res_max[0][0], self.list_res_max[1][0]), np.random.uniform(self.list_res_max[0][1], self.list_res_max[1][1]), np.random.uniform(self.list_res_max[0][2], self.list_res_max[1][2])) for i in range(data_train.shape[0])] train_path_save_npy, train_Path_Datas_mini_batch, train_Labels_mini_batch, train_remaining_patch = \ create_patch_from_df_hr(df=data_train, per_cent_val_max=self.percent_val_max, contrast_list=train_contrast_list, list_res=res_train, order=3, thresholdvalue=0, patch_size=patch_size, batch_size=batch_size, path_save_npy=os.path.join(folder_training_data,"train_mini_batch"), stride=20, is_conditional=self.is_conditional, interp=interp, interpolation_type=interpolation_type, fit_mask=self.fit_mask, image_cropping_method=image_cropping_method, nb_classe_mask = self.nb_classe_mask) # Training phase for EpochIndex in range(initialize_epoch, training_epoch + 1): '''train_contrast_list = np.random.uniform(1 - self.contrast_max, 1 + self.contrast_max, data_train.shape[0]) res_train = [(np.random.uniform(self.list_res_max[0][0], self.list_res_max[1][0]), np.random.uniform(self.list_res_max[0][1], self.list_res_max[1][1]), np.random.uniform(self.list_res_max[0][2], self.list_res_max[1][2])) for i in range(data_train.shape[0])] t1 = time.time() train_path_save_npy, train_Path_Datas_mini_batch, train_Labels_mini_batch, train_remaining_patch = \ create_patch_from_df_hr(df=data_train, per_cent_val_max=self.percent_val_max, contrast_list=train_contrast_list, list_res=res_train, order=3, thresholdvalue=0, patch_size=patch_size, batch_size=batch_size, path_save_npy=os.path.join(folder_training_data,"train_mini_batch"), stride=20, is_conditional=self.is_conditional, interp=interp, interpolation_type=interpolation_type, fit_mask=self.fit_mask, image_cropping_method=image_cropping_method, nb_classe_mask = self.nb_classe_mask)''' iterationPerEpoch = len(train_Path_Datas_mini_batch) t2 = time.time() print("time for making train npy :" + str(t2 - t1)) if never_print: print("At each epoch " + str(train_remaining_patch) + " patches will not be in the training data for " "this epoch") never_print = False print("Processing epoch : " + str(EpochIndex)) for iters in range(0, iterationPerEpoch): iteration += 1 # Training discriminator for cidx in range(number_of_disciminator_iteration): t1 = time.time() # Loading data randomly randomNumber = int(np.random.randint(0, iterationPerEpoch, 1)) #print("train on batch : ",train_Path_Datas_mini_batch[randomNumber]) train_input = np.load(train_Path_Datas_mini_batch[randomNumber])[:, 0, :, :, :][:, np.newaxis, :, :, :] # select 0 coordoniate and add one axis at the same place train_output = np.load(train_Labels_mini_batch[randomNumber]) if self.is_conditional: train_res = np.load(train_Path_Datas_mini_batch[randomNumber])[:, 1, :, :, :][:, np.newaxis, :, :, :] # Generating fake and interpolation images fake_images = self.GeneratorModel_multi_gpu.predict([train_input, train_res])[1] if self.fit_mask : epsilon = np.random.uniform(0, 1, size=(batch_size, 3, 1, 1, 1)) else : epsilon = np.random.uniform(0, 1, size=(batch_size, 2, 1, 1, 1)) #@vectorize(['float64(float64)'], target ="cuda") start = timer() interpolation = interp_func(epsilon, train_output, fake_images) print("type of inter args",interp_func.inspect_types()) print("INTERPOLATION WITH GPU:", timer()-start) interpolation = epsilon * train_output + (1 - epsilon) * fake_images #Training dis_loss = self.DiscriminatorModel_multi_gpu.train_on_batch([train_output, fake_images, interpolation, train_res], [real, fake, dummy]) else: # Generating fake and interpolation images fake_images = self.GeneratorModel_multi_gpu.predict(train_input)[1] if self.fit_mask : epsilon = np.random.uniform(0, 1, size=(batch_size, 2+self.nb_classe_mask, 1, 1, 1)) else : epsilon = np.random.uniform(0, 1, size=(batch_size, 2, 1, 1, 1)) #print("epsilon type:", type(epsilon), epsilon.shape) #print("output type:", type(train_output),train_output.shape) #print("fake type:", type(fake_images), fake_images.shape) start = timer() interpolation = interp_func(epsilon, train_output, fake_images) #print("type of inter args",interp_func.inspect_types()) print("INTERPOLATION WITH GPU:", timer()-start) interpolation = epsilon * train_output + (1 - epsilon) * fake_images # Training dis_loss = self.DiscriminatorModel_multi_gpu.train_on_batch([train_output, fake_images, interpolation], [real, fake, dummy]) t2 = time.time() print("time for one uptade of discriminator :" + str(t2 - t1)) print("Update " + str(cidx) + ": [D loss : " + str(dis_loss) + "]") # Training generator # Loading data t1 = time.time() train_input_gen = np.load(train_Path_Datas_mini_batch[iters])[:, 0, :, :, :][:, np.newaxis, :, :, :] train_output_gen = np.load(train_Labels_mini_batch[iters]) if self.is_conditional: train_res_gen = np.load(train_Path_Datas_mini_batch[iters])[:, 1, :, :, :][:, np.newaxis, :, :, :] # Training gen_loss = self.GeneratorModel_multi_gpu.train_on_batch([train_input_gen, train_res_gen], [real, train_output_gen]) else: # Training gen_loss = self.GeneratorModel_multi_gpu.train_on_batch([train_input_gen], [real, train_output_gen]) print("Iter " + str(iteration) + " [A loss : " + str(gen_loss) + "]") t2 = time.time() print("time for one uptade of generator :" + str(t2 - t1)) if EpochIndex % snapshot_epoch == 0: # Save weights: self.GeneratorModel.save_weights(snapshot_prefix + '_' + str(EpochIndex)) print("Snapshot :" + snapshot_prefix + '_' + str(EpochIndex)) MSE_list = [] VP = [] Pos_pred = [] Pos_label = [] # for the three following object first is the dimension of the mask class and the second dimension is the test patch dimension VP_mask_all_label = [[] for i in range(self.nb_classe_mask)] Pos_pred_mask_all_label = [[] for i in range(self.nb_classe_mask)] Pos_label_mask_all_label = [[] for i in range(self.nb_classe_mask)] t1 = time.time() for test_iter in range(len(test_Labels_mini_batch)): TestLabels = np.load(test_Labels_mini_batch[test_iter]) TestDatas = np.load(test_Path_Datas_mini_batch[test_iter])[:, 0, :, :, :][:, np.newaxis, :, :, :] if self.is_conditional: TestRes = np.load(test_Path_Datas_mini_batch[test_iter])[:, 1, :, :, :][:, np.newaxis, :, :, :] pred = self.generator.predict([TestDatas, TestRes]) else: pred = self.generator.predict([TestDatas]) pred[:, 0, :, :, :][pred[:, 0, :, :, :] < 0] = 0 MSE_list.append(np.sum((pred[:, 0, :, :, :] - TestLabels[:, 0, :, :, :]) ** 2)) VP.append(np.sum((pred[:, 1, :, :, :] > 0.5) & (TestLabels[:, 1, :, :, :] == 1))) Pos_pred.append(np.sum(pred[:, 1, :, :, :] > 0.5)) Pos_label.append(np.sum(TestLabels[:, 1, :, :, :])) if self.fit_mask : estimated_mask_discretized = np.argmax(pred[:, 2:, :, :, :],axis=1) for i in range(self.nb_classe_mask): VP_mask_all_label[i].append(np.sum((estimated_mask_discretized==i) & (TestLabels[:, 2+i, :, :, :] == 1))) Pos_pred_mask_all_label[i].append(np.sum(estimated_mask_discretized==i)) Pos_label_mask_all_label[i].append(np.sum(TestLabels[:, 2+i, :, :, :] == 1)) t2 = time.time() print("Evaluation on test data time : " + str(t2 - t1)) gen_weights = np.array(self.GeneratorModel.get_weights()) gen_weights_multi = np.array(self.GeneratorModel_multi_gpu.get_weights()) weights_idem = True for i in range(len(gen_weights)): idem = np.array_equal(gen_weights[i], gen_weights_multi[i]) weights_idem = weights_idem & idem if weights_idem: print("Model multi_gpu and base Model have the same weights") else: print("Model multi_gpu and base Model haven't the same weights") Dice = (2 * np.sum(VP)) / (np.sum(Pos_pred) + np.sum(Pos_label)) MSE = np.sum(MSE_list) / (patch_size ** 3 * len(MSE_list)) print("Iter " + str(EpochIndex) + " [Test MSE : " + str(MSE) + "]") df_MSE.loc[EpochIndex, "MSE"] = MSE if self.fit_mask : Dice_Mask = (2 * np.sum(VP_mask_all_label,axis=1)) / (np.sum(Pos_pred_mask_all_label,axis=1) + np.sum(Pos_label_mask_all_label,axis=1)) df_dice.loc[EpochIndex, "Dice_label_1"] = Dice df_dice.loc[EpochIndex, colunms_dice_mask] = Dice_Mask print("Iter " + str(EpochIndex) + " [Test Dice label 1 : " + str(Dice) + "]") print("Iter " + str(EpochIndex) + " [Test Dice Mask : " + str(Dice_Mask) + "]") else : df_dice.loc[EpochIndex, "Dice"] = Dice print("Iter " + str(EpochIndex) + " [Test Dice : " + str(Dice) + "]") df_dice.to_csv(dice_file) df_MSE.to_csv(mse_file)
class SegSRGAN_test(object): def __init__(self, weights, patch1, patch2, patch3, is_conditional, u_net_gen, is_residual, first_generator_kernel, first_discriminator_kernel, resolution=0): self.patch1 = patch1 self.patch2 = patch2 self.patch3 = patch3 self.prediction = None self.SegSRGAN = SegSRGAN( first_generator_kernel=first_generator_kernel, first_discriminator_kernel=first_discriminator_kernel, u_net_gen=u_net_gen, image_row=patch1, image_column=patch2, image_depth=patch3, is_conditional=is_conditional, is_residual=is_residual) self.generator_model = self.SegSRGAN.generator_model_for_pred() self.generator_model.load_weights(weights, by_name=True) self.generator = self.SegSRGAN.generator() self.is_conditional = is_conditional self.resolution = resolution self.is_residual = is_residual self.res_tensor = np.expand_dims(np.expand_dims( np.ones([patch1, patch2, patch3]) * self.resolution, axis=0), axis=0) def get_patch(self): """ :return: """ return self.patch def test_by_patch(self, test_image, step=1, by_batch=False): """ :param test_image: Image to be tested :param step: step :param by_batch: to enable by batch processing :return: """ # Init temp height, width, depth = np.shape(test_image) temp_hr_image = np.zeros_like(test_image) temp_seg = np.zeros_like(test_image) weighted_image = np.zeros_like(test_image) # if is_conditional is set to True we predict on the image AND the resolution if self.is_conditional is True: if not by_batch: i = 0 bar = progressbar.ProgressBar(maxval=len(np.arange(0, height - self.patch1 + 1, step)) * len( np.arange(0, width - self.patch2 + 1, step)) * len(np.arange(0, depth - self.patch3 + 1, step))).\ start() print('Patch=', self.patch1) print('Step=', step) for idx in range(0, height - self.patch1 + 1, step): for idy in range(0, width - self.patch2 + 1, step): for idz in range(0, depth - self.patch3 + 1, step): # Cropping image test_patch = test_image[idx:idx + self.patch1, idy:idy + self.patch2, idz:idz + self.patch3] image_tensor = test_patch.reshape(1, 1, self.patch1, self.patch2, self.patch3).\ astype(np.float32) predict_patch = self.generator.predict( [image_tensor, self.res_tensor], batch_size=1) # Adding temp_hr_image[idx:idx + self.patch1, idy:idy + self.patch2, idz:idz + self.patch3] += predict_patch[ 0, 0, :, :, :] temp_seg[idx:idx + self.patch1, idy:idy + self.patch2, idz:idz + self.patch3] += \ predict_patch[0, 1, :, :, :] weighted_image[idx:idx + self.patch1, idy:idy + self.patch2, idz:idz + self.patch3] += np.ones_like( predict_patch[0, 0, :, :, :]) i += 1 bar.update(i) else: height = test_image.shape[0] width = test_image.shape[1] depth = test_image.shape[2] patch1 = self.patch1 patch2 = self.patch2 patch3 = self.patch3 patches = np.array( [[ test_image[idx:idx + patch1, idy:idy + patch2, idz:idz + patch3] ] for idx in range(0, height - patch1 + 1, step) for idy in range(0, width - patch2 + 1, step) for idz in range(0, depth - patch3 + 1, step)]) indice_patch = np.array([ (idx, idy, idz) for idx in range(0, height - patch1 + 1, step) for idy in range(0, width - patch2 + 1, step) for idz in range(0, depth - patch3 + 1, step) ]) pred = self.generator.predict(patches, batch_size=patches.shape[0]) weight = np.zeros_like(test_image) temp_hr_image = np.zeros(test_image) temp_seg = np.zeros(test_image) for i in range(indice_patch.shape[0]): temp_hr_image[indice_patch[i][0]:indice_patch[i][0] + patch1, indice_patch[i][1]:indice_patch[i][1] + patch2, indice_patch[i][2]:indice_patch[i][2] + patch3] += pred[i, 0, :, :, :] temp_seg[indice_patch[i][0]:indice_patch[i][0] + patch1, indice_patch[i][1]:indice_patch[i][1] + patch2, indice_patch[i][2]:indice_patch[i][2] + patch3] += pred[i, 1, :, :, :] weight[indice_patch[i][0]:indice_patch[i][0] + patch1, indice_patch[i][1]:indice_patch[i][1] + patch2, indice_patch[i][2]:indice_patch[i][2] + patch3] + np.ones_like(weight[ indice_patch[i][0]:indice_patch[i][0] + patch1, indice_patch[i][1]:indice_patch[i][1] + patch2, indice_patch[i][2]:indice_patch[i][2] + patch3]) else: if not by_batch: i = 0 bar = progressbar.ProgressBar( maxval=len(np.arange(0, height - self.patch1 + 1, step)) * len(np.arange(0, width - self.patch2 + 1, step)) * len(np.arange(0, depth - self.patch3 + 1, step))).start() print('Patch=', self.patch1) print('Step=', step) for idx in range(0, height - self.patch1 + 1, step): for idy in range(0, width - self.patch2 + 1, step): for idz in range(0, depth - self.patch3 + 1, step): # Cropping image test_patch = test_image[idx:idx + self.patch1, idy:idy + self.patch2, idz:idz + self.patch3] image_tensor = test_patch.reshape( 1, 1, self.patch1, self.patch2, self.patch3).astype(np.float32) predict_patch = self.generator.predict( image_tensor, batch_size=1) # Adding temp_hr_image[idx:idx + self.patch1, idy:idy + self.patch2, idz:idz + self.patch3] += predict_patch[ 0, 0, :, :, :] temp_seg[idx:idx + self.patch1, idy:idy + self.patch2, idz:idz + self.patch3] += predict_patch[0, 1, :, :, :] weighted_image[idx:idx + self.patch1, idy:idy + self.patch2, idz:idz + self.patch3] += np.ones_like( predict_patch[0, 0, :, :, :]) i += 1 bar.update(i) else: height = test_image.shape[0] width = test_image.shape[1] depth = test_image.shape[2] patch1 = self.patch1 patch2 = self.patch2 patch3 = self.patch3 patches = np.array( [[ test_image[idx:idx + patch1, idy:idy + patch2, idz:idz + patch3] ] for idx in range(0, height - patch1 + 1, step) for idy in range(0, width - patch2 + 1, step) for idz in range(0, depth - patch3 + 1, step)]) indice_patch = np.array([ (idx, idy, idz) for idx in range(0, height - patch1 + 1, step) for idy in range(0, width - patch2 + 1, step) for idz in range(0, depth - patch3 + 1, step) ]) pred = self.generator.predict(patches, batch_size=patches.shape[0]) weight = np.zeros_like(test_image) temp_hr_image = np.zeros(test_image) temp_seg = np.zeros(test_image) for i in range(indice_patch.shape[0]): temp_hr_image[indice_patch[i][0]:indice_patch[i][0] + patch1, indice_patch[i][1]:indice_patch[i][1] + patch2, indice_patch[i][2]:indice_patch[i][2] + patch3] += pred[i, 0, :, :, :] temp_seg[indice_patch[i][0]:indice_patch[i][0] + patch1, indice_patch[i][1]:indice_patch[i][1] + patch2, indice_patch[i][2]:indice_patch[i][2] + patch3] += pred[i, 1, :, :, :] weight[indice_patch[i][0]:indice_patch[i][0] + patch1, indice_patch[i][1]:indice_patch[i][1] + patch2, indice_patch[i][2]:indice_patch[i][2] + patch3] + np.ones_like(weight[ indice_patch[i][0]:indice_patch[i][0] + patch1, indice_patch[i][1]:indice_patch[i][1] + patch2, indice_patch[i][2]:indice_patch[i][2] + patch3]) # weight sum of patches print(GREEN + start + '\nDone !' + end + RESET) estimated_hr = temp_hr_image / weighted_image estimated_segmentation = temp_seg / weighted_image return estimated_hr, estimated_segmentation