Beispiel #1
0
    def __init__(self, base_path, contrast_max, percent_val_max, list_res_max, training_csv, multi_gpu, patch=64,
                 first_discriminator_kernel=32, first_generator_kernel=16, lamb_rec=1, lamb_adv=0.001, lamb_gp=10,
                 lr_dis_model=0.0001, lr_gen_model=0.0001, u_net_gen=False, is_conditional=False, is_residual=True,fit_mask=False,nb_classe_mask = 0,loss_name="charbonnier"):

        self.SegSRGAN = SegSRGAN(image_row=patch, image_column=patch, image_depth=patch,
                                 first_discriminator_kernel=first_discriminator_kernel,
                                 first_generator_kernel=first_generator_kernel,
                                 lamb_rec=lamb_rec, lamb_adv=lamb_adv, lamb_gp=lamb_gp,
                                 lr_dis_model=lr_dis_model, lr_gen_model=lr_gen_model, u_net_gen=u_net_gen,
                                 multi_gpu=multi_gpu, is_conditional=is_conditional,fit_mask=fit_mask,nb_classe_mask=nb_classe_mask,loss_name=loss_name)
        self.generator = self.SegSRGAN.generator()
        self.training_csv = training_csv
        self.DiscriminatorModel, self.DiscriminatorModel_multi_gpu = self.SegSRGAN.discriminator_model()
        self.GeneratorModel, self.GeneratorModel_multi_gpu = self.SegSRGAN.generator_model()
        self.base_path = base_path
        self.contrast_max = contrast_max
        self.percent_val_max = percent_val_max
        self.list_res_max = list_res_max
        self.multi_gpu = multi_gpu
        self.is_conditional = is_conditional
        self.is_residual = is_residual
        self.fit_mask = fit_mask
        self.nb_classe_mask = nb_classe_mask

        print("initialization completed")
    def __init__(self,
                 weights,
                 patch1,
                 patch2,
                 patch3,
                 is_conditional,
                 u_net_gen,
                 is_residual,
                 first_generator_kernel,
                 first_discriminator_kernel,
                 resolution=0):

        self.patch1 = patch1
        self.patch2 = patch2
        self.patch3 = patch3
        self.prediction = None
        self.SegSRGAN = SegSRGAN(
            first_generator_kernel=first_generator_kernel,
            first_discriminator_kernel=first_discriminator_kernel,
            u_net_gen=u_net_gen,
            image_row=patch1,
            image_column=patch2,
            image_depth=patch3,
            is_conditional=is_conditional,
            is_residual=is_residual)
        self.generator_model = self.SegSRGAN.generator_model_for_pred()
        self.generator_model.load_weights(weights, by_name=True)
        self.generator = self.SegSRGAN.generator()
        self.is_conditional = is_conditional
        self.resolution = resolution
        self.is_residual = is_residual
        self.res_tensor = np.expand_dims(np.expand_dims(
            np.ones([patch1, patch2, patch3]) * self.resolution, axis=0),
                                         axis=0)
