def main():
    args = get_arguments()
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    ## FOR REPRODUCIBILITY OF RESULTS
    seed = 1777777
    utils.reproducibility(args, seed)

    utils.make_dirs(args.save)
    utils.save_arguments(args, args.save)

    training_generator, val_generator, full_volume, affine = medical_loaders.generate_datasets(
        args, path='.././datasets')
    model, optimizer = medzoo.create_model(args)
    criterion = create_loss('CrossEntropyLoss')
    criterion = DiceLoss(classes=args.classes,
                         weight=torch.tensor([0.1, 1, 1, 1]).cuda())

    if args.cuda:
        model = model.cuda()
        print("Model transferred in GPU.....")

    trainer = Trainer(args,
                      model,
                      criterion,
                      optimizer,
                      train_data_loader=training_generator,
                      valid_data_loader=val_generator,
                      lr_scheduler=None)
    print("START TRAINING...")
    trainer.training()
def train():
    # args = brats2019_arguments()

    utils.reproducibility(args, seed)
    utils.make_dirs(args.save)

    (
        training_generator,
        val_generator,
        full_volume,
        affine,
    ) = medical_loaders.generate_datasets(args)
    model, optimizer = medzoo.create_model(args)
    val_criterion = DiceLoss(classes=11, skip_index_after=args.classes)

    # criterion = DiceLoss(classes=3, skip_index_after=args.classes)
    # criterion = DiceLoss(classes=args.classes)
    criterion = torch.nn.CrossEntropyLoss()

    if args.cuda:
        model = model.cuda()
        print("Model transferred in GPU.....")

    trainer = train_module.Trainer(
        args,
        model,
        criterion,
        optimizer,
        val_criterion=val_criterion,
        train_data_loader=training_generator,
        valid_data_loader=val_generator,
        lr_scheduler=None,
    )
    print("START TRAINING...")
    trainer.training()
Пример #3
0
def main():
    args = get_arguments()

    utils.reproducibility(args, seed)
    utils.make_dirs(args.save)

    training_generator, val_generator, full_volume, affine = medical_loaders.generate_datasets(
        args, path='.././datasets')
    model, optimizer = medzoo.create_model(args)
    criterion = DiceLoss(classes=args.classes)

    if args.cuda:
        model = model.cuda()
        print("Model transferred in GPU.....")

    trainer = train.Trainer(args,
                            model,
                            criterion,
                            optimizer,
                            train_data_loader=training_generator,
                            valid_data_loader=val_generator,
                            lr_scheduler=None)
    print("START TRAINING...")
    trainer.training()

    visualize_3D_no_overlap_new(args, full_volume, affine, model, 10, args.dim)
Пример #4
0
def main():
    args = get_arguments()

    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    ## FOR REPRODUCIBILITY OF RESULTS
    seed = 1777777
    utils.reproducibility(args, seed)
    utils.make_dirs(args.save)
    name_model = args.model + "_" + args.dataset_name + "_" + utils.datestr()

    # TODO visual3D_temp.Basewriter package
    writer = SummaryWriter(log_dir='./runs/' + name_model, comment=name_model)

    training_generator, val_generator, full_volume, affine = medical_loaders.generate_datasets(
        args, path='.././datasets')
    model, optimizer = medzoo.create_model(args)

    if args.cuda:
        model = model.cuda()
        print("Model transferred in GPU.....")

    print("START TRAINING...")
    for epoch in range(1, args.nEpochs + 1):
        train(args, model, training_generator, optimizer, epoch, writer)
        val_metrics, confusion_matrix = validation(args, model, val_generator,
                                                   epoch, writer)
    def __init__(self, mode, sub_task='lung', split=0.2, fold=0, n_classes=3, samples=10, dataset_path='../datasets',
                 crop_dim=(32, 32, 32)):
        print("COVID SEGMENTATION DATASET")
        self.CLASSES = n_classes
        self.fold = int(fold)
        self.crop_size = crop_dim
        self.full_vol_dim = (512, 512, 301)  # width, height,slice,
        self.mode = mode
        self.full_volume = None
        self.affine = None
        self.list = []
        self.samples = samples
        subvol = '_vol_' + str(crop_dim[0]) + 'x' + str(crop_dim[1]) + 'x' + str(crop_dim[2])

        self.sub_vol_path = dataset_path + '/covid_segmap_dataset/generated/' + mode + subvol + '/'
        utils.make_dirs(self.sub_vol_path)

        self.train_images, self.train_labels, self.val_labels, self.val_images = [], [], [], []
        list_images = sorted(
            glob.glob(os.path.join(dataset_path, 'covid_segmap_dataset/COVID-19-CT-Seg_20cases/*')))

        if sub_task == 'lung':
            list_labels = sorted(glob.glob(os.path.join(dataset_path, 'covid_segmap_dataset/Lung_Mask/*')))
        elif sub_task == 'infection':
            list_labels = sorted(glob.glob(os.path.join(dataset_path, 'covid_segmap_dataset/Infection_Mask/*')))
        else:
            list_labels = sorted(
                glob.glob(os.path.join(dataset_path, 'covid_segmap_dataset/Lung_and_Infection_Mask/*')))
        len_of_data = len(list_images)

        for i in range(len_of_data):
            if i >= (self.fold * int(split * len_of_data)) and i < (
                    (self.fold * int(split * len_of_data)) + int(split * len_of_data)):
                self.train_images.append(list_images[i])
                self.train_labels.append(list_labels[i])
            else:
                self.val_images.append(list_images[i])
                self.val_labels.append(list_labels[i])

        if (mode == 'train'):
            self.list_IDs = self.train_images
            self.list_labels = self.train_labels

        elif (mode == 'val'):
            self.list_IDs = self.val_images
            self.list_labels = self.val_labels

        self.list = create_sub_volumes(self.list_IDs, self.list_labels, dataset_name='covid19seg', mode=mode,
                                       samples=samples, full_vol_dim=self.full_vol_dim, crop_size=self.crop_size,
                                       sub_vol_path=self.sub_vol_path)
        print("{} SAMPLES =  {}".format(mode, len(self.list)))
