def train(): """ Function to train a network with the given input parameters :return: """ args = setup_options() # logger and snapshot current code if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) shutil.copy2(__file__, os.path.join(args.log_dir, "train.py")) shutil.copy2("./models/networks.py", os.path.join(args.log_dir, "networks.py")) shutil.copy2("./models/sub_module.py", os.path.join(args.log_dir, "sub_module.py")) logger = logging.getLogger("train") logger.setLevel(logging.DEBUG) logger.handlers = [] ch = logging.StreamHandler() logger.addHandler(ch) fh = logging.FileHandler(os.path.join(args.log_dir, "log.txt")) logger.addHandler(fh) logger.info("%s", repr(args)) # Define augmentations transform_train = transforms.Compose([ AugmentationPadImage(pad_size=8), AugmentationRandomCrop(output_size=256), ToTensor() ]) transform_test = transforms.Compose([ToTensor()]) # Prepare and load data params_dataset_train = { 'dataset_name': args.hdf5_name_train, 'plane': args.plane } params_dataset_test = { 'dataset_name': args.hdf5_name_val, 'plane': args.plane } dataset_train = AsegDatasetWithAugmentation(params_dataset_train, transforms=transform_train) dataset_validation = AsegDatasetWithAugmentation(params_dataset_test, transforms=transform_test) train_dataloader = DataLoader(dataset=dataset_train, batch_size=args.batch_size, shuffle=True) validation_dataloader = DataLoader(dataset=dataset_validation, batch_size=args.test_batch_size, shuffle=True) # Set up network params_network = { 'num_channels': args.num_channels, 'num_filters': args.num_filters, 'kernel_h': args.kernel_height, 'kernel_w': args.kernel_width, 'stride_conv': args.stride, 'pool': args.pool, 'stride_pool': args.stride_pool, 'num_classes': args.num_classes, 'kernel_c': 1, 'kernel_d': 1, 'batch_size': args.batch_size, 'height': 256, 'width': 256 } # Select the model model = FastSurferCNN(params_network) # Put model on GPU if torch.cuda.is_available(): if torch.cuda.device_count() > 1: model = nn.DataParallel(model) model.cuda() # Define labels if args.plane == "sagittal": curr_labels = CLASS_NAMES_SAG else: curr_labels = CLASS_NAMES # optimizer selection if args.optim == "sgd": optim = torch.optim.SGD # Global optimization parameters for adam default_optim_args = { "lr": args.lr, "momentum": args.momentum, "dampening": 0, "weight_decay": 0, "nesterov": args.nesterov } else: optim = torch.optim.Adam # Global optimization parameters for adam default_optim_args = { "lr": args.lr, "betas": (0.9, 0.999), "eps": 1e-8, "weight_decay": 0.01 } # Run network solver = Solver(num_classes=params_network["num_classes"], optimizer_args=default_optim_args, optimizer=optim) solver.train(model, train_dataloader, None, validation_dataloader, class_names=curr_labels, num_epochs=args.epochs, log_params={ 'logdir': args.log_dir + "logs", 'log_iter': args.log_interval, 'logger': logger }, expdir=args.log_dir + "ckpts", scheduler_type=args.scheduler, torch_v11=args.torchv11, resume=args.resume)
def fast_surfer_cnn(img_filename, save_as, args): """ Cortical parcellation of single image :param str img_filename: name of image file :param parser.Options args: Arguments (passed via command line) to set up networks * args.network_sagittal_path: path to sagittal checkpoint (stored pretrained network) * args.network_coronal_path: path to coronal checkpoint (stored pretrained network) * args.network_axial_path: path to axial checkpoint (stored pretrained network) * args.cleanup: Whether to clean up the segmentation (medial filter on certain labels) * args.no_cuda: Whether to use CUDA (GPU) or not (CPU) :param str save_as: name under which to save prediction. :return None: saves prediction to save_as """ print("Reading volume {}".format(img_filename)) header_info, affine_info, orig_data = load_and_conform_image( img_filename, interpol=args.order) transform_test = transforms.Compose([ToTensorTest()]) test_dataset_axial = OrigDataThickSlices(img_filename, orig_data, transforms=transform_test, plane='Axial') test_dataset_sagittal = OrigDataThickSlices(img_filename, orig_data, transforms=transform_test, plane='Sagittal') test_dataset_coronal = OrigDataThickSlices(img_filename, orig_data, transforms=transform_test, plane='Coronal') start = time.time() test_data_loader = DataLoader(dataset=test_dataset_axial, batch_size=args.batch_size, shuffle=False) # Axial View Testing params_network = { 'num_channels': 7, 'num_filters': 64, 'kernel_h': 5, 'kernel_w': 5, 'stride_conv': 1, 'pool': 2, 'stride_pool': 2, 'num_classes': 79, 'kernel_c': 1, 'kernel_d': 1 } # Select the model model = FastSurferCNN(params_network) # Put it onto the GPU or CPU use_cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") print("Cuda available: {}, " "# Available GPUS: {}, " "Cuda user disabled (--no_cuda flag): {}, " "--> Using device: {}".format(torch.cuda.is_available(), torch.cuda.device_count(), args.no_cuda, device)) if use_cuda and torch.cuda.device_count() > 1: model = nn.DataParallel(model) model_parallel = True else: model_parallel = False model.to(device) # Set up state dict (remapping of names, if not multiple GPUs/CPUs) print("Loading Axial Net from {}".format(args.network_axial_path)) model_state = torch.load(args.network_axial_path, map_location=device) new_state_dict = OrderedDict() for k, v in model_state["model_state_dict"].items(): if k[:7] == "module." and not model_parallel: new_state_dict[k[7:]] = v elif k[:7] != "module." and model_parallel: new_state_dict["module." + k] = v else: new_state_dict[k] = v model.load_state_dict(new_state_dict) model.eval() prediction_probability_axial = torch.zeros( (256, params_network["num_classes"], 256, 256), dtype=torch.float) print("Axial model loaded.") with torch.no_grad(): start_index = 0 for batch_idx, sample_batch in enumerate(test_data_loader): images_batch = Variable(sample_batch["image"]) if use_cuda: images_batch = images_batch.cuda() temp = model(images_batch) prediction_probability_axial[start_index:start_index + temp.shape[0]] = temp.cpu() start_index += temp.shape[0] print("--->Batch {} Axial Testing Done.".format(batch_idx)) end = time.time() - start print("Axial View Tested in {:0.4f} seconds".format(end)) # Coronal View Testing start = time.time() test_data_loader = DataLoader(dataset=test_dataset_coronal, batch_size=args.batch_size, shuffle=False) params_network = { 'num_channels': 7, 'num_filters': 64, 'kernel_h': 5, 'kernel_w': 5, 'stride_conv': 1, 'pool': 2, 'stride_pool': 2, 'num_classes': 79, 'kernel_c': 1, 'kernel_d': 1 } # Select the model model = FastSurferCNN(params_network) # Put it onto the GPU or CPU use_cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") if use_cuda and torch.cuda.device_count() > 1: model = nn.DataParallel(model) model_parallel = True else: model_parallel = False model.to(device) # Set up new state dict (remapping of names, if not multiple GPUs/CPUs) print("Loading Coronal Net from {}".format(args.network_coronal_path)) model_state = torch.load(args.network_coronal_path, map_location=device) new_state_dict = OrderedDict() for k, v in model_state["model_state_dict"].items(): if k[:7] == "module." and not model_parallel: new_state_dict[k[7:]] = v elif k[:7] != "module." and model_parallel: new_state_dict["module." + k] = v else: new_state_dict[k] = v model.load_state_dict(new_state_dict) model.eval() prediction_probability_coronal = torch.zeros( (256, params_network["num_classes"], 256, 256), dtype=torch.float) print("Coronal model loaded.") start_index = 0 with torch.no_grad(): for batch_idx, sample_batch in enumerate(test_data_loader): images_batch = Variable(sample_batch["image"]) if use_cuda: images_batch = images_batch.cuda() temp = model(images_batch) prediction_probability_coronal[start_index:start_index + temp.shape[0]] = temp.cpu() start_index += temp.shape[0] print("--->Batch {} Coronal Testing Done.".format(batch_idx)) end = time.time() - start print("Coronal View Tested in {:0.4f} seconds".format(end)) start = time.time() test_data_loader = DataLoader(dataset=test_dataset_sagittal, batch_size=args.batch_size, shuffle=False) params_network = { 'num_channels': 7, 'num_filters': 64, 'kernel_h': 5, 'kernel_w': 5, 'stride_conv': 1, 'pool': 2, 'stride_pool': 2, 'num_classes': 51, 'kernel_c': 1, 'kernel_d': 1 } # Select the model model = FastSurferCNN(params_network) # Put it onto the GPU or CPU use_cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") if use_cuda and torch.cuda.device_count() > 1: model = nn.DataParallel(model) model_parallel = True else: model_parallel = False model.to(device) # Set up new state dict (remapping of names, if not multiple GPUs/CPUs) print("Loading Sagittal Net from {}".format(args.network_sagittal_path)) model_state = torch.load(args.network_sagittal_path, map_location=device) new_state_dict = OrderedDict() for k, v in model_state["model_state_dict"].items(): if k[:7] == "module." and not model_parallel: new_state_dict[k[7:]] = v elif k[:7] != "module." and model_parallel: new_state_dict["module." + k] = v else: new_state_dict[k] = v model.load_state_dict(new_state_dict) model.eval() prediction_probability_sagittal = torch.zeros( (256, params_network["num_classes"], 256, 256), dtype=torch.float) start_index = 0 with torch.no_grad(): for batch_idx, sample_batch in enumerate(test_data_loader): images_batch = Variable(sample_batch["image"]) if use_cuda: images_batch = images_batch.cuda() temp = model(images_batch) prediction_probability_sagittal[start_index:start_index + temp.shape[0]] = temp.cpu() start_index += temp.shape[0] print("--->Batch {} Sagittal Testing Done.".format(batch_idx)) prediction_probability_sagittal = map_prediction_sagittal2full( prediction_probability_sagittal) end = time.time() - start print("Sagittal View Tested in {:0.4f} seconds".format(end)) del model, test_dataset_axial, test_dataset_coronal, test_dataset_sagittal, test_data_loader start = time.time() # Start View Aggregation: change from N,C,X,Y to coronal view with C in last dimension = H,W,D,C prediction_probability_axial = prediction_probability_axial.permute( 3, 0, 2, 1) prediction_probability_coronal = prediction_probability_coronal.permute( 2, 3, 0, 1) prediction_probability_sagittal = prediction_probability_sagittal.permute( 0, 3, 2, 1) _, prediction_image = torch.max( torch.add( torch.mul( torch.add(prediction_probability_axial, prediction_probability_coronal), 0.4), torch.mul(prediction_probability_sagittal, 0.2)), 3) prediction_image = prediction_image.numpy() end = time.time() - start print("View Aggregation finished in {:0.4f} seconds.".format(end)) prediction_image = map_label2aparc_aseg(prediction_image) # Quick Fix for 2026 vs 1026; 2029 vs. 1029; 2025 vs. 1025 rh_wm = get_largest_cc(prediction_image == 41) lh_wm = get_largest_cc(prediction_image == 2) rh_wm = regionprops(label(rh_wm, background=0)) lh_wm = regionprops(label(lh_wm, background=0)) centroid_rh = np.asarray(rh_wm[0].centroid) centroid_lh = np.asarray(lh_wm[0].centroid) labels_list = np.array([ 1003, 1006, 1007, 1008, 1009, 1011, 1015, 1018, 1019, 1020, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1034, 1035 ]) for label_current in labels_list: label_img = label(prediction_image == label_current, connectivity=3, background=0) for region in regionprops(label_img): if region.label != 0: # To avoid background if np.linalg.norm(np.asarray(region.centroid) - centroid_rh) < np.linalg.norm( np.asarray(region.centroid) - centroid_lh): mask = label_img == region.label prediction_image[mask] = label_current + 1000 # Quick Fixes for overlapping classes aseg_lh = gaussian_filter( 1000 * np.asarray(prediction_image == 2, dtype=np.float), sigma=3) aseg_rh = gaussian_filter( 1000 * np.asarray(prediction_image == 41, dtype=np.float), sigma=3) lh_rh_split = np.argmax(np.concatenate( (np.expand_dims(aseg_lh, axis=3), np.expand_dims(aseg_rh, axis=3)), axis=3), axis=3) # Problematic classes: 1026, 1011, 1029, 1019 for prob_class_lh in [1011, 1019, 1026, 1029]: prob_class_rh = prob_class_lh + 1000 mask_lh = ((prediction_image == prob_class_lh) | (prediction_image == prob_class_rh)) & (lh_rh_split == 0) mask_rh = ((prediction_image == prob_class_lh) | (prediction_image == prob_class_rh)) & (lh_rh_split == 1) prediction_image[mask_lh] = prob_class_lh prediction_image[mask_rh] = prob_class_rh # Clean-Up if args.cleanup is True: labels = [ 2, 4, 5, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 24, 26, 28, 31, 41, 43, 44, 46, 47, 49, 50, 51, 52, 53, 54, 58, 60, 63, 77, 1026, 2026 ] start = time.time() prediction_image_medfilt = median_filter(prediction_image, size=(3, 3, 3)) mask = np.zeros_like(prediction_image) tolerance = 25 for current_label in labels: current_class = (prediction_image == current_label) label_image = label(current_class, connectivity=3) for region in regionprops(label_image): if region.area <= tolerance: mask_label = (label_image == region.label) mask[mask_label] = 1 prediction_image[mask == 1] = prediction_image_medfilt[mask == 1] end = time.time() - start print("Segmentation Cleaned up in {:0.4f} seconds.".format(end)) # Saving image header_info.set_data_dtype(np.int16) mapped_aseg_img = nib.MGHImage(prediction_image, affine_info, header_info) mapped_aseg_img.to_filename(save_as) print("Saving Segmentation to {}".format(save_as))
def fastsurfercnn(img_filename, save_as, logger, args): """ Cortical parcellation of single image with FastSurferCNN. :param str img_filename: name of image file :param parser.Argparse args: Arguments (passed via command line) to set up networks * args.network_sagittal_path: path to sagittal checkpoint (stored pretrained network) * args.network_coronal_path: path to coronal checkpoint (stored pretrained network) * args.network_axial_path: path to axial checkpoint (stored pretrained network) * args.cleanup: Whether to clean up the segmentation (medial filter on certain labels) * args.no_cuda: Whether to use CUDA (GPU) or not (CPU) * args.batch_size: Input batch size for inference (Default=8) * args.num_classes_ax_cor: Number of classes to predict in axial/coronal net (Default=79) * args.num_classes_sag: Number of classes to predict in sagittal net (Default=51) * args.num_channels: Number of input channels (Default=7, thick slices) * args.num_filters: Number of filter dimensions for DenseNet (Default=64) * args.kernel_height and args.kernel_width: Height and width of Kernel (Default=5) * args.stride: Stride during convolution (Default=1) * args.stride_pool: Stride during pooling (Default=2) * args.pool: Size of pooling filter (Default=2) :param logging.logger logger: Logging instance info messages will be written to :param str save_as: name under which to save prediction. :return None: saves prediction to save_as """ start_total = time.time() logger.info("Reading volume {}".format(img_filename)) header_info, affine_info, orig_data = load_and_conform_image(img_filename, interpol=1) # Set up model for axial and coronal networks params_network = {'num_channels': args.num_channels, 'num_filters': args.num_filters, 'kernel_h': args.kernel_height, 'kernel_w': args.kernel_width, 'stride_conv': args.stride, 'pool': args.pool, 'stride_pool': args.stride_pool, 'num_classes': args.num_classes_ax_cor, 'kernel_c': 1, 'kernel_d': 1} # Select the model model = FastSurferCNN(params_network) # Put it onto the GPU or CPU use_cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") logger.info("Cuda available: {}, # Available GPUS: {}, " "Cuda user disabled (--no_cuda flag): {}, " "--> Using device: {}".format(torch.cuda.is_available(), torch.cuda.device_count(), args.no_cuda, device)) if use_cuda and torch.cuda.device_count() > 1: model = nn.DataParallel(model) model_parallel = True else: model_parallel = False model.to(device) params_model = {'device': device, "use_cuda": use_cuda, "batch_size": args.batch_size, "model_parallel": model_parallel} # Set up tensor to hold probabilities pred_prob = torch.zeros((256, 256, 256, args.num_classes_ax_cor), dtype=torch.float) # Axial Prediction start = time.time() pred_prob = run_network(img_filename, orig_data, pred_prob, "Axial", args.network_axial_path, params_model, model, logger) logger.info("Axial View Tested in {:0.4f} seconds".format(time.time() - start)) # Coronal Prediction start = time.time() pred_prob = run_network(img_filename, orig_data, pred_prob, "Coronal", args.network_coronal_path, params_model, model, logger) logger.info("Coronal View Tested in {:0.4f} seconds".format(time.time() - start)) # Sagittal Prediction start = time.time() params_network["num_classes"] = args.num_classes_sag params_network["num_channels"] = args.num_channels model = FastSurferCNN(params_network) if model_parallel: model = nn.DataParallel(model) model.to(device) pred_prob = run_network(img_filename, orig_data, pred_prob, "Sagittal", args.network_sagittal_path, params_model, model, logger) logger.info("Sagittal View Tested in {:0.4f} seconds".format(time.time() - start)) # Get predictions and map to freesurfer label space _, pred_prob = torch.max(pred_prob, 3) pred_prob = pred_prob.numpy() pred_prob = map_label2aparc_aseg(pred_prob) # Post processing - Splitting classes # Quick Fix for 2026 vs 1026; 2029 vs. 1029; 2025 vs. 1025 rh_wm = get_largest_cc(pred_prob == 41) lh_wm = get_largest_cc(pred_prob == 2) rh_wm = regionprops(label(rh_wm, background=0)) lh_wm = regionprops(label(lh_wm, background=0)) centroid_rh = np.asarray(rh_wm[0].centroid) centroid_lh = np.asarray(lh_wm[0].centroid) labels_list = np.array([1003, 1006, 1007, 1008, 1009, 1011, 1015, 1018, 1019, 1020, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1034, 1035]) for label_current in labels_list: label_img = label(pred_prob == label_current, connectivity=3, background=0) for region in regionprops(label_img): if region.label != 0: # To avoid background if np.linalg.norm(np.asarray(region.centroid) - centroid_rh) < np.linalg.norm( np.asarray(region.centroid) - centroid_lh): mask = label_img == region.label pred_prob[mask] = label_current + 1000 # Quick Fixes for overlapping classes aseg_lh = gaussian_filter(1000 * np.asarray(pred_prob == 2, dtype=np.float), sigma=3) aseg_rh = gaussian_filter(1000 * np.asarray(pred_prob == 41, dtype=np.float), sigma=3) lh_rh_split = np.argmax(np.concatenate((np.expand_dims(aseg_lh, axis=3), np.expand_dims(aseg_rh, axis=3)), axis=3), axis=3) # Problematic classes: 1026, 1011, 1029, 1019 for prob_class_lh in [1011, 1019, 1026, 1029]: prob_class_rh = prob_class_lh + 1000 mask_lh = ((pred_prob == prob_class_lh) | (pred_prob == prob_class_rh)) & (lh_rh_split == 0) mask_rh = ((pred_prob == prob_class_lh) | (pred_prob == prob_class_rh)) & (lh_rh_split == 1) pred_prob[mask_lh] = prob_class_lh pred_prob[mask_rh] = prob_class_rh # Clean-Up if args.cleanup is True: labels = [2, 4, 5, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 24, 26, 28, 31, 41, 43, 44, 46, 47, 49, 50, 51, 52, 53, 54, 58, 60, 63, 77, 1026, 2026] start = time.time() pred_prob_medfilt = median_filter(pred_prob, size=(3, 3, 3)) mask = np.zeros_like(pred_prob) tolerance = 25 for current_label in labels: current_class = (pred_prob == current_label) label_image = label(current_class, connectivity=3) for region in regionprops(label_image): if region.area <= tolerance: mask_label = (label_image == region.label) mask[mask_label] = 1 pred_prob[mask == 1] = pred_prob_medfilt[mask == 1] logger.info("Segmentation Cleaned up in {:0.4f} seconds.".format(time.time() - start)) # Saving image header_info.set_data_dtype(np.int16) mapped_aseg_img = nib.MGHImage(pred_prob, affine_info, header_info) mapped_aseg_img.to_filename(save_as) logger.info("Saving Segmentation to {}".format(save_as)) logger.info("Total processing time: {:0.4f} seconds.".format(time.time() - start_total))