示例#1
0
    def __init__(self,
                 name: str,
                 checkpoint_folder: str,
                 weight_folder: str,
                 logs_folder: str,
                 shape: tuple = (64, 64, 64),
                 lambda_rec: float = 1,
                 lambda_adv: float = 0.001,
                 lambda_gp: float = 10,
                 dis_kernel: int = 32,
                 gen_kernel: int = 16,
                 lr_dismodel: float = 0.0001,
                 lr_genmodel: float = 0.0001,
                 max_checkpoints_to_keep: int = 2,
                 *args,
                 **kwargs):

        self.patchsize = shape
        self.name = name
        self.generator = self.make_generator_model(shape, gen_kernel, *args,
                                                   **kwargs)
        self.generator.summary()

        self.discriminator = self.make_discriminator_model(
            shape, dis_kernel, *args, **kwargs)
        self.discriminator.summary()

        self.generator_trainer = self.make_generator_trainer(
            shape, lr_genmodel, lambda_adv, lambda_rec)
        self.discriminator_trainer = self.make_discriminator_trainer(
            shape, lr_dismodel, lambda_gp)

        if not isdir(checkpoint_folder):
            raise Exception(
                f"Checkpoint's folder unknow : {checkpoint_folder}")
        else:
            self.checkpoints_folder = get_and_create_dir(
                normpath(join(checkpoint_folder, name)))

        if not isdir(weight_folder):
            raise Exception(f"Weight's folder unknow : {weight_folder}")
        else:
            self.weight_folder = get_and_create_dir(
                normpath(join(weight_folder, name)))

        if not isdir(logs_folder):
            raise Exception(f" Logs's folder unknow : {logs_folder}")
        else:
            self.logs_folder = get_and_create_dir(
                normpath(join(logs_folder, name)))

        self.checkpoint = tf.train.Checkpoint(epoch=tf.Variable(0,
                                                                name='epoch'),
                                              generator=self.generator,
                                              discriminator=self.discriminator)

        self.checkpoint_manager = tf.train.CheckpointManager(
            checkpoint=self.checkpoint,
            directory=self.checkpoints_folder,
            max_to_keep=3)
示例#2
0
    def __init__(self,
                 name: str,
                 checkpoint_folder: str,
                 weight_folder: str,
                 logs_folder: str,
                 make_generator_model=make_generator_model,
                 make_discriminator_model=None,
                 patchsize=(32, 32, 32),
                 *args,
                 **kwargs):

        self.name = name
        self.patchsize = patchsize

        if K.backend() == "tensorflow":
            from tensorflow.python.client import device_lib
            print(device_lib.list_local_devices())

        if not isdir(checkpoint_folder):
            raise Exception(
                f"Checkpoint's folder unknow : {checkpoint_folder}")
        else:
            self.checkpoint_folder = get_and_create_dir(
                normpath(join(checkpoint_folder, name)))

        if not isdir(weight_folder):
            raise Exception(f"Weight's folder unknow : {weight_folder}")
        else:
            self.weight_folder = get_and_create_dir(
                normpath(join(weight_folder, name)))

        if not isdir(logs_folder):
            raise Exception(f" Logs's folder unknow : {logs_folder}")
        else:
            self.logs_folder = get_and_create_dir(
                normpath(join(logs_folder, name)))

        self.optimizer_gen = keras.optimizers.Adam()

        self.generator = make_generator_model("gen", self.patchsize, 4)
        self.generator.summary()

        self.checkpoint = tf.train.Checkpoint(epoch=tf.Variable(0,
                                                                name='epoch'),
                                              optimizer_G=self.optimizer_gen,
                                              model=self.generator)
        self.checkpoint_manager = tf.train.CheckpointManager(
            checkpoint=self.checkpoint,
            directory=self.checkpoint_folder,
            max_to_keep=3)
