Example #1
0
 def save_multilabel_img_as_multiple_files_endings(HP, img, affine, path):
     bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
     for idx, bundle in enumerate(bundles):
         img_seg = nib.Nifti1Image(img[:, :, :, idx], affine)
         ExpUtils.make_dir(join(path, "endings_segmentations"))
         nib.save(img_seg,
                  join(path, "endings_segmentations", bundle + ".nii.gz"))
Example #2
0
 def load_model(path):
     ExpUtils.print_verbose(self.HP,
                            "Loading weights ... ({})".format(path))
     with np.load(
             path
     ) as f:  #if both pathes are absolute and beginning of pathes are the same, join will merge the beginning
         param_values = [f['arr_%d' % i] for i in range(len(f.files))]
     L.layers.set_all_param_values(output_layer_for_loss, param_values)
Example #3
0
    def save_fusion_nifti_as_npy():

        #Can leave this always the same (for 270g and 32g)
        class HP:
            DATASET = "HCP"
            RESOLUTION = "1.25mm"
            FEATURES_FILENAME = "270g_125mm_peaks"
            LABELS_TYPE = np.int16
            LABELS_FILENAME = "bundle_masks"
            DATASET_FOLDER = "HCP"

        #change this for 270g and 32g
        DIFFUSION_FOLDER = "32g_25mm"

        subjects = get_all_subjects()
        # fold0 = ['687163', '685058', '683256', '680957', '679568', '677968', '673455', '672756', '665254', '654754', '645551', '644044', '638049', '627549', '623844', '622236', '620434', '613538', '601127', '599671', '599469']
        # fold1 = ['992774', '991267', '987983', '984472', '983773', '979984', '978578', '965771', '965367', '959574', '958976', '957974', '951457', '932554', '930449', '922854', '917255', '912447', '910241', '907656', '904044']
        # fold2 = ['901442', '901139', '901038', '899885', '898176', '896879', '896778', '894673', '889579', '887373', '877269', '877168', '872764', '872158', '871964', '871762', '865363', '861456', '859671', '857263', '856766']
        # fold3 = ['849971', '845458', '837964', '837560', '833249', '833148', '826454', '826353', '816653', '814649', '802844', '792766', '792564', '789373', '786569', '784565', '782561', '779370', '771354', '770352', '765056']
        # fold4 = ['761957', '759869', '756055', '753251', '751348', '749361', '748662', '748258', '742549', '734045', '732243', '729557', '729254', '715647', '715041', '709551', '705341', '704238', '702133', '695768', '690152']
        # subjects = fold2 + fold3 + fold4

        # subjects = ['654754', '645551', '644044', '638049', '627549', '623844', '622236', '620434', '613538', '601127', '599671', '599469']

        print("\n\nProcessing Data...")
        for s in subjects:
            print("processing data subject {}".format(s))
            start_time = time.time()
            data = nib.load(
                join(C.NETWORK_DRIVE, "HCP_fusion_" + DIFFUSION_FOLDER,
                     s + "_probmap.nii.gz")).get_data()
            print("Done Loading")
            data = np.nan_to_num(data)
            data = DatasetUtils.scale_input_to_unet_shape(
                data, HP.DATASET, HP.RESOLUTION)
            data = data[:-1, :, :
                        -1, :]  # cut one pixel at the end, because in scale_input_to_world_shape we ouputted 146 -> one too much at the end
            ExpUtils.make_dir(
                join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s))
            np.save(
                join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s,
                     DIFFUSION_FOLDER + "_xyz.npy"), data)
            print("Took {}s".format(time.time() - start_time))

            print("processing seg subject {}".format(s))
            start_time = time.time()
            # seg = ImgUtils.create_multilabel_mask(HP, s, labels_type=HP.LABELS_TYPE)
            seg = nib.load(
                join(C.NETWORK_DRIVE, "HCP_for_training_COPY", s,
                     HP.LABELS_FILENAME + ".nii.gz")).get_data()
            if HP.RESOLUTION == "2.5mm":
                seg = ImgUtils.resize_first_three_dims(seg, order=0, zoom=0.5)
            seg = DatasetUtils.scale_input_to_unet_shape(
                seg, HP.DATASET, HP.RESOLUTION)
            np.save(
                join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s,
                     "bundle_masks.npy"), seg)
            print("Took {}s".format(time.time() - start_time))
Example #4
0
    def precompute_batches(custom_type=None):
        '''
        9000 slices per epoch -> 200 batches (batchsize=44) per epoch
        => 200-1000 batches needed


        270g_125mm_bundle_peaks_Y: no DAug, no Norm, only Y
        All_sizes_DAug_XYZ: 12g, 90g, 270g, DAug (no rotation, no elastic deform), Norm, XYZ
        270g_125mm_bundle_peaks_XYZ: no DAug, Norm, XYZ
        '''

        class HP:
            NORMALIZE_DATA = True
            DATA_AUGMENTATION = False
            CV_FOLD = 0
            INPUT_DIM = (144, 144)
            BATCH_SIZE = 44
            DATASET_FOLDER = "HCP"
            TYPE = "single_direction"
            EXP_PATH = "~"
            LABELS_FILENAME = "bundle_peaks"
            FEATURES_FILENAME = "270g_125mm_peaks"
            DATASET = "HCP"
            RESOLUTION = "1.25mm"
            LABELS_TYPE = np.float32

        HP.TRAIN_SUBJECTS, HP.VALIDATE_SUBJECTS, HP.TEST_SUBJECTS = ExpUtils.get_cv_fold(HP.CV_FOLD)

        num_batches_base = 5000
        num_batches = {
            "train": num_batches_base,
            "validate": int(num_batches_base / 3.),
            "test": int(num_batches_base / 3.),
        }

        if custom_type is None:
            types = ["train", "validate", "test"]
        else:
            types = [custom_type]

        for type in types:
            dataManager = DataManagerTrainingNiftiImgs(HP)
            batch_gen = dataManager.get_batches(batch_size=HP.BATCH_SIZE, type=type,
                                                subjects=getattr(HP, type.upper() + "_SUBJECTS"), num_batches=num_batches[type])

            for idx, batch in enumerate(batch_gen):
                print("Processing: {}".format(idx))

                # DATASET_DIR = "HCP_batches/270g_125mm_bundle_peaks_Y"
                # DATASET_DIR = "HCP_batches/All_sizes_DAug_XYZ"
                DATASET_DIR = "HCP_batches/270g_125mm_bundle_peaks_XYZ"
                ExpUtils.make_dir(join(C.HOME, DATASET_DIR, type))

                data = nib.Nifti1Image(batch["data"], ImgUtils.get_dwi_affine(HP.DATASET, HP.RESOLUTION))
                nib.save(data, join(C.HOME, DATASET_DIR, type, "batch_" + str(idx) + "_data.nii.gz"))

                seg = nib.Nifti1Image(batch["seg"], ImgUtils.get_dwi_affine(HP.DATASET, HP.RESOLUTION))
                nib.save(seg, join(C.HOME, DATASET_DIR, type, "batch_" + str(idx) + "_seg.nii.gz"))
Example #5
0
    def save_multilabel_img_as_multiple_files_peaks(HP, img, affine, path):
        bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
        for idx, bundle in enumerate(bundles):
            data = img[:, :, :, (idx * 3):(idx * 3) + 3]

            if HP.FLIP_OUTPUT_PEAKS:
                data[:, :, :, 2] *= -1  # flip z Axis for correct view in MITK
                filename = bundle + "_f.nii.gz"
            else:
                filename = bundle + ".nii.gz"

            img_seg = nib.Nifti1Image(data, affine)
            ExpUtils.make_dir(join(path, "TOM"))
            nib.save(img_seg, join(path, "TOM", filename))
Example #6
0
    def save_multilabel_img_as_multiple_files_peaks(HP, img, affine, path):
        bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
        for idx, bundle in enumerate(bundles):
            data = img[:, :, :, (idx*3):(idx*3)+3]

            if HP.FLIP_OUTPUT_PEAKS:
                data[:, :, :, 2] *= -1  # flip z Axis for correct view in MITK
                filename = bundle + "_f.nii.gz"
            else:
                filename = bundle + ".nii.gz"

            img_seg = nib.Nifti1Image(data, affine)
            ExpUtils.make_dir(join(path, "TOM"))
            nib.save(img_seg, join(path, "TOM", filename))
Example #7
0
    def calc_peak_dice_onlySeg(HP, y_pred, y_true):
        '''
        Create binary mask of peaks by simple thresholding. Then calculate Dice.

        :param y_pred:
        :param y_true:
        :return:
        '''

        score_per_bundle = {}
        bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
        for idx, bundle in enumerate(bundles):
            y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3]
            y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3]  # [x,y,z,3]

            # 0.1 -> keep some outliers, but also some holes already; 0.2 also ok (still looks like e.g. CST)
            #  Resulting dice for 0.1 and 0.2 very similar
            y_pred_binary = np.abs(y_pred_bund).sum(axis=-1) > 0.2
            y_true_binary = np.abs(y_true_bund).sum(axis=-1) > 1e-3

            f1 = f1_score(y_true_binary.flatten(),
                          y_pred_binary.flatten(),
                          average="binary")
            score_per_bundle[bundle] = f1

        return score_per_bundle
Example #8
0
    def create_multilabel_mask(HP,
                               subject,
                               labels_type=np.int16,
                               dataset_folder="HCP",
                               labels_folder="bundle_masks"):
        '''
        One-hot encoding of all bundles in one big image
        :param subject:
        :return: image of shape (x, y, z, nr_of_bundles + 1)
        '''
        bundles = ExpUtils.get_bundle_names(HP.CLASSES)

        #Masks sind immer HCP_highRes (später erst downsample)
        mask_ml = np.zeros((145, 174, 145, len(bundles)))
        background = np.ones(
            (145, 174, 145))  # everything that contains no bundle

        for idx, bundle in enumerate(
                bundles[1:]
        ):  #first bundle is background -> already considered by setting np.ones in the beginning
            mask = nib.load(
                join(C.HOME, dataset_folder, subject, labels_folder,
                     bundle + ".nii.gz"))
            mask_data = mask.get_data()  # dtype: uint8
            mask_ml[:, :, :, idx + 1] = mask_data
            background[mask_data ==
                       1] = 0  # remove this bundle from background

        mask_ml[:, :, :, 0] = background
        return mask_ml.astype(labels_type)
Example #9
0
    def save_fusion_nifti_as_npy():

        #Can leave this always the same (for 270g and 32g)
        class HP:
            DATASET = "HCP"
            RESOLUTION = "1.25mm"
            FEATURES_FILENAME = "270g_125mm_peaks"
            LABELS_TYPE = np.int16
            LABELS_FILENAME = "bundle_masks"
            DATASET_FOLDER = "HCP"

        #change this for 270g and 32g
        DIFFUSION_FOLDER = "32g_25mm"

        subjects = get_all_subjects()
        # fold0 = ['687163', '685058', '683256', '680957', '679568', '677968', '673455', '672756', '665254', '654754', '645551', '644044', '638049', '627549', '623844', '622236', '620434', '613538', '601127', '599671', '599469']
        # fold1 = ['992774', '991267', '987983', '984472', '983773', '979984', '978578', '965771', '965367', '959574', '958976', '957974', '951457', '932554', '930449', '922854', '917255', '912447', '910241', '907656', '904044']
        # fold2 = ['901442', '901139', '901038', '899885', '898176', '896879', '896778', '894673', '889579', '887373', '877269', '877168', '872764', '872158', '871964', '871762', '865363', '861456', '859671', '857263', '856766']
        # fold3 = ['849971', '845458', '837964', '837560', '833249', '833148', '826454', '826353', '816653', '814649', '802844', '792766', '792564', '789373', '786569', '784565', '782561', '779370', '771354', '770352', '765056']
        # fold4 = ['761957', '759869', '756055', '753251', '751348', '749361', '748662', '748258', '742549', '734045', '732243', '729557', '729254', '715647', '715041', '709551', '705341', '704238', '702133', '695768', '690152']
        # subjects = fold2 + fold3 + fold4

        # subjects = ['654754', '645551', '644044', '638049', '627549', '623844', '622236', '620434', '613538', '601127', '599671', '599469']

        print("\n\nProcessing Data...")
        for s in subjects:
            print("processing data subject {}".format(s))
            start_time = time.time()
            data = nib.load(join(C.NETWORK_DRIVE, "HCP_fusion_" + DIFFUSION_FOLDER, s + "_probmap.nii.gz")).get_data()
            print("Done Loading")
            data = np.nan_to_num(data)
            data = DatasetUtils.scale_input_to_unet_shape(data, HP.DATASET, HP.RESOLUTION)
            data = data[:-1, :, :-1, :]  # cut one pixel at the end, because in scale_input_to_world_shape we ouputted 146 -> one too much at the end
            ExpUtils.make_dir(join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s))
            np.save(join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s, DIFFUSION_FOLDER + "_xyz.npy"), data)
            print("Took {}s".format(time.time() - start_time))

            print("processing seg subject {}".format(s))
            start_time = time.time()
            # seg = ImgUtils.create_multilabel_mask(HP, s, labels_type=HP.LABELS_TYPE)
            seg = nib.load(join(C.NETWORK_DRIVE, "HCP_for_training_COPY", s, HP.LABELS_FILENAME + ".nii.gz")).get_data()
            if HP.RESOLUTION == "2.5mm":
                seg = ImgUtils.resize_first_three_dims(seg, order=0, zoom=0.5)
            seg = DatasetUtils.scale_input_to_unet_shape(seg, HP.DATASET, HP.RESOLUTION)
            np.save(join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s, "bundle_masks.npy"), seg)
            print("Took {}s".format(time.time() - start_time))
Example #10
0
    def copy_training_files_to_ssd(HP, data_path):

        def id_generator(size=6, chars=string.ascii_uppercase + string.digits):
            return ''.join(random.choice(chars) for _ in range(size))

        target_data_path = join("/ssd/", "tmp_" + id_generator(), HP.DATASET_FOLDER)
        ExpUtils.make_dir(join(target_data_path))

        #get all folders in data_path directory
        subjects = [os.path.basename(os.path.normpath(d)) for d in glob(data_path + "/*/")]

        for subject in subjects:
            src = join(data_path, subject, HP.FEATURES_FILENAME)
            target = join(target_data_path, subject, HP.FEATURES_FILENAME)
            print("cp: {} -> {}".format(src, target))
            # shutil.copyfile(src, target)

        return target_data_path
Example #11
0
def soft_dice_paul(HP, idxs, marker, preds, ys):
    n_classes = len(ExpUtils.get_bundle_names(HP.CLASSES))
    dice = T.constant(0)
    for cl in range(n_classes):
        pred = preds[marker, cl, :, :]
        y = ys[marker, cl, :, :]
        intersect = T.sum(pred * y)
        denominator = T.sum(pred) + T.sum(y)
        dice += T.constant(2) * intersect / (denominator + T.constant(1e-6))
    return 1 - (dice / n_classes)
Example #12
0
def soft_dice_paul(HP, idxs, marker, preds, ys):
    n_classes = len(ExpUtils.get_bundle_names(HP.CLASSES))
    dice = T.constant(0)
    for cl in range(n_classes):
        pred = preds[marker, cl, :, :]
        y = ys[marker, cl, :, :]
        intersect = T.sum(pred * y)
        denominator = T.sum(pred) + T.sum(y)
        dice += T.constant(2) * intersect / (denominator + T.constant(1e-6))
    return 1 - (dice / n_classes)
Example #13
0
    def calc_peak_length_dice_pytorch(HP, y_pred, y_true, max_angle_error=[0.9], max_length_error=0.1):
        '''
        Ca

        :param y_pred:
        :param y_true:
        :param max_angle_error:  0.7 ->  angle error of 45° or less; 0.9 ->  angle error of 23° or less
                                 Can be list with several values -> calculate for several thresholds
        :return:
        '''
        import torch
        from tractseg.libs.PytorchEinsum import einsum
        from tractseg.libs.PytorchUtils import PytorchUtils

        y_true = y_true.permute(0, 2, 3, 1)
        y_pred = y_pred.permute(0, 2, 3, 1)

        def angle_last_dim(a, b):
            '''
            Calculate the angle between two nd-arrays (array of vectors) along the last dimension

            without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90°
            np.arccos -> returns degree in pi (90°: 0.5*pi)

            return: one dimension less then input
            '''
            return torch.abs(einsum('abcd,abcd->abc', a, b) / (torch.norm(a, 2., -1) * torch.norm(b, 2, -1) + 1e-7))

        #Single threshold
        score_per_bundle = {}
        bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
        for idx, bundle in enumerate(bundles):
            # if bundle == "CST_right":
            y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()
            y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()      # [x,y,z,3]

            angles = angle_last_dim(y_pred_bund, y_true_bund)

            lenghts_pred = torch.norm(y_pred_bund, 2., -1)
            lengths_true = torch.norm(y_true_bund, 2, -1)
            lengths_binary = torch.abs(lenghts_pred-lengths_true) < (max_length_error * lengths_true)
            lengths_binary = lengths_binary.view(-1)

            gt_binary = y_true_bund.sum(dim=-1) > 0
            gt_binary = gt_binary.view(-1)  # [bs*x*y]

            angles_binary = angles > max_angle_error[0]
            angles_binary = angles_binary.view(-1)

            combined = lengths_binary * angles_binary

            f1 = PytorchUtils.f1_score_binary(gt_binary, combined)
            score_per_bundle[bundle] = f1
        return score_per_bundle
