def save_multilabel_img_as_multiple_files_endings(HP, img, affine, path): bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:] for idx, bundle in enumerate(bundles): img_seg = nib.Nifti1Image(img[:, :, :, idx], affine) ExpUtils.make_dir(join(path, "endings_segmentations")) nib.save(img_seg, join(path, "endings_segmentations", bundle + ".nii.gz"))
def load_model(path): ExpUtils.print_verbose(self.HP, "Loading weights ... ({})".format(path)) with np.load( path ) as f: #if both pathes are absolute and beginning of pathes are the same, join will merge the beginning param_values = [f['arr_%d' % i] for i in range(len(f.files))] L.layers.set_all_param_values(output_layer_for_loss, param_values)
def save_fusion_nifti_as_npy(): #Can leave this always the same (for 270g and 32g) class HP: DATASET = "HCP" RESOLUTION = "1.25mm" FEATURES_FILENAME = "270g_125mm_peaks" LABELS_TYPE = np.int16 LABELS_FILENAME = "bundle_masks" DATASET_FOLDER = "HCP" #change this for 270g and 32g DIFFUSION_FOLDER = "32g_25mm" subjects = get_all_subjects() # fold0 = ['687163', '685058', '683256', '680957', '679568', '677968', '673455', '672756', '665254', '654754', '645551', '644044', '638049', '627549', '623844', '622236', '620434', '613538', '601127', '599671', '599469'] # fold1 = ['992774', '991267', '987983', '984472', '983773', '979984', '978578', '965771', '965367', '959574', '958976', '957974', '951457', '932554', '930449', '922854', '917255', '912447', '910241', '907656', '904044'] # fold2 = ['901442', '901139', '901038', '899885', '898176', '896879', '896778', '894673', '889579', '887373', '877269', '877168', '872764', '872158', '871964', '871762', '865363', '861456', '859671', '857263', '856766'] # fold3 = ['849971', '845458', '837964', '837560', '833249', '833148', '826454', '826353', '816653', '814649', '802844', '792766', '792564', '789373', '786569', '784565', '782561', '779370', '771354', '770352', '765056'] # fold4 = ['761957', '759869', '756055', '753251', '751348', '749361', '748662', '748258', '742549', '734045', '732243', '729557', '729254', '715647', '715041', '709551', '705341', '704238', '702133', '695768', '690152'] # subjects = fold2 + fold3 + fold4 # subjects = ['654754', '645551', '644044', '638049', '627549', '623844', '622236', '620434', '613538', '601127', '599671', '599469'] print("\n\nProcessing Data...") for s in subjects: print("processing data subject {}".format(s)) start_time = time.time() data = nib.load( join(C.NETWORK_DRIVE, "HCP_fusion_" + DIFFUSION_FOLDER, s + "_probmap.nii.gz")).get_data() print("Done Loading") data = np.nan_to_num(data) data = DatasetUtils.scale_input_to_unet_shape( data, HP.DATASET, HP.RESOLUTION) data = data[:-1, :, : -1, :] # cut one pixel at the end, because in scale_input_to_world_shape we ouputted 146 -> one too much at the end ExpUtils.make_dir( join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s)) np.save( join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s, DIFFUSION_FOLDER + "_xyz.npy"), data) print("Took {}s".format(time.time() - start_time)) print("processing seg subject {}".format(s)) start_time = time.time() # seg = ImgUtils.create_multilabel_mask(HP, s, labels_type=HP.LABELS_TYPE) seg = nib.load( join(C.NETWORK_DRIVE, "HCP_for_training_COPY", s, HP.LABELS_FILENAME + ".nii.gz")).get_data() if HP.RESOLUTION == "2.5mm": seg = ImgUtils.resize_first_three_dims(seg, order=0, zoom=0.5) seg = DatasetUtils.scale_input_to_unet_shape( seg, HP.DATASET, HP.RESOLUTION) np.save( join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s, "bundle_masks.npy"), seg) print("Took {}s".format(time.time() - start_time))
def precompute_batches(custom_type=None): ''' 9000 slices per epoch -> 200 batches (batchsize=44) per epoch => 200-1000 batches needed 270g_125mm_bundle_peaks_Y: no DAug, no Norm, only Y All_sizes_DAug_XYZ: 12g, 90g, 270g, DAug (no rotation, no elastic deform), Norm, XYZ 270g_125mm_bundle_peaks_XYZ: no DAug, Norm, XYZ ''' class HP: NORMALIZE_DATA = True DATA_AUGMENTATION = False CV_FOLD = 0 INPUT_DIM = (144, 144) BATCH_SIZE = 44 DATASET_FOLDER = "HCP" TYPE = "single_direction" EXP_PATH = "~" LABELS_FILENAME = "bundle_peaks" FEATURES_FILENAME = "270g_125mm_peaks" DATASET = "HCP" RESOLUTION = "1.25mm" LABELS_TYPE = np.float32 HP.TRAIN_SUBJECTS, HP.VALIDATE_SUBJECTS, HP.TEST_SUBJECTS = ExpUtils.get_cv_fold(HP.CV_FOLD) num_batches_base = 5000 num_batches = { "train": num_batches_base, "validate": int(num_batches_base / 3.), "test": int(num_batches_base / 3.), } if custom_type is None: types = ["train", "validate", "test"] else: types = [custom_type] for type in types: dataManager = DataManagerTrainingNiftiImgs(HP) batch_gen = dataManager.get_batches(batch_size=HP.BATCH_SIZE, type=type, subjects=getattr(HP, type.upper() + "_SUBJECTS"), num_batches=num_batches[type]) for idx, batch in enumerate(batch_gen): print("Processing: {}".format(idx)) # DATASET_DIR = "HCP_batches/270g_125mm_bundle_peaks_Y" # DATASET_DIR = "HCP_batches/All_sizes_DAug_XYZ" DATASET_DIR = "HCP_batches/270g_125mm_bundle_peaks_XYZ" ExpUtils.make_dir(join(C.HOME, DATASET_DIR, type)) data = nib.Nifti1Image(batch["data"], ImgUtils.get_dwi_affine(HP.DATASET, HP.RESOLUTION)) nib.save(data, join(C.HOME, DATASET_DIR, type, "batch_" + str(idx) + "_data.nii.gz")) seg = nib.Nifti1Image(batch["seg"], ImgUtils.get_dwi_affine(HP.DATASET, HP.RESOLUTION)) nib.save(seg, join(C.HOME, DATASET_DIR, type, "batch_" + str(idx) + "_seg.nii.gz"))
def save_multilabel_img_as_multiple_files_peaks(HP, img, affine, path): bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:] for idx, bundle in enumerate(bundles): data = img[:, :, :, (idx * 3):(idx * 3) + 3] if HP.FLIP_OUTPUT_PEAKS: data[:, :, :, 2] *= -1 # flip z Axis for correct view in MITK filename = bundle + "_f.nii.gz" else: filename = bundle + ".nii.gz" img_seg = nib.Nifti1Image(data, affine) ExpUtils.make_dir(join(path, "TOM")) nib.save(img_seg, join(path, "TOM", filename))
def save_multilabel_img_as_multiple_files_peaks(HP, img, affine, path): bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:] for idx, bundle in enumerate(bundles): data = img[:, :, :, (idx*3):(idx*3)+3] if HP.FLIP_OUTPUT_PEAKS: data[:, :, :, 2] *= -1 # flip z Axis for correct view in MITK filename = bundle + "_f.nii.gz" else: filename = bundle + ".nii.gz" img_seg = nib.Nifti1Image(data, affine) ExpUtils.make_dir(join(path, "TOM")) nib.save(img_seg, join(path, "TOM", filename))
def calc_peak_dice_onlySeg(HP, y_pred, y_true): ''' Create binary mask of peaks by simple thresholding. Then calculate Dice. :param y_pred: :param y_true: :return: ''' score_per_bundle = {} bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:] for idx, bundle in enumerate(bundles): y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3] y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3] # [x,y,z,3] # 0.1 -> keep some outliers, but also some holes already; 0.2 also ok (still looks like e.g. CST) # Resulting dice for 0.1 and 0.2 very similar y_pred_binary = np.abs(y_pred_bund).sum(axis=-1) > 0.2 y_true_binary = np.abs(y_true_bund).sum(axis=-1) > 1e-3 f1 = f1_score(y_true_binary.flatten(), y_pred_binary.flatten(), average="binary") score_per_bundle[bundle] = f1 return score_per_bundle
def create_multilabel_mask(HP, subject, labels_type=np.int16, dataset_folder="HCP", labels_folder="bundle_masks"): ''' One-hot encoding of all bundles in one big image :param subject: :return: image of shape (x, y, z, nr_of_bundles + 1) ''' bundles = ExpUtils.get_bundle_names(HP.CLASSES) #Masks sind immer HCP_highRes (später erst downsample) mask_ml = np.zeros((145, 174, 145, len(bundles))) background = np.ones( (145, 174, 145)) # everything that contains no bundle for idx, bundle in enumerate( bundles[1:] ): #first bundle is background -> already considered by setting np.ones in the beginning mask = nib.load( join(C.HOME, dataset_folder, subject, labels_folder, bundle + ".nii.gz")) mask_data = mask.get_data() # dtype: uint8 mask_ml[:, :, :, idx + 1] = mask_data background[mask_data == 1] = 0 # remove this bundle from background mask_ml[:, :, :, 0] = background return mask_ml.astype(labels_type)
def save_fusion_nifti_as_npy(): #Can leave this always the same (for 270g and 32g) class HP: DATASET = "HCP" RESOLUTION = "1.25mm" FEATURES_FILENAME = "270g_125mm_peaks" LABELS_TYPE = np.int16 LABELS_FILENAME = "bundle_masks" DATASET_FOLDER = "HCP" #change this for 270g and 32g DIFFUSION_FOLDER = "32g_25mm" subjects = get_all_subjects() # fold0 = ['687163', '685058', '683256', '680957', '679568', '677968', '673455', '672756', '665254', '654754', '645551', '644044', '638049', '627549', '623844', '622236', '620434', '613538', '601127', '599671', '599469'] # fold1 = ['992774', '991267', '987983', '984472', '983773', '979984', '978578', '965771', '965367', '959574', '958976', '957974', '951457', '932554', '930449', '922854', '917255', '912447', '910241', '907656', '904044'] # fold2 = ['901442', '901139', '901038', '899885', '898176', '896879', '896778', '894673', '889579', '887373', '877269', '877168', '872764', '872158', '871964', '871762', '865363', '861456', '859671', '857263', '856766'] # fold3 = ['849971', '845458', '837964', '837560', '833249', '833148', '826454', '826353', '816653', '814649', '802844', '792766', '792564', '789373', '786569', '784565', '782561', '779370', '771354', '770352', '765056'] # fold4 = ['761957', '759869', '756055', '753251', '751348', '749361', '748662', '748258', '742549', '734045', '732243', '729557', '729254', '715647', '715041', '709551', '705341', '704238', '702133', '695768', '690152'] # subjects = fold2 + fold3 + fold4 # subjects = ['654754', '645551', '644044', '638049', '627549', '623844', '622236', '620434', '613538', '601127', '599671', '599469'] print("\n\nProcessing Data...") for s in subjects: print("processing data subject {}".format(s)) start_time = time.time() data = nib.load(join(C.NETWORK_DRIVE, "HCP_fusion_" + DIFFUSION_FOLDER, s + "_probmap.nii.gz")).get_data() print("Done Loading") data = np.nan_to_num(data) data = DatasetUtils.scale_input_to_unet_shape(data, HP.DATASET, HP.RESOLUTION) data = data[:-1, :, :-1, :] # cut one pixel at the end, because in scale_input_to_world_shape we ouputted 146 -> one too much at the end ExpUtils.make_dir(join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s)) np.save(join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s, DIFFUSION_FOLDER + "_xyz.npy"), data) print("Took {}s".format(time.time() - start_time)) print("processing seg subject {}".format(s)) start_time = time.time() # seg = ImgUtils.create_multilabel_mask(HP, s, labels_type=HP.LABELS_TYPE) seg = nib.load(join(C.NETWORK_DRIVE, "HCP_for_training_COPY", s, HP.LABELS_FILENAME + ".nii.gz")).get_data() if HP.RESOLUTION == "2.5mm": seg = ImgUtils.resize_first_three_dims(seg, order=0, zoom=0.5) seg = DatasetUtils.scale_input_to_unet_shape(seg, HP.DATASET, HP.RESOLUTION) np.save(join(C.NETWORK_DRIVE, "HCP_fusion_npy_" + DIFFUSION_FOLDER, s, "bundle_masks.npy"), seg) print("Took {}s".format(time.time() - start_time))
def copy_training_files_to_ssd(HP, data_path): def id_generator(size=6, chars=string.ascii_uppercase + string.digits): return ''.join(random.choice(chars) for _ in range(size)) target_data_path = join("/ssd/", "tmp_" + id_generator(), HP.DATASET_FOLDER) ExpUtils.make_dir(join(target_data_path)) #get all folders in data_path directory subjects = [os.path.basename(os.path.normpath(d)) for d in glob(data_path + "/*/")] for subject in subjects: src = join(data_path, subject, HP.FEATURES_FILENAME) target = join(target_data_path, subject, HP.FEATURES_FILENAME) print("cp: {} -> {}".format(src, target)) # shutil.copyfile(src, target) return target_data_path
def soft_dice_paul(HP, idxs, marker, preds, ys): n_classes = len(ExpUtils.get_bundle_names(HP.CLASSES)) dice = T.constant(0) for cl in range(n_classes): pred = preds[marker, cl, :, :] y = ys[marker, cl, :, :] intersect = T.sum(pred * y) denominator = T.sum(pred) + T.sum(y) dice += T.constant(2) * intersect / (denominator + T.constant(1e-6)) return 1 - (dice / n_classes)
def calc_peak_length_dice_pytorch(HP, y_pred, y_true, max_angle_error=[0.9], max_length_error=0.1): ''' Ca :param y_pred: :param y_true: :param max_angle_error: 0.7 -> angle error of 45° or less; 0.9 -> angle error of 23° or less Can be list with several values -> calculate for several thresholds :return: ''' import torch from tractseg.libs.PytorchEinsum import einsum from tractseg.libs.PytorchUtils import PytorchUtils y_true = y_true.permute(0, 2, 3, 1) y_pred = y_pred.permute(0, 2, 3, 1) def angle_last_dim(a, b): ''' Calculate the angle between two nd-arrays (array of vectors) along the last dimension without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90° np.arccos -> returns degree in pi (90°: 0.5*pi) return: one dimension less then input ''' return torch.abs(einsum('abcd,abcd->abc', a, b) / (torch.norm(a, 2., -1) * torch.norm(b, 2, -1) + 1e-7)) #Single threshold score_per_bundle = {} bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:] for idx, bundle in enumerate(bundles): # if bundle == "CST_right": y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous() y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3].contiguous() # [x,y,z,3] angles = angle_last_dim(y_pred_bund, y_true_bund) lenghts_pred = torch.norm(y_pred_bund, 2., -1) lengths_true = torch.norm(y_true_bund, 2, -1) lengths_binary = torch.abs(lenghts_pred-lengths_true) < (max_length_error * lengths_true) lengths_binary = lengths_binary.view(-1) gt_binary = y_true_bund.sum(dim=-1) > 0 gt_binary = gt_binary.view(-1) # [bs*x*y] angles_binary = angles > max_angle_error[0] angles_binary = angles_binary.view(-1) combined = lengths_binary * angles_binary f1 = PytorchUtils.f1_score_binary(gt_binary, combined) score_per_bundle[bundle] = f1 return score_per_bundle
def copy_training_files_to_ssd(HP, data_path): def id_generator(size=6, chars=string.ascii_uppercase + string.digits): return ''.join(random.choice(chars) for _ in range(size)) target_data_path = join("/ssd/", "tmp_" + id_generator(), HP.DATASET_FOLDER) ExpUtils.make_dir(join(target_data_path)) #get all folders in data_path directory subjects = [ os.path.basename(os.path.normpath(d)) for d in glob(data_path + "/*/") ] for subject in subjects: src = join(data_path, subject, HP.FEATURES_FILENAME) target = join(target_data_path, subject, HP.FEATURES_FILENAME) print("cp: {} -> {}".format(src, target)) # shutil.copyfile(src, target) return target_data_path
def calc_peak_length_dice(HP, y_pred, y_true, max_angle_error=[0.9], max_length_error=0.1): ''' :param y_pred: :param y_true: :param max_angle_error: 0.7 -> angle error of 45° or less; 0.9 -> angle error of 23° or less :return: ''' def angle_last_dim(a, b): ''' Calculate the angle between two nd-arrays (array of vectors) along the last dimension without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90° np.arccos -> returns degree in pi (90°: 0.5*pi) return: one dimension less then input ''' # print(np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1)) return abs( np.einsum('...i,...i', a, b) / (np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1) + 1e-7)) score_per_bundle = {} bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:] for idx, bundle in enumerate(bundles): y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3] y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3] # [x,y,z,3] angles = angle_last_dim(y_pred_bund, y_true_bund) lenghts_pred = np.linalg.norm(y_pred_bund, axis=-1) lengths_true = np.linalg.norm(y_true_bund, axis=-1) lengths_binary = abs(lenghts_pred - lengths_true) < ( max_length_error * lengths_true) lengths_binary = lengths_binary.flatten() gt_binary = y_true_bund.sum(axis=-1) > 0 gt_binary = gt_binary.flatten() # [bs*x*y] angles_binary = angles > max_angle_error[0] angles_binary = angles_binary.flatten() combined = lengths_binary * angles_binary f1 = MetricUtils.my_f1_score(gt_binary, combined) score_per_bundle[bundle] = f1 return score_per_bundle
def save_multilabel_img_as_multiple_files_endings_OLD(HP, img, affine, path, multilabel=True): ''' multilabel True: save as 1 and 2 without fourth dimension multilabel False: save with beginnings and endings combined ''' # bundles = ExpUtils.get_bundle_names("20")[1:] bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:] for idx, bundle in enumerate(bundles): data = img[:, :, :, (idx * 2):(idx * 2) + 2] > 0 multilabel_img = np.zeros(data.shape[:3]) if multilabel: multilabel_img[data[:, :, :, 0]] = 1 multilabel_img[data[:, :, :, 1]] = 2 else: multilabel_img[data[:, :, :, 0]] = 1 multilabel_img[data[:, :, :, 1]] = 1 img_seg = nib.Nifti1Image(multilabel_img, affine) ExpUtils.make_dir(join(path, "endings")) nib.save(img_seg, join(path, "endings", bundle + ".nii.gz"))
def theano_f1_score_OLD(HP, idxs, marker, preds, ys): ''' Von Paul ''' n_classes = len(ExpUtils.get_bundle_names(HP.CLASSES)) dice = T.constant(0) for cl in range(n_classes): pred = preds[marker, cl, :, :] y = ys[marker, cl, :, :] pred = T.gt(pred, T.constant(0.5)) intersect = T.sum(pred * y) denominator = T.sum(pred) + T.sum(y) dice += T.constant(2) * intersect / (denominator + T.constant(1e-6)) return dice / n_classes
def calc_peak_dice(HP, y_pred, y_true, max_angle_error=[0.9]): ''' :param y_pred: :param y_true: :param max_angle_error: 0.7 -> angle error of 45° or less; 0.9 -> angle error of 23° or less :return: ''' def angle(a, b): ''' Calculate the angle between two 1d-arrays (2 vectors) along the last dimension without anything further: 1->0°, 0.7->45°, 0->90° np.arccos -> returns degree in pi (90°: 0.5*pi) ''' return abs(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))) def angle_last_dim(a, b): ''' Calculate the angle between two nd-arrays (array of vectors) along the last dimension without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90° np.arccos -> returns degree in pi (90°: 0.5*pi) return: one dimension less then input ''' # print(np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1)) return abs( np.einsum('...i,...i', a, b) / (np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1) + 1e-7)) score_per_bundle = {} bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:] for idx, bundle in enumerate(bundles): y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3] y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3] # [x,y,z,3] angles = angle_last_dim(y_pred_bund, y_true_bund) angles_binary = angles > max_angle_error[0] gt_binary = y_true_bund.sum(axis=-1) > 0 f1 = f1_score(gt_binary.flatten(), angles_binary.flatten(), average="binary") score_per_bundle[bundle] = f1 return score_per_bundle
def calc_peak_dice(HP, y_pred, y_true, max_angle_error=[0.9]): ''' :param y_pred: :param y_true: :param max_angle_error: 0.7 -> angle error of 45° or less; 0.9 -> angle error of 23° or less :return: ''' def angle(a, b): ''' Calculate the angle between two 1d-arrays (2 vectors) along the last dimension without anything further: 1->0°, 0.7->45°, 0->90° np.arccos -> returns degree in pi (90°: 0.5*pi) ''' return abs(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))) def angle_last_dim(a, b): ''' Calculate the angle between two nd-arrays (array of vectors) along the last dimension without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90° np.arccos -> returns degree in pi (90°: 0.5*pi) return: one dimension less then input ''' # print(np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1)) return abs(np.einsum('...i,...i', a, b) / (np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1) + 1e-7)) score_per_bundle = {} bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:] for idx, bundle in enumerate(bundles): y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3] y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3] # [x,y,z,3] angles = angle_last_dim(y_pred_bund, y_true_bund) angles_binary = angles > max_angle_error[0] gt_binary = y_true_bund.sum(axis=-1) > 0 f1 = f1_score(gt_binary.flatten(), angles_binary.flatten(), average="binary") score_per_bundle[bundle] = f1 return score_per_bundle
def calc_peak_length_dice(HP, y_pred, y_true, max_angle_error=[0.9], max_length_error=0.1): ''' :param y_pred: :param y_true: :param max_angle_error: 0.7 -> angle error of 45° or less; 0.9 -> angle error of 23° or less :return: ''' def angle_last_dim(a, b): ''' Calculate the angle between two nd-arrays (array of vectors) along the last dimension without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90° np.arccos -> returns degree in pi (90°: 0.5*pi) return: one dimension less then input ''' # print(np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1)) return abs(np.einsum('...i,...i', a, b) / (np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1) + 1e-7)) score_per_bundle = {} bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:] for idx, bundle in enumerate(bundles): y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3] y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3] # [x,y,z,3] angles = angle_last_dim(y_pred_bund, y_true_bund) lenghts_pred = np.linalg.norm(y_pred_bund, axis=-1) lengths_true = np.linalg.norm(y_true_bund, axis=-1) lengths_binary = abs(lenghts_pred - lengths_true) < (max_length_error * lengths_true) lengths_binary = lengths_binary.flatten() gt_binary = y_true_bund.sum(axis=-1) > 0 gt_binary = gt_binary.flatten() # [bs*x*y] angles_binary = angles > max_angle_error[0] angles_binary = angles_binary.flatten() combined = lengths_binary * angles_binary f1 = MetricUtils.my_f1_score(gt_binary, combined) score_per_bundle[bundle] = f1 return score_per_bundle
def create_multilabel_mask(HP, subject, labels_type=np.int16, dataset_folder="HCP", labels_folder="bundle_masks"): ''' One-hot encoding of all bundles in one big image :param subject: :return: image of shape (x, y, z, nr_of_bundles + 1) ''' bundles = ExpUtils.get_bundle_names(HP.CLASSES) #Masks sind immer HCP_highRes (später erst downsample) mask_ml = np.zeros((145, 174, 145, len(bundles))) background = np.ones((145, 174, 145)) # everything that contains no bundle for idx, bundle in enumerate(bundles[1:]): #first bundle is background -> already considered by setting np.ones in the beginning mask = nib.load(join(C.HOME, dataset_folder, subject, labels_folder, bundle + ".nii.gz")) mask_data = mask.get_data() # dtype: uint8 mask_ml[:, :, :, idx+1] = mask_data background[mask_data == 1] = 0 # remove this bundle from background mask_ml[:, :, :, 0] = background return mask_ml.astype(labels_type)
def precompute_batches(custom_type=None): ''' 9000 slices per epoch -> 200 batches (batchsize=44) per epoch => 200-1000 batches needed 270g_125mm_bundle_peaks_Y: no DAug, no Norm, only Y All_sizes_DAug_XYZ: 12g, 90g, 270g, DAug (no rotation, no elastic deform), Norm, XYZ 270g_125mm_bundle_peaks_XYZ: no DAug, Norm, XYZ ''' class HP: NORMALIZE_DATA = True DATA_AUGMENTATION = False CV_FOLD = 0 INPUT_DIM = (144, 144) BATCH_SIZE = 44 DATASET_FOLDER = "HCP" TYPE = "single_direction" EXP_PATH = "~" LABELS_FILENAME = "bundle_peaks" FEATURES_FILENAME = "270g_125mm_peaks" DATASET = "HCP" RESOLUTION = "1.25mm" LABELS_TYPE = np.float32 HP.TRAIN_SUBJECTS, HP.VALIDATE_SUBJECTS, HP.TEST_SUBJECTS = ExpUtils.get_cv_fold( HP.CV_FOLD) num_batches_base = 5000 num_batches = { "train": num_batches_base, "validate": int(num_batches_base / 3.), "test": int(num_batches_base / 3.), } if custom_type is None: types = ["train", "validate", "test"] else: types = [custom_type] for type in types: dataManager = DataManagerTrainingNiftiImgs(HP) batch_gen = dataManager.get_batches( batch_size=HP.BATCH_SIZE, type=type, subjects=getattr(HP, type.upper() + "_SUBJECTS"), num_batches=num_batches[type]) for idx, batch in enumerate(batch_gen): print("Processing: {}".format(idx)) # DATASET_DIR = "HCP_batches/270g_125mm_bundle_peaks_Y" # DATASET_DIR = "HCP_batches/All_sizes_DAug_XYZ" DATASET_DIR = "HCP_batches/270g_125mm_bundle_peaks_XYZ" ExpUtils.make_dir(join(C.HOME, DATASET_DIR, type)) data = nib.Nifti1Image( batch["data"], ImgUtils.get_dwi_affine(HP.DATASET, HP.RESOLUTION)) nib.save( data, join(C.HOME, DATASET_DIR, type, "batch_" + str(idx) + "_data.nii.gz")) seg = nib.Nifti1Image( batch["seg"], ImgUtils.get_dwi_affine(HP.DATASET, HP.RESOLUTION)) nib.save( seg, join(C.HOME, DATASET_DIR, type, "batch_" + str(idx) + "_seg.nii.gz"))
def calc_peak_dice_pytorch(HP, y_pred, y_true, max_angle_error=[0.9]): ''' Calculate angle between groundtruth and prediction and keep the voxels where angle is smaller than MAX_ANGLE_ERROR. From groundtruth generate a binary mask by selecting all voxels with len > 0. Calculate Dice from these 2 masks. -> Penalty on peaks outside of tract or if predicted peak=0 -> no penalty on very very small with right direction -> bad => Peak_dice can be high even if peaks inside of tract almost missing (almost 0) :param y_pred: :param y_true: :param max_angle_error: 0.7 -> angle error of 45° or less; 0.9 -> angle error of 23° or less Can be list with several values -> calculate for several thresholds :return: ''' import torch from tractseg.libs.PytorchEinsum import einsum from tractseg.libs.PytorchUtils import PytorchUtils y_true = y_true.permute(0, 2, 3, 1) y_pred = y_pred.permute(0, 2, 3, 1) def angle_last_dim(a, b): ''' Calculate the angle between two nd-arrays (array of vectors) along the last dimension without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90° np.arccos -> returns degree in pi (90°: 0.5*pi) return: one dimension less then input ''' return torch.abs(einsum('abcd,abcd->abc', a, b) / (torch.norm(a, 2., -1) * torch.norm(b, 2, -1) + 1e-7)) #Single threshold if len(max_angle_error) == 1: score_per_bundle = {} bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:] for idx, bundle in enumerate(bundles): # if bundle == "CST_right": y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous() y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3].contiguous() # [x,y,z,3] angles = angle_last_dim(y_pred_bund, y_true_bund) gt_binary = y_true_bund.sum(dim=-1) > 0 gt_binary = gt_binary.view(-1) # [bs*x*y] angles_binary = angles > max_angle_error[0] angles_binary = angles_binary.view(-1) f1 = PytorchUtils.f1_score_binary(gt_binary, angles_binary) score_per_bundle[bundle] = f1 return score_per_bundle #multiple thresholds else: score_per_bundle = {} bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:] for idx, bundle in enumerate(bundles): # if bundle == "CST_right": y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous() y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3].contiguous() # [x,y,z,3] angles = angle_last_dim(y_pred_bund, y_true_bund) gt_binary = y_true_bund.sum(dim=-1) > 0 gt_binary = gt_binary.view(-1) # [bs*x*y] score_per_bundle[bundle] = [] for threshold in max_angle_error: angles_binary = angles > threshold angles_binary = angles_binary.view(-1) f1 = PytorchUtils.f1_score_binary(gt_binary, angles_binary) score_per_bundle[bundle].append(f1) return score_per_bundle
def generate_train_batch(self): subjects = self._data[0] subject_idx = int(random.uniform(0, len(subjects))) # len(subjects)-1 not needed because int always rounds to floor for i in range(20): try: if self.HP.FEATURES_FILENAME == "12g90g270g": # if np.random.random() < 0.5: # data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() # else: # data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data() rnd_choice = np.random.random() if rnd_choice < 0.33: data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() elif rnd_choice < 0.66: data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data() else: data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "12g_125mm_peaks.nii.gz")).get_data() elif self.HP.FEATURES_FILENAME == "T1_Peaks270g": peaks = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() t1 = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "T1.nii.gz")).get_data() data = np.concatenate((peaks, t1), axis=3) elif self.HP.FEATURES_FILENAME == "T1_Peaks12g90g270g": rnd_choice = np.random.random() if rnd_choice < 0.33: peaks = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() elif rnd_choice < 0.66: peaks = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data() else: peaks = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "12g_125mm_peaks.nii.gz")).get_data() t1 = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "T1.nii.gz")).get_data() data = np.concatenate((peaks, t1), axis=3) else: data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], self.HP.FEATURES_FILENAME + ".nii.gz")).get_data() seg = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], self.HP.LABELS_FILENAME + ".nii.gz")).get_data() break except IOError: ExpUtils.print_and_save(self.HP, "\n\nWARNING: Could not load file. Trying again in 20s (Try number: " + str(i) + ").\n\n") ExpUtils.print_and_save(self.HP, "Sleeping 20s") sleep(20) # ExpUtils.print_and_save(self.HP, "Successfully loaded input.") data = np.nan_to_num(data) # Needed otherwise not working seg = np.nan_to_num(seg) data = DatasetUtils.scale_input_to_unet_shape(data, self.HP.DATASET, self.HP.RESOLUTION) # (x, y, z, channels) if self.HP.LABELS_FILENAME not in ["bundle_peaks_11_808080", "bundle_peaks_20_808080", "bundle_peaks_808080", "bundle_masks_20_808080", "bundle_masks_72_808080"]: if self.HP.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]: # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask to the lower resolution seg = DatasetUtils.scale_input_to_unet_shape(seg, "HCP", self.HP.RESOLUTION) else: seg = DatasetUtils.scale_input_to_unet_shape(seg, self.HP.DATASET, self.HP.RESOLUTION) # (x, y, z, classes) slice_idxs = np.random.choice(data.shape[0], self.BATCH_SIZE, False, None) # Randomly sample slice orientation if self.HP.TRAINING_SLICE_DIRECTION == "xyz": slice_direction = int(round(random.uniform(0,2))) else: slice_direction = 1 #always use Y if slice_direction == 0: x = data[slice_idxs, :, :].astype(np.float32) # (batch_size, y, z, channels) y = seg[slice_idxs, :, :].astype(self.HP.LABELS_TYPE) x = np.array(x).transpose(0, 3, 1, 2) # depth-channel has to be before width and height for Unet (but after batches) y = np.array(y).transpose(0, 3, 1, 2) # nr_classes channel has to be before with and height for DataAugmentation (bs, nr_of_classes, x, y) elif slice_direction == 1: x = data[:, slice_idxs, :].astype(np.float32) # (x, batch_size, z, channels) y = seg[:, slice_idxs, :].astype(self.HP.LABELS_TYPE) x = np.array(x).transpose(1, 3, 0, 2) y = np.array(y).transpose(1, 3, 0, 2) elif slice_direction == 2: x = data[:, :, slice_idxs].astype(np.float32) # (x, y, batch_size, channels) y = seg[:, :, slice_idxs].astype(self.HP.LABELS_TYPE) x = np.array(x).transpose(2, 3, 0, 1) y = np.array(y).transpose(2, 3, 0, 1) data_dict = {"data": x, # (batch_size, channels, x, y, [z]) "seg": y} # (batch_size, channels, x, y, [z]) return data_dict
def generate_train_batch(self): subjects = self._data[0] subject_idx = int( random.uniform(0, len(subjects)) ) # len(subjects)-1 not needed because int always rounds to floor for i in range(20): try: if self.HP.FEATURES_FILENAME == "12g90g270g": # if np.random.random() < 0.5: # data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() # else: # data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data() rnd_choice = np.random.random() if rnd_choice < 0.33: data = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() elif rnd_choice < 0.66: data = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data() else: data = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "12g_125mm_peaks.nii.gz")).get_data() elif self.HP.FEATURES_FILENAME == "T1_Peaks270g": peaks = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() t1 = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "T1.nii.gz")).get_data() data = np.concatenate((peaks, t1), axis=3) elif self.HP.FEATURES_FILENAME == "T1_Peaks12g90g270g": rnd_choice = np.random.random() if rnd_choice < 0.33: peaks = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() elif rnd_choice < 0.66: peaks = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data() else: peaks = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "12g_125mm_peaks.nii.gz")).get_data() t1 = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "T1.nii.gz")).get_data() data = np.concatenate((peaks, t1), axis=3) else: data = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], self.HP.FEATURES_FILENAME + ".nii.gz")).get_data() seg = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], self.HP.LABELS_FILENAME + ".nii.gz")).get_data() break except IOError: ExpUtils.print_and_save( self.HP, "\n\nWARNING: Could not load file. Trying again in 20s (Try number: " + str(i) + ").\n\n") ExpUtils.print_and_save(self.HP, "Sleeping 20s") sleep(20) # ExpUtils.print_and_save(self.HP, "Successfully loaded input.") data = np.nan_to_num(data) # Needed otherwise not working seg = np.nan_to_num(seg) data = DatasetUtils.scale_input_to_unet_shape( data, self.HP.DATASET, self.HP.RESOLUTION) # (x, y, z, channels) if self.HP.LABELS_FILENAME not in [ "bundle_peaks_11_808080", "bundle_peaks_20_808080", "bundle_peaks_808080", "bundle_masks_20_808080", "bundle_masks_72_808080" ]: if self.HP.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]: # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask to the lower resolution seg = DatasetUtils.scale_input_to_unet_shape( seg, "HCP", self.HP.RESOLUTION) else: seg = DatasetUtils.scale_input_to_unet_shape( seg, self.HP.DATASET, self.HP.RESOLUTION) # (x, y, z, classes) slice_idxs = np.random.choice(data.shape[0], self.BATCH_SIZE, False, None) # Randomly sample slice orientation if self.HP.TRAINING_SLICE_DIRECTION == "xyz": slice_direction = int(round(random.uniform(0, 2))) else: slice_direction = 1 #always use Y if slice_direction == 0: x = data[slice_idxs, :, :].astype( np.float32) # (batch_size, y, z, channels) y = seg[slice_idxs, :, :].astype(self.HP.LABELS_TYPE) x = np.array(x).transpose( 0, 3, 1, 2 ) # depth-channel has to be before width and height for Unet (but after batches) y = np.array(y).transpose( 0, 3, 1, 2 ) # nr_classes channel has to be before with and height for DataAugmentation (bs, nr_of_classes, x, y) elif slice_direction == 1: x = data[:, slice_idxs, :].astype( np.float32) # (x, batch_size, z, channels) y = seg[:, slice_idxs, :].astype(self.HP.LABELS_TYPE) x = np.array(x).transpose(1, 3, 0, 2) y = np.array(y).transpose(1, 3, 0, 2) elif slice_direction == 2: x = data[:, :, slice_idxs].astype( np.float32) # (x, y, batch_size, channels) y = seg[:, :, slice_idxs].astype(self.HP.LABELS_TYPE) x = np.array(x).transpose(2, 3, 0, 1) y = np.array(y).transpose(2, 3, 0, 1) data_dict = { "data": x, # (batch_size, channels, x, y, [z]) "seg": y } # (batch_size, channels, x, y, [z]) return data_dict
def save_multilabel_img_as_multiple_files_endings(HP, img, affine, path): bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:] for idx, bundle in enumerate(bundles): img_seg = nib.Nifti1Image(img[:,:,:,idx], affine) ExpUtils.make_dir(join(path, "endings_segmentations")) nib.save(img_seg, join(path, "endings_segmentations", bundle + ".nii.gz"))
def create_network(self): def train(X, y): X = torch.from_numpy(X.astype(np.float32)) y = torch.from_numpy(y.astype(np.float32)) if torch.cuda.is_available(): X, y = Variable(X.cuda()), Variable(y.cuda()) # X: (bs, features, x, y) y: (bs, classes, x, y) else: X, y = Variable(X), Variable(y) optimizer.zero_grad() net.train() outputs = net(X) # forward # outputs: (bs, classes, x, y) loss = criterion(outputs, y) loss.backward() # backward optimizer.step() # optimise f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True) # probs = outputs.data.cpu().numpy().transpose(0,2,3,1) # (bs, x, y, classes) probs = None #faster return loss.data[0], probs, f1 def test(X, y): X = torch.from_numpy(X.astype(np.float32)) y = torch.from_numpy(y.astype(np.float32)) if torch.cuda.is_available(): X, y = Variable(X.cuda(), volatile=True), Variable(y.cuda(), volatile=True) else: X, y = Variable(X, volatile=True), Variable(y, volatile=True) net.train(False) outputs = net(X) # forward loss = criterion(outputs, y) f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True) # probs = outputs.data.cpu().numpy().transpose(0,2,3,1) # (bs, x, y, classes) probs = None # faster return loss.data[0], probs, f1 def predict(X): X = torch.from_numpy(X.astype(np.float32)) if torch.cuda.is_available(): X = Variable(X.cuda(), volatile=True) else: X = Variable(X, volatile=True) net.train(False) outputs = net(X) # forward probs = outputs.data.cpu().numpy().transpose(0,2,3,1) # (bs, x, y, classes) return probs def save_model(metrics, epoch_nr): max_f1_idx = np.argmax(metrics["f1_macro_validate"]) max_f1 = np.max(metrics["f1_macro_validate"]) if epoch_nr == max_f1_idx and max_f1 > 0.01: # saving to network drives takes 5s (to local only 0.5s) -> do not save so often print(" Saving weights...") for fl in glob.glob(join(self.HP.EXP_PATH, "best_weights_ep*")): # remove weights from previous epochs os.remove(fl) try: #Actually is a pkl not a npz PytorchUtils.save_checkpoint(join(self.HP.EXP_PATH, "best_weights_ep" + str(epoch_nr) + ".npz"), unet=net) except IOError: print("\nERROR: Could not save weights because of IO Error\n") self.HP.BEST_EPOCH = epoch_nr def load_model(path): PytorchUtils.load_checkpoint(path, unet=net) if self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "single_direction": NR_OF_GRADIENTS = 9 elif self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "combined": NR_OF_GRADIENTS = 3*self.HP.NR_OF_CLASSES else: NR_OF_GRADIENTS = 33 if torch.cuda.is_available(): net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT).cuda() # net = UNet_Skip(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT).cuda() else: net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT) # net = UNet_Skip(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT) if self.HP.TRAIN: ExpUtils.print_and_save(self.HP, str(net), only_log=True) criterion = nn.BCEWithLogitsLoss() optimizer = optim.Adamax(net.parameters(), lr=self.HP.LEARNING_RATE) if self.HP.LOAD_WEIGHTS: ExpUtils.print_verbose(self.HP, "Loading weights ... ({})".format(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH))) load_model(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH)) self.train = train self.predict = test self.get_probs = predict self.save_model = save_model self.load_model = load_model
def create_network(self): # torch.backends.cudnn.benchmark = True #not faster def train(X, y, weight_factor=10): X = torch.from_numpy(X.astype(np.float32)) y = torch.from_numpy(y.astype(np.float32)) if torch.cuda.is_available(): X, y = Variable(X.cuda()), Variable(y.cuda( )) # X: (bs, features, x, y) y: (bs, classes, x, y) else: X, y = Variable(X), Variable(y) optimizer.zero_grad() net.train() outputs = net(X) # forward # outputs: (bs, classes, x, y) loss = criterion(outputs, y) loss.backward() # backward optimizer.step() # optimise f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True) if self.HP.USE_VISLOGGER: probs = outputs.data.cpu().numpy().transpose( 0, 2, 3, 1) # (bs, x, y, classes) else: probs = None #faster return loss.data[0], probs, f1 def test(X, y, weight_factor=10): X = torch.from_numpy(X.astype(np.float32)) y = torch.from_numpy(y.astype(np.float32)) if torch.cuda.is_available(): X, y = Variable(X.cuda(), volatile=True), Variable(y.cuda(), volatile=True) else: X, y = Variable(X, volatile=True), Variable(y, volatile=True) net.train(False) outputs = net(X) # forward loss = criterion(outputs, y) f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True) # probs = outputs.data.cpu().numpy().transpose(0,2,3,1) # (bs, x, y, classes) probs = None # faster return loss.data[0], probs, f1 def predict(X): X = torch.from_numpy(X.astype(np.float32)) if torch.cuda.is_available(): X = Variable(X.cuda(), volatile=True) else: X = Variable(X, volatile=True) net.train(False) outputs = net(X) # forward probs = outputs.data.cpu().numpy().transpose( 0, 2, 3, 1) # (bs, x, y, classes) return probs def save_model(metrics, epoch_nr): max_f1_idx = np.argmax(metrics["f1_macro_validate"]) max_f1 = np.max(metrics["f1_macro_validate"]) if epoch_nr == max_f1_idx and max_f1 > 0.01: # saving to network drives takes 5s (to local only 0.5s) -> do not save so often print(" Saving weights...") for fl in glob.glob(join(self.HP.EXP_PATH, "best_weights_ep*") ): # remove weights from previous epochs os.remove(fl) try: #Actually is a pkl not a npz PytorchUtils.save_checkpoint(join( self.HP.EXP_PATH, "best_weights_ep" + str(epoch_nr) + ".npz"), unet=net) except IOError: print( "\nERROR: Could not save weights because of IO Error\n" ) self.HP.BEST_EPOCH = epoch_nr def load_model(path): PytorchUtils.load_checkpoint(path, unet=net) def print_current_lr(): for param_group in optimizer.param_groups: ExpUtils.print_and_save( self.HP, "current learning rate: {}".format(param_group['lr'])) if self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "single_direction": NR_OF_GRADIENTS = self.HP.NR_OF_GRADIENTS # NR_OF_GRADIENTS = 9 # NR_OF_GRADIENTS = 9 * 5 # NR_OF_GRADIENTS = 9 * 9 # NR_OF_GRADIENTS = 33 elif self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "combined": self.HP.NR_OF_GRADIENTS = 3 * self.HP.NR_OF_CLASSES else: self.HP.NR_OF_GRADIENTS = 33 if self.HP.LOSS_FUNCTION == "soft_sample_dice": criterion = PytorchUtils.soft_sample_dice final_activation = "sigmoid" elif self.HP.LOSS_FUNCTION == "soft_batch_dice": criterion = PytorchUtils.soft_batch_dice final_activation = "sigmoid" else: # weights = torch.ones((self.HP.BATCH_SIZE, self.HP.NR_OF_CLASSES, self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[1])).cuda() # weights[:, 5, :, :] *= 10 #CA # weights[:, 21, :, :] *= 10 #FX_left # weights[:, 22, :, :] *= 10 #FX_right # criterion = nn.BCEWithLogitsLoss(weight=weights) criterion = nn.BCEWithLogitsLoss() final_activation = None net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT, batchnorm=self.HP.BATCH_NORM, final_activation=final_activation) if torch.cuda.is_available(): net = net.cuda() # else: # net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT, # batchnorm=self.HP.BATCH_NORM) # net = nn.DataParallel(net, device_ids=[0,1]) # if self.HP.TRAIN: # ExpUtils.print_and_save(self.HP, str(net), only_log=True) if self.HP.OPTIMIZER == "Adamax": optimizer = Adamax(net.parameters(), lr=self.HP.LEARNING_RATE) elif self.HP.OPTIMIZER == "Adam": #todo important: change # optimizer = Adam(net.parameters(), lr=self.HP.LEARNING_RATE) optimizer = Adam(net.parameters(), lr=self.HP.LEARNING_RATE, weight_decay=self.HP.WEIGHT_DECAY) else: raise ValueError("Optimizer not defined") # scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1) # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="max") if self.HP.LOAD_WEIGHTS: ExpUtils.print_verbose( self.HP, "Loading weights ... ({})".format( join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH))) load_model(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH)) self.train = train self.predict = test self.get_probs = predict self.save_model = save_model self.load_model = load_model self.print_current_lr = print_current_lr
def plot_tracts(HP, bundle_segmentations, affine, out_dir, brain_mask=None): ''' By default this does not work on a remote server connection (ssh -X) because -X does not support OpenGL. On the remote Server you can do 'export DISPLAY=":0"' (you should set the value you get if you do 'echo $DISPLAY' if you login locally on the remote server). Then all graphics will get rendered locally and not via -X. (important: graphical session needs to be running on remote server (e.g. via login locally)) (important: login needed, not just stay at login screen) ''' from dipy.viz import window from tractseg.libs.VtkUtils import VtkUtils SMOOTHING = 10 WINDOW_SIZE = (800, 800) bundles = ["CST_right", "CA", "IFO_right"] renderer = window.Renderer() renderer.projection('parallel') rows = len(bundles) X, Y, Z = bundle_segmentations.shape[:3] for j, bundle in enumerate(bundles): i = 0 #only one method bundle_idx = ExpUtils.get_bundle_names( HP.CLASSES)[1:].index(bundle) mask_data = bundle_segmentations[:, :, :, bundle_idx] if bundle == "CST_right": orientation = "axial" elif bundle == "CA": orientation = "axial" elif bundle == "IFO_right": orientation = "sagittal" else: orientation = "axial" #bigger: more border if orientation == "axial": border_y = -100 #-60 else: border_y = -100 x_current = X * i # column (width) y_current = rows * (Y * 2 + border_y) - ( Y * 2 + border_y) * j # row (height) (starts from bottom?) PlotUtils.plot_mask(renderer, mask_data, affine, x_current, y_current, orientation=orientation, smoothing=SMOOTHING, brain_mask=brain_mask) #Bundle label text_offset_top = -50 # 60 text_offset_side = -100 # -30 position = (0 - int(X) + text_offset_side, y_current + text_offset_top, 50) text_actor = VtkUtils.label(text=bundle, pos=position, scale=(6, 6, 6), color=(1, 1, 1)) renderer.add(text_actor) renderer.reset_camera() window.record(renderer, out_path=join(out_dir, "preview.png"), size=(WINDOW_SIZE[0], WINDOW_SIZE[1]), reset_camera=False, magnification=2)
def run_tractseg(data, output_type="tract_segmentation", input_type="peaks", single_orientation=False, verbose=False, dropout_sampling=False, threshold=0.5, bundle_specific_threshold=False, get_probs=False, peak_threshold=0.1): ''' Run TractSeg :param data: input peaks (4D numpy array with shape [x,y,z,9]) :param output_type: "tract_segmentation" | "endings_segmentation" | "TOM" | "dm_regression" :param input_type: "peaks" :param verbose: show debugging infos :param dropout_sampling: create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142) :param threshold: Threshold for converting probability map to binary map :param bundle_specific_threshold: Threshold is lower for some bundles which need more sensitivity (CA, CST, FX) :param get_probs: Output raw probability map instead of binary map :param peak_threshold: all peaks shorter than peak_threshold will be set to zero :return: 4D numpy array with the output of tractseg for tract_segmentation: [x,y,z,nr_of_bundles] for endings_segmentation: [x,y,z,2*nr_of_bundles] for TOM: [x,y,z,3*nr_of_bundles] ''' start_time = time.time() config = get_config_name(input_type, output_type) HP = getattr(importlib.import_module("tractseg.config.PretrainedModels." + config), "HP")() HP.VERBOSE = verbose HP.TRAIN = False HP.TEST = False HP.SEGMENT = False HP.GET_PROBS = get_probs HP.LOAD_WEIGHTS = True HP.DROPOUT_SAMPLING = dropout_sampling HP.THRESHOLD = threshold if bundle_specific_threshold: HP.GET_PROBS = True if input_type == "peaks": if HP.EXPERIMENT_TYPE == "tract_segmentation" and HP.DROPOUT_SAMPLING: HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_dropout_v1.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_12g90g270g_125mm_DAugAll_Dropout", "best_weights_ep114.npz") elif HP.EXPERIMENT_TYPE == "tract_segmentation": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "x_Pretrained_TractSeg_Models/TractSeg_T1_12g90g270g_125mm_DAugAll", "best_weights_ep392.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888", "best_weights_ep247.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888_SchizoFineT_lr001", "best_weights_ep186.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_12g90g270g_125mm_DS_DAugAll_RotMir", "best_weights_ep200.npz") elif HP.EXPERIMENT_TYPE == "endings_segmentation": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v3.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "EndingsSeg_12g90g270g_125mm_DS_DAugAll", "best_weights_ep234.npz") elif HP.EXPERIMENT_TYPE == "peak_regression": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "x_Pretrained_TractSeg_Models/Peaks20_12g90g270g_125mm_DAugSimp_constW5", "best_weights_ep441.npz") #more oversegmentation with DAug elif HP.EXPERIMENT_TYPE == "dm_regression": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_dm_regression_v1.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "DmReg_12g90g270g_125mm_DAugAll_Ubuntu", "best_weights_ep80.npz") elif input_type == "T1": if HP.EXPERIMENT_TYPE == "tract_segmentation": # HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz") HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes/x_Pretrained_TractSeg_Models", "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz") elif HP.EXPERIMENT_TYPE == "endings_segmentation": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v1.npz") elif HP.EXPERIMENT_TYPE == "peak_regression": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz") print("Loading weights from: {}".format(HP.WEIGHTS_PATH)) if HP.EXPERIMENT_TYPE == "peak_regression": HP.NR_OF_CLASSES = 3*len(ExpUtils.get_bundle_names(HP.CLASSES)[1:]) else: HP.NR_OF_CLASSES = len(ExpUtils.get_bundle_names(HP.CLASSES)[1:]) if HP.VERBOSE: print("Hyperparameters:") ExpUtils.print_HPs(HP) Utils.download_pretrained_weights(experiment_type=HP.EXPERIMENT_TYPE, dropout_sampling=HP.DROPOUT_SAMPLING) data = np.nan_to_num(data) # brain_mask = ImgUtils.simple_brain_mask(data) # if HP.VERBOSE: # nib.save(nib.Nifti1Image(brain_mask, np.eye(4)), "otsu_brain_mask_DEBUG.nii.gz") if input_type == "T1": data = np.reshape(data, (data.shape[0], data.shape[1], data.shape[2], 1)) data, seg_None, bbox, original_shape = DatasetUtils.crop_to_nonzero(data) data, transformation = DatasetUtils.pad_and_scale_img_to_square_img(data, target_size=HP.INPUT_DIM[0]) model = BaseModel(HP) if HP.EXPERIMENT_TYPE == "tract_segmentation" or HP.EXPERIMENT_TYPE == "endings_segmentation" or HP.EXPERIMENT_TYPE == "dm_regression": if single_orientation: # mainly needed for testing because of less RAM requirements dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data) trainerSingle = Trainer(model, dataManagerSingle) if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS: seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True) else: seg, img_y = trainerSingle.get_seg_single_img(HP, probs=False, scale_to_world_shape=False, only_prediction=True) else: seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True) if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS: seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True) else: seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=False) elif HP.EXPERIMENT_TYPE == "peak_regression": dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data) trainerSingle = Trainer(model, dataManagerSingle) seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True) if bundle_specific_threshold: seg = ImgUtils.remove_small_peaks_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:], len_thr=0.3) else: seg = ImgUtils.remove_small_peaks(seg, len_thr=peak_threshold) #3 dir for Peaks -> not working (?) # seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True) # seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True) if bundle_specific_threshold and HP.EXPERIMENT_TYPE == "tract_segmentation": seg = ImgUtils.probs_to_binary_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:]) #remove following two lines to keep super resolution seg = DatasetUtils.cut_and_scale_img_back_to_original_img(seg, transformation) seg = DatasetUtils.add_original_zero_padding_again(seg, bbox, original_shape, HP.NR_OF_CLASSES) ExpUtils.print_verbose(HP, "Took {}s".format(round(time.time() - start_time, 2))) return seg
def create_network(self): def train(X, y): X = torch.from_numpy(X.astype(np.float32)) y = torch.from_numpy(y.astype(np.float32)) if torch.cuda.is_available(): X, y = Variable(X.cuda()), Variable(y.cuda( )) # X: (bs, features, x, y) y: (bs, classes, x, y) else: X, y = Variable(X), Variable(y) optimizer.zero_grad() net.train() outputs = net(X) # forward # outputs: (bs, classes, x, y) loss = criterion(outputs, y) loss.backward() # backward optimizer.step() # optimise f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True) # probs = outputs.data.cpu().numpy().transpose(0,2,3,1) # (bs, x, y, classes) probs = None #faster return loss.data[0], probs, f1 def test(X, y): X = torch.from_numpy(X.astype(np.float32)) y = torch.from_numpy(y.astype(np.float32)) if torch.cuda.is_available(): X, y = Variable(X.cuda(), volatile=True), Variable(y.cuda(), volatile=True) else: X, y = Variable(X, volatile=True), Variable(y, volatile=True) net.train(False) outputs = net(X) # forward loss = criterion(outputs, y) f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True) # probs = outputs.data.cpu().numpy().transpose(0,2,3,1) # (bs, x, y, classes) probs = None # faster return loss.data[0], probs, f1 def predict(X): X = torch.from_numpy(X.astype(np.float32)) if torch.cuda.is_available(): X = Variable(X.cuda(), volatile=True) else: X = Variable(X, volatile=True) net.train(False) outputs = net(X) # forward probs = outputs.data.cpu().numpy().transpose( 0, 2, 3, 1) # (bs, x, y, classes) return probs def save_model(metrics, epoch_nr): max_f1_idx = np.argmax(metrics["f1_macro_validate"]) max_f1 = np.max(metrics["f1_macro_validate"]) if epoch_nr == max_f1_idx and max_f1 > 0.01: # saving to network drives takes 5s (to local only 0.5s) -> do not save so often print(" Saving weights...") for fl in glob.glob(join(self.HP.EXP_PATH, "best_weights_ep*") ): # remove weights from previous epochs os.remove(fl) try: #Actually is a pkl not a npz PytorchUtils.save_checkpoint(join( self.HP.EXP_PATH, "best_weights_ep" + str(epoch_nr) + ".npz"), unet=net) except IOError: print( "\nERROR: Could not save weights because of IO Error\n" ) self.HP.BEST_EPOCH = epoch_nr def load_model(path): PytorchUtils.load_checkpoint(path, unet=net) if self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "single_direction": NR_OF_GRADIENTS = 9 elif self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "combined": NR_OF_GRADIENTS = 3 * self.HP.NR_OF_CLASSES else: NR_OF_GRADIENTS = 33 if torch.cuda.is_available(): net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT).cuda() # net = UNet_Skip(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT).cuda() else: net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT) # net = UNet_Skip(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT) if self.HP.TRAIN: ExpUtils.print_and_save(self.HP, str(net), only_log=True) criterion = nn.BCEWithLogitsLoss() optimizer = optim.Adamax(net.parameters(), lr=self.HP.LEARNING_RATE) if self.HP.LOAD_WEIGHTS: ExpUtils.print_verbose( self.HP, "Loading weights ... ({})".format( join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH))) load_model(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH)) self.train = train self.predict = test self.get_probs = predict self.save_model = save_model self.load_model = load_model
def test_bundle_names(self): bundles = ExpUtils.get_bundle_names("CST_right") self.assertListEqual(bundles, ["BG", "CST_right"], "Error in list of bundle names")
def create_network(self): # torch.backends.cudnn.benchmark = True #not faster def train(X, y, weight_factor=10): X = torch.from_numpy(X.astype(np.float32)) y = torch.from_numpy(y.astype(np.float32)) if torch.cuda.is_available(): X, y = Variable(X.cuda()), Variable(y.cuda( )) # X: (bs, features, x, y) y: (bs, classes, x, y) else: X, y = Variable(X), Variable(y) optimizer.zero_grad() net.train() outputs = net(X) # forward # outputs: (bs, classes, x, y) weights = torch.ones( (self.HP.BATCH_SIZE, self.HP.NR_OF_CLASSES, self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[1])).cuda() bundle_mask = y > 0 weights[bundle_mask.data] *= weight_factor #10 loss = criterion(outputs, y, Variable(weights)) # loss = criterion1(outputs, y, Variable(weights)) + criterion2(outputs, y, Variable(weights)) loss.backward() # backward optimizer.step() # optimise if self.HP.CALC_F1: # f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True) # f1_a = MetricUtils.calc_peak_dice_pytorch(self.HP, outputs.data, y.data, max_angle_error=self.HP.PEAK_DICE_THR) f1 = MetricUtils.calc_peak_length_dice_pytorch( self.HP, outputs.data, y.data, max_angle_error=self.HP.PEAK_DICE_THR, max_length_error=self.HP.PEAK_DICE_LEN_THR) # f1 = (f1_a, f1_b) else: f1 = np.ones(outputs.shape[3]) if self.HP.USE_VISLOGGER: probs = outputs.data.cpu().numpy().transpose( 0, 2, 3, 1) # (bs, x, y, classes) else: # probs = outputs.data.cpu().numpy().transpose(0,2,3,1) # (bs, x, y, classes) probs = None #faster return loss.data[0], probs, f1 def test(X, y, weight_factor=10): X = torch.from_numpy(X.astype(np.float32)) y = torch.from_numpy(y.astype(np.float32)) if torch.cuda.is_available(): X, y = Variable(X.cuda(), volatile=True), Variable(y.cuda(), volatile=True) else: X, y = Variable(X, volatile=True), Variable(y, volatile=True) net.train(False) outputs = net(X) # forward weights = torch.ones( (self.HP.BATCH_SIZE, self.HP.NR_OF_CLASSES, self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[1])).cuda() bundle_mask = y > 0 weights[bundle_mask.data] *= weight_factor #10 loss = criterion(outputs, y, Variable(weights)) # loss = criterion1(outputs, y, Variable(weights)) + criterion2(outputs, y, Variable(weights)) if self.HP.CALC_F1: # f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True) # f1_a = MetricUtils.calc_peak_dice_pytorch(self.HP, outputs.data, y.data, max_angle_error=self.HP.PEAK_DICE_THR) f1 = MetricUtils.calc_peak_length_dice_pytorch( self.HP, outputs.data, y.data, max_angle_error=self.HP.PEAK_DICE_THR, max_length_error=self.HP.PEAK_DICE_LEN_THR) # f1 = (f1_a, f1_b) else: f1 = np.ones(outputs.shape[3]) # probs = outputs.data.cpu().numpy().transpose(0,2,3,1) # (bs, x, y, classes) probs = None # faster return loss.data[0], probs, f1 def predict(X): X = torch.from_numpy(X.astype(np.float32)) if torch.cuda.is_available(): X = Variable(X.cuda(), volatile=True) else: X = Variable(X, volatile=True) net.train(False) outputs = net(X) # forward probs = outputs.data.cpu().numpy().transpose( 0, 2, 3, 1) # (bs, x, y, classes) return probs def save_model(metrics, epoch_nr): max_f1_idx = np.argmax(metrics["f1_macro_validate"]) max_f1 = np.max(metrics["f1_macro_validate"]) if epoch_nr == max_f1_idx and max_f1 > 0.01: # saving to network drives takes 5s (to local only 0.5s) -> do not save so often print(" Saving weights...") for fl in glob.glob(join(self.HP.EXP_PATH, "best_weights_ep*") ): # remove weights from previous epochs os.remove(fl) try: #Actually is a pkl not a npz PytorchUtils.save_checkpoint(join( self.HP.EXP_PATH, "best_weights_ep" + str(epoch_nr) + ".npz"), unet=net) except IOError: print( "\nERROR: Could not save weights because of IO Error\n" ) self.HP.BEST_EPOCH = epoch_nr #Saving Last Epoch: # print(" Saving weights last epoch...") # for fl in glob.glob(join(self.HP.EXP_PATH, "weights_ep*")): # remove weights from previous epochs # os.remove(fl) # try: # # Actually is a pkl not a npz # PytorchUtils.save_checkpoint(join(self.HP.EXP_PATH, "weights_ep" + str(epoch_nr) + ".npz"), unet=net) # except IOError: # print("\nERROR: Could not save weights because of IO Error\n") # self.HP.BEST_EPOCH = epoch_nr def load_model(path): PytorchUtils.load_checkpoint(path, unet=net) def print_current_lr(): for param_group in optimizer.param_groups: ExpUtils.print_and_save( self.HP, "current learning rate: {}".format(param_group['lr'])) if self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "single_direction": NR_OF_GRADIENTS = self.HP.NR_OF_GRADIENTS # NR_OF_GRADIENTS = 9 * 5 # NR_OF_GRADIENTS = 9 * 9 # NR_OF_GRADIENTS = 33 elif self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "combined": NR_OF_GRADIENTS = 3 * self.HP.NR_OF_CLASSES else: NR_OF_GRADIENTS = 33 if torch.cuda.is_available(): net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT).cuda() else: net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT) # if self.HP.TRAIN: # ExpUtils.print_and_save(self.HP, str(net), only_log=True) # criterion1 = PytorchUtils.MSE_weighted # criterion2 = PytorchUtils.angle_loss # criterion = PytorchUtils.MSE_weighted # criterion = PytorchUtils.angle_loss criterion = PytorchUtils.angle_length_loss optimizer = Adamax(net.parameters(), lr=self.HP.LEARNING_RATE) if self.HP.LOAD_WEIGHTS: ExpUtils.print_verbose( self.HP, "Loading weights ... ({})".format( join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH))) load_model(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH)) self.train = train self.predict = test self.get_probs = predict self.save_model = save_model self.load_model = load_model self.print_current_lr = print_current_lr
def __init__(self, HP, data): self.data = data self.HP = HP ExpUtils.print_verbose(self.HP, "Loading data from PREDICT_IMG input file")
def create_network(self): # torch.backends.cudnn.benchmark = True #not faster def train(X, y, weight_factor=10): X = torch.tensor(X, dtype=torch.float32).to(device) # X: (bs, features, x, y) y: (bs, classes, x, y) y = torch.tensor(y, dtype=torch.float32).to(device) optimizer.zero_grad() net.train() outputs, outputs_sigmoid = net(X) # forward # outputs: (bs, classes, x, y) if weight_factor > 1: # weights = torch.ones((self.HP.BATCH_SIZE, self.HP.NR_OF_CLASSES, self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[1])).cuda() weights = torch.ones((self.HP.BATCH_SIZE, self.HP.NR_OF_CLASSES, y.shape[2], y.shape[3])).cuda() bundle_mask = y > 0 weights[bundle_mask.data] *= weight_factor # 10 if self.HP.EXPERIMENT_TYPE == "peak_regression": loss = criterion(outputs, y, weights) else: loss = nn.BCEWithLogitsLoss(weight=weights)(outputs, y) else: if self.HP.LOSS_FUNCTION == "soft_sample_dice" or self.HP.LOSS_FUNCTION == "soft_batch_dice": loss = criterion(outputs_sigmoid, y) # loss = criterion(outputs_sigmoid, y) + nn.BCEWithLogitsLoss()(outputs, y) else: loss = criterion(outputs, y) loss.backward() # backward optimizer.step() # optimise if self.HP.EXPERIMENT_TYPE == "peak_regression": # f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True) # f1_a = MetricUtils.calc_peak_dice_pytorch(self.HP, outputs.data, y.data, max_angle_error=self.HP.PEAK_DICE_THR) f1 = MetricUtils.calc_peak_length_dice_pytorch(self.HP, outputs.detach(), y.detach(), max_angle_error=self.HP.PEAK_DICE_THR, max_length_error=self.HP.PEAK_DICE_LEN_THR) # f1 = (f1_a, f1_b) elif self.HP.EXPERIMENT_TYPE == "dm_regression": #density map regression f1 = PytorchUtils.f1_score_macro(y.detach()>0.5, outputs.detach(), per_class=True) else: f1 = PytorchUtils.f1_score_macro(y.detach(), outputs_sigmoid.detach(), per_class=True, threshold=self.HP.THRESHOLD) if self.HP.USE_VISLOGGER: # probs = outputs_sigmoid.detach().cpu().numpy().transpose(0,2,3,1) # (bs, x, y, classes) probs = outputs_sigmoid else: probs = None #faster return loss.item(), probs, f1 def test(X, y, weight_factor=10): with torch.no_grad(): X = torch.tensor(X, dtype=torch.float32).to(device) y = torch.tensor(y, dtype=torch.float32).to(device) if self.HP.DROPOUT_SAMPLING: net.train() else: net.train(False) outputs, outputs_sigmoid = net(X) # forward if weight_factor > 1: # weights = torch.ones((self.HP.BATCH_SIZE, self.HP.NR_OF_CLASSES, self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[1])).cuda() weights = torch.ones((self.HP.BATCH_SIZE, self.HP.NR_OF_CLASSES, y.shape[2], y.shape[3])).cuda() bundle_mask = y > 0 weights[bundle_mask.data] *= weight_factor # 10 if self.HP.EXPERIMENT_TYPE == "peak_regression": loss = criterion(outputs, y, weights) else: loss = nn.BCEWithLogitsLoss(weight=weights)(outputs, y) else: if self.HP.LOSS_FUNCTION == "soft_sample_dice" or self.HP.LOSS_FUNCTION == "soft_batch_dice": loss = criterion(outputs_sigmoid, y) # loss = criterion(outputs_sigmoid, y) + nn.BCEWithLogitsLoss()(outputs, y) else: loss = criterion(outputs, y) if self.HP.EXPERIMENT_TYPE == "peak_regression": # f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True) # f1_a = MetricUtils.calc_peak_dice_pytorch(self.HP, outputs.data, y.data, max_angle_error=self.HP.PEAK_DICE_THR) f1 = MetricUtils.calc_peak_length_dice_pytorch(self.HP, outputs.detach(), y.detach(), max_angle_error=self.HP.PEAK_DICE_THR, max_length_error=self.HP.PEAK_DICE_LEN_THR) # f1 = (f1_a, f1_b) elif self.HP.EXPERIMENT_TYPE == "dm_regression": #density map regression f1 = PytorchUtils.f1_score_macro(y.detach()>0.5, outputs.detach(), per_class=True) else: f1 = PytorchUtils.f1_score_macro(y.detach(), outputs_sigmoid.detach(), per_class=True, threshold=self.HP.THRESHOLD) if self.HP.USE_VISLOGGER: # probs = outputs_sigmoid.detach().cpu().numpy().transpose(0,2,3,1) # (bs, x, y, classes) probs = outputs_sigmoid else: probs = None # faster return loss.item(), probs, f1 def predict(X): with torch.no_grad(): X = torch.tensor(X, dtype=torch.float32).to(device) if self.HP.DROPOUT_SAMPLING: net.train() else: net.train(False) outputs, outputs_sigmoid = net(X) # forward if self.HP.EXPERIMENT_TYPE == "peak_regression" or self.HP.EXPERIMENT_TYPE == "dm_regression": probs = outputs.detach().cpu().numpy().transpose(0,2,3,1) # (bs, x, y, classes) else: probs = outputs_sigmoid.detach().cpu().numpy().transpose(0, 2, 3, 1) # (bs, x, y, classes) return probs def save_model(metrics, epoch_nr): max_f1_idx = np.argmax(metrics["f1_macro_validate"]) max_f1 = np.max(metrics["f1_macro_validate"]) if epoch_nr == max_f1_idx and max_f1 > 0.01: # saving to network drives takes 5s (to local only 0.5s) -> do not save so often print(" Saving weights...") for fl in glob.glob(join(self.HP.EXP_PATH, "best_weights_ep*")): # remove weights from previous epochs os.remove(fl) try: #Actually is a pkl not a npz PytorchUtils.save_checkpoint(join(self.HP.EXP_PATH, "best_weights_ep" + str(epoch_nr) + ".npz"), unet=net) except IOError: print("\nERROR: Could not save weights because of IO Error\n") self.HP.BEST_EPOCH = epoch_nr def load_model(path): PytorchUtils.load_checkpoint(path, unet=net) def print_current_lr(): for param_group in optimizer.param_groups: ExpUtils.print_and_save(self.HP, "current learning rate: {}".format(param_group['lr'])) if self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "single_direction": NR_OF_GRADIENTS = self.HP.NR_OF_GRADIENTS # NR_OF_GRADIENTS = 9 # NR_OF_GRADIENTS = 9 * 5 # NR_OF_GRADIENTS = 9 * 9 # NR_OF_GRADIENTS = 33 elif self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "combined": self.HP.NR_OF_GRADIENTS = 3*self.HP.NR_OF_CLASSES else: self.HP.NR_OF_GRADIENTS = 33 if self.HP.LOSS_FUNCTION == "soft_sample_dice": criterion = PytorchUtils.soft_sample_dice elif self.HP.LOSS_FUNCTION == "soft_batch_dice": criterion = PytorchUtils.soft_batch_dice elif self.HP.EXPERIMENT_TYPE == "peak_regression": criterion = PytorchUtils.angle_length_loss else: # weights = torch.ones((self.HP.BATCH_SIZE, self.HP.NR_OF_CLASSES, self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[1])).cuda() # weights[:, 5, :, :] *= 10 #CA # weights[:, 21, :, :] *= 10 #FX_left # weights[:, 22, :, :] *= 10 #FX_right # criterion = nn.BCEWithLogitsLoss(weight=weights) criterion = nn.BCEWithLogitsLoss() NetworkClass = getattr(importlib.import_module("tractseg.models." + self.HP.MODEL), self.HP.MODEL) net = NetworkClass(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT, batchnorm=self.HP.BATCH_NORM, dropout=self.HP.USE_DROPOUT) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = net.to(device) # if self.HP.TRAIN: # ExpUtils.print_and_save(self.HP, str(net), only_log=True) if self.HP.OPTIMIZER == "Adamax": optimizer = Adamax(net.parameters(), lr=self.HP.LEARNING_RATE) elif self.HP.OPTIMIZER == "Adam": optimizer = Adam(net.parameters(), lr=self.HP.LEARNING_RATE) # optimizer = Adam(net.parameters(), lr=self.HP.LEARNING_RATE, weight_decay=self.HP.WEIGHT_DECAY) else: raise ValueError("Optimizer not defined") if self.HP.LR_SCHEDULE: scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="max") self.scheduler = scheduler if self.HP.LOAD_WEIGHTS: ExpUtils.print_verbose(self.HP, "Loading weights ... ({})".format(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH))) load_model(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH)) if self.HP.RESET_LAST_LAYER: # net.conv_5 = conv2d(self.HP.UNET_NR_FILT, self.HP.NR_OF_CLASSES, kernel_size=1, stride=1, padding=0, bias=True).to(device) net.conv_5 = nn.Conv2d(self.HP.UNET_NR_FILT, self.HP.NR_OF_CLASSES, kernel_size=1, stride=1, padding=0, bias=True).to(device) self.train = train self.predict = test self.get_probs = predict self.save_model = save_model self.load_model = load_model self.print_current_lr = print_current_lr
def generate_train_batch(self): subjects = self._data[0] subject_idx = int(random.uniform(0, len(subjects))) # len(subjects)-1 not needed because int always rounds to floor for i in range(20): try: if np.random.random() < 0.5: data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() else: data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data() # rnd_choice = np.random.random() # if rnd_choice < 0.33: # data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() # elif rnd_choice < 0.66: # data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data() # else: # data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "12g_125mm_peaks.nii.gz")).get_data() seg = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], self.HP.LABELS_FILENAME + ".nii.gz")).get_data() break except IOError: ExpUtils.print_and_save(self.HP, "\n\nWARNING: Could not load file. Trying again in 20s (Try number: " + str(i) + ").\n\n") ExpUtils.print_and_save(self.HP, "Sleeping 20s") sleep(20) # ExpUtils.print_and_save(self.HP, "Successfully loaded input.") data = np.nan_to_num(data) # Needed otherwise not working seg = np.nan_to_num(seg) data = DatasetUtils.scale_input_to_unet_shape(data, self.HP.DATASET, self.HP.RESOLUTION) # (x, y, z, channels) if self.HP.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]: # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask to the lower resolution seg = DatasetUtils.scale_input_to_unet_shape(seg, "HCP", self.HP.RESOLUTION) else: seg = DatasetUtils.scale_input_to_unet_shape(seg, self.HP.DATASET, self.HP.RESOLUTION) # (x, y, z, classes) slice_idxs = np.random.choice(data.shape[0], self.BATCH_SIZE, False, None) # Randomly sample slice orientation slice_direction = int(round(random.uniform(0,2))) if slice_direction == 0: y = seg[slice_idxs, :, :].astype(self.HP.LABELS_TYPE) y = np.array(y).transpose(0, 3, 1, 2) # nr_classes channel has to be before with and height for DataAugmentation (bs, nr_of_classes, x, y) elif slice_direction == 1: y = seg[:, slice_idxs, :].astype(self.HP.LABELS_TYPE) y = np.array(y).transpose(1, 3, 0, 2) elif slice_direction == 2: y = seg[:, :, slice_idxs].astype(self.HP.LABELS_TYPE) y = np.array(y).transpose(2, 3, 0, 1) sw = 5 #slice_window (only odd numbers allowed) pad = int((sw-1) / 2) data_pad = np.zeros((data.shape[0]+sw-1, data.shape[1]+sw-1, data.shape[2]+sw-1, data.shape[3])).astype(data.dtype) data_pad[pad:-pad, pad:-pad, pad:-pad, :] = data #padded with two slices of zeros on all sides batch=[] for s_idx in slice_idxs: if slice_direction == 0: #(s_idx+2)-2:(s_idx+2)+3 = s_idx:s_idx+5 x = data_pad[s_idx:s_idx+sw:, pad:-pad, pad:-pad, :].astype(np.float32) # (5, y, z, channels) x = np.array(x).transpose(0, 3, 1, 2) # channels dim has to be before width and height for Unet (but after batches) x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) # (5*channels, y, z) batch.append(x) elif slice_direction == 1: x = data_pad[pad:-pad, s_idx:s_idx+sw, pad:-pad, :].astype(np.float32) # (5, y, z, channels) x = np.array(x).transpose(1, 3, 0, 2) # channels dim has to be before width and height for Unet (but after batches) x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) # (5*channels, y, z) batch.append(x) elif slice_direction == 2: x = data_pad[pad:-pad, pad:-pad, s_idx:s_idx+sw, :].astype(np.float32) # (5, y, z, channels) x = np.array(x).transpose(2, 3, 0, 1) # channels dim has to be before width and height for Unet (but after batches) x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) # (5*channels, y, z) batch.append(x) data_dict = {"data": np.array(batch), # (batch_size, channels, x, y, [z]) "seg": y} # (batch_size, channels, x, y, [z]) return data_dict
class HP: EXP_MULTI_NAME = "" #CV Parent Dir name # leave empty for Single Bundle Experiment EXP_NAME = "HCP_TEST" # HCP_TEST MODEL = "UNet_Pytorch" # UNet_Lasagne / UNet_Pytorch EXPERIMENT_TYPE = "tract_segmentation" # tract_segmentation / endings_segmentation / dm_regression / peak_regression NUM_EPOCHS = 250 DATA_AUGMENTATION = False DAUG_SCALE = True DAUG_NOISE = True DAUG_ELASTIC_DEFORM = True DAUG_RESAMPLE = True DAUG_ROTATE = False DAUG_MIRROR = False DAUG_FLIP_PEAKS = False DAUG_INFO = "Elastic(90,120)(9,11) - Scale(0.9, 1.5) - CenterDist60 - DownsampScipy(0.5,1) - Gaussian(0,0.05) - Rotate(-0.8,0.8)" DATASET = "HCP" # HCP / HCP_32g / Schizo RESOLUTION = "1.25mm" # 1.25mm (/ 2.5mm) FEATURES_FILENAME = "12g90g270g" # 12g90g270g / 270g_125mm_xyz / 270g_125mm_peaks / 90g_125mm_peaks / 32g_25mm_peaks / 32g_25mm_xyz LABELS_FILENAME = "" # autofilled #"bundle_peaks/CA" #IMPORTANT: Adapt BatchGen if 808080 # bundle_masks / bundle_masks_72 / bundle_masks_dm / bundle_peaks #Only used when using DataManagerNifti LOSS_FUNCTION = "default" # default / soft_batch_dice OPTIMIZER = "Adamax" CLASSES = "All" # All / 11 / 20 / CST_right NR_OF_GRADIENTS = 9 NR_OF_CLASSES = len(ExpUtils.get_bundle_names(CLASSES)[1:]) # NR_OF_CLASSES = 3 * len(ExpUtils.get_bundle_names(CLASSES)[1:]) INPUT_DIM = (144, 144) # (80, 80) / (144, 144) LOSS_WEIGHT = 1 # 1: no weighting LOSS_WEIGHT_LEN = -1 # -1: constant over all epochs SLICE_DIRECTION = "y" # x, y, z (combined needs z) TRAINING_SLICE_DIRECTION = "xyz" # y / xyz INFO = "-" # Dropout, Deconv, 11bundles, LeakyRelu, PeakDiceThres=0.9 BATCH_NORM = False WEIGHT_DECAY = 0 USE_DROPOUT = False DROPOUT_SAMPLING = False # DATASET_FOLDER = "HCP_batches/270g_125mm_bundle_peaks_Y_subset" # HCP / HCP_batches/XXX / TRACED / HCP_fusion_npy_270g_125mm / HCP_fusion_npy_32g_25mm # DATASET_FOLDER = "HCP_batches/270g_125mm_bundle_peaks_XYZ" DATASET_FOLDER = "HCP" # HCP / Schizo LABELS_FOLDER = "bundle_masks" # bundle_masks / bundle_masks_dm MULTI_PARENT_PATH = join(C.EXP_PATH, EXP_MULTI_NAME) EXP_PATH = join(C.EXP_PATH, EXP_MULTI_NAME, EXP_NAME) # default path BATCH_SIZE = 47 #30/44 #max: #Peak Prediction: 44 #Pytorch: 50 #Lasagne: 56 #Lasagne combined: 42 #Pytorch UpSample: 56 #Pytorch_SE_r16: 45 #Pytorch_SE_r64: 45 LEARNING_RATE = 0.001 # 0.002 #LR find: 0.000143 ? # 0.001 LR_SCHEDULE = False UNET_NR_FILT = 64 LOAD_WEIGHTS = False # WEIGHTS_PATH = join(C.EXP_PATH, "HCP100_45B_UNet_x_DM_lr002_slope2_dec992_ep800/best_weights_ep64.npz") # Can be absolute path or relative like "exp_folder/weights.npz" WEIGHTS_PATH = "" # if empty string: autoloading the best_weights in get_best_weights_path() TYPE = "single_direction" # single_direction / combined CV_FOLD = 0 VALIDATE_SUBJECTS = [] TRAIN_SUBJECTS = [] TEST_SUBJECTS = [] TRAIN = True TEST = True SEGMENT = False GET_PROBS = False OUTPUT_MULTIPLE_FILES = False RESET_LAST_LAYER = False # Peak_regression specific PEAK_DICE_THR = [0.95] PEAK_DICE_LEN_THR = 0.05 FLIP_OUTPUT_PEAKS = True # flip peaks along z axis to make them compatible with MITK # For TractSeg.py application PREDICT_IMG = False PREDICT_IMG_OUTPUT = None TRACTSEG_DIR = "tractseg_output" KEEP_INTERMEDIATE_FILES = False CSD_RESOLUTION = "LOW" # HIGH / LOW #Unimportant / rarly changed: LABELS_TYPE = np.int16 # Binary: np.int16, Regression: np.float32 THRESHOLD = 0.5 # Binary: 0.5, Regression: 0.01 ? TEST_TIME_DAUG = False USE_VISLOGGER = False #only works with Python 3 SAVE_WEIGHTS = True SEG_INPUT = "Peaks" # Gradients/ Peaks NR_SLICES = 1 # adapt manually: NR_OF_GRADIENTS in UNet.py and get_batch... in train() and in get_seg_prediction() PRINT_FREQ = 20 #20 NORMALIZE_DATA = True NORMALIZE_PER_CHANNEL = False BEST_EPOCH = 0 VERBOSE = True CALC_F1 = True
def print_current_lr(): for param_group in optimizer.param_groups: ExpUtils.print_and_save( self.HP, "current learning rate: {}".format(param_group['lr']))
def create_network(self): # torch.backends.cudnn.benchmark = True #not faster def train(X, y): X = torch.from_numpy(X.astype(np.float32)) y = torch.from_numpy(y.astype(np.float32)) if torch.cuda.is_available(): X, y = Variable(X.cuda()), Variable(y.cuda()) # X: (bs, features, x, y) y: (bs, classes, x, y) else: X, y = Variable(X), Variable(y) optimizer.zero_grad() net.train() outputs, intermediate = net(X) # forward # outputs: (bs, classes, x, y) loss = criterion(outputs, y) # loss = PytorchUtils.soft_dice(outputs, y) loss.backward() # backward optimizer.step() # optimise f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True) if self.HP.USE_VISLOGGER: probs = outputs.data.cpu().numpy().transpose(0,2,3,1) # (bs, x, y, classes) else: probs = None #faster return loss.data[0], probs, f1, intermediate def test(X, y): X = torch.from_numpy(X.astype(np.float32)) y = torch.from_numpy(y.astype(np.float32)) if torch.cuda.is_available(): X, y = Variable(X.cuda(), volatile=True), Variable(y.cuda(), volatile=True) else: X, y = Variable(X, volatile=True), Variable(y, volatile=True) net.train(False) outputs = net(X) # forward loss = criterion(outputs, y) # loss = PytorchUtils.soft_dice(outputs, y) f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True) # probs = outputs.data.cpu().numpy().transpose(0,2,3,1) # (bs, x, y, classes) probs = None # faster return loss.data[0], probs, f1 def predict(X): X = torch.from_numpy(X.astype(np.float32)) if torch.cuda.is_available(): X = Variable(X.cuda(), volatile=True) else: X = Variable(X, volatile=True) net.train(False) outputs = net(X) # forward probs = outputs.data.cpu().numpy().transpose(0,2,3,1) # (bs, x, y, classes) return probs def save_model(metrics, epoch_nr): max_f1_idx = np.argmax(metrics["f1_macro_validate"]) max_f1 = np.max(metrics["f1_macro_validate"]) if epoch_nr == max_f1_idx and max_f1 > 0.01: # saving to network drives takes 5s (to local only 0.5s) -> do not save so often print(" Saving weights...") for fl in glob.glob(join(self.HP.EXP_PATH, "best_weights_ep*")): # remove weights from previous epochs os.remove(fl) try: #Actually is a pkl not a npz PytorchUtils.save_checkpoint(join(self.HP.EXP_PATH, "best_weights_ep" + str(epoch_nr) + ".npz"), unet=net) except IOError: print("\nERROR: Could not save weights because of IO Error\n") self.HP.BEST_EPOCH = epoch_nr def load_model(path): PytorchUtils.load_checkpoint(path, unet=net) def print_current_lr(): for param_group in optimizer.param_groups: ExpUtils.print_and_save(self.HP, "current learning rate: {}".format(param_group['lr'])) if self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "single_direction": NR_OF_GRADIENTS = 9 # NR_OF_GRADIENTS = 9 * 5 elif self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "combined": NR_OF_GRADIENTS = 3*self.HP.NR_OF_CLASSES else: NR_OF_GRADIENTS = 33 if torch.cuda.is_available(): net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT).cuda() else: net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT) # net = nn.DataParallel(net, device_ids=[0,1]) if self.HP.TRAIN: ExpUtils.print_and_save(self.HP, str(net), only_log=True) criterion = nn.BCEWithLogitsLoss() optimizer = Adamax(net.parameters(), lr=self.HP.LEARNING_RATE) # optimizer = Adam(net.parameters(), lr=self.HP.LEARNING_RATE) #very slow (half speed of Adamax) -> strange # scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1) # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="max") if self.HP.LOAD_WEIGHTS: ExpUtils.print_verbose(self.HP, "Loading weights ... ({})".format(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH))) load_model(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH)) #plot feature weights # weights = list(list(net.children())[0].children())[0].weight.cpu().data.numpy() # sequential -> conv2d # (64, 9, 3, 3) # weights = weights[:, 0:1, :, :] # select one input channel to plot # (64, 1, 3, 3) # weights = (weights*100).astype(np.uint8) # can not plot negative values (and if float only 0-1 allowed) -> not good: we remove negatives # plot_kernels(weights) self.train = train self.predict = test self.get_probs = predict self.save_model = save_model self.load_model = load_model self.print_current_lr = print_current_lr # self.scheduler = scheduler
def train(self, HP): if HP.USE_VISLOGGER: nvl = Nvl(name="Training") ExpUtils.print_and_save(HP, socket.gethostname()) epoch_times = [] nr_of_updates = 0 metrics = {} for type in ["train", "test", "validate"]: metrics_new = { "loss_" + type: [0], "f1_macro_" + type: [0], } metrics = dict(list(metrics.items()) + list(metrics_new.items())) for epoch_nr in range(HP.NUM_EPOCHS): start_time = time.time() # current_lr = HP.LEARNING_RATE * (HP.LR_DECAY ** epoch_nr) # current_lr = HP.LEARNING_RATE batch_gen_time = 0 data_preparation_time = 0 network_time = 0 metrics_time = 0 saving_time = 0 plotting_time = 0 batch_nr = {"train": 0, "test": 0, "validate": 0} if HP.LOSS_WEIGHT_LEN == -1: weight_factor = float(HP.LOSS_WEIGHT) else: if epoch_nr < HP.LOSS_WEIGHT_LEN: # weight_factor = -(9./100.) * epoch_nr + 10. #ep0: 10 -> linear decrease -> ep100: 1 weight_factor = -((HP.LOSS_WEIGHT - 1) / float( HP.LOSS_WEIGHT_LEN)) * epoch_nr + float(HP.LOSS_WEIGHT) # weight_factor = -((HP.LOSS_WEIGHT-5)/float(HP.LOSS_WEIGHT_LEN)) * epoch_nr + float(HP.LOSS_WEIGHT) else: weight_factor = 1. # weight_factor = 5. for type in ["train", "test", "validate"]: print_loss = [] start_time_batch_gen = time.time() batch_generator = self.dataManager.get_batches( batch_size=HP.BATCH_SIZE, type=type, subjects=getattr(HP, type.upper() + "_SUBJECTS")) batch_gen_time = time.time() - start_time_batch_gen # print("batch_gen_time: {}s".format(batch_gen_time)) print("Start looping batches...") start_time_batch_part = time.time() for batch in batch_generator: #getting next batch takes around 0.14s -> second largest Time part after UNet! start_time_data_preparation = time.time() batch_nr[type] += 1 x = batch["data"] # (bs, nr_of_channels, x, y) y = batch["seg"] # (bs, nr_of_classes, x, y) # since using new BatchGenerator y is not int anymore but float -> would be good for Pytorch but not Lasagne # y = y.astype(HP.LABELS_TYPE) #for bundle_peaks regression: is already float -> saves 0.2s/batch if left out data_preparation_time += time.time( ) - start_time_data_preparation # self.model.learning_rate.set_value(np.float32(current_lr)) start_time_network = time.time() if type == "train": nr_of_updates += 1 loss, probs, f1 = self.model.train( x, y, weight_factor=weight_factor ) # probs: # (bs, x, y, nrClasses) # loss, probs, f1, intermediate = self.model.train(x, y) elif type == "validate": loss, probs, f1 = self.model.predict( x, y, weight_factor=weight_factor) elif type == "test": loss, probs, f1 = self.model.predict( x, y, weight_factor=weight_factor) network_time += time.time() - start_time_network start_time_metrics = time.time() if HP.CALC_F1: if HP.LABELS_TYPE == np.int16: metrics = MetricUtils.calculate_metrics( metrics, None, None, loss, f1=np.mean(f1), type=type, threshold=HP.THRESHOLD) else: #Regression #Following two lines increase metrics_time by 30s (without < 1s); time per batch increases by 1.5s by these lines # y_flat = y.transpose(0, 2, 3, 1) # (bs, x, y, nr_of_classes) # y_flat = np.reshape(y_flat, (-1, y_flat.shape[-1])) # (bs*x*y, nr_of_classes) # metrics = MetricUtils.calculate_metrics(metrics, y_flat, probs, loss, f1=np.mean(f1), type=type, threshold=HP.THRESHOLD, # f1_per_bundle={"CA": f1[5], "FX_left": f1[23], "FX_right": f1[24]}) #Numpy # y_right_order = y.transpose(0, 2, 3, 1) # (bs, x, y, nr_of_classes) # peak_f1 = MetricUtils.calc_peak_dice(HP, probs, y_right_order) # peak_f1_mean = np.array([s for s in peak_f1.values()]).mean() #Pytorch peak_f1_mean = np.array([ s for s in list(f1.values()) ]).mean() #if f1 for multiple bundles metrics = MetricUtils.calculate_metrics( metrics, None, None, loss, f1=peak_f1_mean, type=type, threshold=HP.THRESHOLD) #Pytorch 2 F1 # peak_f1_mean_a = np.array([s for s in f1[0].values()]).mean() # peak_f1_mean_b = np.array([s for s in f1[1].values()]).mean() # metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=peak_f1_mean_a, type=type, threshold=HP.THRESHOLD, # f1_per_bundle={"LenF1": peak_f1_mean_b}) #Single Bundle # metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=f1["CST_right"][0], type=type, threshold=HP.THRESHOLD, # f1_per_bundle={"Thr1": f1["CST_right"][1], "Thr2": f1["CST_right"][2]}) # metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=f1["CST_right"], type=type, threshold=HP.THRESHOLD) else: metrics = MetricUtils.calculate_metrics_onlyLoss( metrics, loss, type=type) metrics_time += time.time() - start_time_metrics print_loss.append(loss) if batch_nr[type] % HP.PRINT_FREQ == 0: time_batch_part = time.time() - start_time_batch_part start_time_batch_part = time.time() ExpUtils.print_and_save( HP, "{} Ep {}, Sp {}, loss {}, t print {}s, t batch {}s" .format(type, epoch_nr, batch_nr[type] * HP.BATCH_SIZE, round(np.array(print_loss).mean(), 6), round(time_batch_part, 3), round(time_batch_part / HP.PRINT_FREQ, 3))) print_loss = [] if HP.USE_VISLOGGER: x_norm = (x - x.min()) / (x.max() - x.min()) nvl.show_images( x_norm[0:1, :, :, :].transpose((1, 0, 2, 3)), name="input batch", title="Input batch") #all channels of one batch probs_shaped = probs[:, :, :, 15:16].transpose( (0, 3, 1, 2)) # (bs, 1, x, y) probs_shaped_bin = (probs_shaped > 0.5).astype( np.int16) nvl.show_images(probs_shaped, name="predictions", title="Predictions Probmap") # nvl.show_images(probs_shaped_bin, name="predictions_binary", title="Predictions Binary") # Show GT and Prediction in one image (bundle: CST) # GREEN: GT; RED: prediction (FP); YELLOW: prediction (TP) combined = np.zeros( (y.shape[0], 3, y.shape[2], y.shape[3])) combined[:, 0:1, :, :] = probs_shaped_bin #Red combined[:, 1:2, :, :] = y[:, 15:16, :, :] #Green nvl.show_images(combined, name="predictions_combined", title="Combined") #Show feature activations contr_1_2 = intermediate[2].data.cpu().numpy( ) # (bs, nr_feature_channels=64, x, y) contr_1_2 = contr_1_2[0:1, :, :, :].transpose( (1, 0, 2, 3)) # (nr_feature_channels=64, 1, x, y) contr_1_2 = (contr_1_2 - contr_1_2.min()) / ( contr_1_2.max() - contr_1_2.min()) nvl.show_images(contr_1_2, name="contr_1_2", title="contr_1_2") # Show feature activations contr_3_2 = intermediate[1].data.cpu().numpy( ) # (bs, nr_feature_channels=64, x, y) contr_3_2 = contr_3_2[0:1, :, :, :].transpose( (1, 0, 2, 3)) # (nr_feature_channels=64, 1, x, y) contr_3_2 = (contr_3_2 - contr_3_2.min()) / ( contr_3_2.max() - contr_3_2.min()) nvl.show_images(contr_3_2, name="contr_3_2", title="contr_3_2") # Show feature activations deconv_2 = intermediate[0].data.cpu().numpy( ) # (bs, nr_feature_channels=64, x, y) deconv_2 = deconv_2[0:1, :, :, :].transpose( (1, 0, 2, 3)) # (nr_feature_channels=64, 1, x, y) deconv_2 = (deconv_2 - deconv_2.min()) / ( deconv_2.max() - deconv_2.min()) nvl.show_images(deconv_2, name="deconv_2", title="deconv_2") nvl.show_value(float(loss), name="loss") nvl.show_value(float(np.mean(f1)), name="f1") ################################### # Post Training tasks (each epoch) ################################### #Adapt LR # self.model.scheduler.step() # self.model.scheduler.step(np.mean(f1)) # self.model.print_current_lr() # Average loss per batch over entire epoch metrics = MetricUtils.normalize_last_element(metrics, batch_nr["train"], type="train") metrics = MetricUtils.normalize_last_element(metrics, batch_nr["validate"], type="validate") metrics = MetricUtils.normalize_last_element(metrics, batch_nr["test"], type="test") print(" Epoch {}, Average Epoch loss = {}".format( epoch_nr, metrics["loss_train"][-1])) print(" Epoch {}, nr_of_updates {}".format( epoch_nr, nr_of_updates)) # Save Weights start_time_saving = time.time() if HP.SAVE_WEIGHTS: self.model.save_model(metrics, epoch_nr) saving_time += time.time() - start_time_saving # Create Plots start_time_plotting = time.time() pickle.dump( metrics, open(join(HP.EXP_PATH, "metrics.pkl"), "wb") ) # wb -> write (override) and binary (binary only needed on windows, on unix also works without) # for loading: pickle.load(open("metrics.pkl", "rb")) ExpUtils.create_exp_plot(metrics, HP.EXP_PATH, HP.EXP_NAME) ExpUtils.create_exp_plot(metrics, HP.EXP_PATH, HP.EXP_NAME, without_first_epochs=True) plotting_time += time.time() - start_time_plotting epoch_time = time.time() - start_time epoch_times.append(epoch_time) ExpUtils.print_and_save( HP, " Epoch {}, time total {}s".format(epoch_nr, epoch_time)) ExpUtils.print_and_save( HP, " Epoch {}, time UNet: {}s".format(epoch_nr, network_time)) ExpUtils.print_and_save( HP, " Epoch {}, time metrics: {}s".format(epoch_nr, metrics_time)) ExpUtils.print_and_save( HP, " Epoch {}, time saving files: {}s".format( epoch_nr, saving_time)) ExpUtils.print_and_save(HP, str(datetime.datetime.now())) # Adding next Epoch if epoch_nr < HP.NUM_EPOCHS - 1: metrics = MetricUtils.add_empty_element(metrics) #################################### # After all epochs ################################### with open(join(HP.EXP_PATH, "Hyperparameters.txt"), "a") as f: # a for append f.write("\n\n") f.write("Average Epoch time: {}s".format( sum(epoch_times) / float(len(epoch_times)))) return metrics
def print_current_lr(): for param_group in optimizer.param_groups: ExpUtils.print_and_save(self.HP, "current learning rate: {}".format(param_group['lr']))
def calc_peak_length_dice_pytorch(HP, y_pred, y_true, max_angle_error=[0.9], max_length_error=0.1): ''' Ca :param y_pred: :param y_true: :param max_angle_error: 0.7 -> angle error of 45° or less; 0.9 -> angle error of 23° or less Can be list with several values -> calculate for several thresholds :return: ''' import torch from tractseg.libs.PytorchEinsum import einsum from tractseg.libs.PytorchUtils import PytorchUtils y_true = y_true.permute(0, 2, 3, 1) y_pred = y_pred.permute(0, 2, 3, 1) def angle_last_dim(a, b): ''' Calculate the angle between two nd-arrays (array of vectors) along the last dimension without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90° np.arccos -> returns degree in pi (90°: 0.5*pi) return: one dimension less then input ''' return torch.abs( einsum('abcd,abcd->abc', a, b) / (torch.norm(a, 2., -1) * torch.norm(b, 2, -1) + 1e-7)) #Single threshold score_per_bundle = {} bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:] for idx, bundle in enumerate(bundles): # if bundle == "CST_right": y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous() y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3].contiguous() # [x,y,z,3] angles = angle_last_dim(y_pred_bund, y_true_bund) lenghts_pred = torch.norm(y_pred_bund, 2., -1) lengths_true = torch.norm(y_true_bund, 2, -1) lengths_binary = torch.abs(lenghts_pred - lengths_true) < ( max_length_error * lengths_true) lengths_binary = lengths_binary.view(-1) gt_binary = y_true_bund.sum(dim=-1) > 0 gt_binary = gt_binary.view(-1) # [bs*x*y] angles_binary = angles > max_angle_error[0] angles_binary = angles_binary.view(-1) combined = lengths_binary * angles_binary f1 = PytorchUtils.f1_score_binary(gt_binary, combined) score_per_bundle[bundle] = f1 return score_per_bundle
def run_tractseg(data, output_type="tract_segmentation", input_type="peaks", single_orientation=False, verbose=False, dropout_sampling=False, threshold=0.5, bundle_specific_threshold=False, get_probs=False): ''' Run TractSeg :param data: input peaks (4D numpy array with shape [x,y,z,9]) :param output_type: "tract_segmentation" | "endings_segmentation" | "TOM" | "dm_regression" :param input_type: "peaks" :param verbose: show debugging infos :param dropout_sampling: create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142) :param threshold: Threshold for converting probability map to binary map :param bundle_specific_threshold: Threshold is lower for some bundles which need more sensitivity (CA, CST, FX) :param get_probs: Output raw probability map instead of binary map :return: 4D numpy array with the output of tractseg for tract_segmentation: [x,y,z,nr_of_bundles] for endings_segmentation: [x,y,z,2*nr_of_bundles] for TOM: [x,y,z,3*nr_of_bundles] ''' start_time = time.time() config = get_config_name(input_type, output_type) HP = getattr(importlib.import_module("tractseg.config.PretrainedModels." + config), "HP")() HP.VERBOSE = verbose HP.TRAIN = False HP.TEST = False HP.SEGMENT = False HP.GET_PROBS = get_probs HP.LOAD_WEIGHTS = True HP.DROPOUT_SAMPLING = dropout_sampling HP.THRESHOLD = threshold if bundle_specific_threshold: HP.GET_PROBS = True if input_type == "peaks": if HP.EXPERIMENT_TYPE == "tract_segmentation" and HP.DROPOUT_SAMPLING: HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_dropout_v1.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_12g90g270g_125mm_DAugAll_Dropout", "best_weights_ep114.npz") elif HP.EXPERIMENT_TYPE == "tract_segmentation": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_T1_12g90g270g_125mm_DAugAll", "best_weights_ep126.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888", "best_weights_ep247.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888_SchizoFineT_lr001", "best_weights_ep186.npz") elif HP.EXPERIMENT_TYPE == "endings_segmentation": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v2.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "EndingsSeg_12g90g270g_125mm_DAugAll", "best_weights_ep16.npz") elif HP.EXPERIMENT_TYPE == "peak_regression": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "x_Pretrained_TractSeg_Models/Peaks20_12g90g270g_125mm_DAugSimp_constW5", "best_weights_ep441.npz") #more oversegmentation with DAug elif HP.EXPERIMENT_TYPE == "dm_regression": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_dm_regression_v1.npz") # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "DmReg_12g90g270g_125mm_DAugAll_Ubuntu", "best_weights_ep80.npz") elif input_type == "T1": if HP.EXPERIMENT_TYPE == "tract_segmentation": # HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz") HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz") elif HP.EXPERIMENT_TYPE == "endings_segmentation": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v1.npz") elif HP.EXPERIMENT_TYPE == "peak_regression": HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz") print("Loading weights from: {}".format(HP.WEIGHTS_PATH)) if HP.EXPERIMENT_TYPE == "peak_regression": HP.NR_OF_CLASSES = 3*len(ExpUtils.get_bundle_names(HP.CLASSES)[1:]) else: HP.NR_OF_CLASSES = len(ExpUtils.get_bundle_names(HP.CLASSES)[1:]) if HP.VERBOSE: print("Hyperparameters:") ExpUtils.print_HPs(HP) Utils.download_pretrained_weights(experiment_type=HP.EXPERIMENT_TYPE, dropout_sampling=HP.DROPOUT_SAMPLING) data = np.nan_to_num(data) # brain_mask = ImgUtils.simple_brain_mask(data) # if HP.VERBOSE: # nib.save(nib.Nifti1Image(brain_mask, np.eye(4)), "otsu_brain_mask_DEBUG.nii.gz") if input_type == "T1": data = np.reshape(data, (data.shape[0], data.shape[1], data.shape[2], 1)) data, seg_None, bbox, original_shape = DatasetUtils.crop_to_nonzero(data) data, transformation = DatasetUtils.pad_and_scale_img_to_square_img(data, target_size=HP.INPUT_DIM[0]) model = BaseModel(HP) if HP.EXPERIMENT_TYPE == "tract_segmentation" or HP.EXPERIMENT_TYPE == "endings_segmentation" or HP.EXPERIMENT_TYPE == "dm_regression": if single_orientation: # mainly needed for testing because of less RAM requirements dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data) trainerSingle = Trainer(model, dataManagerSingle) if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS: seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True) else: seg, img_y = trainerSingle.get_seg_single_img(HP, probs=False, scale_to_world_shape=False, only_prediction=True) else: seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True) if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS: seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True) else: seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=False) elif HP.EXPERIMENT_TYPE == "peak_regression": dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data) trainerSingle = Trainer(model, dataManagerSingle) seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True) if bundle_specific_threshold: seg = ImgUtils.remove_small_peaks_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:], len_thr=0.3) else: seg = ImgUtils.remove_small_peaks(seg, len_thr=0.3) # set lower for more sensitivity #3 dir for Peaks -> not working (?) # seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True) # seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True) if bundle_specific_threshold and HP.EXPERIMENT_TYPE == "tract_segmentation": seg = ImgUtils.probs_to_binary_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:]) #remove following two lines to keep super resolution seg = DatasetUtils.cut_and_scale_img_back_to_original_img(seg, transformation) seg = DatasetUtils.add_original_zero_padding_again(seg, bbox, original_shape, HP.NR_OF_CLASSES) ExpUtils.print_verbose(HP, "Took {}s".format(round(time.time() - start_time, 2))) return seg
def train(self, HP): if HP.USE_VISLOGGER: try: from trixi.logger.visdom import PytorchVisdomLogger except ImportError: pass trixi = PytorchVisdomLogger(port=8080, auto_start=True) ExpUtils.print_and_save(HP, socket.gethostname()) epoch_times = [] nr_of_updates = 0 metrics = {} for type in ["train", "test", "validate"]: metrics_new = { "loss_" + type: [0], "f1_macro_" + type: [0], } metrics = dict(list(metrics.items()) + list(metrics_new.items())) for epoch_nr in range(HP.NUM_EPOCHS): start_time = time.time() # current_lr = HP.LEARNING_RATE * (HP.LR_DECAY ** epoch_nr) # current_lr = HP.LEARNING_RATE batch_gen_time = 0 data_preparation_time = 0 network_time = 0 metrics_time = 0 saving_time = 0 plotting_time = 0 batch_nr = { "train": 0, "test": 0, "validate": 0 } if HP.LOSS_WEIGHT_LEN == -1: weight_factor = float(HP.LOSS_WEIGHT) else: if epoch_nr < HP.LOSS_WEIGHT_LEN: # weight_factor = -(9./100.) * epoch_nr + 10. #ep0: 10 -> linear decrease -> ep100: 1 weight_factor = -((HP.LOSS_WEIGHT-1)/float(HP.LOSS_WEIGHT_LEN)) * epoch_nr + float(HP.LOSS_WEIGHT) # weight_factor = -((HP.LOSS_WEIGHT-5)/float(HP.LOSS_WEIGHT_LEN)) * epoch_nr + float(HP.LOSS_WEIGHT) else: weight_factor = 1. # weight_factor = 5. for type in ["train", "test", "validate"]: print_loss = [] start_time_batch_gen = time.time() batch_generator = self.dataManager.get_batches(batch_size=HP.BATCH_SIZE, type=type, subjects=getattr(HP, type.upper() + "_SUBJECTS")) batch_gen_time = time.time() - start_time_batch_gen # print("batch_gen_time: {}s".format(batch_gen_time)) print("Start looping batches...") start_time_batch_part = time.time() for batch in batch_generator: #getting next batch takes around 0.14s -> second largest Time part after mode! start_time_data_preparation = time.time() batch_nr[type] += 1 x = batch["data"] # (bs, nr_of_channels, x, y) y = batch["seg"] # (bs, nr_of_classes, x, y) # since using new BatchGenerator y is not int anymore but float -> would be good for Pytorch but not Lasagne # y = y.astype(HP.LABELS_TYPE) #for bundle_peaks regression: is already float -> saves 0.2s/batch if left out data_preparation_time += time.time() - start_time_data_preparation # self.model.learning_rate.set_value(np.float32(current_lr)) start_time_network = time.time() if type == "train": nr_of_updates += 1 loss, probs, f1 = self.model.train(x, y, weight_factor=weight_factor) # probs: # (bs, x, y, nrClasses) # loss, probs, f1, intermediate = self.model.train(x, y) elif type == "validate": loss, probs, f1 = self.model.predict(x, y, weight_factor=weight_factor) elif type == "test": loss, probs, f1 = self.model.predict(x, y, weight_factor=weight_factor) network_time += time.time() - start_time_network start_time_metrics = time.time() if HP.CALC_F1: if HP.EXPERIMENT_TYPE == "peak_regression": #Following two lines increase metrics_time by 30s (without < 1s); time per batch increases by 1.5s by these lines # y_flat = y.transpose(0, 2, 3, 1) # (bs, x, y, nr_of_classes) # y_flat = np.reshape(y_flat, (-1, y_flat.shape[-1])) # (bs*x*y, nr_of_classes) # metrics = MetricUtils.calculate_metrics(metrics, y_flat, probs, loss, f1=np.mean(f1), type=type, threshold=HP.THRESHOLD, # f1_per_bundle={"CA": f1[5], "FX_left": f1[23], "FX_right": f1[24]}) #Numpy # y_right_order = y.transpose(0, 2, 3, 1) # (bs, x, y, nr_of_classes) # peak_f1 = MetricUtils.calc_peak_dice(HP, probs, y_right_order) # peak_f1_mean = np.array([s for s in peak_f1.values()]).mean() #Pytorch peak_f1_mean = np.array([s for s in list(f1.values())]).mean() #if f1 for multiple bundles metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=peak_f1_mean, type=type, threshold=HP.THRESHOLD) #Pytorch 2 F1 # peak_f1_mean_a = np.array([s for s in f1[0].values()]).mean() # peak_f1_mean_b = np.array([s for s in f1[1].values()]).mean() # metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=peak_f1_mean_a, type=type, threshold=HP.THRESHOLD, # f1_per_bundle={"LenF1": peak_f1_mean_b}) #Single Bundle # metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=f1["CST_right"][0], type=type, threshold=HP.THRESHOLD, # f1_per_bundle={"Thr1": f1["CST_right"][1], "Thr2": f1["CST_right"][2]}) # metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=f1["CST_right"], type=type, threshold=HP.THRESHOLD) else: metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=np.mean(f1), type=type, threshold=HP.THRESHOLD) else: metrics = MetricUtils.calculate_metrics_onlyLoss(metrics, loss, type=type) metrics_time += time.time() - start_time_metrics print_loss.append(loss) if batch_nr[type] % HP.PRINT_FREQ == 0: time_batch_part = time.time() - start_time_batch_part start_time_batch_part = time.time() ExpUtils.print_and_save(HP, "{} Ep {}, Sp {}, loss {}, t print {}s, t batch {}s".format(type, epoch_nr, batch_nr[type] * HP.BATCH_SIZE, round(np.array(print_loss).mean(), 6), round(time_batch_part, 3), round(time_batch_part / HP.PRINT_FREQ, 3))) print_loss = [] if HP.USE_VISLOGGER: ExpUtils.plot_result_trixi(trixi, x, y, probs, loss, f1, epoch_nr) ################################### # Post Training tasks (each epoch) ################################### #Adapt LR if HP.LR_SCHEDULE: self.model.scheduler.step() # self.model.scheduler.step(np.mean(f1)) self.model.print_current_lr() # Average loss per batch over entire epoch metrics = MetricUtils.normalize_last_element(metrics, batch_nr["train"], type="train") metrics = MetricUtils.normalize_last_element(metrics, batch_nr["validate"], type="validate") metrics = MetricUtils.normalize_last_element(metrics, batch_nr["test"], type="test") print(" Epoch {}, Average Epoch loss = {}".format(epoch_nr, metrics["loss_train"][-1])) print(" Epoch {}, nr_of_updates {}".format(epoch_nr, nr_of_updates)) # Save Weights start_time_saving = time.time() if HP.SAVE_WEIGHTS: self.model.save_model(metrics, epoch_nr) saving_time += time.time() - start_time_saving # Create Plots start_time_plotting = time.time() pickle.dump(metrics, open(join(HP.EXP_PATH, "metrics.pkl"), "wb")) # wb -> write (override) and binary (binary only needed on windows, on unix also works without) # for loading: pickle.load(open("metrics.pkl", "rb")) ExpUtils.create_exp_plot(metrics, HP.EXP_PATH, HP.EXP_NAME) ExpUtils.create_exp_plot(metrics, HP.EXP_PATH, HP.EXP_NAME, without_first_epochs=True) plotting_time += time.time() - start_time_plotting epoch_time = time.time() - start_time epoch_times.append(epoch_time) ExpUtils.print_and_save(HP, " Epoch {}, time total {}s".format(epoch_nr, epoch_time)) ExpUtils.print_and_save(HP, " Epoch {}, time UNet: {}s".format(epoch_nr, network_time)) ExpUtils.print_and_save(HP, " Epoch {}, time metrics: {}s".format(epoch_nr, metrics_time)) ExpUtils.print_and_save(HP, " Epoch {}, time saving files: {}s".format(epoch_nr, saving_time)) ExpUtils.print_and_save(HP, str(datetime.datetime.now())) # Adding next Epoch if epoch_nr < HP.NUM_EPOCHS-1: metrics = MetricUtils.add_empty_element(metrics) #################################### # After all epochs ################################### with open(join(HP.EXP_PATH, "Hyperparameters.txt"), "a") as f: # a for append f.write("\n\n") f.write("Average Epoch time: {}s".format(sum(epoch_times) / float(len(epoch_times)))) return metrics
def calc_peak_dice_pytorch(HP, y_pred, y_true, max_angle_error=[0.9]): ''' Calculate angle between groundtruth and prediction and keep the voxels where angle is smaller than MAX_ANGLE_ERROR. From groundtruth generate a binary mask by selecting all voxels with len > 0. Calculate Dice from these 2 masks. -> Penalty on peaks outside of tract or if predicted peak=0 -> no penalty on very very small with right direction -> bad => Peak_dice can be high even if peaks inside of tract almost missing (almost 0) :param y_pred: :param y_true: :param max_angle_error: 0.7 -> angle error of 45° or less; 0.9 -> angle error of 23° or less Can be list with several values -> calculate for several thresholds :return: ''' import torch from tractseg.libs.PytorchEinsum import einsum from tractseg.libs.PytorchUtils import PytorchUtils y_true = y_true.permute(0, 2, 3, 1) y_pred = y_pred.permute(0, 2, 3, 1) def angle_last_dim(a, b): ''' Calculate the angle between two nd-arrays (array of vectors) along the last dimension without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90° np.arccos -> returns degree in pi (90°: 0.5*pi) return: one dimension less then input ''' return torch.abs( einsum('abcd,abcd->abc', a, b) / (torch.norm(a, 2., -1) * torch.norm(b, 2, -1) + 1e-7)) #Single threshold if len(max_angle_error) == 1: score_per_bundle = {} bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:] for idx, bundle in enumerate(bundles): # if bundle == "CST_right": y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous() y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3].contiguous() # [x,y,z,3] angles = angle_last_dim(y_pred_bund, y_true_bund) gt_binary = y_true_bund.sum(dim=-1) > 0 gt_binary = gt_binary.view(-1) # [bs*x*y] angles_binary = angles > max_angle_error[0] angles_binary = angles_binary.view(-1) f1 = PytorchUtils.f1_score_binary(gt_binary, angles_binary) score_per_bundle[bundle] = f1 return score_per_bundle #multiple thresholds else: score_per_bundle = {} bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:] for idx, bundle in enumerate(bundles): # if bundle == "CST_right": y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous() y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3].contiguous() # [x,y,z,3] angles = angle_last_dim(y_pred_bund, y_true_bund) gt_binary = y_true_bund.sum(dim=-1) > 0 gt_binary = gt_binary.view(-1) # [bs*x*y] score_per_bundle[bundle] = [] for threshold in max_angle_error: angles_binary = angles > threshold angles_binary = angles_binary.view(-1) f1 = PytorchUtils.f1_score_binary(gt_binary, angles_binary) score_per_bundle[bundle].append(f1) return score_per_bundle
def generate_train_batch(self): subjects = self._data[0] subject_idx = int( random.uniform(0, len(subjects)) ) # len(subjects)-1 not needed because int always rounds to floor for i in range(20): try: if np.random.random() < 0.5: data = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() else: data = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data() # rnd_choice = np.random.random() # if rnd_choice < 0.33: # data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() # elif rnd_choice < 0.66: # data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data() # else: # data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "12g_125mm_peaks.nii.gz")).get_data() seg = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], self.HP.LABELS_FILENAME + ".nii.gz")).get_data() break except IOError: ExpUtils.print_and_save( self.HP, "\n\nWARNING: Could not load file. Trying again in 20s (Try number: " + str(i) + ").\n\n") ExpUtils.print_and_save(self.HP, "Sleeping 20s") sleep(20) # ExpUtils.print_and_save(self.HP, "Successfully loaded input.") data = np.nan_to_num(data) # Needed otherwise not working seg = np.nan_to_num(seg) data = DatasetUtils.scale_input_to_unet_shape( data, self.HP.DATASET, self.HP.RESOLUTION) # (x, y, z, channels) if self.HP.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]: # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask to the lower resolution seg = DatasetUtils.scale_input_to_unet_shape( seg, "HCP", self.HP.RESOLUTION) else: seg = DatasetUtils.scale_input_to_unet_shape( seg, self.HP.DATASET, self.HP.RESOLUTION) # (x, y, z, classes) slice_idxs = np.random.choice(data.shape[0], self.BATCH_SIZE, False, None) # Randomly sample slice orientation slice_direction = int(round(random.uniform(0, 2))) if slice_direction == 0: y = seg[slice_idxs, :, :].astype(self.HP.LABELS_TYPE) y = np.array(y).transpose( 0, 3, 1, 2 ) # nr_classes channel has to be before with and height for DataAugmentation (bs, nr_of_classes, x, y) elif slice_direction == 1: y = seg[:, slice_idxs, :].astype(self.HP.LABELS_TYPE) y = np.array(y).transpose(1, 3, 0, 2) elif slice_direction == 2: y = seg[:, :, slice_idxs].astype(self.HP.LABELS_TYPE) y = np.array(y).transpose(2, 3, 0, 1) sw = 5 #slice_window (only odd numbers allowed) pad = int((sw - 1) / 2) data_pad = np.zeros( (data.shape[0] + sw - 1, data.shape[1] + sw - 1, data.shape[2] + sw - 1, data.shape[3])).astype(data.dtype) data_pad[ pad:-pad, pad:-pad, pad:-pad, :] = data #padded with two slices of zeros on all sides batch = [] for s_idx in slice_idxs: if slice_direction == 0: #(s_idx+2)-2:(s_idx+2)+3 = s_idx:s_idx+5 x = data_pad[s_idx:s_idx + sw:, pad:-pad, pad:-pad, :].astype( np.float32) # (5, y, z, channels) x = np.array(x).transpose( 0, 3, 1, 2 ) # channels dim has to be before width and height for Unet (but after batches) x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) # (5*channels, y, z) batch.append(x) elif slice_direction == 1: x = data_pad[pad:-pad, s_idx:s_idx + sw, pad:-pad, :].astype( np.float32) # (5, y, z, channels) x = np.array(x).transpose( 1, 3, 0, 2 ) # channels dim has to be before width and height for Unet (but after batches) x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) # (5*channels, y, z) batch.append(x) elif slice_direction == 2: x = data_pad[pad:-pad, pad:-pad, s_idx:s_idx + sw, :].astype( np.float32) # (5, y, z, channels) x = np.array(x).transpose( 2, 3, 0, 1 ) # channels dim has to be before width and height for Unet (but after batches) x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) # (5*channels, y, z) batch.append(x) data_dict = { "data": np.array(batch), # (batch_size, channels, x, y, [z]) "seg": y } # (batch_size, channels, x, y, [z]) return data_dict
def create_network(self): # torch.backends.cudnn.benchmark = True #not faster def train(X, y): X = torch.from_numpy(X.astype(np.float32)) y = torch.from_numpy(y.astype(np.float32)) if torch.cuda.is_available(): X, y = Variable(X.cuda()), Variable(y.cuda( )) # X: (bs, features, x, y) y: (bs, classes, x, y) else: X, y = Variable(X), Variable(y) optimizer.zero_grad() net.train() outputs = net(X) # forward # outputs: (bs, classes, x, y) loss = criterion(outputs, y) # loss = PytorchUtils.soft_dice(outputs, y) loss.backward() # backward optimizer.step() # optimise f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True) if self.HP.USE_VISLOGGER: probs = outputs.data.cpu().numpy().transpose( 0, 2, 3, 1) # (bs, x, y, classes) else: probs = None #faster return loss.data[0], probs, f1 def test(X, y): X = torch.from_numpy(X.astype(np.float32)) y = torch.from_numpy(y.astype(np.float32)) if torch.cuda.is_available(): X, y = Variable(X.cuda(), volatile=True), Variable(y.cuda(), volatile=True) else: X, y = Variable(X, volatile=True), Variable(y, volatile=True) net.train(False) outputs = net(X) # forward loss = criterion(outputs, y) # loss = PytorchUtils.soft_dice(outputs, y) f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True) # probs = outputs.data.cpu().numpy().transpose(0,2,3,1) # (bs, x, y, classes) probs = None # faster return loss.data[0], probs, f1 def predict(X): X = torch.from_numpy(X.astype(np.float32)) if torch.cuda.is_available(): X = Variable(X.cuda(), volatile=True) else: X = Variable(X, volatile=True) net.train(False) outputs = net(X) # forward probs = outputs.data.cpu().numpy().transpose( 0, 2, 3, 1) # (bs, x, y, classes) return probs def save_model(metrics, epoch_nr): max_f1_idx = np.argmax(metrics["f1_macro_validate"]) max_f1 = np.max(metrics["f1_macro_validate"]) if epoch_nr == max_f1_idx and max_f1 > 0.01: # saving to network drives takes 5s (to local only 0.5s) -> do not save so often print(" Saving weights...") for fl in glob.glob(join(self.HP.EXP_PATH, "best_weights_ep*") ): # remove weights from previous epochs os.remove(fl) try: #Actually is a pkl not a npz PytorchUtils.save_checkpoint(join( self.HP.EXP_PATH, "best_weights_ep" + str(epoch_nr) + ".npz"), unet=net) except IOError: print( "\nERROR: Could not save weights because of IO Error\n" ) self.HP.BEST_EPOCH = epoch_nr def load_model(path): PytorchUtils.load_checkpoint(path, unet=net) def print_current_lr(): for param_group in optimizer.param_groups: ExpUtils.print_and_save( self.HP, "current learning rate: {}".format(param_group['lr'])) if self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "single_direction": NR_OF_GRADIENTS = 9 # NR_OF_GRADIENTS = 9 * 5 # NR_OF_GRADIENTS = 9 * 9 # NR_OF_GRADIENTS = 33 elif self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "combined": NR_OF_GRADIENTS = 3 * self.HP.NR_OF_CLASSES else: NR_OF_GRADIENTS = 33 if torch.cuda.is_available(): net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT).cuda() else: net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT) #Initialisation from U-Net Paper def weights_init(m): classname = m.__class__.__name__ # Do not use with batchnorm -> has to be adapted for batchnorm if classname.find('Conv') != -1: N = m.in_channels * m.kernel_size[0] * m.kernel_size[0] std = math.sqrt(2. / N) m.weight.data.normal_(0.0, std) net.apply(weights_init) # net = nn.DataParallel(net, device_ids=[0,1]) if self.HP.TRAIN: ExpUtils.print_and_save(self.HP, str(net), only_log=True) criterion = nn.BCEWithLogitsLoss() optimizer = Adamax(net.parameters(), lr=self.HP.LEARNING_RATE) # optimizer = Adam(net.parameters(), lr=self.HP.LEARNING_RATE) #very slow (half speed of Adamax) -> strange # scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1) # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="max") if self.HP.LOAD_WEIGHTS: ExpUtils.print_verbose( self.HP, "Loading weights ... ({})".format( join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH))) load_model(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH)) self.train = train self.predict = test self.get_probs = predict self.save_model = save_model self.load_model = load_model self.print_current_lr = print_current_lr