Пример #6
0
    def __init__(self, args):

        name_model = args.log_dir + args.model + "_" + args.dataset_name + "_" + utils.datestr(
        )
        self.writer = SummaryWriter(log_dir=args.log_dir + name_model,
                                    comment=name_model)

        utils.make_dirs(args.save)
        self.csv_train, self.csv_val = self.create_stats_files(args.save)
        self.dataset_name = args.dataset_name
        self.classes = args.classes
        self.label_names = dict_class_names[args.dataset_name]

        self.data = self.create_data_structure()
def main():
    args = get_arguments()
    utils.reproducibility(args, seed)
    utils.make_dirs(args.save)

    training_generator, val_generator, full_volume, affine = medical_loaders.generate_datasets(args,
                                                                                               path='.././datasets')
    model, optimizer = medzoo.create_model(args)
    criterion = DiceLoss(classes=args.classes)

    if args.cuda:
        model = model.cuda()

    trainer = train.Trainer(args, model, criterion, optimizer, train_data_loader=training_generator,
                            valid_data_loader=val_generator)
    trainer.training()
Пример #8
0
    def __init__(self, args, mode, dataset_path='.././datasets', split_idx=150, crop_dim=(512, 512), samples=100,
                 classes=7,
                 save=True):
        """
        :param mode: 'train','val'
        :param image_paths: image dataset paths
        :param label_paths: label dataset paths
        :param crop_dim: 2 element tuple to decide crop values
        :param samples: number of sub-grids to create(patches of the input img)
        """
        image_paths = sorted(glob.glob(dataset_path + "/MICCAI_2019_pathology_challenge/Train Imgs/Train Imgs/*.jpg"))
        label_paths = sorted(glob.glob(dataset_path + "/MICCAI_2019_pathology_challenge/Labels/*.png"))

        image_paths, label_paths = utils.shuffle_lists(image_paths, label_paths, seed=17)
        self.full_volume = None
        self.affine = None

        self.slices = 244  # dataset instances
        self.mode = mode
        self.crop_dim = crop_dim
        self.sample_list = []
        self.samples = samples
        self.save = save
        self.root = dataset_path
        self.per_image_sample = int(self.samples / self.slices)
        if self.per_image_sample < 1:
            self.per_image_sample = 1

        print("per image sampleeeeee", self.per_image_sample)

        sub_grid = '_2dgrid_' + str(crop_dim[0]) + 'x' + str(crop_dim[1])

        if self.save:
            self.sub_vol_path = self.root + '/MICCAI_2019_pathology_challenge/generated/' + mode + sub_grid + '/'
            utils.make_dirs(self.sub_vol_path)

        if self.mode == 'train':
            self.list_imgs = image_paths[0:split_idx]
            self.list_labels = label_paths[0:split_idx]
        elif self.mode == 'val':
            self.list_imgs = image_paths[split_idx:]
            self.list_labels = label_paths[split_idx:]

        self.generate_samples()
Пример #9
0
    def __init__(self,
                 args,
                 dataset_path='./data',
                 voxels_space=(2, 2, 2),
                 modalities=2,
                 to_canonical=False,
                 save=True):
        """
        :param dataset_path: the extracted path that contains the desired images
        :param voxels_space: for reshampling the voxel space
        :param modalities: 1 for T1 only, 2 for T1 and T2
        :param to_canonical: If you want to convert the coordinates to RAS
        for more info on this advice here https://www.slicer.org/wiki/Coordinate_systems
        :param save: to save the generated data offline for faster reading
        and not load RAM
        """
        self.root = str(dataset_path)
        self.modalities = modalities
        self.pathT1 = self.root + '/ixi/T1/'
        self.pathT2 = self.root + '/ixi/T2/'
        self.save = save
        self.CLASSES = 4
        self.full_vol_dim = (150, 256, 256)  # slice, width, height
        self.voxels_space = voxels_space
        self.modalities = str(modalities)
        self.list = []
        self.full_volume = None
        self.to_canonical = to_canonical
        self.affine = None

        subvol = '_vol_' + str(self.voxels_space[0]) + 'x' + str(
            self.voxels_space[1]) + 'x' + str(self.voxels_space[2])

        if self.save:
            self.sub_vol_path = self.root + '/ixi/generated/' + subvol + '/'
            utils.make_dirs(self.sub_vol_path)
        print(self.pathT1)
        self.list_IDsT1 = sorted(
            glob.glob(os.path.join(self.pathT1, '*T1.nii.gz')))
        self.list_IDsT2 = sorted(
            glob.glob(os.path.join(self.pathT2, '*T2.nii.gz')))
        self.affine = img_loader.load_affine_matrix(self.list_IDsT1[0])
        self.create_input_data()
def main():
    args = get_arguments()
    utils.reproducibility(args, seed)
    utils.make_dirs(args.save)

    training_generator, val_generator, full_volume, affine = medical_loaders.generate_datasets(
        args, path='.././datasets')
    model, optimizer = medzoo.create_model(args)
    criterion = DiceLoss(
        classes=args.classes
    )  # ,skip_index_after=2,weight=torch.tensor([0.00001,1,1,1]).cuda())

    if args.cuda:
        model = model.cuda()
        print("Model transferred in GPU.....")

    trainer = train.Trainer(args,
                            model,
                            criterion,
                            optimizer,
                            train_data_loader=training_generator,
                            valid_data_loader=val_generator)
    print("START TRAINING...")
    trainer.training()