Example #14
0
    def copy_training_files_to_ssd(HP, data_path):
        def id_generator(size=6, chars=string.ascii_uppercase + string.digits):
            return ''.join(random.choice(chars) for _ in range(size))

        target_data_path = join("/ssd/", "tmp_" + id_generator(),
                                HP.DATASET_FOLDER)
        ExpUtils.make_dir(join(target_data_path))

        #get all folders in data_path directory
        subjects = [
            os.path.basename(os.path.normpath(d))
            for d in glob(data_path + "/*/")
        ]

        for subject in subjects:
            src = join(data_path, subject, HP.FEATURES_FILENAME)
            target = join(target_data_path, subject, HP.FEATURES_FILENAME)
            print("cp: {} -> {}".format(src, target))
            # shutil.copyfile(src, target)

        return target_data_path
Example #15
0
    def calc_peak_length_dice(HP,
                              y_pred,
                              y_true,
                              max_angle_error=[0.9],
                              max_length_error=0.1):
        '''

        :param y_pred:
        :param y_true:
        :param max_angle_error:  0.7 ->  angle error of 45° or less; 0.9 ->  angle error of 23° or less
        :return:
        '''
        def angle_last_dim(a, b):
            '''
            Calculate the angle between two nd-arrays (array of vectors) along the last dimension

            without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90°
            np.arccos -> returns degree in pi (90°: 0.5*pi)

            return: one dimension less then input
            '''
            # print(np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1))
            return abs(
                np.einsum('...i,...i', a, b) /
                (np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1) +
                 1e-7))

        score_per_bundle = {}
        bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
        for idx, bundle in enumerate(bundles):
            y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3]
            y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3]  # [x,y,z,3]

            angles = angle_last_dim(y_pred_bund, y_true_bund)

            lenghts_pred = np.linalg.norm(y_pred_bund, axis=-1)
            lengths_true = np.linalg.norm(y_true_bund, axis=-1)
            lengths_binary = abs(lenghts_pred - lengths_true) < (
                max_length_error * lengths_true)
            lengths_binary = lengths_binary.flatten()

            gt_binary = y_true_bund.sum(axis=-1) > 0
            gt_binary = gt_binary.flatten()  # [bs*x*y]

            angles_binary = angles > max_angle_error[0]
            angles_binary = angles_binary.flatten()

            combined = lengths_binary * angles_binary

            f1 = MetricUtils.my_f1_score(gt_binary, combined)
            score_per_bundle[bundle] = f1
        return score_per_bundle
Example #16
0
    def save_multilabel_img_as_multiple_files_endings_OLD(HP, img, affine, path, multilabel=True):
        '''
        multilabel True:    save as 1 and 2 without fourth dimension
        multilabel False:   save with beginnings and endings combined
        '''
        # bundles = ExpUtils.get_bundle_names("20")[1:]
        bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
        for idx, bundle in enumerate(bundles):
            data = img[:, :, :, (idx * 2):(idx * 2) + 2] > 0

            multilabel_img = np.zeros(data.shape[:3])

            if multilabel:
                multilabel_img[data[:, :, :, 0]] = 1
                multilabel_img[data[:, :, :, 1]] = 2
            else:
                multilabel_img[data[:, :, :, 0]] = 1
                multilabel_img[data[:, :, :, 1]] = 1

            img_seg = nib.Nifti1Image(multilabel_img, affine)
            ExpUtils.make_dir(join(path, "endings"))
            nib.save(img_seg, join(path, "endings", bundle + ".nii.gz"))
Example #17
0
    def save_multilabel_img_as_multiple_files_endings_OLD(HP, img, affine, path, multilabel=True):
        '''
        multilabel True:    save as 1 and 2 without fourth dimension
        multilabel False:   save with beginnings and endings combined
        '''
        # bundles = ExpUtils.get_bundle_names("20")[1:]
        bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
        for idx, bundle in enumerate(bundles):
            data = img[:, :, :, (idx * 2):(idx * 2) + 2] > 0

            multilabel_img = np.zeros(data.shape[:3])

            if multilabel:
                multilabel_img[data[:, :, :, 0]] = 1
                multilabel_img[data[:, :, :, 1]] = 2
            else:
                multilabel_img[data[:, :, :, 0]] = 1
                multilabel_img[data[:, :, :, 1]] = 1

            img_seg = nib.Nifti1Image(multilabel_img, affine)
            ExpUtils.make_dir(join(path, "endings"))
            nib.save(img_seg, join(path, "endings", bundle + ".nii.gz"))
Example #18
0
def theano_f1_score_OLD(HP, idxs, marker, preds, ys):
    '''
    Von Paul
    '''
    n_classes = len(ExpUtils.get_bundle_names(HP.CLASSES))
    dice = T.constant(0)
    for cl in range(n_classes):
        pred = preds[marker, cl, :, :]
        y = ys[marker, cl, :, :]
        pred = T.gt(pred, T.constant(0.5))
        intersect = T.sum(pred * y)
        denominator = T.sum(pred) + T.sum(y)
        dice += T.constant(2) * intersect / (denominator + T.constant(1e-6))
    return dice / n_classes
Example #19
0
def theano_f1_score_OLD(HP, idxs, marker, preds, ys):
    '''
    Von Paul
    '''
    n_classes = len(ExpUtils.get_bundle_names(HP.CLASSES))
    dice = T.constant(0)
    for cl in range(n_classes):
        pred = preds[marker, cl, :, :]
        y = ys[marker, cl, :, :]
        pred = T.gt(pred, T.constant(0.5))
        intersect = T.sum(pred * y)
        denominator = T.sum(pred) + T.sum(y)
        dice += T.constant(2) * intersect / (denominator + T.constant(1e-6))
    return dice / n_classes
Example #20
0
    def calc_peak_dice(HP, y_pred, y_true, max_angle_error=[0.9]):
        '''

        :param y_pred:
        :param y_true:
        :param max_angle_error:  0.7 ->  angle error of 45° or less; 0.9 ->  angle error of 23° or less
        :return:
        '''
        def angle(a, b):
            '''
            Calculate the angle between two 1d-arrays (2 vectors) along the last dimension

            without anything further: 1->0°, 0.7->45°, 0->90°
            np.arccos -> returns degree in pi (90°: 0.5*pi)
            '''
            return abs(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))

        def angle_last_dim(a, b):
            '''
            Calculate the angle between two nd-arrays (array of vectors) along the last dimension

            without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90°
            np.arccos -> returns degree in pi (90°: 0.5*pi)

            return: one dimension less then input
            '''
            # print(np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1))
            return abs(
                np.einsum('...i,...i', a, b) /
                (np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1) +
                 1e-7))

        score_per_bundle = {}
        bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
        for idx, bundle in enumerate(bundles):
            y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3]
            y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3]  # [x,y,z,3]

            angles = angle_last_dim(y_pred_bund, y_true_bund)
            angles_binary = angles > max_angle_error[0]

            gt_binary = y_true_bund.sum(axis=-1) > 0

            f1 = f1_score(gt_binary.flatten(),
                          angles_binary.flatten(),
                          average="binary")
            score_per_bundle[bundle] = f1

        return score_per_bundle
Example #21
0
    def calc_peak_dice(HP, y_pred, y_true, max_angle_error=[0.9]):
        '''

        :param y_pred:
        :param y_true:
        :param max_angle_error:  0.7 ->  angle error of 45° or less; 0.9 ->  angle error of 23° or less
        :return:
        '''

        def angle(a, b):
            '''
            Calculate the angle between two 1d-arrays (2 vectors) along the last dimension

            without anything further: 1->0°, 0.7->45°, 0->90°
            np.arccos -> returns degree in pi (90°: 0.5*pi)
            '''
            return abs(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))

        def angle_last_dim(a, b):
            '''
            Calculate the angle between two nd-arrays (array of vectors) along the last dimension

            without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90°
            np.arccos -> returns degree in pi (90°: 0.5*pi)

            return: one dimension less then input
            '''
            # print(np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1))
            return abs(np.einsum('...i,...i', a, b) / (np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1) + 1e-7))


        score_per_bundle = {}
        bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
        for idx, bundle in enumerate(bundles):
            y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3]
            y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3]      # [x,y,z,3]

            angles = angle_last_dim(y_pred_bund, y_true_bund)
            angles_binary = angles > max_angle_error[0]

            gt_binary = y_true_bund.sum(axis=-1) > 0

            f1 = f1_score(gt_binary.flatten(), angles_binary.flatten(), average="binary")
            score_per_bundle[bundle] = f1

        return score_per_bundle
Example #22
0
    def calc_peak_length_dice(HP, y_pred, y_true, max_angle_error=[0.9], max_length_error=0.1):
        '''

        :param y_pred:
        :param y_true:
        :param max_angle_error:  0.7 ->  angle error of 45° or less; 0.9 ->  angle error of 23° or less
        :return:
        '''

        def angle_last_dim(a, b):
            '''
            Calculate the angle between two nd-arrays (array of vectors) along the last dimension

            without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90°
            np.arccos -> returns degree in pi (90°: 0.5*pi)

            return: one dimension less then input
            '''
            # print(np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1))
            return abs(np.einsum('...i,...i', a, b) / (np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1) + 1e-7))


        score_per_bundle = {}
        bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
        for idx, bundle in enumerate(bundles):
            y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3]
            y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3]      # [x,y,z,3]

            angles = angle_last_dim(y_pred_bund, y_true_bund)

            lenghts_pred = np.linalg.norm(y_pred_bund, axis=-1)
            lengths_true = np.linalg.norm(y_true_bund, axis=-1)
            lengths_binary = abs(lenghts_pred - lengths_true) < (max_length_error * lengths_true)
            lengths_binary = lengths_binary.flatten()

            gt_binary = y_true_bund.sum(axis=-1) > 0
            gt_binary = gt_binary.flatten()  # [bs*x*y]

            angles_binary = angles > max_angle_error[0]
            angles_binary = angles_binary.flatten()

            combined = lengths_binary * angles_binary

            f1 = MetricUtils.my_f1_score(gt_binary, combined)
            score_per_bundle[bundle] = f1
        return score_per_bundle
Example #23
0
    def create_multilabel_mask(HP, subject, labels_type=np.int16, dataset_folder="HCP", labels_folder="bundle_masks"):
        '''
        One-hot encoding of all bundles in one big image
        :param subject:
        :return: image of shape (x, y, z, nr_of_bundles + 1)
        '''
        bundles = ExpUtils.get_bundle_names(HP.CLASSES)

        #Masks sind immer HCP_highRes (später erst downsample)
        mask_ml = np.zeros((145, 174, 145, len(bundles)))
        background = np.ones((145, 174, 145))   # everything that contains no bundle

        for idx, bundle in enumerate(bundles[1:]):   #first bundle is background -> already considered by setting np.ones in the beginning
            mask = nib.load(join(C.HOME, dataset_folder, subject, labels_folder, bundle + ".nii.gz"))
            mask_data = mask.get_data()     # dtype: uint8
            mask_ml[:, :, :, idx+1] = mask_data
            background[mask_data == 1] = 0    # remove this bundle from background

        mask_ml[:, :, :, 0] = background
        return mask_ml.astype(labels_type)
Example #24
0
    def calc_peak_dice_onlySeg(HP, y_pred, y_true):
        '''
        Create binary mask of peaks by simple thresholding. Then calculate Dice.

        :param y_pred:
        :param y_true:
        :return:
        '''

        score_per_bundle = {}
        bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
        for idx, bundle in enumerate(bundles):
            y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3]
            y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3]      # [x,y,z,3]

            # 0.1 -> keep some outliers, but also some holes already; 0.2 also ok (still looks like e.g. CST)
            #  Resulting dice for 0.1 and 0.2 very similar
            y_pred_binary = np.abs(y_pred_bund).sum(axis=-1) > 0.2
            y_true_binary = np.abs(y_true_bund).sum(axis=-1) > 1e-3

            f1 = f1_score(y_true_binary.flatten(), y_pred_binary.flatten(), average="binary")
            score_per_bundle[bundle] = f1

        return score_per_bundle
Example #25
0
    def precompute_batches(custom_type=None):
        '''
        9000 slices per epoch -> 200 batches (batchsize=44) per epoch
        => 200-1000 batches needed


        270g_125mm_bundle_peaks_Y: no DAug, no Norm, only Y
        All_sizes_DAug_XYZ: 12g, 90g, 270g, DAug (no rotation, no elastic deform), Norm, XYZ
        270g_125mm_bundle_peaks_XYZ: no DAug, Norm, XYZ
        '''
        class HP:
            NORMALIZE_DATA = True
            DATA_AUGMENTATION = False
            CV_FOLD = 0
            INPUT_DIM = (144, 144)
            BATCH_SIZE = 44
            DATASET_FOLDER = "HCP"
            TYPE = "single_direction"
            EXP_PATH = "~"
            LABELS_FILENAME = "bundle_peaks"
            FEATURES_FILENAME = "270g_125mm_peaks"
            DATASET = "HCP"
            RESOLUTION = "1.25mm"
            LABELS_TYPE = np.float32

        HP.TRAIN_SUBJECTS, HP.VALIDATE_SUBJECTS, HP.TEST_SUBJECTS = ExpUtils.get_cv_fold(
            HP.CV_FOLD)

        num_batches_base = 5000
        num_batches = {
            "train": num_batches_base,
            "validate": int(num_batches_base / 3.),
            "test": int(num_batches_base / 3.),
        }

        if custom_type is None:
            types = ["train", "validate", "test"]
        else:
            types = [custom_type]

        for type in types:
            dataManager = DataManagerTrainingNiftiImgs(HP)
            batch_gen = dataManager.get_batches(
                batch_size=HP.BATCH_SIZE,
                type=type,
                subjects=getattr(HP,
                                 type.upper() + "_SUBJECTS"),
                num_batches=num_batches[type])

            for idx, batch in enumerate(batch_gen):
                print("Processing: {}".format(idx))

                # DATASET_DIR = "HCP_batches/270g_125mm_bundle_peaks_Y"
                # DATASET_DIR = "HCP_batches/All_sizes_DAug_XYZ"
                DATASET_DIR = "HCP_batches/270g_125mm_bundle_peaks_XYZ"
                ExpUtils.make_dir(join(C.HOME, DATASET_DIR, type))

                data = nib.Nifti1Image(
                    batch["data"],
                    ImgUtils.get_dwi_affine(HP.DATASET, HP.RESOLUTION))
                nib.save(
                    data,
                    join(C.HOME, DATASET_DIR, type,
                         "batch_" + str(idx) + "_data.nii.gz"))

                seg = nib.Nifti1Image(
                    batch["seg"],
                    ImgUtils.get_dwi_affine(HP.DATASET, HP.RESOLUTION))
                nib.save(
                    seg,
                    join(C.HOME, DATASET_DIR, type,
                         "batch_" + str(idx) + "_seg.nii.gz"))
Example #26
0
    def calc_peak_dice_pytorch(HP, y_pred, y_true, max_angle_error=[0.9]):
        '''
        Calculate angle between groundtruth and prediction and keep the voxels where
        angle is smaller than MAX_ANGLE_ERROR.

        From groundtruth generate a binary mask by selecting all voxels with len > 0.

        Calculate Dice from these 2 masks.

        -> Penalty on peaks outside of tract or if predicted peak=0
        -> no penalty on very very small with right direction -> bad
        => Peak_dice can be high even if peaks inside of tract almost missing (almost 0)

        :param y_pred:
        :param y_true:
        :param max_angle_error:  0.7 ->  angle error of 45° or less; 0.9 ->  angle error of 23° or less
                                 Can be list with several values -> calculate for several thresholds
        :return:
        '''
        import torch
        from tractseg.libs.PytorchEinsum import einsum
        from tractseg.libs.PytorchUtils import PytorchUtils

        y_true = y_true.permute(0, 2, 3, 1)
        y_pred = y_pred.permute(0, 2, 3, 1)

        def angle_last_dim(a, b):
            '''
            Calculate the angle between two nd-arrays (array of vectors) along the last dimension

            without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90°
            np.arccos -> returns degree in pi (90°: 0.5*pi)

            return: one dimension less then input
            '''
            return torch.abs(einsum('abcd,abcd->abc', a, b) / (torch.norm(a, 2., -1) * torch.norm(b, 2, -1) + 1e-7))

        #Single threshold
        if len(max_angle_error) == 1:
            score_per_bundle = {}
            bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
            for idx, bundle in enumerate(bundles):
                # if bundle == "CST_right":
                y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()
                y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()      # [x,y,z,3]

                angles = angle_last_dim(y_pred_bund, y_true_bund)
                gt_binary = y_true_bund.sum(dim=-1) > 0
                gt_binary = gt_binary.view(-1)  # [bs*x*y]

                angles_binary = angles > max_angle_error[0]
                angles_binary = angles_binary.view(-1)

                f1 = PytorchUtils.f1_score_binary(gt_binary, angles_binary)
                score_per_bundle[bundle] = f1

            return score_per_bundle

        #multiple thresholds
        else:
            score_per_bundle = {}
            bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
            for idx, bundle in enumerate(bundles):
                # if bundle == "CST_right":
                y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()
                y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()  # [x,y,z,3]

                angles = angle_last_dim(y_pred_bund, y_true_bund)
                gt_binary = y_true_bund.sum(dim=-1) > 0
                gt_binary = gt_binary.view(-1)  # [bs*x*y]

                score_per_bundle[bundle] = []
                for threshold in max_angle_error:
                    angles_binary = angles > threshold
                    angles_binary = angles_binary.view(-1)

                    f1 = PytorchUtils.f1_score_binary(gt_binary, angles_binary)
                    score_per_bundle[bundle].append(f1)

            return score_per_bundle