Beispiel #3
0
class SegSrganTrain(object):
    def __init__(self, base_path, contrast_max, percent_val_max, list_res_max, training_csv, multi_gpu, patch=64,
                 first_discriminator_kernel=32, first_generator_kernel=16, lamb_rec=1, lamb_adv=0.001, lamb_gp=10,
                 lr_dis_model=0.0001, lr_gen_model=0.0001, u_net_gen=False, is_conditional=False, is_residual=True,fit_mask=False,nb_classe_mask = 0,loss_name="charbonnier"):

        self.SegSRGAN = SegSRGAN(image_row=patch, image_column=patch, image_depth=patch,
                                 first_discriminator_kernel=first_discriminator_kernel,
                                 first_generator_kernel=first_generator_kernel,
                                 lamb_rec=lamb_rec, lamb_adv=lamb_adv, lamb_gp=lamb_gp,
                                 lr_dis_model=lr_dis_model, lr_gen_model=lr_gen_model, u_net_gen=u_net_gen,
                                 multi_gpu=multi_gpu, is_conditional=is_conditional,fit_mask=fit_mask,nb_classe_mask=nb_classe_mask,loss_name=loss_name)
        self.generator = self.SegSRGAN.generator()
        self.training_csv = training_csv
        self.DiscriminatorModel, self.DiscriminatorModel_multi_gpu = self.SegSRGAN.discriminator_model()
        self.GeneratorModel, self.GeneratorModel_multi_gpu = self.SegSRGAN.generator_model()
        self.base_path = base_path
        self.contrast_max = contrast_max
        self.percent_val_max = percent_val_max
        self.list_res_max = list_res_max
        self.multi_gpu = multi_gpu
        self.is_conditional = is_conditional
        self.is_residual = is_residual
        self.fit_mask = fit_mask
        self.nb_classe_mask = nb_classe_mask

        print("initialization completed")
    #@vectorize(['float64(float64, float64, float64)'], target ="cuda")   
    #def interp_func(epsilon, train_output, fake_images):
        #return epsilon * train_output + (1 - epsilon) * fake_images
    #@profile
    def train(self,
              snapshot_folder,
              dice_file,
              mse_file,
              folder_training_data, patch_size,
              training_epoch=200, batch_size=16, snapshot_epoch=1, initialize_epoch=1, number_of_disciminator_iteration=5,
              resuming=None, interp='scipy', interpolation_type='Spline',image_cropping_method="bounding_box"):
        """

        :param patch_size:
        :param snapshot_folder:
        :param dice_file:
        :param mse_file:
        :param folder_training_data:
        :param training_epoch:
        :param batch_size:
        :param snapshot_epoch:
        :param initialize_epoch:
        :param number_of_disciminator_iteration:
        :param resuming:
        :param interp: interpolation type (scipy or sitk)
        """
        # snapshot_prefix='weights/SegSRGAN_epoch'
        print("train begin")
        snapshot_prefix = os.path.join(snapshot_folder,"SegSRGAN_epoch")

        print("Generator metrics name :"+str(self.GeneratorModel_multi_gpu.metrics_names))
        print("Disciminator metrics name :"+str(self.DiscriminatorModel_multi_gpu.metrics_names))

        # boolean to print only one time 'the number of patch not in one epoch (mode batch_size)'
        never_print = True
        if os.path.exists(snapshot_folder) is False:
            os.makedirs(snapshot_folder)

        # Initialization Parameters
        real = -np.ones([batch_size, 1], dtype=np.float32)
        fake = -real
        dummy = np.zeros([batch_size, 1], dtype=np.float32)

        # Data processing
        # TrainingSet = ProcessingTrainingSet(self.TrainingText,batch_size, InputName='data', LabelName = 'label')

        data = pd.read_csv(self.training_csv)

        data["HR_image"] = self.base_path + data["HR_image"]
        data["Label_image"] = self.base_path + data["Label_image"]

        if self.fit_mask or (image_cropping_method=='overlapping_with_mask'):

            data["Mask_image"] = self.base_path + data["Mask_image"]

        data_train = data[data['Base'] == "Train"]
        data_test = data[data['Base'] == "Test"]

        # Resuming
        if initialize_epoch == 1:
            iteration = 0
            if resuming is None:
                print("Training from scratch")
            else:
                print("Training from the pretrained model (names of layers must be identical): ", resuming)
                self.GeneratorModel.load_weights(resuming, by_name=True)

        elif initialize_epoch < 1:
            raise AssertionError('Resumming needs a positive epoch')

        elif training_epoch < initialize_epoch :

            raise AssertionError('initialize epoch need to be smaller than the total number of training epoch ')
        else:
            if resuming is None:
                raise AssertionError('We need pretrained weights')
            else:
                print ('TRAINING IS RESUMING')
                print('Continue training from : ', resuming)
                self.GeneratorModel.load_weights(resuming, by_name=True)
                iteration = 0
        # patch test creation :

        t1 = time.time()

        test_contrast_list = np.linspace(1 - self.contrast_max, 1 + self.contrast_max, data_test.shape[0])

        # list_res[0] = lower bound and list_res[1] = borne supp
        # list_res[0][0] = lower bound for the first coordinate

        lin_res_x = np.linspace(self.list_res_max[0][0], self.list_res_max[1][0], data_test.shape[0])
        lin_res_y = np.linspace(self.list_res_max[0][1], self.list_res_max[1][1], data_test.shape[0])
        lin_res_z = np.linspace(self.list_res_max[0][2], self.list_res_max[1][2], data_test.shape[0])

        res_test = [(lin_res_x[i],
                     lin_res_y[i],
                     lin_res_z[i]) for i in range(data_test.shape[0])]

        test_path_save_npy, test_Path_Datas_mini_batch, test_Labels_mini_batch, test_remaining_patch = \
            create_patch_from_df_hr(df=data_test, per_cent_val_max=self.percent_val_max,
                                    contrast_list=test_contrast_list, list_res=res_test, order=3,
                                    thresholdvalue=0, patch_size=patch_size, batch_size=1,
                                    # 1 to keep all data
                                    path_save_npy=os.path.join(folder_training_data,"test_mini_batch"), stride=20,
                                    is_conditional=self.is_conditional, interp =interp,
                                    interpolation_type=interpolation_type,
                                    fit_mask=self.fit_mask,
                                    image_cropping_method=image_cropping_method,
                                    nb_classe_mask = self.nb_classe_mask)

        t2 = time.time()

        print("time for making test npy :" + str(t2 - t1))

        if self.fit_mask :

            colunms_dice = ["Dice_label_1"]
            colunms_dice_mask =  ["Dice_mask"+str(i) for i in range(self.nb_classe_mask)]
            colunms_dice.extend(colunms_dice_mask)

        else :
            colunms_dice=["Dice"]

        df_dice = pd.DataFrame(index=np.arange(initialize_epoch, training_epoch + 1), columns=colunms_dice)
        df_MSE = pd.DataFrame(index=np.arange(initialize_epoch, training_epoch + 1), columns=["MSE"])


        #Edited BY ME
        train_contrast_list = np.random.uniform(1 - self.contrast_max, 1 + self.contrast_max, data_train.shape[0])
        res_train = [(np.random.uniform(self.list_res_max[0][0], self.list_res_max[1][0]),
                      np.random.uniform(self.list_res_max[0][1], self.list_res_max[1][1]),
                      np.random.uniform(self.list_res_max[0][2], self.list_res_max[1][2])) for i in
                     range(data_train.shape[0])]
        train_path_save_npy, train_Path_Datas_mini_batch, train_Labels_mini_batch, train_remaining_patch = \
            create_patch_from_df_hr(df=data_train, per_cent_val_max=self.percent_val_max,
                                    contrast_list=train_contrast_list, list_res=res_train, order=3,
                                    thresholdvalue=0, patch_size=patch_size, batch_size=batch_size,
                                    path_save_npy=os.path.join(folder_training_data,"train_mini_batch"), stride=20,
                                    is_conditional=self.is_conditional, interp=interp,
                                    interpolation_type=interpolation_type,
                                    fit_mask=self.fit_mask,
                                    image_cropping_method=image_cropping_method,
                                    nb_classe_mask = self.nb_classe_mask)
        # Training phase
        for EpochIndex in range(initialize_epoch, training_epoch + 1):

            '''train_contrast_list = np.random.uniform(1 - self.contrast_max, 1 + self.contrast_max, data_train.shape[0])

            res_train = [(np.random.uniform(self.list_res_max[0][0], self.list_res_max[1][0]),
                          np.random.uniform(self.list_res_max[0][1], self.list_res_max[1][1]),
                          np.random.uniform(self.list_res_max[0][2], self.list_res_max[1][2])) for i in
                         range(data_train.shape[0])]

            t1 = time.time()

            train_path_save_npy, train_Path_Datas_mini_batch, train_Labels_mini_batch, train_remaining_patch = \
                create_patch_from_df_hr(df=data_train, per_cent_val_max=self.percent_val_max,
                                        contrast_list=train_contrast_list, list_res=res_train, order=3,
                                        thresholdvalue=0, patch_size=patch_size, batch_size=batch_size,
                                        path_save_npy=os.path.join(folder_training_data,"train_mini_batch"), stride=20,
                                        is_conditional=self.is_conditional, interp=interp,
                                        interpolation_type=interpolation_type,
                                        fit_mask=self.fit_mask,
                                        image_cropping_method=image_cropping_method,
                                        nb_classe_mask = self.nb_classe_mask)'''

            iterationPerEpoch = len(train_Path_Datas_mini_batch)

            t2 = time.time()

            print("time for making train npy :" + str(t2 - t1))

            if never_print:
                print("At each epoch " + str(train_remaining_patch) + " patches will not be in the training data for "
                                                                      "this epoch")
                never_print = False

            print("Processing epoch : " + str(EpochIndex))
            for iters in range(0, iterationPerEpoch):

                iteration += 1

                # Training discriminator
                
                for cidx in range(number_of_disciminator_iteration):

                    t1 = time.time()

                    # Loading data randomly
                    randomNumber = int(np.random.randint(0, iterationPerEpoch, 1))

                    #print("train on batch : ",train_Path_Datas_mini_batch[randomNumber])

                    train_input = np.load(train_Path_Datas_mini_batch[randomNumber])[:, 0, :, :, :][:, np.newaxis, :, :,
                                  :]
                    # select 0 coordoniate and add one axis at the same place

                    train_output = np.load(train_Labels_mini_batch[randomNumber])

                    if self.is_conditional:


                        train_res = np.load(train_Path_Datas_mini_batch[randomNumber])[:, 1, :, :, :][:, np.newaxis, :,
                                    :, :]

                        # Generating fake and interpolation images
                        fake_images = self.GeneratorModel_multi_gpu.predict([train_input, train_res])[1]

                        if self.fit_mask :
                            epsilon = np.random.uniform(0, 1, size=(batch_size, 3, 1, 1, 1))
                        else :
                            epsilon = np.random.uniform(0, 1, size=(batch_size, 2, 1, 1, 1))
                        #@vectorize(['float64(float64)'], target ="cuda")
                        start = timer()
                        
                        interpolation = interp_func(epsilon, train_output, fake_images)
                        print("type of inter args",interp_func.inspect_types())
                        print("INTERPOLATION WITH GPU:", timer()-start) 
                        interpolation = epsilon * train_output + (1 - epsilon) * fake_images
                        #Training
                        dis_loss = self.DiscriminatorModel_multi_gpu.train_on_batch([train_output, fake_images,
                                                                                     interpolation, train_res],
                                                                                    [real, fake, dummy])
                    else:

                        # Generating fake and interpolation images
                        fake_images = self.GeneratorModel_multi_gpu.predict(train_input)[1]


                        if self.fit_mask :
                            epsilon = np.random.uniform(0, 1, size=(batch_size, 2+self.nb_classe_mask, 1, 1, 1))
                        else :
                            epsilon = np.random.uniform(0, 1, size=(batch_size, 2, 1, 1, 1))
                        #print("epsilon type:", type(epsilon), epsilon.shape)
                        #print("output type:", type(train_output),train_output.shape)
                        #print("fake type:", type(fake_images), fake_images.shape)
                        start = timer()
                        interpolation = interp_func(epsilon, train_output, fake_images)
                        #print("type of inter args",interp_func.inspect_types())
                        print("INTERPOLATION WITH GPU:", timer()-start)
                        interpolation = epsilon * train_output + (1 - epsilon) * fake_images
                        # Training
                        dis_loss = self.DiscriminatorModel_multi_gpu.train_on_batch([train_output, fake_images,
                                                                                     interpolation],
                                                                                    [real, fake, dummy])

                    t2 = time.time()

                    print("time for one uptade of discriminator :" + str(t2 - t1))
                    print("Update " + str(cidx) + ": [D loss : " + str(dis_loss) + "]")
                
                # Training generator
                # Loading data

                t1 = time.time()

                train_input_gen = np.load(train_Path_Datas_mini_batch[iters])[:, 0, :, :, :][:, np.newaxis, :, :, :]
                train_output_gen = np.load(train_Labels_mini_batch[iters])

                if self.is_conditional:

                    train_res_gen = np.load(train_Path_Datas_mini_batch[iters])[:, 1, :, :, :][:, np.newaxis, :, :, :]
                    # Training
                    gen_loss = self.GeneratorModel_multi_gpu.train_on_batch([train_input_gen, train_res_gen],
                                                                            [real, train_output_gen])
                else:
                    # Training
                    gen_loss = self.GeneratorModel_multi_gpu.train_on_batch([train_input_gen], [real, train_output_gen])

                print("Iter " + str(iteration) + " [A loss : " + str(gen_loss) + "]")

                t2 = time.time()

                print("time for one uptade of generator :" + str(t2 - t1))

            if EpochIndex % snapshot_epoch == 0:
                # Save weights:
                self.GeneratorModel.save_weights(snapshot_prefix + '_' + str(EpochIndex))
                print("Snapshot :" + snapshot_prefix + '_' + str(EpochIndex))

            MSE_list = []
            VP = []
            Pos_pred = []
            Pos_label = []
            # for the three following object first is the dimension of the mask class and the second dimension is the test patch dimension
            VP_mask_all_label = [[] for i in range(self.nb_classe_mask)]
            Pos_pred_mask_all_label = [[] for i in range(self.nb_classe_mask)]
            Pos_label_mask_all_label = [[] for i in range(self.nb_classe_mask)]

            t1 = time.time()

            for test_iter in range(len(test_Labels_mini_batch)):

                TestLabels = np.load(test_Labels_mini_batch[test_iter])
                TestDatas = np.load(test_Path_Datas_mini_batch[test_iter])[:, 0, :, :, :][:, np.newaxis, :, :, :]

                if self.is_conditional:

                    TestRes = np.load(test_Path_Datas_mini_batch[test_iter])[:, 1, :, :, :][:, np.newaxis, :, :, :]

                    pred = self.generator.predict([TestDatas, TestRes])

                else:

                    pred = self.generator.predict([TestDatas])

                pred[:, 0, :, :, :][pred[:, 0, :, :, :] < 0] = 0

                MSE_list.append(np.sum((pred[:, 0, :, :, :] - TestLabels[:, 0, :, :, :]) ** 2))

                VP.append(np.sum((pred[:, 1, :, :, :] > 0.5) & (TestLabels[:, 1, :, :, :] == 1)))

                Pos_pred.append(np.sum(pred[:, 1, :, :, :] > 0.5))

                Pos_label.append(np.sum(TestLabels[:, 1, :, :, :]))

                if self.fit_mask :

                    estimated_mask_discretized = np.argmax(pred[:, 2:, :, :, :],axis=1)

                    for i in range(self.nb_classe_mask):

                        VP_mask_all_label[i].append(np.sum((estimated_mask_discretized==i) & (TestLabels[:, 2+i, :, :, :] == 1)))
                        Pos_pred_mask_all_label[i].append(np.sum(estimated_mask_discretized==i))
                        Pos_label_mask_all_label[i].append(np.sum(TestLabels[:, 2+i, :, :, :] == 1))

            t2 = time.time()

            print("Evaluation on test data time : " + str(t2 - t1))

            gen_weights = np.array(self.GeneratorModel.get_weights())
            gen_weights_multi = np.array(self.GeneratorModel_multi_gpu.get_weights())

            weights_idem = True

            for i in range(len(gen_weights)):
                idem = np.array_equal(gen_weights[i], gen_weights_multi[i])

                weights_idem = weights_idem & idem

            if weights_idem:

                print("Model multi_gpu and base Model have the same weights")

            else:
                print("Model multi_gpu and base Model haven't the same weights")

            Dice = (2 * np.sum(VP)) / (np.sum(Pos_pred) + np.sum(Pos_label))

            MSE = np.sum(MSE_list) / (patch_size ** 3 * len(MSE_list))

            print("Iter " + str(EpochIndex) + " [Test MSE : " + str(MSE) + "]")

            df_MSE.loc[EpochIndex, "MSE"] = MSE

            if self.fit_mask :

                Dice_Mask = (2 * np.sum(VP_mask_all_label,axis=1)) / (np.sum(Pos_pred_mask_all_label,axis=1) + np.sum(Pos_label_mask_all_label,axis=1))

                df_dice.loc[EpochIndex, "Dice_label_1"] = Dice
                df_dice.loc[EpochIndex, colunms_dice_mask] = Dice_Mask

                print("Iter " + str(EpochIndex) + " [Test Dice label 1 : " + str(Dice) + "]")
                print("Iter " + str(EpochIndex) + " [Test Dice Mask : " + str(Dice_Mask) + "]")


            else :

                df_dice.loc[EpochIndex, "Dice"] = Dice

                print("Iter " + str(EpochIndex) + " [Test Dice : " + str(Dice) + "]")




            df_dice.to_csv(dice_file)
            df_MSE.to_csv(mse_file)