Пример #11
0
    def __init__(self,
                 args,
                 mode,
                 dataset_path='./datasets',
                 crop_dim=(32, 32, 32),
                 split_id=1,
                 samples=1000,
                 load=False):
        load = False
        """
        :param mode: 'train','val','test'
        :param dataset_path: root dataset folder
        :param crop_dim: subvolume tuple
        :param fold_id: 1 to 10 values
        :param samples: number of sub-volumes that you want to create
        """
        self.mode = mode
        self.root = str(dataset_path)
        self.training_path = self.root + '/iseg_2017/iSeg-2017-Training/'
        self.testing_path = self.root + '/iseg_2017/iSeg-2017-Testing/'
        self.CLASSES = 4
        self.full_vol_dim = (144, 192, 256)  # slice, width, height
        self.threshold = args.threshold
        self.normalization = args.normalization
        self.augmentation = args.augmentation
        self.crop_size = crop_dim
        self.list = []
        self.samples = samples
        self.full_volume = None
        # self.save_name = self.root + '/iseg_2017/iSeg-2017-Training/iseg2017-list-' + mode + '-samples-' + str(
        self.save_name = self.root + '/iseg2017-list-' + mode + '-samples-' + str(
            samples) + '.txt'
        if self.augmentation:
            self.transform = augment3D.RandomChoice(transforms=[
                augment3D.GaussianNoise(mean=0, std=0.01),
                augment3D.RandomFlip(),
                augment3D.ElasticTransform()
            ],
                                                    p=0.5)
        if load:
            ## load pre-generated data
            self.list = utils.load_list(self.save_name)
            list_IDsT1 = sorted(
                glob.glob(os.path.join(self.training_path, '*T1.img')))
            self.affine = img_loader.load_affine_matrix(list_IDsT1[0])
            return

        subvol = '_vol_' + str(crop_dim[0]) + 'x' + str(
            crop_dim[1]) + 'x' + str(crop_dim[2])
        self.sub_vol_path = self.root + '/iseg_2017/generated/' + mode + subvol + '/'

        utils.make_dirs(self.sub_vol_path)
        list_IDsT1 = sorted(
            glob.glob(os.path.join(self.training_path, '*T1.img')))
        list_IDsT2 = sorted(
            glob.glob(os.path.join(self.training_path, '*T2.img')))
        labels = sorted(
            glob.glob(os.path.join(self.training_path, '*label.img')))
        self.affine = img_loader.load_affine_matrix(list_IDsT1[0])

        if self.mode == 'train':

            list_IDsT1 = list_IDsT1[:split_id]
            list_IDsT2 = list_IDsT2[:split_id]
            labels = labels[:split_id]

            self.list = create_sub_volumes(list_IDsT1,
                                           list_IDsT2,
                                           labels,
                                           dataset_name="iseg2017",
                                           mode=mode,
                                           samples=samples,
                                           full_vol_dim=self.full_vol_dim,
                                           crop_size=self.crop_size,
                                           sub_vol_path=self.sub_vol_path,
                                           th_percent=self.threshold,
                                           normalization=args.normalization)

        elif self.mode == 'val':
            utils.make_dirs(self.sub_vol_path)
            list_IDsT1 = list_IDsT1[split_id:]
            list_IDsT2 = list_IDsT2[split_id:]
            labels = labels[split_id:]
            self.list = create_sub_volumes(list_IDsT1,
                                           list_IDsT2,
                                           labels,
                                           dataset_name="iseg2017",
                                           mode=mode,
                                           samples=samples,
                                           full_vol_dim=self.full_vol_dim,
                                           crop_size=self.crop_size,
                                           sub_vol_path=self.sub_vol_path,
                                           th_percent=self.threshold,
                                           normalization=args.normalization)

            self.full_volume = get_viz_set(list_IDsT1,
                                           list_IDsT2,
                                           labels,
                                           dataset_name="iseg2017")

        elif self.mode == 'test':
            self.list_IDsT1 = sorted(
                glob.glob(os.path.join(self.testing_path, '*T1.img')))
            self.list_IDsT2 = sorted(
                glob.glob(os.path.join(self.testing_path, '*T2.img')))
            self.labels = None
        elif self.mode == 'viz':
            list_IDsT1 = list_IDsT1[split_id:]
            list_IDsT2 = list_IDsT2[:split_id:]
            labels = labels[split_id:]
            self.full_volume = get_viz_set(list_IDsT1,
                                           list_IDsT2,
                                           labels,
                                           dataset_name="iseg2017")
            self.list = []
        utils.save_list(self.save_name, self.list)