Example #27
0
    def generate_train_batch(self):
        subjects = self._data[0]
        subject_idx = int(random.uniform(0, len(subjects)))     # len(subjects)-1 not needed because int always rounds to floor

        for i in range(20):
            try:
                if self.HP.FEATURES_FILENAME == "12g90g270g":
                    # if np.random.random() < 0.5:
                    #     data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data()
                    # else:
                    #     data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data()

                    rnd_choice = np.random.random()
                    if rnd_choice < 0.33:
                        data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data()
                    elif rnd_choice < 0.66:
                        data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data()
                    else:
                        data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "12g_125mm_peaks.nii.gz")).get_data()
                elif self.HP.FEATURES_FILENAME == "T1_Peaks270g":
                    peaks = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data()
                    t1 = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "T1.nii.gz")).get_data()
                    data = np.concatenate((peaks, t1), axis=3)
                elif self.HP.FEATURES_FILENAME == "T1_Peaks12g90g270g":
                    rnd_choice = np.random.random()
                    if rnd_choice < 0.33:
                        peaks = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data()
                    elif rnd_choice < 0.66:
                        peaks = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data()
                    else:
                        peaks = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "12g_125mm_peaks.nii.gz")).get_data()
                    t1 = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "T1.nii.gz")).get_data()
                    data = np.concatenate((peaks, t1), axis=3)
                else:
                    data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], self.HP.FEATURES_FILENAME + ".nii.gz")).get_data()

                seg = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], self.HP.LABELS_FILENAME + ".nii.gz")).get_data()
                break
            except IOError:
                ExpUtils.print_and_save(self.HP, "\n\nWARNING: Could not load file. Trying again in 20s (Try number: " + str(i) + ").\n\n")
            ExpUtils.print_and_save(self.HP, "Sleeping 20s")
            sleep(20)
        # ExpUtils.print_and_save(self.HP, "Successfully loaded input.")

        data = np.nan_to_num(data)    # Needed otherwise not working
        seg = np.nan_to_num(seg)

        data = DatasetUtils.scale_input_to_unet_shape(data, self.HP.DATASET, self.HP.RESOLUTION)    # (x, y, z, channels)

        if self.HP.LABELS_FILENAME not in ["bundle_peaks_11_808080", "bundle_peaks_20_808080", "bundle_peaks_808080",
                                           "bundle_masks_20_808080", "bundle_masks_72_808080"]:
            if self.HP.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]:
                # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask to the lower resolution
                seg = DatasetUtils.scale_input_to_unet_shape(seg, "HCP", self.HP.RESOLUTION)
            else:
                seg = DatasetUtils.scale_input_to_unet_shape(seg, self.HP.DATASET, self.HP.RESOLUTION)  # (x, y, z, classes)

        slice_idxs = np.random.choice(data.shape[0], self.BATCH_SIZE, False, None)

        # Randomly sample slice orientation
        if self.HP.TRAINING_SLICE_DIRECTION == "xyz":
            slice_direction = int(round(random.uniform(0,2)))
        else:
            slice_direction = 1 #always use Y

        if slice_direction == 0:
            x = data[slice_idxs, :, :].astype(np.float32)      # (batch_size, y, z, channels)
            y = seg[slice_idxs, :, :].astype(self.HP.LABELS_TYPE)
            x = np.array(x).transpose(0, 3, 1, 2)  # depth-channel has to be before width and height for Unet (but after batches)
            y = np.array(y).transpose(0, 3, 1, 2)  # nr_classes channel has to be before with and height for DataAugmentation (bs, nr_of_classes, x, y)
        elif slice_direction == 1:
            x = data[:, slice_idxs, :].astype(np.float32)      # (x, batch_size, z, channels)
            y = seg[:, slice_idxs, :].astype(self.HP.LABELS_TYPE)
            x = np.array(x).transpose(1, 3, 0, 2)
            y = np.array(y).transpose(1, 3, 0, 2)
        elif slice_direction == 2:
            x = data[:, :, slice_idxs].astype(np.float32)      # (x, y, batch_size, channels)
            y = seg[:, :, slice_idxs].astype(self.HP.LABELS_TYPE)
            x = np.array(x).transpose(2, 3, 0, 1)
            y = np.array(y).transpose(2, 3, 0, 1)

        data_dict = {"data": x,     # (batch_size, channels, x, y, [z])
                     "seg": y}      # (batch_size, channels, x, y, [z])
        return data_dict
Example #28
0
    def generate_train_batch(self):
        subjects = self._data[0]
        subject_idx = int(
            random.uniform(0, len(subjects))
        )  # len(subjects)-1 not needed because int always rounds to floor

        for i in range(20):
            try:
                if self.HP.FEATURES_FILENAME == "12g90g270g":
                    # if np.random.random() < 0.5:
                    #     data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data()
                    # else:
                    #     data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data()

                    rnd_choice = np.random.random()
                    if rnd_choice < 0.33:
                        data = nib.load(
                            join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                                 subjects[subject_idx],
                                 "270g_125mm_peaks.nii.gz")).get_data()
                    elif rnd_choice < 0.66:
                        data = nib.load(
                            join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                                 subjects[subject_idx],
                                 "90g_125mm_peaks.nii.gz")).get_data()
                    else:
                        data = nib.load(
                            join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                                 subjects[subject_idx],
                                 "12g_125mm_peaks.nii.gz")).get_data()
                elif self.HP.FEATURES_FILENAME == "T1_Peaks270g":
                    peaks = nib.load(
                        join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                             subjects[subject_idx],
                             "270g_125mm_peaks.nii.gz")).get_data()
                    t1 = nib.load(
                        join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                             subjects[subject_idx], "T1.nii.gz")).get_data()
                    data = np.concatenate((peaks, t1), axis=3)
                elif self.HP.FEATURES_FILENAME == "T1_Peaks12g90g270g":
                    rnd_choice = np.random.random()
                    if rnd_choice < 0.33:
                        peaks = nib.load(
                            join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                                 subjects[subject_idx],
                                 "270g_125mm_peaks.nii.gz")).get_data()
                    elif rnd_choice < 0.66:
                        peaks = nib.load(
                            join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                                 subjects[subject_idx],
                                 "90g_125mm_peaks.nii.gz")).get_data()
                    else:
                        peaks = nib.load(
                            join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                                 subjects[subject_idx],
                                 "12g_125mm_peaks.nii.gz")).get_data()
                    t1 = nib.load(
                        join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                             subjects[subject_idx], "T1.nii.gz")).get_data()
                    data = np.concatenate((peaks, t1), axis=3)
                else:
                    data = nib.load(
                        join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                             subjects[subject_idx], self.HP.FEATURES_FILENAME +
                             ".nii.gz")).get_data()

                seg = nib.load(
                    join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                         subjects[subject_idx],
                         self.HP.LABELS_FILENAME + ".nii.gz")).get_data()
                break
            except IOError:
                ExpUtils.print_and_save(
                    self.HP,
                    "\n\nWARNING: Could not load file. Trying again in 20s (Try number: "
                    + str(i) + ").\n\n")
            ExpUtils.print_and_save(self.HP, "Sleeping 20s")
            sleep(20)
        # ExpUtils.print_and_save(self.HP, "Successfully loaded input.")

        data = np.nan_to_num(data)  # Needed otherwise not working
        seg = np.nan_to_num(seg)

        data = DatasetUtils.scale_input_to_unet_shape(
            data, self.HP.DATASET, self.HP.RESOLUTION)  # (x, y, z, channels)

        if self.HP.LABELS_FILENAME not in [
                "bundle_peaks_11_808080", "bundle_peaks_20_808080",
                "bundle_peaks_808080", "bundle_masks_20_808080",
                "bundle_masks_72_808080"
        ]:
            if self.HP.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]:
                # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask to the lower resolution
                seg = DatasetUtils.scale_input_to_unet_shape(
                    seg, "HCP", self.HP.RESOLUTION)
            else:
                seg = DatasetUtils.scale_input_to_unet_shape(
                    seg, self.HP.DATASET,
                    self.HP.RESOLUTION)  # (x, y, z, classes)

        slice_idxs = np.random.choice(data.shape[0], self.BATCH_SIZE, False,
                                      None)

        # Randomly sample slice orientation
        if self.HP.TRAINING_SLICE_DIRECTION == "xyz":
            slice_direction = int(round(random.uniform(0, 2)))
        else:
            slice_direction = 1  #always use Y

        if slice_direction == 0:
            x = data[slice_idxs, :, :].astype(
                np.float32)  # (batch_size, y, z, channels)
            y = seg[slice_idxs, :, :].astype(self.HP.LABELS_TYPE)
            x = np.array(x).transpose(
                0, 3, 1, 2
            )  # depth-channel has to be before width and height for Unet (but after batches)
            y = np.array(y).transpose(
                0, 3, 1, 2
            )  # nr_classes channel has to be before with and height for DataAugmentation (bs, nr_of_classes, x, y)
        elif slice_direction == 1:
            x = data[:, slice_idxs, :].astype(
                np.float32)  # (x, batch_size, z, channels)
            y = seg[:, slice_idxs, :].astype(self.HP.LABELS_TYPE)
            x = np.array(x).transpose(1, 3, 0, 2)
            y = np.array(y).transpose(1, 3, 0, 2)
        elif slice_direction == 2:
            x = data[:, :, slice_idxs].astype(
                np.float32)  # (x, y, batch_size, channels)
            y = seg[:, :, slice_idxs].astype(self.HP.LABELS_TYPE)
            x = np.array(x).transpose(2, 3, 0, 1)
            y = np.array(y).transpose(2, 3, 0, 1)

        data_dict = {
            "data": x,  # (batch_size, channels, x, y, [z])
            "seg": y
        }  # (batch_size, channels, x, y, [z])
        return data_dict
Example #29
0
 def save_multilabel_img_as_multiple_files_endings(HP, img, affine, path):
     bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
     for idx, bundle in enumerate(bundles):
         img_seg = nib.Nifti1Image(img[:,:,:,idx], affine)
         ExpUtils.make_dir(join(path, "endings_segmentations"))
         nib.save(img_seg, join(path, "endings_segmentations", bundle + ".nii.gz"))
Example #30
0
    def create_network(self):

        def train(X, y):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda()), Variable(y.cuda())  # X: (bs, features, x, y)   y: (bs, classes, x, y)
            else:
                X, y = Variable(X), Variable(y)
            optimizer.zero_grad()
            net.train()
            outputs = net(X)  # forward     # outputs: (bs, classes, x, y)
            loss = criterion(outputs, y)
            loss.backward()  # backward
            optimizer.step()  # optimise
            f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True)
            # probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            probs = None    #faster
            return loss.data[0], probs, f1

        def test(X, y):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda(), volatile=True), Variable(y.cuda(), volatile=True)
            else:
                X, y = Variable(X, volatile=True), Variable(y, volatile=True)
            net.train(False)
            outputs = net(X)  # forward
            loss = criterion(outputs, y)
            f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True)
            # probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            probs = None  # faster
            return loss.data[0], probs, f1

        def predict(X):
            X = torch.from_numpy(X.astype(np.float32))
            if torch.cuda.is_available():
                X = Variable(X.cuda(), volatile=True)
            else:
                X = Variable(X, volatile=True)
            net.train(False)
            outputs = net(X)  # forward
            probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            return probs

        def save_model(metrics, epoch_nr):
            max_f1_idx = np.argmax(metrics["f1_macro_validate"])
            max_f1 = np.max(metrics["f1_macro_validate"])
            if epoch_nr == max_f1_idx and max_f1 > 0.01:  # saving to network drives takes 5s (to local only 0.5s) -> do not save so often
                print("  Saving weights...")
                for fl in glob.glob(join(self.HP.EXP_PATH, "best_weights_ep*")):  # remove weights from previous epochs
                    os.remove(fl)
                try:
                    #Actually is a pkl not a npz
                    PytorchUtils.save_checkpoint(join(self.HP.EXP_PATH, "best_weights_ep" + str(epoch_nr) + ".npz"), unet=net)
                except IOError:
                    print("\nERROR: Could not save weights because of IO Error\n")
                self.HP.BEST_EPOCH = epoch_nr

        def load_model(path):
            PytorchUtils.load_checkpoint(path, unet=net)


        if self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "single_direction":
            NR_OF_GRADIENTS = 9
        elif self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "combined":
            NR_OF_GRADIENTS = 3*self.HP.NR_OF_CLASSES
        else:
            NR_OF_GRADIENTS = 33

        if torch.cuda.is_available():
            net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT).cuda()
            # net = UNet_Skip(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT).cuda()
        else:
            net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT)
            # net = UNet_Skip(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT)

        if self.HP.TRAIN:
            ExpUtils.print_and_save(self.HP, str(net), only_log=True)

        criterion = nn.BCEWithLogitsLoss()
        optimizer = optim.Adamax(net.parameters(), lr=self.HP.LEARNING_RATE)

        if self.HP.LOAD_WEIGHTS:
            ExpUtils.print_verbose(self.HP, "Loading weights ... ({})".format(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH)))
            load_model(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH))

        self.train = train
        self.predict = test
        self.get_probs = predict
        self.save_model = save_model
        self.load_model = load_model
Example #31
0
    def create_network(self):
        # torch.backends.cudnn.benchmark = True     #not faster

        def train(X, y, weight_factor=10):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda()), Variable(y.cuda(
                ))  # X: (bs, features, x, y)   y: (bs, classes, x, y)
            else:
                X, y = Variable(X), Variable(y)
            optimizer.zero_grad()
            net.train()
            outputs = net(X)  # forward     # outputs: (bs, classes, x, y)
            loss = criterion(outputs, y)
            loss.backward()  # backward
            optimizer.step()  # optimise
            f1 = PytorchUtils.f1_score_macro(y.data,
                                             outputs.data,
                                             per_class=True)

            if self.HP.USE_VISLOGGER:
                probs = outputs.data.cpu().numpy().transpose(
                    0, 2, 3, 1)  # (bs, x, y, classes)
            else:
                probs = None  #faster

            return loss.data[0], probs, f1

        def test(X, y, weight_factor=10):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda(),
                                volatile=True), Variable(y.cuda(),
                                                         volatile=True)
            else:
                X, y = Variable(X, volatile=True), Variable(y, volatile=True)
            net.train(False)
            outputs = net(X)  # forward
            loss = criterion(outputs, y)
            f1 = PytorchUtils.f1_score_macro(y.data,
                                             outputs.data,
                                             per_class=True)
            # probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            probs = None  # faster
            return loss.data[0], probs, f1

        def predict(X):
            X = torch.from_numpy(X.astype(np.float32))
            if torch.cuda.is_available():
                X = Variable(X.cuda(), volatile=True)
            else:
                X = Variable(X, volatile=True)
            net.train(False)
            outputs = net(X)  # forward
            probs = outputs.data.cpu().numpy().transpose(
                0, 2, 3, 1)  # (bs, x, y, classes)
            return probs

        def save_model(metrics, epoch_nr):
            max_f1_idx = np.argmax(metrics["f1_macro_validate"])
            max_f1 = np.max(metrics["f1_macro_validate"])
            if epoch_nr == max_f1_idx and max_f1 > 0.01:  # saving to network drives takes 5s (to local only 0.5s) -> do not save so often
                print("  Saving weights...")
                for fl in glob.glob(join(self.HP.EXP_PATH, "best_weights_ep*")
                                    ):  # remove weights from previous epochs
                    os.remove(fl)
                try:
                    #Actually is a pkl not a npz
                    PytorchUtils.save_checkpoint(join(
                        self.HP.EXP_PATH,
                        "best_weights_ep" + str(epoch_nr) + ".npz"),
                                                 unet=net)
                except IOError:
                    print(
                        "\nERROR: Could not save weights because of IO Error\n"
                    )
                self.HP.BEST_EPOCH = epoch_nr

        def load_model(path):
            PytorchUtils.load_checkpoint(path, unet=net)

        def print_current_lr():
            for param_group in optimizer.param_groups:
                ExpUtils.print_and_save(
                    self.HP,
                    "current learning rate: {}".format(param_group['lr']))

        if self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "single_direction":
            NR_OF_GRADIENTS = self.HP.NR_OF_GRADIENTS
            # NR_OF_GRADIENTS = 9
            # NR_OF_GRADIENTS = 9 * 5
            # NR_OF_GRADIENTS = 9 * 9
            # NR_OF_GRADIENTS = 33
        elif self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "combined":
            self.HP.NR_OF_GRADIENTS = 3 * self.HP.NR_OF_CLASSES
        else:
            self.HP.NR_OF_GRADIENTS = 33

        if self.HP.LOSS_FUNCTION == "soft_sample_dice":
            criterion = PytorchUtils.soft_sample_dice
            final_activation = "sigmoid"
        elif self.HP.LOSS_FUNCTION == "soft_batch_dice":
            criterion = PytorchUtils.soft_batch_dice
            final_activation = "sigmoid"
        else:
            # weights = torch.ones((self.HP.BATCH_SIZE, self.HP.NR_OF_CLASSES, self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[1])).cuda()
            # weights[:, 5, :, :] *= 10     #CA
            # weights[:, 21, :, :] *= 10    #FX_left
            # weights[:, 22, :, :] *= 10    #FX_right
            # criterion = nn.BCEWithLogitsLoss(weight=weights)
            criterion = nn.BCEWithLogitsLoss()
            final_activation = None

        net = UNet(n_input_channels=NR_OF_GRADIENTS,
                   n_classes=self.HP.NR_OF_CLASSES,
                   n_filt=self.HP.UNET_NR_FILT,
                   batchnorm=self.HP.BATCH_NORM,
                   final_activation=final_activation)

        if torch.cuda.is_available():
            net = net.cuda()
        # else:
        #     net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT,
        #                batchnorm=self.HP.BATCH_NORM)

        # net = nn.DataParallel(net, device_ids=[0,1])

        # if self.HP.TRAIN:
        #     ExpUtils.print_and_save(self.HP, str(net), only_log=True)

        if self.HP.OPTIMIZER == "Adamax":
            optimizer = Adamax(net.parameters(), lr=self.HP.LEARNING_RATE)
        elif self.HP.OPTIMIZER == "Adam":
            #todo important: change
            # optimizer = Adam(net.parameters(), lr=self.HP.LEARNING_RATE)
            optimizer = Adam(net.parameters(),
                             lr=self.HP.LEARNING_RATE,
                             weight_decay=self.HP.WEIGHT_DECAY)
        else:
            raise ValueError("Optimizer not defined")
        # scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)
        # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="max")

        if self.HP.LOAD_WEIGHTS:
            ExpUtils.print_verbose(
                self.HP, "Loading weights ... ({})".format(
                    join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH)))
            load_model(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH))

        self.train = train
        self.predict = test
        self.get_probs = predict
        self.save_model = save_model
        self.load_model = load_model
        self.print_current_lr = print_current_lr