示例#3
0
    def train(self,
              dataset: MRI_Dataset,
              n_epochs: int = 1,
              mri_to_visualize=None,
              output_dir=None,
              *args,
              **kwargs):

        if output_dir:
            output_dir = get_and_create_dir(join(output_dir, self.name))

        self.load_checkpoint()
        num_epoch = self.checkpoint.epoch.numpy()
        losses = []
        for epoch in range(num_epoch, n_epochs):
            print(f"Epoch {epoch+1} / {n_epochs} : ")
            last_losses = self._fit_one_epoch(dataset('Train'), *args,
                                              **kwargs)
            losses.append(last_losses)
            print("Discriminator loss mean : ",
                  np.mean(np.mean(last_losses[DIS_LOSSES], axis=0)))
            print("Generator loss mean : ", np.mean(last_losses[GEN_LOSSES]))
            self.checkpoint_manager.save()

            if mri_to_visualize:
                if output_dir is None:
                    raise Exception(
                        "You should specify the directory of output")
                sr_mri = test_by_patch(mri_to_visualize, self)
                sr_mri.save_mri(
                    join(
                        output_dir, self.name + "_epoch_" + str(epoch) +
                        "_SR_" + basename(mri_to_visualize.filepath)))
示例#4
0
 def train(self, dataset, n_epochs, mri_to_visualize=None, output_dir=None):
     if output_dir:
         output_dir = get_and_create_dir(join(output_dir, self.name))
     self.load_checkpoint()
     losses = []
     val_losses = []
     for epoch_index in range(self.checkpoint.epoch.numpy(), n_epochs):
         for step, (lr, label) in enumerate(dataset('Train')):
             _, total_loss = self.train_step_generator(lr, label)
             losses.append(total_loss)
         for step, (lr, label) in enumerate(dataset('Val')):
             _, val_total_loss = self.evaluation_step_generator(lr, label)
             val_losses.append(val_total_loss)
         print(
             f"Epoch : {epoch_index+1:04d}/{n_epochs} - mean total_loss : {np.mean(losses):04f} - mean val_total_loss : {np.mean(val_losses):04f}"
         )
         self.checkpoint_manager.save()
         print("*save ckpt file at {}\n".format(
             self.checkpoint_manager.latest_checkpoint))
         self.checkpoint.epoch.assign_add(1)
         if mri_to_visualize:
             if output_dir is None:
                 raise Exception(
                     "You should specify the directory of output")
             sr_mri = test_by_patch(mri_to_visualize, self)
             sr_mri.save_mri(
                 join(
                     output_dir,
                     self.name + "_epoch_" + str(epoch_index + 1) + "_SR_" +
                     basename(mri_to_visualize.filepath)))
     self.generator.save_weights(join(self.weight_folder,
                                      self.name + ".h5"))
     print("\nSave weights file at {}".format(
         join(self.weight_folder, self.name + ".h5")))
     print("Training done !")
示例#5
0
    def train(self,
              dataset: MRI_Dataset,
              n_epochs: int = 1,
              mri_to_visualize=None,
              output_dir=None,
              *args,
              **kwargs):
        if output_dir:
            output_dir = get_and_create_dir(join(output_dir, self.name))

        status = self.checkpoint.restore(
            self.checkpoint_manager.latest_checkpoint)

        for epoch in range(0, n_epochs):
            print(f"Epoch {epoch+1} / {n_epochs} : ")
            self._fit_one_epoch(dataset('Train'), *args, **kwargs)
            self.checkpoint_manager.save()

        if mri_to_visualize:
            if output_dir is None:
                raise Exception("You should specify the directory of output")
            sr_mri = test_by_patch(mri_to_visualize, self)
            sr_mri.save_mri(
                join(
                    output_dir, self.name + "_epoch_" + str(1) + "_SR_" +
                    basename(mri_to_visualize.filepath)))
示例#6
0
    def __init__(self, config: ConfigParser, batch_folder: str, *args,
                 **kwargs):

        self.cfg = config

        self.batch_folder = batch_folder
        self.train_batch_folder_name = self.cfg.get('Batch_Path',
                                                    'Train_batch')
        self.val_batch_folder_name = self.cfg.get('Batch_Path',
                                                  'Validatation_batch')
        self.test_batch_folder_name = self.cfg.get('Batch_Path', 'Test_Batch')

        self.train_batch_folder_path = get_and_create_dir(
            normpath(join(self.batch_folder, self.train_batch_folder_name)))
        self.val_batch_folder_path = get_and_create_dir(
            normpath(join(self.batch_folder, self.val_batch_folder_name)))
        self.test_batch_folder_path = get_and_create_dir(
            normpath(join(self.batch_folder, self.test_batch_folder_name)))

        self.index = 0

        self.batchs_path_list = {
            self.cfg.get('Base_Header_Values', 'Train'): [],
            self.cfg.get('Base_Header_Values', 'Validation'): [],
            self.cfg.get('Base_Header_Values', 'Test'): []
        }

        self.list_batchs_folder = {
            self.cfg.get('Base_Header_Values', 'Train'):
            self.train_batch_folder_path,
            self.cfg.get('Base_Header_Values', 'Validation'):
            self.val_batch_folder_path,
            self.cfg.get('Base_Header_Values', 'Test'):
            self.test_batch_folder_path
        }

        self.initialize = False