Пример #12
0
    def __init__(
        self,
        args,
        mode,
        dataset_path="./datasets",
        classes=5,
        crop_dim=(200, 200, 150),
        split_idx=260,
        samples=10,
        load=False,
    ):
        """
        :param mode: 'train','val','test'
        :param dataset_path: root dataset folder
        :param crop_dim: subvolume tuple
        :param split_idx: 1 to 10 values
        :param samples: number of sub-volumes that you want to create
        """
        self.mode = mode
        self.root = str(dataset_path)
        self.training_path = self.root + "/brats2019/MICCAI_BraTS_2019_Data_Training/"
        self.testing_path = self.root + "/brats2019/MICCAI_BraTS_2019_Data_Validation/"
        self.full_vol_dim = (240, 240, 155)  # slice, width, height
        self.crop_size = crop_dim
        self.threshold = args.threshold
        self.normalization = args.normalization
        self.augmentation = args.augmentation
        self.list = []
        self.samples = samples
        self.full_volume = None
        self.classes = classes
        if self.augmentation:
            self.transform = augment3D.RandomChoice(
                transforms=[
                    augment3D.GaussianNoise(mean=0, std=0.01),
                    augment3D.RandomFlip(),
                    augment3D.ElasticTransform(),
                ],
                p=0.5,
            )
        self.save_name = os.path.join(
            self.root, "brats2019", f"brats2019-list-{mode}-samples-{samples}.txt"
        )

        if load:
            ## load pre-generated data
            self.list = utils.load_list(self.save_name)
            list_IDsT1 = sorted(
                glob.glob(os.path.join(self.training_path, "*GG/*/*t1.nii.gz"))
            )
            self.affine = img_loader.load_affine_matrix(list_IDsT1[0])
            return

        subvol = (
            "_vol_" + str(crop_dim[0]) + "x" + str(crop_dim[1]) + "x" + str(crop_dim[2])
        )
        self.sub_vol_path = (
            self.root
            + "/brats2019/MICCAI_BraTS_2019_Data_Training/generated/"
            + mode
            + subvol
            + "/"
        )
        utils.make_dirs(self.sub_vol_path)

        # split HGG and LGG
        HGG_IDsT1 = sorted(
            glob.glob(os.path.join(self.training_path, "HGG/*/*t1.nii.gz"))
        )
        HGG_IDsT1ce = sorted(
            glob.glob(os.path.join(self.training_path, "HGG/*/*t1ce.nii.gz"))
        )
        HGG_IDsT2 = sorted(
            glob.glob(os.path.join(self.training_path, "HGG/*/*t2.nii.gz"))
        )
        HGG_IDsFlair = sorted(
            glob.glob(os.path.join(self.training_path, "HGG/*/*_flair.nii.gz"))
        )
        HGG_labels = sorted(
            glob.glob(os.path.join(self.training_path, "HGG/*/*_seg.nii.gz"))
        )

        LGG_IDsT1 = sorted(
            glob.glob(os.path.join(self.training_path, "LGG/*/*t1.nii.gz"))
        )
        LGG_IDsT1ce = sorted(
            glob.glob(os.path.join(self.training_path, "LGG/*/*t1ce.nii.gz"))
        )
        LGG_IDsT2 = sorted(
            glob.glob(os.path.join(self.training_path, "LGG/*/*t2.nii.gz"))
        )
        LGG_IDsFlair = sorted(
            glob.glob(os.path.join(self.training_path, "LGG/*/*_flair.nii.gz"))
        )
        LGG_labels = sorted(
            glob.glob(os.path.join(self.training_path, "LGG/*/*_seg.nii.gz"))
        )

        (
            HGG_IDsT1,
            HGG_IDsT1ce,
            HGG_IDsT2,
            HGG_IDsFlair,
            HGG_labels,
        ) = utils.shuffle_lists(
            HGG_IDsT1, HGG_IDsT1ce, HGG_IDsT2, HGG_IDsFlair, HGG_labels, seed=17
        )

        (
            LGG_IDsT1,
            LGG_IDsT1ce,
            LGG_IDsT2,
            LGG_IDsFlair,
            LGG_labels,
        ) = utils.shuffle_lists(
            LGG_IDsT1, LGG_IDsT1ce, LGG_IDsT2, LGG_IDsFlair, LGG_labels, seed=17
        )

        self.affine = img_loader.load_affine_matrix((HGG_IDsT1 + LGG_IDsT1)[0])

        hgg_len = len(HGG_IDsT1)
        lgg_len = len(LGG_IDsT1)
        print("Brats2019, Training HGG:", hgg_len)
        print("Brats2019, Training LGG:", lgg_len)
        print("Brats2019, Training total:", hgg_len + lgg_len)

        hgg_split = int(hgg_len * 0.8)
        lgg_split = int(lgg_len * 0.8)

        if self.mode == "train":
            list_IDsT1 = HGG_IDsT1[:hgg_split] + LGG_IDsT1[:hgg_split]
            list_IDsT1ce = HGG_IDsT1ce[:hgg_split] + LGG_IDsT1ce[:hgg_split]
            list_IDsT2 = HGG_IDsT2[:hgg_split] + LGG_IDsT2[:hgg_split]
            list_IDsFlair = HGG_IDsFlair[:hgg_split] + LGG_IDsFlair[:hgg_split]
            labels = HGG_labels[:hgg_split] + LGG_labels[:hgg_split]
            self.list = create_sub_volumes(
                list_IDsT1,
                list_IDsT1ce,
                list_IDsT2,
                list_IDsFlair,
                labels,
                dataset_name="brats2019",
                mode=mode,
                samples=samples,
                full_vol_dim=self.full_vol_dim,
                crop_size=self.crop_size,
                sub_vol_path=self.sub_vol_path,
                th_percent=self.threshold,
            )

        elif self.mode == "val":
            list_IDsT1 = HGG_IDsT1[hgg_split:] + LGG_IDsT1[hgg_split:]
            list_IDsT1ce = HGG_IDsT1ce[hgg_split:] + LGG_IDsT1ce[hgg_split:]
            list_IDsT2 = HGG_IDsT2[hgg_split:] + LGG_IDsT2[hgg_split:]
            list_IDsFlair = HGG_IDsFlair[hgg_split:] + LGG_IDsFlair[hgg_split:]
            labels = HGG_labels[hgg_split:] + LGG_labels[hgg_split:]
            self.list = create_sub_volumes(
                list_IDsT1,
                list_IDsT1ce,
                list_IDsT2,
                list_IDsFlair,
                labels,
                dataset_name="brats2019",
                mode=mode,
                samples=samples,
                full_vol_dim=self.full_vol_dim,
                crop_size=self.crop_size,
                sub_vol_path=self.sub_vol_path,
                th_percent=self.threshold,
            )

        elif self.mode == "test":
            # self.list_IDsT1 = sorted(glob.glob(os.path.join(self.testing_path, '*GG/*/*t1.nii.gz')))
            # self.list_IDsT1ce = sorted(glob.glob(os.path.join(self.testing_path, '*GG/*/*t1ce.nii.gz')))
            # self.list_IDsT2 = sorted(glob.glob(os.path.join(self.testing_path, '*GG/*/*t2.nii.gz')))
            # self.list_IDsFlair = sorted(glob.glob(os.path.join(self.testing_path, '*GG/*/*_flair.nii.gz')))
            # self.labels = None

            list_IDsT1 = HGG_IDsT1[hgg_split:] + LGG_IDsT1[hgg_split:]
            list_IDsT1ce = HGG_IDsT1ce[hgg_split:] + LGG_IDsT1ce[hgg_split:]
            list_IDsT2 = HGG_IDsT2[hgg_split:] + LGG_IDsT2[hgg_split:]
            list_IDsFlair = HGG_IDsFlair[hgg_split:] + LGG_IDsFlair[hgg_split:]
            labels = HGG_labels[hgg_split:] + LGG_labels[hgg_split:]

            self.list = create_non_overlapping_sub_volumes(
                list_IDsT1,
                list_IDsT1ce,
                list_IDsT2,
                list_IDsFlair,
                labels,
                dataset_name="brats2019",
                mode=mode,
                samples=samples,
                full_vol_dim=self.full_vol_dim,
                crop_size=self.crop_size,
                sub_vol_path=self.sub_vol_path,
                th_percent=self.threshold,
            )

        utils.save_list(self.save_name, self.list)