Example #32
0
    def plot_tracts(HP,
                    bundle_segmentations,
                    affine,
                    out_dir,
                    brain_mask=None):
        '''
        By default this does not work on a remote server connection (ssh -X) because -X does not support OpenGL.
        On the remote Server you can do 'export DISPLAY=":0"' (you should set the value you get if you do 'echo $DISPLAY' if you
        login locally on the remote server). Then all graphics will get rendered locally and not via -X.
        (important: graphical session needs to be running on remote server (e.g. via login locally))
        (important: login needed, not just stay at login screen)
        '''
        from dipy.viz import window
        from tractseg.libs.VtkUtils import VtkUtils

        SMOOTHING = 10
        WINDOW_SIZE = (800, 800)
        bundles = ["CST_right", "CA", "IFO_right"]

        renderer = window.Renderer()
        renderer.projection('parallel')

        rows = len(bundles)
        X, Y, Z = bundle_segmentations.shape[:3]
        for j, bundle in enumerate(bundles):
            i = 0  #only one method

            bundle_idx = ExpUtils.get_bundle_names(
                HP.CLASSES)[1:].index(bundle)
            mask_data = bundle_segmentations[:, :, :, bundle_idx]

            if bundle == "CST_right":
                orientation = "axial"
            elif bundle == "CA":
                orientation = "axial"
            elif bundle == "IFO_right":
                orientation = "sagittal"
            else:
                orientation = "axial"

            #bigger: more border
            if orientation == "axial":
                border_y = -100  #-60
            else:
                border_y = -100

            x_current = X * i  # column (width)
            y_current = rows * (Y * 2 + border_y) - (
                Y * 2 + border_y) * j  # row (height)  (starts from bottom?)

            PlotUtils.plot_mask(renderer,
                                mask_data,
                                affine,
                                x_current,
                                y_current,
                                orientation=orientation,
                                smoothing=SMOOTHING,
                                brain_mask=brain_mask)

            #Bundle label
            text_offset_top = -50  # 60
            text_offset_side = -100  # -30
            position = (0 - int(X) + text_offset_side,
                        y_current + text_offset_top, 50)
            text_actor = VtkUtils.label(text=bundle,
                                        pos=position,
                                        scale=(6, 6, 6),
                                        color=(1, 1, 1))
            renderer.add(text_actor)

        renderer.reset_camera()
        window.record(renderer,
                      out_path=join(out_dir, "preview.png"),
                      size=(WINDOW_SIZE[0], WINDOW_SIZE[1]),
                      reset_camera=False,
                      magnification=2)
Example #33
0
def run_tractseg(data, output_type="tract_segmentation", input_type="peaks",
                 single_orientation=False, verbose=False, dropout_sampling=False, threshold=0.5,
                 bundle_specific_threshold=False, get_probs=False, peak_threshold=0.1):
    '''
    Run TractSeg

    :param data: input peaks (4D numpy array with shape [x,y,z,9])
    :param output_type: "tract_segmentation" | "endings_segmentation" | "TOM" | "dm_regression"
    :param input_type: "peaks"
    :param verbose: show debugging infos
    :param dropout_sampling: create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142)
    :param threshold: Threshold for converting probability map to binary map
    :param bundle_specific_threshold: Threshold is lower for some bundles which need more sensitivity (CA, CST, FX)
    :param get_probs: Output raw probability map instead of binary map
    :param peak_threshold: all peaks shorter than peak_threshold will be set to zero
    :return: 4D numpy array with the output of tractseg
        for tract_segmentation:     [x,y,z,nr_of_bundles]
        for endings_segmentation:   [x,y,z,2*nr_of_bundles]
        for TOM:                    [x,y,z,3*nr_of_bundles]
    '''
    start_time = time.time()

    config = get_config_name(input_type, output_type)
    HP = getattr(importlib.import_module("tractseg.config.PretrainedModels." + config), "HP")()
    HP.VERBOSE = verbose
    HP.TRAIN = False
    HP.TEST = False
    HP.SEGMENT = False
    HP.GET_PROBS = get_probs
    HP.LOAD_WEIGHTS = True
    HP.DROPOUT_SAMPLING = dropout_sampling
    HP.THRESHOLD = threshold

    if bundle_specific_threshold:
        HP.GET_PROBS = True

    if input_type == "peaks":
        if HP.EXPERIMENT_TYPE == "tract_segmentation" and HP.DROPOUT_SAMPLING:
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_dropout_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_12g90g270g_125mm_DAugAll_Dropout", "best_weights_ep114.npz")
        elif HP.EXPERIMENT_TYPE == "tract_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "x_Pretrained_TractSeg_Models/TractSeg_T1_12g90g270g_125mm_DAugAll", "best_weights_ep392.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888", "best_weights_ep247.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888_SchizoFineT_lr001", "best_weights_ep186.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_12g90g270g_125mm_DS_DAugAll_RotMir", "best_weights_ep200.npz")
        elif HP.EXPERIMENT_TYPE == "endings_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v3.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "EndingsSeg_12g90g270g_125mm_DS_DAugAll", "best_weights_ep234.npz")
        elif HP.EXPERIMENT_TYPE == "peak_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "x_Pretrained_TractSeg_Models/Peaks20_12g90g270g_125mm_DAugSimp_constW5", "best_weights_ep441.npz")  #more oversegmentation with DAug
        elif HP.EXPERIMENT_TYPE == "dm_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_dm_regression_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "DmReg_12g90g270g_125mm_DAugAll_Ubuntu", "best_weights_ep80.npz")
    elif input_type == "T1":
        if HP.EXPERIMENT_TYPE == "tract_segmentation":
            # HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz")
            HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes/x_Pretrained_TractSeg_Models", "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz")
        elif HP.EXPERIMENT_TYPE == "endings_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v1.npz")
        elif HP.EXPERIMENT_TYPE == "peak_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz")
    print("Loading weights from: {}".format(HP.WEIGHTS_PATH))

    if HP.EXPERIMENT_TYPE == "peak_regression":
        HP.NR_OF_CLASSES = 3*len(ExpUtils.get_bundle_names(HP.CLASSES)[1:])
    else:
        HP.NR_OF_CLASSES = len(ExpUtils.get_bundle_names(HP.CLASSES)[1:])

    if HP.VERBOSE:
        print("Hyperparameters:")
        ExpUtils.print_HPs(HP)

    Utils.download_pretrained_weights(experiment_type=HP.EXPERIMENT_TYPE, dropout_sampling=HP.DROPOUT_SAMPLING)

    data = np.nan_to_num(data)
    # brain_mask = ImgUtils.simple_brain_mask(data)
    # if HP.VERBOSE:
    #     nib.save(nib.Nifti1Image(brain_mask, np.eye(4)), "otsu_brain_mask_DEBUG.nii.gz")

    if input_type == "T1":
        data = np.reshape(data, (data.shape[0], data.shape[1], data.shape[2], 1))
    data, seg_None, bbox, original_shape = DatasetUtils.crop_to_nonzero(data)
    data, transformation = DatasetUtils.pad_and_scale_img_to_square_img(data, target_size=HP.INPUT_DIM[0])

    model = BaseModel(HP)

    if HP.EXPERIMENT_TYPE == "tract_segmentation" or HP.EXPERIMENT_TYPE == "endings_segmentation" or HP.EXPERIMENT_TYPE == "dm_regression":
        if single_orientation:     # mainly needed for testing because of less RAM requirements
            dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data)
            trainerSingle = Trainer(model, dataManagerSingle)
            if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS:
                seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True)
            else:
                seg, img_y = trainerSingle.get_seg_single_img(HP, probs=False, scale_to_world_shape=False, only_prediction=True)
        else:
            seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True)
            if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS:
                seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True)
            else:
                seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=False)

    elif HP.EXPERIMENT_TYPE == "peak_regression":
        dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data)
        trainerSingle = Trainer(model, dataManagerSingle)
        seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True)
        if bundle_specific_threshold:
            seg = ImgUtils.remove_small_peaks_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:], len_thr=0.3)
        else:
            seg = ImgUtils.remove_small_peaks(seg, len_thr=peak_threshold)
        #3 dir for Peaks -> not working (?)
        # seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True)
        # seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True)

    if bundle_specific_threshold and HP.EXPERIMENT_TYPE == "tract_segmentation":
        seg = ImgUtils.probs_to_binary_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:])

    #remove following two lines to keep super resolution
    seg = DatasetUtils.cut_and_scale_img_back_to_original_img(seg, transformation)
    seg = DatasetUtils.add_original_zero_padding_again(seg, bbox, original_shape, HP.NR_OF_CLASSES)
    ExpUtils.print_verbose(HP, "Took {}s".format(round(time.time() - start_time, 2)))

    return seg
Example #34
0
    def create_network(self):
        def train(X, y):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda()), Variable(y.cuda(
                ))  # X: (bs, features, x, y)   y: (bs, classes, x, y)
            else:
                X, y = Variable(X), Variable(y)
            optimizer.zero_grad()
            net.train()
            outputs = net(X)  # forward     # outputs: (bs, classes, x, y)
            loss = criterion(outputs, y)
            loss.backward()  # backward
            optimizer.step()  # optimise
            f1 = PytorchUtils.f1_score_macro(y.data,
                                             outputs.data,
                                             per_class=True)
            # probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            probs = None  #faster
            return loss.data[0], probs, f1

        def test(X, y):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda(),
                                volatile=True), Variable(y.cuda(),
                                                         volatile=True)
            else:
                X, y = Variable(X, volatile=True), Variable(y, volatile=True)
            net.train(False)
            outputs = net(X)  # forward
            loss = criterion(outputs, y)
            f1 = PytorchUtils.f1_score_macro(y.data,
                                             outputs.data,
                                             per_class=True)
            # probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            probs = None  # faster
            return loss.data[0], probs, f1

        def predict(X):
            X = torch.from_numpy(X.astype(np.float32))
            if torch.cuda.is_available():
                X = Variable(X.cuda(), volatile=True)
            else:
                X = Variable(X, volatile=True)
            net.train(False)
            outputs = net(X)  # forward
            probs = outputs.data.cpu().numpy().transpose(
                0, 2, 3, 1)  # (bs, x, y, classes)
            return probs

        def save_model(metrics, epoch_nr):
            max_f1_idx = np.argmax(metrics["f1_macro_validate"])
            max_f1 = np.max(metrics["f1_macro_validate"])
            if epoch_nr == max_f1_idx and max_f1 > 0.01:  # saving to network drives takes 5s (to local only 0.5s) -> do not save so often
                print("  Saving weights...")
                for fl in glob.glob(join(self.HP.EXP_PATH, "best_weights_ep*")
                                    ):  # remove weights from previous epochs
                    os.remove(fl)
                try:
                    #Actually is a pkl not a npz
                    PytorchUtils.save_checkpoint(join(
                        self.HP.EXP_PATH,
                        "best_weights_ep" + str(epoch_nr) + ".npz"),
                                                 unet=net)
                except IOError:
                    print(
                        "\nERROR: Could not save weights because of IO Error\n"
                    )
                self.HP.BEST_EPOCH = epoch_nr

        def load_model(path):
            PytorchUtils.load_checkpoint(path, unet=net)

        if self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "single_direction":
            NR_OF_GRADIENTS = 9
        elif self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "combined":
            NR_OF_GRADIENTS = 3 * self.HP.NR_OF_CLASSES
        else:
            NR_OF_GRADIENTS = 33

        if torch.cuda.is_available():
            net = UNet(n_input_channels=NR_OF_GRADIENTS,
                       n_classes=self.HP.NR_OF_CLASSES,
                       n_filt=self.HP.UNET_NR_FILT).cuda()
            # net = UNet_Skip(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT).cuda()
        else:
            net = UNet(n_input_channels=NR_OF_GRADIENTS,
                       n_classes=self.HP.NR_OF_CLASSES,
                       n_filt=self.HP.UNET_NR_FILT)
            # net = UNet_Skip(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT)

        if self.HP.TRAIN:
            ExpUtils.print_and_save(self.HP, str(net), only_log=True)

        criterion = nn.BCEWithLogitsLoss()
        optimizer = optim.Adamax(net.parameters(), lr=self.HP.LEARNING_RATE)

        if self.HP.LOAD_WEIGHTS:
            ExpUtils.print_verbose(
                self.HP, "Loading weights ... ({})".format(
                    join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH)))
            load_model(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH))

        self.train = train
        self.predict = test
        self.get_probs = predict
        self.save_model = save_model
        self.load_model = load_model
Example #35
0
 def test_bundle_names(self):
     bundles = ExpUtils.get_bundle_names("CST_right")
     self.assertListEqual(bundles, ["BG", "CST_right"], "Error in list of bundle names")
Example #36
0
    def create_network(self):
        # torch.backends.cudnn.benchmark = True     #not faster

        def train(X, y, weight_factor=10):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda()), Variable(y.cuda(
                ))  # X: (bs, features, x, y)   y: (bs, classes, x, y)
            else:
                X, y = Variable(X), Variable(y)
            optimizer.zero_grad()
            net.train()
            outputs = net(X)  # forward     # outputs: (bs, classes, x, y)

            weights = torch.ones(
                (self.HP.BATCH_SIZE, self.HP.NR_OF_CLASSES,
                 self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[1])).cuda()
            bundle_mask = y > 0
            weights[bundle_mask.data] *= weight_factor  #10

            loss = criterion(outputs, y, Variable(weights))
            # loss = criterion1(outputs, y, Variable(weights)) + criterion2(outputs, y, Variable(weights))

            loss.backward()  # backward
            optimizer.step()  # optimise

            if self.HP.CALC_F1:
                # f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True)
                # f1_a = MetricUtils.calc_peak_dice_pytorch(self.HP, outputs.data, y.data, max_angle_error=self.HP.PEAK_DICE_THR)
                f1 = MetricUtils.calc_peak_length_dice_pytorch(
                    self.HP,
                    outputs.data,
                    y.data,
                    max_angle_error=self.HP.PEAK_DICE_THR,
                    max_length_error=self.HP.PEAK_DICE_LEN_THR)
                # f1 = (f1_a, f1_b)
            else:
                f1 = np.ones(outputs.shape[3])

            if self.HP.USE_VISLOGGER:
                probs = outputs.data.cpu().numpy().transpose(
                    0, 2, 3, 1)  # (bs, x, y, classes)
            else:
                # probs = outputs.data.cpu().numpy().transpose(0,2,3,1)  # (bs, x, y, classes)
                probs = None  #faster

            return loss.data[0], probs, f1

        def test(X, y, weight_factor=10):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda(),
                                volatile=True), Variable(y.cuda(),
                                                         volatile=True)
            else:
                X, y = Variable(X, volatile=True), Variable(y, volatile=True)
            net.train(False)
            outputs = net(X)  # forward

            weights = torch.ones(
                (self.HP.BATCH_SIZE, self.HP.NR_OF_CLASSES,
                 self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[1])).cuda()
            bundle_mask = y > 0
            weights[bundle_mask.data] *= weight_factor  #10

            loss = criterion(outputs, y, Variable(weights))
            # loss = criterion1(outputs, y, Variable(weights)) + criterion2(outputs, y, Variable(weights))

            if self.HP.CALC_F1:
                # f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True)
                # f1_a = MetricUtils.calc_peak_dice_pytorch(self.HP, outputs.data, y.data, max_angle_error=self.HP.PEAK_DICE_THR)
                f1 = MetricUtils.calc_peak_length_dice_pytorch(
                    self.HP,
                    outputs.data,
                    y.data,
                    max_angle_error=self.HP.PEAK_DICE_THR,
                    max_length_error=self.HP.PEAK_DICE_LEN_THR)
                # f1 = (f1_a, f1_b)
            else:
                f1 = np.ones(outputs.shape[3])

            # probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            probs = None  # faster
            return loss.data[0], probs, f1

        def predict(X):
            X = torch.from_numpy(X.astype(np.float32))
            if torch.cuda.is_available():
                X = Variable(X.cuda(), volatile=True)
            else:
                X = Variable(X, volatile=True)
            net.train(False)
            outputs = net(X)  # forward
            probs = outputs.data.cpu().numpy().transpose(
                0, 2, 3, 1)  # (bs, x, y, classes)
            return probs

        def save_model(metrics, epoch_nr):
            max_f1_idx = np.argmax(metrics["f1_macro_validate"])
            max_f1 = np.max(metrics["f1_macro_validate"])
            if epoch_nr == max_f1_idx and max_f1 > 0.01:  # saving to network drives takes 5s (to local only 0.5s) -> do not save so often
                print("  Saving weights...")
                for fl in glob.glob(join(self.HP.EXP_PATH, "best_weights_ep*")
                                    ):  # remove weights from previous epochs
                    os.remove(fl)
                try:
                    #Actually is a pkl not a npz
                    PytorchUtils.save_checkpoint(join(
                        self.HP.EXP_PATH,
                        "best_weights_ep" + str(epoch_nr) + ".npz"),
                                                 unet=net)
                except IOError:
                    print(
                        "\nERROR: Could not save weights because of IO Error\n"
                    )
                self.HP.BEST_EPOCH = epoch_nr

            #Saving Last Epoch:
            # print("  Saving weights last epoch...")
            # for fl in glob.glob(join(self.HP.EXP_PATH, "weights_ep*")):  # remove weights from previous epochs
            #     os.remove(fl)
            # try:
            #     # Actually is a pkl not a npz
            #     PytorchUtils.save_checkpoint(join(self.HP.EXP_PATH, "weights_ep" + str(epoch_nr) + ".npz"), unet=net)
            # except IOError:
            #     print("\nERROR: Could not save weights because of IO Error\n")
            # self.HP.BEST_EPOCH = epoch_nr

        def load_model(path):
            PytorchUtils.load_checkpoint(path, unet=net)

        def print_current_lr():
            for param_group in optimizer.param_groups:
                ExpUtils.print_and_save(
                    self.HP,
                    "current learning rate: {}".format(param_group['lr']))

        if self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "single_direction":
            NR_OF_GRADIENTS = self.HP.NR_OF_GRADIENTS
            # NR_OF_GRADIENTS = 9 * 5
            # NR_OF_GRADIENTS = 9 * 9
            # NR_OF_GRADIENTS = 33
        elif self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "combined":
            NR_OF_GRADIENTS = 3 * self.HP.NR_OF_CLASSES
        else:
            NR_OF_GRADIENTS = 33

        if torch.cuda.is_available():
            net = UNet(n_input_channels=NR_OF_GRADIENTS,
                       n_classes=self.HP.NR_OF_CLASSES,
                       n_filt=self.HP.UNET_NR_FILT).cuda()
        else:
            net = UNet(n_input_channels=NR_OF_GRADIENTS,
                       n_classes=self.HP.NR_OF_CLASSES,
                       n_filt=self.HP.UNET_NR_FILT)

        # if self.HP.TRAIN:
        #     ExpUtils.print_and_save(self.HP, str(net), only_log=True)

        # criterion1 = PytorchUtils.MSE_weighted
        # criterion2 = PytorchUtils.angle_loss

        # criterion = PytorchUtils.MSE_weighted
        # criterion = PytorchUtils.angle_loss
        criterion = PytorchUtils.angle_length_loss

        optimizer = Adamax(net.parameters(), lr=self.HP.LEARNING_RATE)

        if self.HP.LOAD_WEIGHTS:
            ExpUtils.print_verbose(
                self.HP, "Loading weights ... ({})".format(
                    join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH)))
            load_model(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH))

        self.train = train
        self.predict = test
        self.get_probs = predict
        self.save_model = save_model
        self.load_model = load_model
        self.print_current_lr = print_current_lr