示例#7
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_name",
                        "-n",
                        help="the dataset name...",
                        required=True)
    parser.add_argument("--csv_name",
                        "-csv",
                        help="file path of the csv listing mri path",
                        required=True)
    parser.add_argument("--batchsize",
                        "-bs",
                        help="batchsize of the training",
                        default=128)
    parser.add_argument(
        "--downscale_factor",
        "-lr",
        help=
        "factor for downscaling hr image by. it's a tuple of 3. example : 0.5 0.5 0.5",
        nargs=3,
        default=(2, 2, 2))
    parser.add_argument(
        "--patchsize",
        "-ps",
        help="tuple of the 3d patchsize. example : '16 16 16' ",
        nargs=3,
        default=(64, 64, 64))
    parser.add_argument("--step",
                        '-st',
                        help="step/stride for patches construction",
                        default=4)
    parser.add_argument(
        '--percent_valmax',
        help=
        "N trained on image on which we add gaussian noise with sigma equal to this % of val_max",
        default=0.03)
    parser.add_argument('--save_lr',
                        help="if you want to save lr mri",
                        action="store_true")
    parser.add_argument(
        '--segmentation',
        help="if you want to marge hr and segmentation for label",
        action="store_true")

    args = parser.parse_args()

    config = ConfigParser()
    if not isfile(CONFIG_INI_FILEPATH):
        raise Exception("You must run 'build_env.py -f <home_folder>'")
    config.read(CONFIG_INI_FILEPATH)

    print(
        f"build_dataset.py -n {args.dataset_name} -csv {args.csv_name} -bs {args.batchsize} -lr {args.downscale_factor} -ps {args.patchsize} -st {args.step} --percent_valmax {args.percent_valmax}"
    )

    home_folder = config.get('Path', 'Home')

    print(f"workspace : {home_folder}")

    try:
        (home_folder, out_repo_path, training_repo_path, dataset_repo_path,
         batch_repo_path, checkpoint_repo_path, csv_repo_path,
         weights_repo_path, indices_repo_path,
         result_repo_path) = get_environment(home_folder, config)

    except Exception:

        raise Exception(
            f"Home folder has not been set. You must run 'build_env.py -f <home_folder>' script before launch the training"
        )

    csv_listfile_path = normpath(join(csv_repo_path, args.csv_name))

    if not isfile(csv_listfile_path):
        raise Exception(
            f"{csv_listfile_path} unknown. you must put {args.csv_name} in {csv_repo_path} folder"
        )

    dataset_name = args.dataset_name
    batchsize = int(args.batchsize)
    patchsize = (int(args.patchsize[0]), int(args.patchsize[1]),
                 int(args.patchsize[2]))
    lr_downscale_factor = (float(args.downscale_factor[0]),
                           float(args.downscale_factor[1]),
                           float(args.downscale_factor[2]))
    step = int(args.step)
    percent_valmax = float(args.percent_valmax)

    print("Dataset creation : preprocess and patches generation...")

    batch_repo_path = get_and_create_dir(join(batch_repo_path, dataset_name))

    dataset = MRI_Dataset(config, batch_folder=batch_repo_path)
    dataset.make_and_save_dataset_batchs(
        mri_folder=dataset_repo_path,
        csv_listfile_path=csv_listfile_path,
        batchsize=batchsize,
        lr_downscale_factor=lr_downscale_factor,
        patchsize=patchsize,
        step=step,
        percent_valmax=percent_valmax,
        save_lr=args.save_lr,
        segmentation=args.segmentation)

    print(f"Done !")
    print(
        f"Dataset create at : {batch_repo_path} with :\n*batchsize of : {batchsize}"
    )
    print(
        f"*patchsize of : {patchsize} by {step} step\n*gaussian noise of : {percent_valmax}\n*downscale factor of : {lr_downscale_factor}"
    )
    sys.exit(0)