Пример #13
0
def test():
    # args = mrbrains9_arguments(loadData=True)

    utils.reproducibility(args, seed)
    utils.make_dirs(args.save)

    params = {"batch_size": args.batchSz, "shuffle": True, "num_workers": 2}
    samples_train = args.samples_train
    samples_val = args.samples_val
    test_loader = MRIDatasetMRBRAINS2018(
        args,
        "test",
        dataset_path=dataset_dir,
        dim=args.dim,
        split_id=0,
        samples=samples_train,
        load=args.loadData,
    )

    model_name = args.model
    lr = args.lr
    in_channels = args.inChannels
    num_classes = args.classes
    weight_decay = 0.0000000001
    print("Building Model . . . . . . . ." + model_name)
    model = UNet3D(in_channels=in_channels, n_classes=num_classes, base_n_filter=8)
    print(
        model_name,
        "Number of params: {}".format(
            sum([p.data.nelement() for p in model.parameters()])
        ),
    )

    model.restore_checkpoint(
        "/home/kyle/results/UNET3D/mrbrains9_148_09-08_17-46/mrbrains9_148_09-08_17-46_BEST.pth"
    )

    # model = model.cuda()
    # print("Model transferred in GPU.....")

    print("TESTING...")

    # [[37507023, 290552, 0, 25074, 30, 1323040, 134, 20823, 10884, 0],
    #  [256417, 16475613, 1592518, 1920243, 2259, 2491037, 67078, 61665, 497, 0],
    #  [60651, 3655997, 594069, 16095428, 2494541, 102916, 456685, 4495, 4150, 0],
    #  [1225472, 1183528, 24771, 74524, 614, 13215040, 2649646, 104672, 72727, 0],
    #  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    #  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    #  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    #  [91439, 131011, 492, 8209, 0, 784077, 1, 6204313, 249132, 0],
    #  [9185, 3047, 0, 13415, 0, 62690, 0, 12931, 1513371, 0],
    #  [1471, 46677, 0, 5422, 10476, 32999, 1025, 12, 0, 0],

    model.eval()

    confusion_matrix = [[0] * (num_classes * 2) for i in range(num_classes * 2)]

    for batch_idx, input_tuple in enumerate(test_loader):
        with torch.no_grad():
            img_t1, img_t2, img_t3, target = input_tuple

            target = torch.reshape(torch.from_numpy(target), (-1, 1, 48, 48, 48))
            img_t1 = torch.reshape(torch.from_numpy(img_t1), (-1, 1, 48, 48, 48))
            img_t2 = torch.reshape(torch.from_numpy(img_t2), (-1, 1, 48, 48, 48))
            img_t3 = torch.reshape(torch.from_numpy(img_t3), (-1, 1, 48, 48, 48))

            input_tensor = torch.cat((img_t1, img_t2, img_t3), dim=1)

            input_tensor.requires_grad = False

            output = model(input_tensor)

            output = torch.argmax(output, dim=1)
            output = torch.reshape(output, (-1, 1, 48, 48, 48))

            assert target.size() == output.size()

            output = torch.reshape(output, (-1,)).tolist()
            target = torch.reshape(target, (-1,)).tolist()

            assert len(output) == len(target)

            for gt, pred in zip(target, output):
                confusion_matrix[int(gt)][int(pred)] += 1

    pprint(confusion_matrix)