Example #37
0
 def __init__(self, HP, data):
     self.data = data
     self.HP = HP
     ExpUtils.print_verbose(self.HP, "Loading data from PREDICT_IMG input file")
Example #38
0
    def create_network(self):
        # torch.backends.cudnn.benchmark = True     #not faster

        def train(X, y, weight_factor=10):
            X = torch.tensor(X, dtype=torch.float32).to(device)   # X: (bs, features, x, y)   y: (bs, classes, x, y)
            y = torch.tensor(y, dtype=torch.float32).to(device)

            optimizer.zero_grad()
            net.train()
            outputs, outputs_sigmoid = net(X)  # forward     # outputs: (bs, classes, x, y)

            if weight_factor > 1:
                # weights = torch.ones((self.HP.BATCH_SIZE, self.HP.NR_OF_CLASSES, self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[1])).cuda()
                weights = torch.ones((self.HP.BATCH_SIZE, self.HP.NR_OF_CLASSES, y.shape[2], y.shape[3])).cuda()
                bundle_mask = y > 0
                weights[bundle_mask.data] *= weight_factor  # 10
                if self.HP.EXPERIMENT_TYPE == "peak_regression":
                    loss = criterion(outputs, y, weights)
                else:
                    loss = nn.BCEWithLogitsLoss(weight=weights)(outputs, y)
            else:
                if self.HP.LOSS_FUNCTION == "soft_sample_dice" or self.HP.LOSS_FUNCTION == "soft_batch_dice":
                    loss = criterion(outputs_sigmoid, y)
                    # loss = criterion(outputs_sigmoid, y) + nn.BCEWithLogitsLoss()(outputs, y)
                else:
                    loss = criterion(outputs, y)

            loss.backward()  # backward
            optimizer.step()  # optimise

            if self.HP.EXPERIMENT_TYPE == "peak_regression":
                # f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True)
                # f1_a = MetricUtils.calc_peak_dice_pytorch(self.HP, outputs.data, y.data, max_angle_error=self.HP.PEAK_DICE_THR)
                f1 = MetricUtils.calc_peak_length_dice_pytorch(self.HP, outputs.detach(), y.detach(),
                                                               max_angle_error=self.HP.PEAK_DICE_THR, max_length_error=self.HP.PEAK_DICE_LEN_THR)
                # f1 = (f1_a, f1_b)
            elif self.HP.EXPERIMENT_TYPE == "dm_regression":   #density map regression
                f1 = PytorchUtils.f1_score_macro(y.detach()>0.5, outputs.detach(), per_class=True)
            else:
                f1 = PytorchUtils.f1_score_macro(y.detach(), outputs_sigmoid.detach(), per_class=True, threshold=self.HP.THRESHOLD)

            if self.HP.USE_VISLOGGER:
                # probs = outputs_sigmoid.detach().cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
                probs = outputs_sigmoid
            else:
                probs = None    #faster

            return loss.item(), probs, f1


        def test(X, y, weight_factor=10):
            with torch.no_grad():
                X = torch.tensor(X, dtype=torch.float32).to(device)
                y = torch.tensor(y, dtype=torch.float32).to(device)

            if self.HP.DROPOUT_SAMPLING:
                net.train()
            else:
                net.train(False)
            outputs, outputs_sigmoid = net(X)  # forward

            if weight_factor > 1:
                # weights = torch.ones((self.HP.BATCH_SIZE, self.HP.NR_OF_CLASSES, self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[1])).cuda()
                weights = torch.ones((self.HP.BATCH_SIZE, self.HP.NR_OF_CLASSES, y.shape[2], y.shape[3])).cuda()
                bundle_mask = y > 0
                weights[bundle_mask.data] *= weight_factor  # 10
                if self.HP.EXPERIMENT_TYPE == "peak_regression":
                    loss = criterion(outputs, y, weights)
                else:
                    loss = nn.BCEWithLogitsLoss(weight=weights)(outputs, y)
            else:
                if self.HP.LOSS_FUNCTION == "soft_sample_dice" or self.HP.LOSS_FUNCTION == "soft_batch_dice":
                    loss = criterion(outputs_sigmoid, y)
                    # loss = criterion(outputs_sigmoid, y) + nn.BCEWithLogitsLoss()(outputs, y)
                else:
                    loss = criterion(outputs, y)

            if self.HP.EXPERIMENT_TYPE == "peak_regression":
                # f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True)
                # f1_a = MetricUtils.calc_peak_dice_pytorch(self.HP, outputs.data, y.data, max_angle_error=self.HP.PEAK_DICE_THR)
                f1 = MetricUtils.calc_peak_length_dice_pytorch(self.HP, outputs.detach(), y.detach(),
                                                               max_angle_error=self.HP.PEAK_DICE_THR, max_length_error=self.HP.PEAK_DICE_LEN_THR)
                # f1 = (f1_a, f1_b)
            elif self.HP.EXPERIMENT_TYPE == "dm_regression":   #density map regression
                f1 = PytorchUtils.f1_score_macro(y.detach()>0.5, outputs.detach(), per_class=True)
            else:
                f1 = PytorchUtils.f1_score_macro(y.detach(), outputs_sigmoid.detach(), per_class=True, threshold=self.HP.THRESHOLD)

            if self.HP.USE_VISLOGGER:
                # probs = outputs_sigmoid.detach().cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
                probs = outputs_sigmoid
            else:
                probs = None  # faster

            return loss.item(), probs, f1


        def predict(X):
            with torch.no_grad():
                X = torch.tensor(X, dtype=torch.float32).to(device)

            if self.HP.DROPOUT_SAMPLING:
                net.train()
            else:
                net.train(False)
            outputs, outputs_sigmoid = net(X)  # forward
            if self.HP.EXPERIMENT_TYPE == "peak_regression" or self.HP.EXPERIMENT_TYPE == "dm_regression":
                probs = outputs.detach().cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            else:
                probs = outputs_sigmoid.detach().cpu().numpy().transpose(0, 2, 3, 1)  # (bs, x, y, classes)
            return probs


        def save_model(metrics, epoch_nr):
            max_f1_idx = np.argmax(metrics["f1_macro_validate"])
            max_f1 = np.max(metrics["f1_macro_validate"])
            if epoch_nr == max_f1_idx and max_f1 > 0.01:  # saving to network drives takes 5s (to local only 0.5s) -> do not save so often
                print("  Saving weights...")
                for fl in glob.glob(join(self.HP.EXP_PATH, "best_weights_ep*")):  # remove weights from previous epochs
                    os.remove(fl)
                try:
                    #Actually is a pkl not a npz
                    PytorchUtils.save_checkpoint(join(self.HP.EXP_PATH, "best_weights_ep" + str(epoch_nr) + ".npz"), unet=net)
                except IOError:
                    print("\nERROR: Could not save weights because of IO Error\n")
                self.HP.BEST_EPOCH = epoch_nr

        def load_model(path):
            PytorchUtils.load_checkpoint(path, unet=net)

        def print_current_lr():
            for param_group in optimizer.param_groups:
                ExpUtils.print_and_save(self.HP, "current learning rate: {}".format(param_group['lr']))


        if self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "single_direction":
            NR_OF_GRADIENTS = self.HP.NR_OF_GRADIENTS
            # NR_OF_GRADIENTS = 9
            # NR_OF_GRADIENTS = 9 * 5
            # NR_OF_GRADIENTS = 9 * 9
            # NR_OF_GRADIENTS = 33
        elif self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "combined":
            self.HP.NR_OF_GRADIENTS = 3*self.HP.NR_OF_CLASSES
        else:
            self.HP.NR_OF_GRADIENTS = 33

        if self.HP.LOSS_FUNCTION == "soft_sample_dice":
            criterion = PytorchUtils.soft_sample_dice
        elif self.HP.LOSS_FUNCTION == "soft_batch_dice":
            criterion = PytorchUtils.soft_batch_dice
        elif self.HP.EXPERIMENT_TYPE == "peak_regression":
            criterion = PytorchUtils.angle_length_loss
        else:
            # weights = torch.ones((self.HP.BATCH_SIZE, self.HP.NR_OF_CLASSES, self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[1])).cuda()
            # weights[:, 5, :, :] *= 10     #CA
            # weights[:, 21, :, :] *= 10    #FX_left
            # weights[:, 22, :, :] *= 10    #FX_right
            # criterion = nn.BCEWithLogitsLoss(weight=weights)
            criterion = nn.BCEWithLogitsLoss()

        NetworkClass = getattr(importlib.import_module("tractseg.models." + self.HP.MODEL), self.HP.MODEL)
        net = NetworkClass(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT,
                   batchnorm=self.HP.BATCH_NORM, dropout=self.HP.USE_DROPOUT)

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        net = net.to(device)

        # if self.HP.TRAIN:
        #     ExpUtils.print_and_save(self.HP, str(net), only_log=True)

        if self.HP.OPTIMIZER == "Adamax":
            optimizer = Adamax(net.parameters(), lr=self.HP.LEARNING_RATE)
        elif self.HP.OPTIMIZER == "Adam":
            optimizer = Adam(net.parameters(), lr=self.HP.LEARNING_RATE)
            # optimizer = Adam(net.parameters(), lr=self.HP.LEARNING_RATE, weight_decay=self.HP.WEIGHT_DECAY)
        else:
            raise ValueError("Optimizer not defined")

        if self.HP.LR_SCHEDULE:
            scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
            # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="max")
            self.scheduler = scheduler

        if self.HP.LOAD_WEIGHTS:
            ExpUtils.print_verbose(self.HP, "Loading weights ... ({})".format(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH)))
            load_model(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH))

        if self.HP.RESET_LAST_LAYER:
            # net.conv_5 = conv2d(self.HP.UNET_NR_FILT, self.HP.NR_OF_CLASSES, kernel_size=1, stride=1, padding=0, bias=True).to(device)
            net.conv_5 = nn.Conv2d(self.HP.UNET_NR_FILT, self.HP.NR_OF_CLASSES, kernel_size=1, stride=1, padding=0, bias=True).to(device)

        self.train = train
        self.predict = test
        self.get_probs = predict
        self.save_model = save_model
        self.load_model = load_model
        self.print_current_lr = print_current_lr
Example #39
0
    def generate_train_batch(self):
        subjects = self._data[0]
        subject_idx = int(random.uniform(0, len(subjects)))     # len(subjects)-1 not needed because int always rounds to floor

        for i in range(20):
            try:
                if np.random.random() < 0.5:
                    data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data()
                else:
                    data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data()

                # rnd_choice = np.random.random()
                # if rnd_choice < 0.33:
                #     data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data()
                # elif rnd_choice < 0.66:
                #     data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data()
                # else:
                #     data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "12g_125mm_peaks.nii.gz")).get_data()

                seg = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], self.HP.LABELS_FILENAME + ".nii.gz")).get_data()
                break
            except IOError:
                ExpUtils.print_and_save(self.HP, "\n\nWARNING: Could not load file. Trying again in 20s (Try number: " + str(i) + ").\n\n")
            ExpUtils.print_and_save(self.HP, "Sleeping 20s")
            sleep(20)
        # ExpUtils.print_and_save(self.HP, "Successfully loaded input.")

        data = np.nan_to_num(data)    # Needed otherwise not working
        seg = np.nan_to_num(seg)

        data = DatasetUtils.scale_input_to_unet_shape(data, self.HP.DATASET, self.HP.RESOLUTION)    # (x, y, z, channels)
        if self.HP.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]:
            # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask to the lower resolution
            seg = DatasetUtils.scale_input_to_unet_shape(seg, "HCP", self.HP.RESOLUTION)
        else:
            seg = DatasetUtils.scale_input_to_unet_shape(seg, self.HP.DATASET, self.HP.RESOLUTION)  # (x, y, z, classes)

        slice_idxs = np.random.choice(data.shape[0], self.BATCH_SIZE, False, None)

        # Randomly sample slice orientation
        slice_direction = int(round(random.uniform(0,2)))

        if slice_direction == 0:
            y = seg[slice_idxs, :, :].astype(self.HP.LABELS_TYPE)
            y = np.array(y).transpose(0, 3, 1, 2)  # nr_classes channel has to be before with and height for DataAugmentation (bs, nr_of_classes, x, y)
        elif slice_direction == 1:
            y = seg[:, slice_idxs, :].astype(self.HP.LABELS_TYPE)
            y = np.array(y).transpose(1, 3, 0, 2)
        elif slice_direction == 2:
            y = seg[:, :, slice_idxs].astype(self.HP.LABELS_TYPE)
            y = np.array(y).transpose(2, 3, 0, 1)


        sw = 5 #slice_window (only odd numbers allowed)
        pad = int((sw-1) / 2)

        data_pad = np.zeros((data.shape[0]+sw-1, data.shape[1]+sw-1, data.shape[2]+sw-1, data.shape[3])).astype(data.dtype)
        data_pad[pad:-pad, pad:-pad, pad:-pad, :] = data   #padded with two slices of zeros on all sides
        batch=[]
        for s_idx in slice_idxs:
            if slice_direction == 0:
                #(s_idx+2)-2:(s_idx+2)+3 = s_idx:s_idx+5
                x = data_pad[s_idx:s_idx+sw:, pad:-pad, pad:-pad, :].astype(np.float32)      # (5, y, z, channels)
                x = np.array(x).transpose(0, 3, 1, 2)  # channels dim has to be before width and height for Unet (but after batches)
                x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))  # (5*channels, y, z)
                batch.append(x)
            elif slice_direction == 1:
                x = data_pad[pad:-pad, s_idx:s_idx+sw, pad:-pad, :].astype(np.float32)  # (5, y, z, channels)
                x = np.array(x).transpose(1, 3, 0, 2)  # channels dim has to be before width and height for Unet (but after batches)
                x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))  # (5*channels, y, z)
                batch.append(x)
            elif slice_direction == 2:
                x = data_pad[pad:-pad, pad:-pad, s_idx:s_idx+sw, :].astype(np.float32)  # (5, y, z, channels)
                x = np.array(x).transpose(2, 3, 0, 1)  # channels dim has to be before width and height for Unet (but after batches)
                x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))  # (5*channels, y, z)
                batch.append(x)
        data_dict = {"data": np.array(batch),     # (batch_size, channels, x, y, [z])
                     "seg": y}                    # (batch_size, channels, x, y, [z])

        return data_dict