class SegSRGAN_test(object):
    def __init__(self,
                 weights,
                 patch1,
                 patch2,
                 patch3,
                 is_conditional,
                 u_net_gen,
                 is_residual,
                 first_generator_kernel,
                 first_discriminator_kernel,
                 resolution=0):

        self.patch1 = patch1
        self.patch2 = patch2
        self.patch3 = patch3
        self.prediction = None
        self.SegSRGAN = SegSRGAN(
            first_generator_kernel=first_generator_kernel,
            first_discriminator_kernel=first_discriminator_kernel,
            u_net_gen=u_net_gen,
            image_row=patch1,
            image_column=patch2,
            image_depth=patch3,
            is_conditional=is_conditional,
            is_residual=is_residual)
        self.generator_model = self.SegSRGAN.generator_model_for_pred()
        self.generator_model.load_weights(weights, by_name=True)
        self.generator = self.SegSRGAN.generator()
        self.is_conditional = is_conditional
        self.resolution = resolution
        self.is_residual = is_residual
        self.res_tensor = np.expand_dims(np.expand_dims(
            np.ones([patch1, patch2, patch3]) * self.resolution, axis=0),
                                         axis=0)

    def get_patch(self):
        """

        :return:
        """
        return self.patch

    def test_by_patch(self, test_image, step=1, by_batch=False):
        """

        :param test_image: Image to be tested
        :param step: step
        :param by_batch: to enable by batch processing
        :return:
        """
        # Init temp
        height, width, depth = np.shape(test_image)

        temp_hr_image = np.zeros_like(test_image)
        temp_seg = np.zeros_like(test_image)
        weighted_image = np.zeros_like(test_image)

        # if is_conditional is set to True we predict on the image AND the resolution
        if self.is_conditional is True:
            if not by_batch:

                i = 0
                bar = progressbar.ProgressBar(maxval=len(np.arange(0, height - self.patch1 + 1, step)) * len(
                    np.arange(0, width - self.patch2 + 1, step)) * len(np.arange(0, depth - self.patch3 + 1, step))).\
                    start()
                print('Patch=', self.patch1)
                print('Step=', step)
                for idx in range(0, height - self.patch1 + 1, step):
                    for idy in range(0, width - self.patch2 + 1, step):
                        for idz in range(0, depth - self.patch3 + 1, step):
                            # Cropping image
                            test_patch = test_image[idx:idx + self.patch1,
                                                    idy:idy + self.patch2,
                                                    idz:idz + self.patch3]
                            image_tensor = test_patch.reshape(1, 1, self.patch1, self.patch2, self.patch3).\
                                astype(np.float32)
                            predict_patch = self.generator.predict(
                                [image_tensor, self.res_tensor], batch_size=1)

                            # Adding
                            temp_hr_image[idx:idx + self.patch1,
                                          idy:idy + self.patch2, idz:idz +
                                          self.patch3] += predict_patch[
                                              0, 0, :, :, :]
                            temp_seg[idx:idx + self.patch1, idy:idy + self.patch2, idz:idz + self.patch3] += \
                                predict_patch[0, 1, :, :, :]
                            weighted_image[idx:idx + self.patch1,
                                           idy:idy + self.patch2, idz:idz +
                                           self.patch3] += np.ones_like(
                                               predict_patch[0, 0, :, :, :])

                            i += 1

                            bar.update(i)
            else:

                height = test_image.shape[0]
                width = test_image.shape[1]
                depth = test_image.shape[2]

                patch1 = self.patch1
                patch2 = self.patch2
                patch3 = self.patch3

                patches = np.array(
                    [[
                        test_image[idx:idx + patch1, idy:idy + patch2,
                                   idz:idz + patch3]
                    ] for idx in range(0, height - patch1 + 1, step)
                     for idy in range(0, width - patch2 + 1, step)
                     for idz in range(0, depth - patch3 + 1, step)])

                indice_patch = np.array([
                    (idx, idy, idz)
                    for idx in range(0, height - patch1 + 1, step)
                    for idy in range(0, width - patch2 + 1, step)
                    for idz in range(0, depth - patch3 + 1, step)
                ])

                pred = self.generator.predict(patches,
                                              batch_size=patches.shape[0])

                weight = np.zeros_like(test_image)
                temp_hr_image = np.zeros(test_image)
                temp_seg = np.zeros(test_image)

                for i in range(indice_patch.shape[0]):
                    temp_hr_image[indice_patch[i][0]:indice_patch[i][0] +
                                  patch1,
                                  indice_patch[i][1]:indice_patch[i][1] +
                                  patch2,
                                  indice_patch[i][2]:indice_patch[i][2] +
                                  patch3] += pred[i, 0, :, :, :]
                    temp_seg[indice_patch[i][0]:indice_patch[i][0] + patch1,
                             indice_patch[i][1]:indice_patch[i][1] + patch2,
                             indice_patch[i][2]:indice_patch[i][2] +
                             patch3] += pred[i, 1, :, :, :]
                    weight[indice_patch[i][0]:indice_patch[i][0] + patch1,
                           indice_patch[i][1]:indice_patch[i][1] + patch2,
                           indice_patch[i][2]:indice_patch[i][2] +
                           patch3] + np.ones_like(weight[
                               indice_patch[i][0]:indice_patch[i][0] + patch1,
                               indice_patch[i][1]:indice_patch[i][1] + patch2,
                               indice_patch[i][2]:indice_patch[i][2] + patch3])
        else:
            if not by_batch:

                i = 0
                bar = progressbar.ProgressBar(
                    maxval=len(np.arange(0, height - self.patch1 + 1, step)) *
                    len(np.arange(0, width - self.patch2 + 1, step)) *
                    len(np.arange(0, depth - self.patch3 + 1, step))).start()
                print('Patch=', self.patch1)
                print('Step=', step)
                for idx in range(0, height - self.patch1 + 1, step):
                    for idy in range(0, width - self.patch2 + 1, step):
                        for idz in range(0, depth - self.patch3 + 1, step):
                            # Cropping image
                            test_patch = test_image[idx:idx + self.patch1,
                                                    idy:idy + self.patch2,
                                                    idz:idz + self.patch3]
                            image_tensor = test_patch.reshape(
                                1, 1, self.patch1, self.patch2,
                                self.patch3).astype(np.float32)
                            predict_patch = self.generator.predict(
                                image_tensor, batch_size=1)

                            # Adding
                            temp_hr_image[idx:idx + self.patch1,
                                          idy:idy + self.patch2, idz:idz +
                                          self.patch3] += predict_patch[
                                              0, 0, :, :, :]
                            temp_seg[idx:idx + self.patch1,
                                     idy:idy + self.patch2, idz:idz +
                                     self.patch3] += predict_patch[0,
                                                                   1, :, :, :]
                            weighted_image[idx:idx + self.patch1,
                                           idy:idy + self.patch2, idz:idz +
                                           self.patch3] += np.ones_like(
                                               predict_patch[0, 0, :, :, :])

                            i += 1

                            bar.update(i)
            else:

                height = test_image.shape[0]
                width = test_image.shape[1]
                depth = test_image.shape[2]

                patch1 = self.patch1
                patch2 = self.patch2
                patch3 = self.patch3

                patches = np.array(
                    [[
                        test_image[idx:idx + patch1, idy:idy + patch2,
                                   idz:idz + patch3]
                    ] for idx in range(0, height - patch1 + 1, step)
                     for idy in range(0, width - patch2 + 1, step)
                     for idz in range(0, depth - patch3 + 1, step)])

                indice_patch = np.array([
                    (idx, idy, idz)
                    for idx in range(0, height - patch1 + 1, step)
                    for idy in range(0, width - patch2 + 1, step)
                    for idz in range(0, depth - patch3 + 1, step)
                ])

                pred = self.generator.predict(patches,
                                              batch_size=patches.shape[0])

                weight = np.zeros_like(test_image)
                temp_hr_image = np.zeros(test_image)
                temp_seg = np.zeros(test_image)

                for i in range(indice_patch.shape[0]):
                    temp_hr_image[indice_patch[i][0]:indice_patch[i][0] +
                                  patch1,
                                  indice_patch[i][1]:indice_patch[i][1] +
                                  patch2,
                                  indice_patch[i][2]:indice_patch[i][2] +
                                  patch3] += pred[i, 0, :, :, :]
                    temp_seg[indice_patch[i][0]:indice_patch[i][0] + patch1,
                             indice_patch[i][1]:indice_patch[i][1] + patch2,
                             indice_patch[i][2]:indice_patch[i][2] +
                             patch3] += pred[i, 1, :, :, :]
                    weight[indice_patch[i][0]:indice_patch[i][0] + patch1,
                           indice_patch[i][1]:indice_patch[i][1] + patch2,
                           indice_patch[i][2]:indice_patch[i][2] +
                           patch3] + np.ones_like(weight[
                               indice_patch[i][0]:indice_patch[i][0] + patch1,
                               indice_patch[i][1]:indice_patch[i][1] + patch2,
                               indice_patch[i][2]:indice_patch[i][2] + patch3])
        # weight sum of patches
        print(GREEN + start + '\nDone !' + end + RESET)
        estimated_hr = temp_hr_image / weighted_image
        estimated_segmentation = temp_seg / weighted_image

        return estimated_hr, estimated_segmentation