Пример #14
0
    def __init__(
            self,
            args,
            mode,
            dataset_path="./datasets",
            crop_dim=(32, 32, 32),
            split_id=1,
            samples=1000,
            load=False,
    ):
        # split_id = int(split_id)
        fold_id = int(args.fold_id)
        print(f"using fold_id {fold_id}")
        """
        :param mode: 'train','val','test'
        :param dataset_path: root dataset folder
        :param crop_dim: subvolume tuple
        :param fold_id: 1 to 10 values
        :param samples: number of sub-volumes that you want to create
        """
        self.mode = mode
        self.root = str(dataset_path)
        self.training_path = self.root + "/iseg_2019/iSeg-2019-Training/"
        self.testing_path = self.root + "/iseg_2019/iSeg-2019-Validation/"
        self.CLASSES = 4
        self.full_vol_dim = (144, 192, 256)  # slice, width, height
        self.crop_size = crop_dim
        self.threshold = args.threshold
        self.normalization = args.normalization
        self.augmentation = args.augmentation
        self.list = []
        self.samples = int(samples)
        self.full_volume = None
        self.save_name = (self.root + "/iseg_2019/iseg2019-list-" + mode +
                          "-samples-" + str(samples) + ".txt")
        if self.augmentation:
            self.transform = augment3D.RandomChoice(
                transforms=[
                    augment3D.GaussianNoise(mean=0, std=0.01),
                    augment3D.RandomFlip(),
                    augment3D.ElasticTransform(),
                ],
                p=0.5,
            )
        if load:
            ## load pre-generated data
            self.list = utils.load_list(self.save_name)
            list_IDsT1 = sorted(
                glob.glob(os.path.join(self.training_path, "*T1.img")))
            self.affine = img_loader.load_affine_matrix(list_IDsT1[0])
            return

        subvol = ("_vol_" + str(crop_dim[0]) + "x" + str(crop_dim[1]) + "x" +
                  str(crop_dim[2]))
        self.sub_vol_path = self.root + "/iseg_2019/generated/" + mode + subvol + "/"
        utils.make_dirs(self.sub_vol_path)

        list_IDsT1 = sorted(
            glob.glob(os.path.join(self.training_path, "*T1.img")))
        list_IDsT2 = sorted(
            glob.glob(os.path.join(self.training_path, "*T2.img")))
        labels = sorted(
            glob.glob(os.path.join(self.training_path, "*label.img")))
        print(self.training_path)
        self.affine = img_loader.load_affine_matrix(list_IDsT1[0])

        if self.mode == "train":
            # custom code
            # list_IDsT1 = list_IDsT1[:split_id]
            # list_IDsT2 = list_IDsT2[:split_id]
            # labels = labels[:split_id]

            list_IDsT1 = [x for x in list_IDsT1 if f"-{fold_id}-" not in x]
            list_IDsT2 = [x for x in list_IDsT2 if f"-{fold_id}-" not in x]
            labels = [x for x in labels if f"-{fold_id}-" not in x]

            assert len(labels) == len(list_IDsT1)
            assert len(labels) == len(list_IDsT2)
            assert len(labels) == 9

            self.list = create_sub_volumes(
                list_IDsT1,
                list_IDsT2,
                labels,
                dataset_name="iseg2019",
                mode=mode,
                samples=samples,
                full_vol_dim=self.full_vol_dim,
                crop_size=self.crop_size,
                sub_vol_path=self.sub_vol_path,
                th_percent=self.threshold,
            )

        elif self.mode == "val":
            # list_IDsT1 = list_IDsT1[split_id:]
            # list_IDsT2 = list_IDsT2[split_id:]
            # labels = labels[split_id:]

            list_IDsT1 = [x for x in list_IDsT1 if f"-{fold_id}-" in x]
            list_IDsT2 = [x for x in list_IDsT2 if f"-{fold_id}-" in x]
            labels = [x for x in labels if f"-{fold_id}-" in x]
            assert len(labels) == len(list_IDsT1)
            assert len(labels) == len(list_IDsT2)
            assert len(labels) == 1

            self.list = create_sub_volumes(
                list_IDsT1,
                list_IDsT2,
                labels,
                dataset_name="iseg2019",
                mode=mode,
                samples=samples,
                full_vol_dim=self.full_vol_dim,
                crop_size=self.crop_size,
                sub_vol_path=self.sub_vol_path,
                th_percent=self.threshold,
            )

            self.full_volume = get_viz_set(list_IDsT1,
                                           list_IDsT2,
                                           labels,
                                           dataset_name="iseg2019")

        elif self.mode == "test":
            # self.list_IDsT1 = sorted(glob.glob(os.path.join(self.testing_path, '*T1.img')))
            # self.list_IDsT2 = sorted(glob.glob(os.path.join(self.testing_path, '*T2.img')))
            # self.labels = None
            # todo inference here

            list_IDsT1 = [x for x in list_IDsT1 if f"-{fold_id}-" in x]
            list_IDsT2 = [x for x in list_IDsT2 if f"-{fold_id}-" in x]
            labels = [x for x in labels if f"-{fold_id}-" in x]
            assert len(labels) == len(list_IDsT1)
            assert len(labels) == len(list_IDsT2)
            assert len(labels) == 1

            self.list = create_non_overlapping_sub_volumes(
                list_IDsT1,
                list_IDsT2,
                labels,
                dataset_name="iseg2019",
                mode=mode,
                samples=samples,
                full_vol_dim=self.full_vol_dim,
                crop_size=self.crop_size,
                sub_vol_path=self.sub_vol_path,
                th_percent=self.threshold,
            )

            self.full_volume = get_viz_set(list_IDsT1,
                                           list_IDsT2,
                                           labels,
                                           dataset_name="iseg2019")

        utils.save_list(self.save_name, self.list)