Example #40
0
class HP:
    EXP_MULTI_NAME = ""  #CV Parent Dir name # leave empty for Single Bundle Experiment
    EXP_NAME = "HCP_TEST"  # HCP_TEST
    MODEL = "UNet_Pytorch"  # UNet_Lasagne / UNet_Pytorch
    EXPERIMENT_TYPE = "tract_segmentation"  # tract_segmentation / endings_segmentation / dm_regression / peak_regression

    NUM_EPOCHS = 250
    DATA_AUGMENTATION = False
    DAUG_SCALE = True
    DAUG_NOISE = True
    DAUG_ELASTIC_DEFORM = True
    DAUG_RESAMPLE = True
    DAUG_ROTATE = False
    DAUG_MIRROR = False
    DAUG_FLIP_PEAKS = False
    DAUG_INFO = "Elastic(90,120)(9,11) - Scale(0.9, 1.5) - CenterDist60 - DownsampScipy(0.5,1) - Gaussian(0,0.05) - Rotate(-0.8,0.8)"
    DATASET = "HCP"  # HCP / HCP_32g / Schizo
    RESOLUTION = "1.25mm"  # 1.25mm (/ 2.5mm)
    FEATURES_FILENAME = "12g90g270g"  # 12g90g270g / 270g_125mm_xyz / 270g_125mm_peaks / 90g_125mm_peaks / 32g_25mm_peaks / 32g_25mm_xyz
    LABELS_FILENAME = ""  # autofilled      #"bundle_peaks/CA"  #IMPORTANT: Adapt BatchGen if 808080              # bundle_masks / bundle_masks_72 / bundle_masks_dm / bundle_peaks      #Only used when using DataManagerNifti
    LOSS_FUNCTION = "default"  # default / soft_batch_dice
    OPTIMIZER = "Adamax"
    CLASSES = "All"  # All / 11 / 20 / CST_right
    NR_OF_GRADIENTS = 9
    NR_OF_CLASSES = len(ExpUtils.get_bundle_names(CLASSES)[1:])
    # NR_OF_CLASSES = 3 * len(ExpUtils.get_bundle_names(CLASSES)[1:])

    INPUT_DIM = (144, 144)  # (80, 80) / (144, 144)
    LOSS_WEIGHT = 1  # 1: no weighting
    LOSS_WEIGHT_LEN = -1  # -1: constant over all epochs
    SLICE_DIRECTION = "y"  # x, y, z  (combined needs z)
    TRAINING_SLICE_DIRECTION = "xyz"  # y / xyz
    INFO = "-"  # Dropout, Deconv, 11bundles, LeakyRelu, PeakDiceThres=0.9
    BATCH_NORM = False
    WEIGHT_DECAY = 0
    USE_DROPOUT = False
    DROPOUT_SAMPLING = False

    # DATASET_FOLDER = "HCP_batches/270g_125mm_bundle_peaks_Y_subset"  # HCP / HCP_batches/XXX / TRACED / HCP_fusion_npy_270g_125mm / HCP_fusion_npy_32g_25mm
    # DATASET_FOLDER = "HCP_batches/270g_125mm_bundle_peaks_XYZ"
    DATASET_FOLDER = "HCP"  # HCP / Schizo
    LABELS_FOLDER = "bundle_masks"  # bundle_masks / bundle_masks_dm
    MULTI_PARENT_PATH = join(C.EXP_PATH, EXP_MULTI_NAME)
    EXP_PATH = join(C.EXP_PATH, EXP_MULTI_NAME, EXP_NAME)  # default path
    BATCH_SIZE = 47  #30/44  #max: #Peak Prediction: 44 #Pytorch: 50  #Lasagne: 56  #Lasagne combined: 42  #Pytorch UpSample: 56   #Pytorch_SE_r16: 45    #Pytorch_SE_r64: 45
    LEARNING_RATE = 0.001  # 0.002 #LR find: 0.000143 ?  # 0.001
    LR_SCHEDULE = False
    UNET_NR_FILT = 64
    LOAD_WEIGHTS = False
    # WEIGHTS_PATH = join(C.EXP_PATH, "HCP100_45B_UNet_x_DM_lr002_slope2_dec992_ep800/best_weights_ep64.npz")    # Can be absolute path or relative like "exp_folder/weights.npz"
    WEIGHTS_PATH = ""  # if empty string: autoloading the best_weights in get_best_weights_path()
    TYPE = "single_direction"  # single_direction / combined
    CV_FOLD = 0
    VALIDATE_SUBJECTS = []
    TRAIN_SUBJECTS = []
    TEST_SUBJECTS = []
    TRAIN = True
    TEST = True
    SEGMENT = False
    GET_PROBS = False
    OUTPUT_MULTIPLE_FILES = False
    RESET_LAST_LAYER = False

    # Peak_regression specific
    PEAK_DICE_THR = [0.95]
    PEAK_DICE_LEN_THR = 0.05
    FLIP_OUTPUT_PEAKS = True  # flip peaks along z axis to make them compatible with MITK

    # For TractSeg.py application
    PREDICT_IMG = False
    PREDICT_IMG_OUTPUT = None
    TRACTSEG_DIR = "tractseg_output"
    KEEP_INTERMEDIATE_FILES = False
    CSD_RESOLUTION = "LOW"  # HIGH / LOW

    #Unimportant / rarly changed:
    LABELS_TYPE = np.int16  # Binary: np.int16, Regression: np.float32
    THRESHOLD = 0.5  # Binary: 0.5, Regression: 0.01 ?
    TEST_TIME_DAUG = False
    USE_VISLOGGER = False  #only works with Python 3
    SAVE_WEIGHTS = True
    SEG_INPUT = "Peaks"  # Gradients/ Peaks
    NR_SLICES = 1  # adapt manually: NR_OF_GRADIENTS in UNet.py and get_batch... in train() and in get_seg_prediction()
    PRINT_FREQ = 20  #20
    NORMALIZE_DATA = True
    NORMALIZE_PER_CHANNEL = False
    BEST_EPOCH = 0
    VERBOSE = True
    CALC_F1 = True
Example #41
0
 def __init__(self, HP, data):
     self.data = data
     self.HP = HP
     ExpUtils.print_verbose(self.HP,
                            "Loading data from PREDICT_IMG input file")
Example #42
0
 def print_current_lr():
     for param_group in optimizer.param_groups:
         ExpUtils.print_and_save(
             self.HP,
             "current learning rate: {}".format(param_group['lr']))
Example #43
0
    def create_network(self):
        # torch.backends.cudnn.benchmark = True     #not faster

        def train(X, y):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda()), Variable(y.cuda())  # X: (bs, features, x, y)   y: (bs, classes, x, y)
            else:
                X, y = Variable(X), Variable(y)
            optimizer.zero_grad()
            net.train()

            outputs, intermediate = net(X)  # forward     # outputs: (bs, classes, x, y)

            loss = criterion(outputs, y)
            # loss = PytorchUtils.soft_dice(outputs, y)
            loss.backward()  # backward
            optimizer.step()  # optimise
            f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True)

            if self.HP.USE_VISLOGGER:
                probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            else:
                probs = None    #faster

            return loss.data[0], probs, f1, intermediate

        def test(X, y):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda(), volatile=True), Variable(y.cuda(), volatile=True)
            else:
                X, y = Variable(X, volatile=True), Variable(y, volatile=True)
            net.train(False)
            outputs = net(X)  # forward
            loss = criterion(outputs, y)
            # loss = PytorchUtils.soft_dice(outputs, y)
            f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True)
            # probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            probs = None  # faster
            return loss.data[0], probs, f1

        def predict(X):
            X = torch.from_numpy(X.astype(np.float32))
            if torch.cuda.is_available():
                X = Variable(X.cuda(), volatile=True)
            else:
                X = Variable(X, volatile=True)
            net.train(False)
            outputs = net(X)  # forward
            probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            return probs

        def save_model(metrics, epoch_nr):
            max_f1_idx = np.argmax(metrics["f1_macro_validate"])
            max_f1 = np.max(metrics["f1_macro_validate"])
            if epoch_nr == max_f1_idx and max_f1 > 0.01:  # saving to network drives takes 5s (to local only 0.5s) -> do not save so often
                print("  Saving weights...")
                for fl in glob.glob(join(self.HP.EXP_PATH, "best_weights_ep*")):  # remove weights from previous epochs
                    os.remove(fl)
                try:
                    #Actually is a pkl not a npz
                    PytorchUtils.save_checkpoint(join(self.HP.EXP_PATH, "best_weights_ep" + str(epoch_nr) + ".npz"), unet=net)
                except IOError:
                    print("\nERROR: Could not save weights because of IO Error\n")
                self.HP.BEST_EPOCH = epoch_nr

        def load_model(path):
            PytorchUtils.load_checkpoint(path, unet=net)

        def print_current_lr():
            for param_group in optimizer.param_groups:
                ExpUtils.print_and_save(self.HP, "current learning rate: {}".format(param_group['lr']))


        if self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "single_direction":
            NR_OF_GRADIENTS = 9
            # NR_OF_GRADIENTS = 9 * 5
        elif self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "combined":
            NR_OF_GRADIENTS = 3*self.HP.NR_OF_CLASSES
        else:
            NR_OF_GRADIENTS = 33

        if torch.cuda.is_available():
            net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT).cuda()
        else:
            net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT)

        # net = nn.DataParallel(net, device_ids=[0,1])

        if self.HP.TRAIN:
            ExpUtils.print_and_save(self.HP, str(net), only_log=True)

        criterion = nn.BCEWithLogitsLoss()
        optimizer = Adamax(net.parameters(), lr=self.HP.LEARNING_RATE)
        # optimizer = Adam(net.parameters(), lr=self.HP.LEARNING_RATE)  #very slow (half speed of Adamax) -> strange
        # scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)
        # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="max")

        if self.HP.LOAD_WEIGHTS:
            ExpUtils.print_verbose(self.HP, "Loading weights ... ({})".format(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH)))
            load_model(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH))

        #plot feature weights
        # weights = list(list(net.children())[0].children())[0].weight.cpu().data.numpy()   # sequential -> conv2d   # (64, 9, 3, 3)
        # weights = weights[:, 0:1, :, :]  # select one input channel to plot       # (64, 1, 3, 3)
        # weights = (weights*100).astype(np.uint8) # can not plot negative values (and if float only 0-1 allowed) -> not good: we remove negatives
        # plot_kernels(weights)

        self.train = train
        self.predict = test
        self.get_probs = predict
        self.save_model = save_model
        self.load_model = load_model
        self.print_current_lr = print_current_lr
        # self.scheduler = scheduler
Example #44
0
    def train(self, HP):

        if HP.USE_VISLOGGER:
            nvl = Nvl(name="Training")

        ExpUtils.print_and_save(HP, socket.gethostname())

        epoch_times = []
        nr_of_updates = 0

        metrics = {}
        for type in ["train", "test", "validate"]:
            metrics_new = {
                "loss_" + type: [0],
                "f1_macro_" + type: [0],
            }
            metrics = dict(list(metrics.items()) + list(metrics_new.items()))

        for epoch_nr in range(HP.NUM_EPOCHS):
            start_time = time.time()
            # current_lr = HP.LEARNING_RATE * (HP.LR_DECAY ** epoch_nr)
            # current_lr = HP.LEARNING_RATE

            batch_gen_time = 0
            data_preparation_time = 0
            network_time = 0
            metrics_time = 0
            saving_time = 0
            plotting_time = 0

            batch_nr = {"train": 0, "test": 0, "validate": 0}

            if HP.LOSS_WEIGHT_LEN == -1:
                weight_factor = float(HP.LOSS_WEIGHT)
            else:
                if epoch_nr < HP.LOSS_WEIGHT_LEN:
                    # weight_factor = -(9./100.) * epoch_nr + 10.   #ep0: 10 -> linear decrease -> ep100: 1
                    weight_factor = -((HP.LOSS_WEIGHT - 1) / float(
                        HP.LOSS_WEIGHT_LEN)) * epoch_nr + float(HP.LOSS_WEIGHT)
                    # weight_factor = -((HP.LOSS_WEIGHT-5)/float(HP.LOSS_WEIGHT_LEN)) * epoch_nr + float(HP.LOSS_WEIGHT)
                else:
                    weight_factor = 1.
                    # weight_factor = 5.

            for type in ["train", "test", "validate"]:
                print_loss = []
                start_time_batch_gen = time.time()

                batch_generator = self.dataManager.get_batches(
                    batch_size=HP.BATCH_SIZE,
                    type=type,
                    subjects=getattr(HP,
                                     type.upper() + "_SUBJECTS"))
                batch_gen_time = time.time() - start_time_batch_gen
                # print("batch_gen_time: {}s".format(batch_gen_time))

                print("Start looping batches...")
                start_time_batch_part = time.time()
                for batch in batch_generator:  #getting next batch takes around 0.14s -> second largest Time part after UNet!

                    start_time_data_preparation = time.time()
                    batch_nr[type] += 1

                    x = batch["data"]  # (bs, nr_of_channels, x, y)
                    y = batch["seg"]  # (bs, nr_of_classes, x, y)
                    # since using new BatchGenerator y is not int anymore but float -> would be good for Pytorch but not Lasagne
                    # y = y.astype(HP.LABELS_TYPE)  #for bundle_peaks regression: is already float -> saves 0.2s/batch if left out

                    data_preparation_time += time.time(
                    ) - start_time_data_preparation
                    # self.model.learning_rate.set_value(np.float32(current_lr))
                    start_time_network = time.time()
                    if type == "train":
                        nr_of_updates += 1
                        loss, probs, f1 = self.model.train(
                            x, y, weight_factor=weight_factor
                        )  # probs: # (bs, x, y, nrClasses)
                        # loss, probs, f1, intermediate = self.model.train(x, y)
                    elif type == "validate":
                        loss, probs, f1 = self.model.predict(
                            x, y, weight_factor=weight_factor)
                    elif type == "test":
                        loss, probs, f1 = self.model.predict(
                            x, y, weight_factor=weight_factor)
                    network_time += time.time() - start_time_network

                    start_time_metrics = time.time()

                    if HP.CALC_F1:
                        if HP.LABELS_TYPE == np.int16:
                            metrics = MetricUtils.calculate_metrics(
                                metrics,
                                None,
                                None,
                                loss,
                                f1=np.mean(f1),
                                type=type,
                                threshold=HP.THRESHOLD)

                        else:  #Regression
                            #Following two lines increase metrics_time by 30s (without < 1s); time per batch increases by 1.5s by these lines
                            # y_flat = y.transpose(0, 2, 3, 1)  # (bs, x, y, nr_of_classes)
                            # y_flat = np.reshape(y_flat, (-1, y_flat.shape[-1]))  # (bs*x*y, nr_of_classes)
                            # metrics = MetricUtils.calculate_metrics(metrics, y_flat, probs, loss, f1=np.mean(f1), type=type, threshold=HP.THRESHOLD,
                            #                                         f1_per_bundle={"CA": f1[5], "FX_left": f1[23], "FX_right": f1[24]})

                            #Numpy
                            # y_right_order = y.transpose(0, 2, 3, 1)  # (bs, x, y, nr_of_classes)
                            # peak_f1 = MetricUtils.calc_peak_dice(HP, probs, y_right_order)
                            # peak_f1_mean = np.array([s for s in peak_f1.values()]).mean()

                            #Pytorch
                            peak_f1_mean = np.array([
                                s for s in list(f1.values())
                            ]).mean()  #if f1 for multiple bundles
                            metrics = MetricUtils.calculate_metrics(
                                metrics,
                                None,
                                None,
                                loss,
                                f1=peak_f1_mean,
                                type=type,
                                threshold=HP.THRESHOLD)

                            #Pytorch 2 F1
                            # peak_f1_mean_a = np.array([s for s in f1[0].values()]).mean()
                            # peak_f1_mean_b = np.array([s for s in f1[1].values()]).mean()
                            # metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=peak_f1_mean_a, type=type, threshold=HP.THRESHOLD,
                            #                                         f1_per_bundle={"LenF1": peak_f1_mean_b})

                            #Single Bundle
                            # metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=f1["CST_right"][0], type=type, threshold=HP.THRESHOLD,
                            #                                         f1_per_bundle={"Thr1": f1["CST_right"][1], "Thr2": f1["CST_right"][2]})
                            # metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=f1["CST_right"], type=type, threshold=HP.THRESHOLD)
                    else:
                        metrics = MetricUtils.calculate_metrics_onlyLoss(
                            metrics, loss, type=type)

                    metrics_time += time.time() - start_time_metrics

                    print_loss.append(loss)
                    if batch_nr[type] % HP.PRINT_FREQ == 0:
                        time_batch_part = time.time() - start_time_batch_part
                        start_time_batch_part = time.time()
                        ExpUtils.print_and_save(
                            HP,
                            "{} Ep {}, Sp {}, loss {}, t print {}s, t batch {}s"
                            .format(type, epoch_nr,
                                    batch_nr[type] * HP.BATCH_SIZE,
                                    round(np.array(print_loss).mean(), 6),
                                    round(time_batch_part, 3),
                                    round(time_batch_part / HP.PRINT_FREQ, 3)))
                        print_loss = []

                    if HP.USE_VISLOGGER:
                        x_norm = (x - x.min()) / (x.max() - x.min())
                        nvl.show_images(
                            x_norm[0:1, :, :, :].transpose((1, 0, 2, 3)),
                            name="input batch",
                            title="Input batch")  #all channels of one batch
                        probs_shaped = probs[:, :, :, 15:16].transpose(
                            (0, 3, 1, 2))  # (bs, 1, x, y)
                        probs_shaped_bin = (probs_shaped > 0.5).astype(
                            np.int16)
                        nvl.show_images(probs_shaped,
                                        name="predictions",
                                        title="Predictions Probmap")
                        # nvl.show_images(probs_shaped_bin, name="predictions_binary", title="Predictions Binary")

                        # Show GT and Prediction in one image  (bundle: CST)
                        # GREEN: GT; RED: prediction (FP); YELLOW: prediction (TP)
                        combined = np.zeros(
                            (y.shape[0], 3, y.shape[2], y.shape[3]))
                        combined[:, 0:1, :, :] = probs_shaped_bin  #Red
                        combined[:, 1:2, :, :] = y[:, 15:16, :, :]  #Green
                        nvl.show_images(combined,
                                        name="predictions_combined",
                                        title="Combined")

                        #Show feature activations
                        contr_1_2 = intermediate[2].data.cpu().numpy(
                        )  # (bs, nr_feature_channels=64, x, y)
                        contr_1_2 = contr_1_2[0:1, :, :, :].transpose(
                            (1, 0, 2, 3))  # (nr_feature_channels=64, 1, x, y)
                        contr_1_2 = (contr_1_2 - contr_1_2.min()) / (
                            contr_1_2.max() - contr_1_2.min())
                        nvl.show_images(contr_1_2,
                                        name="contr_1_2",
                                        title="contr_1_2")

                        # Show feature activations
                        contr_3_2 = intermediate[1].data.cpu().numpy(
                        )  # (bs, nr_feature_channels=64, x, y)
                        contr_3_2 = contr_3_2[0:1, :, :, :].transpose(
                            (1, 0, 2, 3))  # (nr_feature_channels=64, 1, x, y)
                        contr_3_2 = (contr_3_2 - contr_3_2.min()) / (
                            contr_3_2.max() - contr_3_2.min())
                        nvl.show_images(contr_3_2,
                                        name="contr_3_2",
                                        title="contr_3_2")

                        # Show feature activations
                        deconv_2 = intermediate[0].data.cpu().numpy(
                        )  # (bs, nr_feature_channels=64, x, y)
                        deconv_2 = deconv_2[0:1, :, :, :].transpose(
                            (1, 0, 2, 3))  # (nr_feature_channels=64, 1, x, y)
                        deconv_2 = (deconv_2 - deconv_2.min()) / (
                            deconv_2.max() - deconv_2.min())
                        nvl.show_images(deconv_2,
                                        name="deconv_2",
                                        title="deconv_2")

                        nvl.show_value(float(loss), name="loss")
                        nvl.show_value(float(np.mean(f1)), name="f1")

            ###################################
            # Post Training tasks (each epoch)
            ###################################

            #Adapt LR
            # self.model.scheduler.step()
            # self.model.scheduler.step(np.mean(f1))
            # self.model.print_current_lr()

            # Average loss per batch over entire epoch
            metrics = MetricUtils.normalize_last_element(metrics,
                                                         batch_nr["train"],
                                                         type="train")
            metrics = MetricUtils.normalize_last_element(metrics,
                                                         batch_nr["validate"],
                                                         type="validate")
            metrics = MetricUtils.normalize_last_element(metrics,
                                                         batch_nr["test"],
                                                         type="test")

            print("  Epoch {}, Average Epoch loss = {}".format(
                epoch_nr, metrics["loss_train"][-1]))
            print("  Epoch {}, nr_of_updates {}".format(
                epoch_nr, nr_of_updates))

            # Save Weights
            start_time_saving = time.time()
            if HP.SAVE_WEIGHTS:
                self.model.save_model(metrics, epoch_nr)
            saving_time += time.time() - start_time_saving

            # Create Plots
            start_time_plotting = time.time()
            pickle.dump(
                metrics, open(join(HP.EXP_PATH, "metrics.pkl"), "wb")
            )  # wb -> write (override) and binary (binary only needed on windows, on unix also works without) # for loading: pickle.load(open("metrics.pkl", "rb"))
            ExpUtils.create_exp_plot(metrics, HP.EXP_PATH, HP.EXP_NAME)
            ExpUtils.create_exp_plot(metrics,
                                     HP.EXP_PATH,
                                     HP.EXP_NAME,
                                     without_first_epochs=True)
            plotting_time += time.time() - start_time_plotting

            epoch_time = time.time() - start_time
            epoch_times.append(epoch_time)

            ExpUtils.print_and_save(
                HP, "  Epoch {}, time total {}s".format(epoch_nr, epoch_time))
            ExpUtils.print_and_save(
                HP,
                "  Epoch {}, time UNet: {}s".format(epoch_nr, network_time))
            ExpUtils.print_and_save(
                HP,
                "  Epoch {}, time metrics: {}s".format(epoch_nr, metrics_time))
            ExpUtils.print_and_save(
                HP, "  Epoch {}, time saving files: {}s".format(
                    epoch_nr, saving_time))
            ExpUtils.print_and_save(HP, str(datetime.datetime.now()))

            # Adding next Epoch
            if epoch_nr < HP.NUM_EPOCHS - 1:
                metrics = MetricUtils.add_empty_element(metrics)

        ####################################
        # After all epochs
        ###################################
        with open(join(HP.EXP_PATH, "Hyperparameters.txt"),
                  "a") as f:  # a for append
            f.write("\n\n")
            f.write("Average Epoch time: {}s".format(
                sum(epoch_times) / float(len(epoch_times))))

        return metrics
Example #45
0
 def print_current_lr():
     for param_group in optimizer.param_groups:
         ExpUtils.print_and_save(self.HP, "current learning rate: {}".format(param_group['lr']))
Example #46
0
    def calc_peak_length_dice_pytorch(HP,
                                      y_pred,
                                      y_true,
                                      max_angle_error=[0.9],
                                      max_length_error=0.1):
        '''
        Ca

        :param y_pred:
        :param y_true:
        :param max_angle_error:  0.7 ->  angle error of 45° or less; 0.9 ->  angle error of 23° or less
                                 Can be list with several values -> calculate for several thresholds
        :return:
        '''
        import torch
        from tractseg.libs.PytorchEinsum import einsum
        from tractseg.libs.PytorchUtils import PytorchUtils

        y_true = y_true.permute(0, 2, 3, 1)
        y_pred = y_pred.permute(0, 2, 3, 1)

        def angle_last_dim(a, b):
            '''
            Calculate the angle between two nd-arrays (array of vectors) along the last dimension

            without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90°
            np.arccos -> returns degree in pi (90°: 0.5*pi)

            return: one dimension less then input
            '''
            return torch.abs(
                einsum('abcd,abcd->abc', a, b) /
                (torch.norm(a, 2., -1) * torch.norm(b, 2, -1) + 1e-7))

        #Single threshold
        score_per_bundle = {}
        bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
        for idx, bundle in enumerate(bundles):
            # if bundle == "CST_right":
            y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()
            y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) +
                                 3].contiguous()  # [x,y,z,3]

            angles = angle_last_dim(y_pred_bund, y_true_bund)

            lenghts_pred = torch.norm(y_pred_bund, 2., -1)
            lengths_true = torch.norm(y_true_bund, 2, -1)
            lengths_binary = torch.abs(lenghts_pred - lengths_true) < (
                max_length_error * lengths_true)
            lengths_binary = lengths_binary.view(-1)

            gt_binary = y_true_bund.sum(dim=-1) > 0
            gt_binary = gt_binary.view(-1)  # [bs*x*y]

            angles_binary = angles > max_angle_error[0]
            angles_binary = angles_binary.view(-1)

            combined = lengths_binary * angles_binary

            f1 = PytorchUtils.f1_score_binary(gt_binary, combined)
            score_per_bundle[bundle] = f1
        return score_per_bundle
Example #47
0
def run_tractseg(data, output_type="tract_segmentation", input_type="peaks",
                 single_orientation=False, verbose=False, dropout_sampling=False, threshold=0.5,
                 bundle_specific_threshold=False, get_probs=False):
    '''
    Run TractSeg

    :param data: input peaks (4D numpy array with shape [x,y,z,9])
    :param output_type: "tract_segmentation" | "endings_segmentation" | "TOM" | "dm_regression"
    :param input_type: "peaks"
    :param verbose: show debugging infos
    :param dropout_sampling: create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142)
    :param threshold: Threshold for converting probability map to binary map
    :param bundle_specific_threshold: Threshold is lower for some bundles which need more sensitivity (CA, CST, FX)
    :param get_probs: Output raw probability map instead of binary map
    :return: 4D numpy array with the output of tractseg
        for tract_segmentation:     [x,y,z,nr_of_bundles]
        for endings_segmentation:   [x,y,z,2*nr_of_bundles]
        for TOM:                    [x,y,z,3*nr_of_bundles]
    '''
    start_time = time.time()

    config = get_config_name(input_type, output_type)
    HP = getattr(importlib.import_module("tractseg.config.PretrainedModels." + config), "HP")()
    HP.VERBOSE = verbose
    HP.TRAIN = False
    HP.TEST = False
    HP.SEGMENT = False
    HP.GET_PROBS = get_probs
    HP.LOAD_WEIGHTS = True
    HP.DROPOUT_SAMPLING = dropout_sampling
    HP.THRESHOLD = threshold

    if bundle_specific_threshold:
        HP.GET_PROBS = True

    if input_type == "peaks":
        if HP.EXPERIMENT_TYPE == "tract_segmentation" and HP.DROPOUT_SAMPLING:
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_dropout_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_12g90g270g_125mm_DAugAll_Dropout", "best_weights_ep114.npz")
        elif HP.EXPERIMENT_TYPE == "tract_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_T1_12g90g270g_125mm_DAugAll", "best_weights_ep126.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888", "best_weights_ep247.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888_SchizoFineT_lr001", "best_weights_ep186.npz")
        elif HP.EXPERIMENT_TYPE == "endings_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v2.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "EndingsSeg_12g90g270g_125mm_DAugAll", "best_weights_ep16.npz")
        elif HP.EXPERIMENT_TYPE == "peak_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "x_Pretrained_TractSeg_Models/Peaks20_12g90g270g_125mm_DAugSimp_constW5", "best_weights_ep441.npz")  #more oversegmentation with DAug
        elif HP.EXPERIMENT_TYPE == "dm_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_dm_regression_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "DmReg_12g90g270g_125mm_DAugAll_Ubuntu", "best_weights_ep80.npz")
    elif input_type == "T1":
        if HP.EXPERIMENT_TYPE == "tract_segmentation":
            # HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz")
            HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz")
        elif HP.EXPERIMENT_TYPE == "endings_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v1.npz")
        elif HP.EXPERIMENT_TYPE == "peak_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz")
    print("Loading weights from: {}".format(HP.WEIGHTS_PATH))

    if HP.EXPERIMENT_TYPE == "peak_regression":
        HP.NR_OF_CLASSES = 3*len(ExpUtils.get_bundle_names(HP.CLASSES)[1:])
    else:
        HP.NR_OF_CLASSES = len(ExpUtils.get_bundle_names(HP.CLASSES)[1:])

    if HP.VERBOSE:
        print("Hyperparameters:")
        ExpUtils.print_HPs(HP)

    Utils.download_pretrained_weights(experiment_type=HP.EXPERIMENT_TYPE, dropout_sampling=HP.DROPOUT_SAMPLING)

    data = np.nan_to_num(data)
    # brain_mask = ImgUtils.simple_brain_mask(data)
    # if HP.VERBOSE:
    #     nib.save(nib.Nifti1Image(brain_mask, np.eye(4)), "otsu_brain_mask_DEBUG.nii.gz")

    if input_type == "T1":
        data = np.reshape(data, (data.shape[0], data.shape[1], data.shape[2], 1))
    data, seg_None, bbox, original_shape = DatasetUtils.crop_to_nonzero(data)
    data, transformation = DatasetUtils.pad_and_scale_img_to_square_img(data, target_size=HP.INPUT_DIM[0])

    model = BaseModel(HP)

    if HP.EXPERIMENT_TYPE == "tract_segmentation" or HP.EXPERIMENT_TYPE == "endings_segmentation" or HP.EXPERIMENT_TYPE == "dm_regression":
        if single_orientation:     # mainly needed for testing because of less RAM requirements
            dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data)
            trainerSingle = Trainer(model, dataManagerSingle)
            if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS:
                seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True)
            else:
                seg, img_y = trainerSingle.get_seg_single_img(HP, probs=False, scale_to_world_shape=False, only_prediction=True)
        else:
            seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True)
            if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS:
                seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True)
            else:
                seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=False)

    elif HP.EXPERIMENT_TYPE == "peak_regression":
        dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data)
        trainerSingle = Trainer(model, dataManagerSingle)
        seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True)
        if bundle_specific_threshold:
            seg = ImgUtils.remove_small_peaks_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:], len_thr=0.3)
        else:
            seg = ImgUtils.remove_small_peaks(seg, len_thr=0.3)  # set lower for more sensitivity
        #3 dir for Peaks -> not working (?)
        # seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True)
        # seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True)

    if bundle_specific_threshold and HP.EXPERIMENT_TYPE == "tract_segmentation":
        seg = ImgUtils.probs_to_binary_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:])

    #remove following two lines to keep super resolution
    seg = DatasetUtils.cut_and_scale_img_back_to_original_img(seg, transformation)
    seg = DatasetUtils.add_original_zero_padding_again(seg, bbox, original_shape, HP.NR_OF_CLASSES)
    ExpUtils.print_verbose(HP, "Took {}s".format(round(time.time() - start_time, 2)))

    return seg