Пример #15
0
    def __init__(self,
                 args,
                 mode,
                 dataset_path='./datasets',
                 classes=5,
                 crop_dim=(32, 32, 32),
                 split_idx=10,
                 samples=10,
                 load=False):
        """
        :param mode: 'train','val','test'
        :param dataset_path: root dataset folder
        :param crop_dim: subvolume tuple
        :param split_idx: 1 to 10 values
        :param samples: number of sub-volumes that you want to create
        """
        self.mode = mode
        self.root = str(dataset_path)
        self.training_path = self.root + '/MICCAI_BraTS_2018_Data_Training/'
        self.testing_path = self.root + ' '
        self.CLASSES = 4
        self.full_vol_dim = (240, 240, 155)  # slice, width, height
        self.crop_size = crop_dim
        self.threshold = args.threshold
        self.normalization = args.normalization
        self.augmentation = args.augmentation
        self.list = []
        self.samples = samples
        self.full_volume = None
        self.classes = classes
        self.save_name = self.root + '/MICCAI_BraTS_2018_Data_Training/brats2018-list-' + mode + '-samples-' + str(
            samples) + '.txt'
        if self.augmentation:
            self.transform = augment3D.RandomChoice(transforms=[
                augment3D.GaussianNoise(mean=0, std=0.01),
                augment3D.RandomFlip(),
                augment3D.ElasticTransform()
            ],
                                                    p=0.5)
        if load:
            ## load pre-generated data
            list_IDsT1 = sorted(
                glob.glob(os.path.join(self.training_path,
                                       '*GG/*/*t1.nii.gz')))
            self.affine = img_loader.load_affine_matrix(list_IDsT1[0])
            self.list = utils.load_list(self.save_name)
            return

        subvol = '_vol_' + str(crop_dim[0]) + 'x' + str(
            crop_dim[1]) + 'x' + str(crop_dim[2])
        self.sub_vol_path = self.root + '/MICCAI_BraTS_2018_Data_Training/generated/' + mode + subvol + '/'
        utils.make_dirs(self.sub_vol_path)

        list_IDsT1 = sorted(
            glob.glob(os.path.join(self.training_path, '*GG/*/*t1.nii.gz')))
        list_IDsT1ce = sorted(
            glob.glob(os.path.join(self.training_path, '*GG/*/*t1ce.nii.gz')))
        list_IDsT2 = sorted(
            glob.glob(os.path.join(self.training_path, '*GG/*/*t2.nii.gz')))
        list_IDsFlair = sorted(
            glob.glob(os.path.join(self.training_path,
                                   '*GG/*/*_flair.nii.gz')))
        labels = sorted(
            glob.glob(os.path.join(self.training_path, '*GG/*/*_seg.nii.gz')))
        # print(len(list_IDsT1),len(list_IDsT2),len(list_IDsFlair),len(labels))

        self.affine = img_loader.load_affine_matrix(list_IDsT1[0])

        if self.mode == 'train':
            list_IDsT1 = list_IDsT1[:split_idx]
            list_IDsT1ce = list_IDsT1ce[:split_idx]
            list_IDsT2 = list_IDsT2[:split_idx]
            list_IDsFlair = list_IDsFlair[:split_idx]
            labels = labels[:split_idx]

            self.list = create_sub_volumes(list_IDsT1,
                                           list_IDsT1ce,
                                           list_IDsT2,
                                           list_IDsFlair,
                                           labels,
                                           dataset_name="brats2018",
                                           mode=mode,
                                           samples=samples,
                                           full_vol_dim=self.full_vol_dim,
                                           crop_size=self.crop_size,
                                           sub_vol_path=self.sub_vol_path,
                                           normalization=self.normalization,
                                           th_percent=self.threshold)
        elif self.mode == 'val':
            list_IDsT1 = list_IDsT1[split_idx:]
            list_IDsT1ce = list_IDsT1ce[split_idx:]
            list_IDsT2 = list_IDsT2[split_idx:]
            list_IDsFlair = list_IDsFlair[split_idx:]
            labels = labels[split_idx:]
            self.list = create_sub_volumes(list_IDsT1,
                                           list_IDsT1ce,
                                           list_IDsT2,
                                           list_IDsFlair,
                                           labels,
                                           dataset_name="brats2018",
                                           mode=mode,
                                           samples=samples,
                                           full_vol_dim=self.full_vol_dim,
                                           crop_size=self.crop_size,
                                           sub_vol_path=self.sub_vol_path,
                                           normalization=self.normalization,
                                           th_percent=self.threshold)

        elif self.mode == 'test':
            self.list_IDsT1 = sorted(
                glob.glob(os.path.join(self.testing_path, '*GG/*/*t1.nii.gz')))
            self.list_IDsT1ce = sorted(
                glob.glob(os.path.join(self.testing_path,
                                       '*GG/*/*t1ce.nii.gz')))
            self.list_IDsT2 = sorted(
                glob.glob(os.path.join(self.testing_path, '*GG/*/*t2.nii.gz')))
            self.list_IDsFlair = sorted(
                glob.glob(
                    os.path.join(self.testing_path, '*GG/*/*_flair.nii.gz')))
            self.labels = None

        utils.save_list(self.save_name, self.list)