Example #48
0
    def train(self, HP):

        if HP.USE_VISLOGGER:
            try:
                from trixi.logger.visdom import PytorchVisdomLogger
            except ImportError:
                pass
            trixi = PytorchVisdomLogger(port=8080, auto_start=True)

        ExpUtils.print_and_save(HP, socket.gethostname())

        epoch_times = []
        nr_of_updates = 0

        metrics = {}
        for type in ["train", "test", "validate"]:
            metrics_new = {
                "loss_" + type: [0],
                "f1_macro_" + type: [0],
            }
            metrics = dict(list(metrics.items()) + list(metrics_new.items()))

        for epoch_nr in range(HP.NUM_EPOCHS):
            start_time = time.time()
            # current_lr = HP.LEARNING_RATE * (HP.LR_DECAY ** epoch_nr)
            # current_lr = HP.LEARNING_RATE

            batch_gen_time = 0
            data_preparation_time = 0
            network_time = 0
            metrics_time = 0
            saving_time = 0
            plotting_time = 0

            batch_nr = {
                "train": 0,
                "test": 0,
                "validate": 0
            }

            if HP.LOSS_WEIGHT_LEN == -1:
                weight_factor = float(HP.LOSS_WEIGHT)
            else:
                if epoch_nr < HP.LOSS_WEIGHT_LEN:
                    # weight_factor = -(9./100.) * epoch_nr + 10.   #ep0: 10 -> linear decrease -> ep100: 1
                    weight_factor = -((HP.LOSS_WEIGHT-1)/float(HP.LOSS_WEIGHT_LEN)) * epoch_nr + float(HP.LOSS_WEIGHT)
                    # weight_factor = -((HP.LOSS_WEIGHT-5)/float(HP.LOSS_WEIGHT_LEN)) * epoch_nr + float(HP.LOSS_WEIGHT)
                else:
                    weight_factor = 1.
                    # weight_factor = 5.

            for type in ["train", "test", "validate"]:
                print_loss = []
                start_time_batch_gen = time.time()

                batch_generator = self.dataManager.get_batches(batch_size=HP.BATCH_SIZE,
                                                               type=type, subjects=getattr(HP, type.upper() + "_SUBJECTS"))
                batch_gen_time = time.time() - start_time_batch_gen
                # print("batch_gen_time: {}s".format(batch_gen_time))

                print("Start looping batches...")
                start_time_batch_part = time.time()
                for batch in batch_generator:                   #getting next batch takes around 0.14s -> second largest Time part after mode!

                    start_time_data_preparation = time.time()
                    batch_nr[type] += 1

                    x = batch["data"] # (bs, nr_of_channels, x, y)
                    y = batch["seg"]  # (bs, nr_of_classes, x, y)
                    # since using new BatchGenerator y is not int anymore but float -> would be good for Pytorch but not Lasagne
                    # y = y.astype(HP.LABELS_TYPE)  #for bundle_peaks regression: is already float -> saves 0.2s/batch if left out

                    data_preparation_time += time.time() - start_time_data_preparation
                    # self.model.learning_rate.set_value(np.float32(current_lr))
                    start_time_network = time.time()
                    if type == "train":
                        nr_of_updates += 1
                        loss, probs, f1 = self.model.train(x, y, weight_factor=weight_factor)    # probs: # (bs, x, y, nrClasses)
                        # loss, probs, f1, intermediate = self.model.train(x, y)
                    elif type == "validate":
                        loss, probs, f1 = self.model.predict(x, y, weight_factor=weight_factor)
                    elif type == "test":
                        loss, probs, f1 = self.model.predict(x, y, weight_factor=weight_factor)
                    network_time += time.time() - start_time_network

                    start_time_metrics = time.time()

                    if HP.CALC_F1:
                        if HP.EXPERIMENT_TYPE == "peak_regression":
                            #Following two lines increase metrics_time by 30s (without < 1s); time per batch increases by 1.5s by these lines
                            # y_flat = y.transpose(0, 2, 3, 1)  # (bs, x, y, nr_of_classes)
                            # y_flat = np.reshape(y_flat, (-1, y_flat.shape[-1]))  # (bs*x*y, nr_of_classes)
                            # metrics = MetricUtils.calculate_metrics(metrics, y_flat, probs, loss, f1=np.mean(f1), type=type, threshold=HP.THRESHOLD,
                            #                                         f1_per_bundle={"CA": f1[5], "FX_left": f1[23], "FX_right": f1[24]})

                            #Numpy
                            # y_right_order = y.transpose(0, 2, 3, 1)  # (bs, x, y, nr_of_classes)
                            # peak_f1 = MetricUtils.calc_peak_dice(HP, probs, y_right_order)
                            # peak_f1_mean = np.array([s for s in peak_f1.values()]).mean()

                            #Pytorch
                            peak_f1_mean = np.array([s for s in list(f1.values())]).mean()  #if f1 for multiple bundles
                            metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=peak_f1_mean, type=type, threshold=HP.THRESHOLD)

                            #Pytorch 2 F1
                            # peak_f1_mean_a = np.array([s for s in f1[0].values()]).mean()
                            # peak_f1_mean_b = np.array([s for s in f1[1].values()]).mean()
                            # metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=peak_f1_mean_a, type=type, threshold=HP.THRESHOLD,
                            #                                         f1_per_bundle={"LenF1": peak_f1_mean_b})

                            #Single Bundle
                            # metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=f1["CST_right"][0], type=type, threshold=HP.THRESHOLD,
                            #                                         f1_per_bundle={"Thr1": f1["CST_right"][1], "Thr2": f1["CST_right"][2]})
                            # metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=f1["CST_right"], type=type, threshold=HP.THRESHOLD)
                        else:
                            metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=np.mean(f1), type=type, threshold=HP.THRESHOLD)

                    else:
                        metrics = MetricUtils.calculate_metrics_onlyLoss(metrics, loss, type=type)

                    metrics_time += time.time() - start_time_metrics

                    print_loss.append(loss)
                    if batch_nr[type] % HP.PRINT_FREQ == 0:
                        time_batch_part = time.time() - start_time_batch_part
                        start_time_batch_part = time.time()
                        ExpUtils.print_and_save(HP, "{} Ep {}, Sp {}, loss {}, t print {}s, t batch {}s".format(type, epoch_nr,
                                                                batch_nr[type] * HP.BATCH_SIZE,
                                                                round(np.array(print_loss).mean(), 6), round(time_batch_part, 3),
                                                                round(time_batch_part / HP.PRINT_FREQ, 3)))
                        print_loss = []

                    if HP.USE_VISLOGGER:
                        ExpUtils.plot_result_trixi(trixi, x, y, probs, loss, f1, epoch_nr)


            ###################################
            # Post Training tasks (each epoch)
            ###################################

            #Adapt LR
            if HP.LR_SCHEDULE:
                self.model.scheduler.step()
                # self.model.scheduler.step(np.mean(f1))
                self.model.print_current_lr()

            # Average loss per batch over entire epoch
            metrics = MetricUtils.normalize_last_element(metrics, batch_nr["train"], type="train")
            metrics = MetricUtils.normalize_last_element(metrics, batch_nr["validate"], type="validate")
            metrics = MetricUtils.normalize_last_element(metrics, batch_nr["test"], type="test")

            print("  Epoch {}, Average Epoch loss = {}".format(epoch_nr, metrics["loss_train"][-1]))
            print("  Epoch {}, nr_of_updates {}".format(epoch_nr, nr_of_updates))

            # Save Weights
            start_time_saving = time.time()
            if HP.SAVE_WEIGHTS:
                self.model.save_model(metrics, epoch_nr)
            saving_time += time.time() - start_time_saving

            # Create Plots
            start_time_plotting = time.time()
            pickle.dump(metrics, open(join(HP.EXP_PATH, "metrics.pkl"), "wb")) # wb -> write (override) and binary (binary only needed on windows, on unix also works without) # for loading: pickle.load(open("metrics.pkl", "rb"))
            ExpUtils.create_exp_plot(metrics, HP.EXP_PATH, HP.EXP_NAME)
            ExpUtils.create_exp_plot(metrics, HP.EXP_PATH, HP.EXP_NAME, without_first_epochs=True)
            plotting_time += time.time() - start_time_plotting

            epoch_time = time.time() - start_time
            epoch_times.append(epoch_time)

            ExpUtils.print_and_save(HP, "  Epoch {}, time total {}s".format(epoch_nr, epoch_time))
            ExpUtils.print_and_save(HP, "  Epoch {}, time UNet: {}s".format(epoch_nr, network_time))
            ExpUtils.print_and_save(HP, "  Epoch {}, time metrics: {}s".format(epoch_nr, metrics_time))
            ExpUtils.print_and_save(HP, "  Epoch {}, time saving files: {}s".format(epoch_nr, saving_time))
            ExpUtils.print_and_save(HP, str(datetime.datetime.now()))

            # Adding next Epoch
            if epoch_nr < HP.NUM_EPOCHS-1:
                metrics = MetricUtils.add_empty_element(metrics)


        ####################################
        # After all epochs
        ###################################
        with open(join(HP.EXP_PATH, "Hyperparameters.txt"), "a") as f:  # a for append
            f.write("\n\n")
            f.write("Average Epoch time: {}s".format(sum(epoch_times) / float(len(epoch_times))))

        return metrics
Example #49
0
    def calc_peak_dice_pytorch(HP, y_pred, y_true, max_angle_error=[0.9]):
        '''
        Calculate angle between groundtruth and prediction and keep the voxels where
        angle is smaller than MAX_ANGLE_ERROR.

        From groundtruth generate a binary mask by selecting all voxels with len > 0.

        Calculate Dice from these 2 masks.

        -> Penalty on peaks outside of tract or if predicted peak=0
        -> no penalty on very very small with right direction -> bad
        => Peak_dice can be high even if peaks inside of tract almost missing (almost 0)

        :param y_pred:
        :param y_true:
        :param max_angle_error:  0.7 ->  angle error of 45° or less; 0.9 ->  angle error of 23° or less
                                 Can be list with several values -> calculate for several thresholds
        :return:
        '''
        import torch
        from tractseg.libs.PytorchEinsum import einsum
        from tractseg.libs.PytorchUtils import PytorchUtils

        y_true = y_true.permute(0, 2, 3, 1)
        y_pred = y_pred.permute(0, 2, 3, 1)

        def angle_last_dim(a, b):
            '''
            Calculate the angle between two nd-arrays (array of vectors) along the last dimension

            without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90°
            np.arccos -> returns degree in pi (90°: 0.5*pi)

            return: one dimension less then input
            '''
            return torch.abs(
                einsum('abcd,abcd->abc', a, b) /
                (torch.norm(a, 2., -1) * torch.norm(b, 2, -1) + 1e-7))

        #Single threshold
        if len(max_angle_error) == 1:
            score_per_bundle = {}
            bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
            for idx, bundle in enumerate(bundles):
                # if bundle == "CST_right":
                y_pred_bund = y_pred[:, :, :,
                                     (idx * 3):(idx * 3) + 3].contiguous()
                y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) +
                                     3].contiguous()  # [x,y,z,3]

                angles = angle_last_dim(y_pred_bund, y_true_bund)
                gt_binary = y_true_bund.sum(dim=-1) > 0
                gt_binary = gt_binary.view(-1)  # [bs*x*y]

                angles_binary = angles > max_angle_error[0]
                angles_binary = angles_binary.view(-1)

                f1 = PytorchUtils.f1_score_binary(gt_binary, angles_binary)
                score_per_bundle[bundle] = f1

            return score_per_bundle

        #multiple thresholds
        else:
            score_per_bundle = {}
            bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
            for idx, bundle in enumerate(bundles):
                # if bundle == "CST_right":
                y_pred_bund = y_pred[:, :, :,
                                     (idx * 3):(idx * 3) + 3].contiguous()
                y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) +
                                     3].contiguous()  # [x,y,z,3]

                angles = angle_last_dim(y_pred_bund, y_true_bund)
                gt_binary = y_true_bund.sum(dim=-1) > 0
                gt_binary = gt_binary.view(-1)  # [bs*x*y]

                score_per_bundle[bundle] = []
                for threshold in max_angle_error:
                    angles_binary = angles > threshold
                    angles_binary = angles_binary.view(-1)

                    f1 = PytorchUtils.f1_score_binary(gt_binary, angles_binary)
                    score_per_bundle[bundle].append(f1)

            return score_per_bundle
Example #50
0
    def generate_train_batch(self):
        subjects = self._data[0]
        subject_idx = int(
            random.uniform(0, len(subjects))
        )  # len(subjects)-1 not needed because int always rounds to floor

        for i in range(20):
            try:
                if np.random.random() < 0.5:
                    data = nib.load(
                        join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                             subjects[subject_idx],
                             "270g_125mm_peaks.nii.gz")).get_data()
                else:
                    data = nib.load(
                        join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                             subjects[subject_idx],
                             "90g_125mm_peaks.nii.gz")).get_data()

                # rnd_choice = np.random.random()
                # if rnd_choice < 0.33:
                #     data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data()
                # elif rnd_choice < 0.66:
                #     data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data()
                # else:
                #     data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "12g_125mm_peaks.nii.gz")).get_data()

                seg = nib.load(
                    join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                         subjects[subject_idx],
                         self.HP.LABELS_FILENAME + ".nii.gz")).get_data()
                break
            except IOError:
                ExpUtils.print_and_save(
                    self.HP,
                    "\n\nWARNING: Could not load file. Trying again in 20s (Try number: "
                    + str(i) + ").\n\n")
            ExpUtils.print_and_save(self.HP, "Sleeping 20s")
            sleep(20)
        # ExpUtils.print_and_save(self.HP, "Successfully loaded input.")

        data = np.nan_to_num(data)  # Needed otherwise not working
        seg = np.nan_to_num(seg)

        data = DatasetUtils.scale_input_to_unet_shape(
            data, self.HP.DATASET, self.HP.RESOLUTION)  # (x, y, z, channels)
        if self.HP.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]:
            # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask to the lower resolution
            seg = DatasetUtils.scale_input_to_unet_shape(
                seg, "HCP", self.HP.RESOLUTION)
        else:
            seg = DatasetUtils.scale_input_to_unet_shape(
                seg, self.HP.DATASET, self.HP.RESOLUTION)  # (x, y, z, classes)

        slice_idxs = np.random.choice(data.shape[0], self.BATCH_SIZE, False,
                                      None)

        # Randomly sample slice orientation
        slice_direction = int(round(random.uniform(0, 2)))

        if slice_direction == 0:
            y = seg[slice_idxs, :, :].astype(self.HP.LABELS_TYPE)
            y = np.array(y).transpose(
                0, 3, 1, 2
            )  # nr_classes channel has to be before with and height for DataAugmentation (bs, nr_of_classes, x, y)
        elif slice_direction == 1:
            y = seg[:, slice_idxs, :].astype(self.HP.LABELS_TYPE)
            y = np.array(y).transpose(1, 3, 0, 2)
        elif slice_direction == 2:
            y = seg[:, :, slice_idxs].astype(self.HP.LABELS_TYPE)
            y = np.array(y).transpose(2, 3, 0, 1)

        sw = 5  #slice_window (only odd numbers allowed)
        pad = int((sw - 1) / 2)

        data_pad = np.zeros(
            (data.shape[0] + sw - 1, data.shape[1] + sw - 1,
             data.shape[2] + sw - 1, data.shape[3])).astype(data.dtype)
        data_pad[
            pad:-pad, pad:-pad,
            pad:-pad, :] = data  #padded with two slices of zeros on all sides
        batch = []
        for s_idx in slice_idxs:
            if slice_direction == 0:
                #(s_idx+2)-2:(s_idx+2)+3 = s_idx:s_idx+5
                x = data_pad[s_idx:s_idx + sw:, pad:-pad, pad:-pad, :].astype(
                    np.float32)  # (5, y, z, channels)
                x = np.array(x).transpose(
                    0, 3, 1, 2
                )  # channels dim has to be before width and height for Unet (but after batches)
                x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2],
                                   x.shape[3]))  # (5*channels, y, z)
                batch.append(x)
            elif slice_direction == 1:
                x = data_pad[pad:-pad, s_idx:s_idx + sw, pad:-pad, :].astype(
                    np.float32)  # (5, y, z, channels)
                x = np.array(x).transpose(
                    1, 3, 0, 2
                )  # channels dim has to be before width and height for Unet (but after batches)
                x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2],
                                   x.shape[3]))  # (5*channels, y, z)
                batch.append(x)
            elif slice_direction == 2:
                x = data_pad[pad:-pad, pad:-pad, s_idx:s_idx + sw, :].astype(
                    np.float32)  # (5, y, z, channels)
                x = np.array(x).transpose(
                    2, 3, 0, 1
                )  # channels dim has to be before width and height for Unet (but after batches)
                x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2],
                                   x.shape[3]))  # (5*channels, y, z)
                batch.append(x)
        data_dict = {
            "data": np.array(batch),  # (batch_size, channels, x, y, [z])
            "seg": y
        }  # (batch_size, channels, x, y, [z])

        return data_dict
Example #51
0
    def create_network(self):
        # torch.backends.cudnn.benchmark = True     #not faster

        def train(X, y):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda()), Variable(y.cuda(
                ))  # X: (bs, features, x, y)   y: (bs, classes, x, y)
            else:
                X, y = Variable(X), Variable(y)
            optimizer.zero_grad()
            net.train()
            outputs = net(X)  # forward     # outputs: (bs, classes, x, y)
            loss = criterion(outputs, y)
            # loss = PytorchUtils.soft_dice(outputs, y)
            loss.backward()  # backward
            optimizer.step()  # optimise
            f1 = PytorchUtils.f1_score_macro(y.data,
                                             outputs.data,
                                             per_class=True)

            if self.HP.USE_VISLOGGER:
                probs = outputs.data.cpu().numpy().transpose(
                    0, 2, 3, 1)  # (bs, x, y, classes)
            else:
                probs = None  #faster

            return loss.data[0], probs, f1

        def test(X, y):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda(),
                                volatile=True), Variable(y.cuda(),
                                                         volatile=True)
            else:
                X, y = Variable(X, volatile=True), Variable(y, volatile=True)
            net.train(False)
            outputs = net(X)  # forward
            loss = criterion(outputs, y)
            # loss = PytorchUtils.soft_dice(outputs, y)
            f1 = PytorchUtils.f1_score_macro(y.data,
                                             outputs.data,
                                             per_class=True)
            # probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            probs = None  # faster
            return loss.data[0], probs, f1

        def predict(X):
            X = torch.from_numpy(X.astype(np.float32))
            if torch.cuda.is_available():
                X = Variable(X.cuda(), volatile=True)
            else:
                X = Variable(X, volatile=True)
            net.train(False)
            outputs = net(X)  # forward
            probs = outputs.data.cpu().numpy().transpose(
                0, 2, 3, 1)  # (bs, x, y, classes)
            return probs

        def save_model(metrics, epoch_nr):
            max_f1_idx = np.argmax(metrics["f1_macro_validate"])
            max_f1 = np.max(metrics["f1_macro_validate"])
            if epoch_nr == max_f1_idx and max_f1 > 0.01:  # saving to network drives takes 5s (to local only 0.5s) -> do not save so often
                print("  Saving weights...")
                for fl in glob.glob(join(self.HP.EXP_PATH, "best_weights_ep*")
                                    ):  # remove weights from previous epochs
                    os.remove(fl)
                try:
                    #Actually is a pkl not a npz
                    PytorchUtils.save_checkpoint(join(
                        self.HP.EXP_PATH,
                        "best_weights_ep" + str(epoch_nr) + ".npz"),
                                                 unet=net)
                except IOError:
                    print(
                        "\nERROR: Could not save weights because of IO Error\n"
                    )
                self.HP.BEST_EPOCH = epoch_nr

        def load_model(path):
            PytorchUtils.load_checkpoint(path, unet=net)

        def print_current_lr():
            for param_group in optimizer.param_groups:
                ExpUtils.print_and_save(
                    self.HP,
                    "current learning rate: {}".format(param_group['lr']))

        if self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "single_direction":
            NR_OF_GRADIENTS = 9
            # NR_OF_GRADIENTS = 9 * 5
            # NR_OF_GRADIENTS = 9 * 9
            # NR_OF_GRADIENTS = 33
        elif self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "combined":
            NR_OF_GRADIENTS = 3 * self.HP.NR_OF_CLASSES
        else:
            NR_OF_GRADIENTS = 33

        if torch.cuda.is_available():
            net = UNet(n_input_channels=NR_OF_GRADIENTS,
                       n_classes=self.HP.NR_OF_CLASSES,
                       n_filt=self.HP.UNET_NR_FILT).cuda()
        else:
            net = UNet(n_input_channels=NR_OF_GRADIENTS,
                       n_classes=self.HP.NR_OF_CLASSES,
                       n_filt=self.HP.UNET_NR_FILT)

        #Initialisation from U-Net Paper
        def weights_init(m):
            classname = m.__class__.__name__
            # Do not use with batchnorm -> has to be adapted for batchnorm
            if classname.find('Conv') != -1:
                N = m.in_channels * m.kernel_size[0] * m.kernel_size[0]
                std = math.sqrt(2. / N)
                m.weight.data.normal_(0.0, std)

        net.apply(weights_init)

        # net = nn.DataParallel(net, device_ids=[0,1])

        if self.HP.TRAIN:
            ExpUtils.print_and_save(self.HP, str(net), only_log=True)

        criterion = nn.BCEWithLogitsLoss()
        optimizer = Adamax(net.parameters(), lr=self.HP.LEARNING_RATE)
        # optimizer = Adam(net.parameters(), lr=self.HP.LEARNING_RATE)  #very slow (half speed of Adamax) -> strange
        # scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)
        # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="max")

        if self.HP.LOAD_WEIGHTS:
            ExpUtils.print_verbose(
                self.HP, "Loading weights ... ({})".format(
                    join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH)))
            load_model(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH))

        self.train = train
        self.predict = test
        self.get_probs = predict
        self.save_model = save_model
        self.load_model = load_model
        self.print_current_lr = print_current_lr