Пример #16
0
    def __init__(self,
                 args,
                 mode,
                 dataset_path='../datasets',
                 classes=4,
                 dim=(32, 32, 32),
                 split_id=0,
                 samples=1000,
                 load=False):

        fold_id = args.fold_id  # one of 070  1  14  148  4  5  7
        print(f'using fold_id {fold_id}')

        self.mode = mode
        self.root = dataset_path
        self.classes = classes
        dataset_name = f'mrbrains{classes}'
        self.training_path = os.path.join(self.root, 'mrbrains_2018',
                                          'training')
        self.dirs = os.listdir(self.training_path)
        self.samples = samples
        self.list = []
        self.full_vol_size = (240, 240, 48)
        self.threshold = 0.1
        self.crop_dim = dim
        self.list_flair = []
        self.list_ir = []
        self.list_reg_ir = []
        self.list_reg_t1 = []
        self.labels = []
        self.full_volume = None
        self.save_name = os.path.join(
            self.training_path,
            f'mrbrains_2018-classes-{classes}-list-{mode}-samples-{samples}.txt'
        )

        list_reg_t1 = sorted(
            glob.glob(os.path.join(self.training_path, '*/pr*/*g_T1.nii.gz')))
        list_reg_ir = sorted(
            glob.glob(os.path.join(self.training_path, '*/pr*/*g_IR.nii.gz')))
        list_flair = sorted(
            glob.glob(os.path.join(self.training_path, '*/pr*/*AIR.nii.gz')))
        labels = sorted(
            glob.glob(os.path.join(self.training_path, '*/*egm.nii.gz')))

        self.affine = img_loader.load_affine_matrix(list_reg_t1[0])

        if load:
            ## load pre-generated data
            self.list = utils.load_list(self.save_name)
            return

        self.sub_vol_path = os.path.join(
            self.root, 'mrbrains_2018', 'generated',
            f'{mode}_vol_{dim[0]}x{dim[1]}x{dim[2]}') + '/'
        utils.make_dirs(self.sub_vol_path)

        split_id = int(split_id)
        if mode == 'val':
            # labels = [labels[split_id]]
            # list_reg_t1 = [list_reg_t1[split_id]]
            # list_reg_ir = [list_reg_ir[split_id]]
            # list_flair = [list_flair[split_id]]

            labels = [x for x in labels if f'/{fold_id}/' in x]
            list_reg_t1 = [x for x in list_reg_t1 if f'/{fold_id}/' in x]
            list_reg_ir = [x for x in list_reg_ir if f'/{fold_id}/' in x]
            list_flair = [x for x in list_flair if f'/{fold_id}/' in x]

            assert len(labels) == len(list_reg_t1)
            assert len(labels) == len(list_reg_ir)
            assert len(labels) == len(list_flair)
            assert len(labels) == 1

            self.full_volume = get_viz_set(list_reg_t1,
                                           list_reg_ir,
                                           list_flair,
                                           labels,
                                           dataset_name=dataset_name)
        elif mode == 'train':
            # labels.pop(split_id)
            # list_reg_t1.pop(split_id)
            # list_reg_ir.pop(split_id)
            # list_flair.pop(split_id)

            labels = [x for x in labels if f'/{fold_id}/' not in x]
            list_reg_t1 = [x for x in list_reg_t1 if f'/{fold_id}/' not in x]
            list_reg_ir = [x for x in list_reg_ir if f'/{fold_id}/' not in x]
            list_flair = [x for x in list_flair if f'/{fold_id}/' not in x]

            assert len(labels) == len(list_reg_t1)
            assert len(labels) == len(list_reg_ir)
            assert len(labels) == len(list_flair)
            assert len(labels) == 6
        else:
            labels = [x for x in labels if f'/{fold_id}/' in x]
            list_reg_t1 = [x for x in list_reg_t1 if f'/{fold_id}/' in x]
            list_reg_ir = [x for x in list_reg_ir if f'/{fold_id}/' in x]
            list_flair = [x for x in list_flair if f'/{fold_id}/' in x]

            assert len(labels) == len(list_reg_t1)
            assert len(labels) == len(list_reg_ir)
            assert len(labels) == len(list_flair)
            assert len(labels) == 1

        if mode == 'test':
            self.list = create_non_overlapping_sub_volumes(
                list_reg_t1,
                list_reg_ir,
                list_flair,
                labels,
                dataset_name=dataset_name,
                mode=mode,
                samples=samples,
                full_vol_dim=self.full_vol_size,
                crop_size=self.crop_dim,
                sub_vol_path=self.sub_vol_path,
                th_percent=self.threshold)
        else:
            self.list = create_sub_volumes(list_reg_t1,
                                           list_reg_ir,
                                           list_flair,
                                           labels,
                                           dataset_name=dataset_name,
                                           mode=mode,
                                           samples=samples,
                                           full_vol_dim=self.full_vol_size,
                                           crop_size=self.crop_dim,
                                           sub_vol_path=self.sub_vol_path,
                                           th_percent=self.threshold)

        utils.save_list(self.save_name, self.list)
Пример #17
0
def test():
    # args = iseg2019_arguments()
    print(args)

    utils.reproducibility(args, seed)
    utils.make_dirs(args.save)

    params = {"batch_size": args.batchSz, "shuffle": True, "num_workers": 2}
    print(params)
    samples_train = args.samples_train
    samples_val = args.samples_val
    test_loader = MRIDatasetISEG2019(
        args,
        "test",
        dataset_path=dataset_dir,
        crop_dim=args.dim,
        split_id=0,
        samples=samples_train,
        load=args.loadData,
    )

    model_name = args.model
    lr = args.lr
    in_channels = args.inChannels
    num_classes = args.classes
    weight_decay = 0.0000000001
    print("Building Model . . . . . . . ." + model_name)
    model = UNet3D(in_channels=in_channels, n_classes=num_classes, base_n_filter=8)
    print(
        model_name,
        "Number of params: {}".format(
            sum([p.data.nelement() for p in model.parameters()])
        ),
    )

    model.restore_checkpoint(
        "/home/kyle/results/UNET3D/iseg2019_9_06-08_21-25/iseg2019_9_06-08_21-25_BEST.pth"
    )
    criterion = DiceLoss(classes=args.classes)

    # model = model.cuda()
    # print("Model transferred in GPU.....")

    print("TESTING...")

    model.eval()

    confusion_matrix = [[0] * 4 for i in range(4)]

    for batch_idx, input_tuple in enumerate(test_loader):
        with torch.no_grad():
            img_t1, img_t2, target = input_tuple

            target = torch.reshape(target, (-1, 1, 64, 64, 64))
            img_t1 = torch.reshape(img_t1, (-1, 1, 64, 64, 64))
            img_t2 = torch.reshape(img_t2, (-1, 1, 64, 64, 64))

            input_tensor = torch.cat((img_t1, img_t2), dim=1)
            # print(input_tensor.size())

            input_tensor.requires_grad = False

            output = model(input_tensor)

            output = torch.argmax(output, dim=1)
            output = torch.reshape(output, (-1, 1, 64, 64, 64))

            assert target.size() == output.size()

            output = torch.reshape(output, (-1,)).tolist()
            target = torch.reshape(target, (-1,)).tolist()

            assert len(output) == len(target)

            for gt, pred in zip(target, output):
                confusion_matrix[int(gt)][int(pred)] += 1

    pprint(confusion_matrix)