示例#1
0
def train():
    """
    Function to train a network with the given input parameters
    :return:
    """
    args = setup_options()

    # logger and snapshot current code
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)
    shutil.copy2(__file__, os.path.join(args.log_dir, "train.py"))
    shutil.copy2("./models/networks.py",
                 os.path.join(args.log_dir, "networks.py"))
    shutil.copy2("./models/sub_module.py",
                 os.path.join(args.log_dir, "sub_module.py"))

    logger = logging.getLogger("train")
    logger.setLevel(logging.DEBUG)
    logger.handlers = []
    ch = logging.StreamHandler()
    logger.addHandler(ch)
    fh = logging.FileHandler(os.path.join(args.log_dir, "log.txt"))
    logger.addHandler(fh)

    logger.info("%s", repr(args))

    # Define augmentations
    transform_train = transforms.Compose([
        AugmentationPadImage(pad_size=8),
        AugmentationRandomCrop(output_size=256),
        ToTensor()
    ])
    transform_test = transforms.Compose([ToTensor()])

    # Prepare and load data
    params_dataset_train = {
        'dataset_name': args.hdf5_name_train,
        'plane': args.plane
    }
    params_dataset_test = {
        'dataset_name': args.hdf5_name_val,
        'plane': args.plane
    }

    dataset_train = AsegDatasetWithAugmentation(params_dataset_train,
                                                transforms=transform_train)
    dataset_validation = AsegDatasetWithAugmentation(params_dataset_test,
                                                     transforms=transform_test)

    train_dataloader = DataLoader(dataset=dataset_train,
                                  batch_size=args.batch_size,
                                  shuffle=True)
    validation_dataloader = DataLoader(dataset=dataset_validation,
                                       batch_size=args.test_batch_size,
                                       shuffle=True)

    # Set up network
    params_network = {
        'num_channels': args.num_channels,
        'num_filters': args.num_filters,
        'kernel_h': args.kernel_height,
        'kernel_w': args.kernel_width,
        'stride_conv': args.stride,
        'pool': args.pool,
        'stride_pool': args.stride_pool,
        'num_classes': args.num_classes,
        'kernel_c': 1,
        'kernel_d': 1,
        'batch_size': args.batch_size,
        'height': 256,
        'width': 256
    }

    # Select the model
    model = FastSurferCNN(params_network)

    # Put model on GPU
    if torch.cuda.is_available():
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
        model.cuda()

    # Define labels
    if args.plane == "sagittal":
        curr_labels = CLASS_NAMES_SAG

    else:
        curr_labels = CLASS_NAMES

    # optimizer selection
    if args.optim == "sgd":
        optim = torch.optim.SGD

        # Global optimization parameters for adam
        default_optim_args = {
            "lr": args.lr,
            "momentum": args.momentum,
            "dampening": 0,
            "weight_decay": 0,
            "nesterov": args.nesterov
        }
    else:
        optim = torch.optim.Adam

        # Global optimization parameters for adam
        default_optim_args = {
            "lr": args.lr,
            "betas": (0.9, 0.999),
            "eps": 1e-8,
            "weight_decay": 0.01
        }

    # Run network
    solver = Solver(num_classes=params_network["num_classes"],
                    optimizer_args=default_optim_args,
                    optimizer=optim)
    solver.train(model,
                 train_dataloader,
                 None,
                 validation_dataloader,
                 class_names=curr_labels,
                 num_epochs=args.epochs,
                 log_params={
                     'logdir': args.log_dir + "logs",
                     'log_iter': args.log_interval,
                     'logger': logger
                 },
                 expdir=args.log_dir + "ckpts",
                 scheduler_type=args.scheduler,
                 torch_v11=args.torchv11,
                 resume=args.resume)
示例#2
0
def fast_surfer_cnn(img_filename, save_as, args):
    """
    Cortical parcellation of single image
    :param str img_filename: name of image file
    :param parser.Options args: Arguments (passed via command line) to set up networks
            * args.network_sagittal_path: path to sagittal checkpoint (stored pretrained network)
            * args.network_coronal_path: path to coronal checkpoint (stored pretrained network)
            * args.network_axial_path: path to axial checkpoint (stored pretrained network)
            * args.cleanup: Whether to clean up the segmentation (medial filter on certain labels)
            * args.no_cuda: Whether to use CUDA (GPU) or not (CPU)
    :param str save_as: name under which to save prediction.
    :return None: saves prediction to save_as
    """
    print("Reading volume {}".format(img_filename))

    header_info, affine_info, orig_data = load_and_conform_image(
        img_filename, interpol=args.order)

    transform_test = transforms.Compose([ToTensorTest()])

    test_dataset_axial = OrigDataThickSlices(img_filename,
                                             orig_data,
                                             transforms=transform_test,
                                             plane='Axial')

    test_dataset_sagittal = OrigDataThickSlices(img_filename,
                                                orig_data,
                                                transforms=transform_test,
                                                plane='Sagittal')

    test_dataset_coronal = OrigDataThickSlices(img_filename,
                                               orig_data,
                                               transforms=transform_test,
                                               plane='Coronal')

    start = time.time()

    test_data_loader = DataLoader(dataset=test_dataset_axial,
                                  batch_size=args.batch_size,
                                  shuffle=False)

    # Axial View Testing
    params_network = {
        'num_channels': 7,
        'num_filters': 64,
        'kernel_h': 5,
        'kernel_w': 5,
        'stride_conv': 1,
        'pool': 2,
        'stride_pool': 2,
        'num_classes': 79,
        'kernel_c': 1,
        'kernel_d': 1
    }

    # Select the model
    model = FastSurferCNN(params_network)

    # Put it onto the GPU or CPU
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    print("Cuda available: {}, "
          "# Available GPUS: {}, "
          "Cuda user disabled (--no_cuda flag): {}, "
          "--> Using device: {}".format(torch.cuda.is_available(),
                                        torch.cuda.device_count(),
                                        args.no_cuda, device))

    if use_cuda and torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        model_parallel = True
    else:
        model_parallel = False

    model.to(device)

    # Set up state dict (remapping of names, if not multiple GPUs/CPUs)
    print("Loading Axial Net from {}".format(args.network_axial_path))

    model_state = torch.load(args.network_axial_path, map_location=device)
    new_state_dict = OrderedDict()

    for k, v in model_state["model_state_dict"].items():

        if k[:7] == "module." and not model_parallel:
            new_state_dict[k[7:]] = v

        elif k[:7] != "module." and model_parallel:
            new_state_dict["module." + k] = v

        else:
            new_state_dict[k] = v

    model.load_state_dict(new_state_dict)

    model.eval()
    prediction_probability_axial = torch.zeros(
        (256, params_network["num_classes"], 256, 256), dtype=torch.float)

    print("Axial model loaded.")
    with torch.no_grad():

        start_index = 0
        for batch_idx, sample_batch in enumerate(test_data_loader):
            images_batch = Variable(sample_batch["image"])

            if use_cuda:
                images_batch = images_batch.cuda()

            temp = model(images_batch)

            prediction_probability_axial[start_index:start_index +
                                         temp.shape[0]] = temp.cpu()
            start_index += temp.shape[0]
            print("--->Batch {} Axial Testing Done.".format(batch_idx))

    end = time.time() - start
    print("Axial View Tested in {:0.4f} seconds".format(end))

    # Coronal View Testing
    start = time.time()

    test_data_loader = DataLoader(dataset=test_dataset_coronal,
                                  batch_size=args.batch_size,
                                  shuffle=False)

    params_network = {
        'num_channels': 7,
        'num_filters': 64,
        'kernel_h': 5,
        'kernel_w': 5,
        'stride_conv': 1,
        'pool': 2,
        'stride_pool': 2,
        'num_classes': 79,
        'kernel_c': 1,
        'kernel_d': 1
    }

    # Select the model

    model = FastSurferCNN(params_network)

    # Put it onto the GPU or CPU
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    if use_cuda and torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        model_parallel = True
    else:
        model_parallel = False

    model.to(device)

    # Set up new state dict (remapping of names, if not multiple GPUs/CPUs)
    print("Loading Coronal Net from {}".format(args.network_coronal_path))

    model_state = torch.load(args.network_coronal_path, map_location=device)
    new_state_dict = OrderedDict()

    for k, v in model_state["model_state_dict"].items():

        if k[:7] == "module." and not model_parallel:
            new_state_dict[k[7:]] = v

        elif k[:7] != "module." and model_parallel:
            new_state_dict["module." + k] = v

        else:
            new_state_dict[k] = v

    model.load_state_dict(new_state_dict)

    model.eval()
    prediction_probability_coronal = torch.zeros(
        (256, params_network["num_classes"], 256, 256), dtype=torch.float)

    print("Coronal model loaded.")
    start_index = 0
    with torch.no_grad():

        for batch_idx, sample_batch in enumerate(test_data_loader):

            images_batch = Variable(sample_batch["image"])

            if use_cuda:
                images_batch = images_batch.cuda()

            temp = model(images_batch)

            prediction_probability_coronal[start_index:start_index +
                                           temp.shape[0]] = temp.cpu()
            start_index += temp.shape[0]
            print("--->Batch {} Coronal Testing Done.".format(batch_idx))

    end = time.time() - start
    print("Coronal View Tested in {:0.4f} seconds".format(end))

    start = time.time()

    test_data_loader = DataLoader(dataset=test_dataset_sagittal,
                                  batch_size=args.batch_size,
                                  shuffle=False)

    params_network = {
        'num_channels': 7,
        'num_filters': 64,
        'kernel_h': 5,
        'kernel_w': 5,
        'stride_conv': 1,
        'pool': 2,
        'stride_pool': 2,
        'num_classes': 51,
        'kernel_c': 1,
        'kernel_d': 1
    }

    # Select the model
    model = FastSurferCNN(params_network)

    # Put it onto the GPU or CPU
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    if use_cuda and torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        model_parallel = True
    else:
        model_parallel = False

    model.to(device)

    # Set up new state dict (remapping of names, if not multiple GPUs/CPUs)
    print("Loading Sagittal Net from {}".format(args.network_sagittal_path))

    model_state = torch.load(args.network_sagittal_path, map_location=device)
    new_state_dict = OrderedDict()

    for k, v in model_state["model_state_dict"].items():

        if k[:7] == "module." and not model_parallel:
            new_state_dict[k[7:]] = v

        elif k[:7] != "module." and model_parallel:
            new_state_dict["module." + k] = v

        else:
            new_state_dict[k] = v

    model.load_state_dict(new_state_dict)

    model.eval()
    prediction_probability_sagittal = torch.zeros(
        (256, params_network["num_classes"], 256, 256), dtype=torch.float)

    start_index = 0
    with torch.no_grad():

        for batch_idx, sample_batch in enumerate(test_data_loader):

            images_batch = Variable(sample_batch["image"])

            if use_cuda:
                images_batch = images_batch.cuda()

            temp = model(images_batch)

            prediction_probability_sagittal[start_index:start_index +
                                            temp.shape[0]] = temp.cpu()
            start_index += temp.shape[0]
            print("--->Batch {} Sagittal Testing Done.".format(batch_idx))

    prediction_probability_sagittal = map_prediction_sagittal2full(
        prediction_probability_sagittal)
    end = time.time() - start

    print("Sagittal View Tested in {:0.4f} seconds".format(end))

    del model, test_dataset_axial, test_dataset_coronal, test_dataset_sagittal, test_data_loader

    start = time.time()

    # Start View Aggregation: change from N,C,X,Y to coronal view with C in last dimension = H,W,D,C
    prediction_probability_axial = prediction_probability_axial.permute(
        3, 0, 2, 1)
    prediction_probability_coronal = prediction_probability_coronal.permute(
        2, 3, 0, 1)
    prediction_probability_sagittal = prediction_probability_sagittal.permute(
        0, 3, 2, 1)

    _, prediction_image = torch.max(
        torch.add(
            torch.mul(
                torch.add(prediction_probability_axial,
                          prediction_probability_coronal), 0.4),
            torch.mul(prediction_probability_sagittal, 0.2)), 3)

    prediction_image = prediction_image.numpy()

    end = time.time() - start
    print("View Aggregation finished in {:0.4f} seconds.".format(end))

    prediction_image = map_label2aparc_aseg(prediction_image)

    # Quick Fix for 2026 vs 1026; 2029 vs. 1029; 2025 vs. 1025
    rh_wm = get_largest_cc(prediction_image == 41)
    lh_wm = get_largest_cc(prediction_image == 2)
    rh_wm = regionprops(label(rh_wm, background=0))
    lh_wm = regionprops(label(lh_wm, background=0))
    centroid_rh = np.asarray(rh_wm[0].centroid)
    centroid_lh = np.asarray(lh_wm[0].centroid)

    labels_list = np.array([
        1003, 1006, 1007, 1008, 1009, 1011, 1015, 1018, 1019, 1020, 1025, 1026,
        1027, 1028, 1029, 1030, 1031, 1034, 1035
    ])

    for label_current in labels_list:

        label_img = label(prediction_image == label_current,
                          connectivity=3,
                          background=0)

        for region in regionprops(label_img):

            if region.label != 0:  # To avoid background

                if np.linalg.norm(np.asarray(region.centroid) -
                                  centroid_rh) < np.linalg.norm(
                                      np.asarray(region.centroid) -
                                      centroid_lh):
                    mask = label_img == region.label
                    prediction_image[mask] = label_current + 1000

    # Quick Fixes for overlapping classes
    aseg_lh = gaussian_filter(
        1000 * np.asarray(prediction_image == 2, dtype=np.float), sigma=3)
    aseg_rh = gaussian_filter(
        1000 * np.asarray(prediction_image == 41, dtype=np.float), sigma=3)

    lh_rh_split = np.argmax(np.concatenate(
        (np.expand_dims(aseg_lh, axis=3), np.expand_dims(aseg_rh, axis=3)),
        axis=3),
                            axis=3)

    # Problematic classes: 1026, 1011, 1029, 1019
    for prob_class_lh in [1011, 1019, 1026, 1029]:
        prob_class_rh = prob_class_lh + 1000
        mask_lh = ((prediction_image == prob_class_lh) |
                   (prediction_image == prob_class_rh)) & (lh_rh_split == 0)
        mask_rh = ((prediction_image == prob_class_lh) |
                   (prediction_image == prob_class_rh)) & (lh_rh_split == 1)

        prediction_image[mask_lh] = prob_class_lh
        prediction_image[mask_rh] = prob_class_rh

    # Clean-Up
    if args.cleanup is True:

        labels = [
            2, 4, 5, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 24, 26, 28, 31,
            41, 43, 44, 46, 47, 49, 50, 51, 52, 53, 54, 58, 60, 63, 77, 1026,
            2026
        ]

        start = time.time()
        prediction_image_medfilt = median_filter(prediction_image,
                                                 size=(3, 3, 3))
        mask = np.zeros_like(prediction_image)
        tolerance = 25

        for current_label in labels:
            current_class = (prediction_image == current_label)
            label_image = label(current_class, connectivity=3)

            for region in regionprops(label_image):

                if region.area <= tolerance:
                    mask_label = (label_image == region.label)
                    mask[mask_label] = 1

        prediction_image[mask == 1] = prediction_image_medfilt[mask == 1]
        end = time.time() - start
        print("Segmentation Cleaned up in {:0.4f} seconds.".format(end))

    # Saving image
    header_info.set_data_dtype(np.int16)
    mapped_aseg_img = nib.MGHImage(prediction_image, affine_info, header_info)
    mapped_aseg_img.to_filename(save_as)
    print("Saving Segmentation to {}".format(save_as))
示例#3
0
def fastsurfercnn(img_filename, save_as, logger, args):
    """
    Cortical parcellation of single image with FastSurferCNN.

    :param str img_filename: name of image file
    :param parser.Argparse args: Arguments (passed via command line) to set up networks
            * args.network_sagittal_path: path to sagittal checkpoint (stored pretrained network)
            * args.network_coronal_path: path to coronal checkpoint (stored pretrained network)
            * args.network_axial_path: path to axial checkpoint (stored pretrained network)
            * args.cleanup: Whether to clean up the segmentation (medial filter on certain labels)
            * args.no_cuda: Whether to use CUDA (GPU) or not (CPU)
            * args.batch_size: Input batch size for inference (Default=8)
            * args.num_classes_ax_cor: Number of classes to predict in axial/coronal net (Default=79)
            * args.num_classes_sag: Number of classes to predict in sagittal net (Default=51)
            * args.num_channels: Number of input channels (Default=7, thick slices)
            * args.num_filters: Number of filter dimensions for DenseNet (Default=64)
            * args.kernel_height and args.kernel_width: Height and width of Kernel (Default=5)
            * args.stride: Stride during convolution (Default=1)
            * args.stride_pool: Stride during pooling (Default=2)
            * args.pool: Size of pooling filter (Default=2)
    :param logging.logger logger: Logging instance info messages will be written to
    :param str save_as: name under which to save prediction.

    :return None: saves prediction to save_as
    """
    start_total = time.time()
    logger.info("Reading volume {}".format(img_filename))

    header_info, affine_info, orig_data = load_and_conform_image(img_filename, interpol=1)

    # Set up model for axial and coronal networks
    params_network = {'num_channels': args.num_channels, 'num_filters': args.num_filters,
                      'kernel_h': args.kernel_height, 'kernel_w': args.kernel_width,
                      'stride_conv': args.stride, 'pool': args.pool,
                      'stride_pool': args.stride_pool, 'num_classes': args.num_classes_ax_cor,
                      'kernel_c': 1, 'kernel_d': 1}

    # Select the model
    model = FastSurferCNN(params_network)

    # Put it onto the GPU or CPU
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    logger.info("Cuda available: {}, # Available GPUS: {}, "
                "Cuda user disabled (--no_cuda flag): {}, "
                "--> Using device: {}".format(torch.cuda.is_available(),
                                              torch.cuda.device_count(),
                                              args.no_cuda, device))

    if use_cuda and torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        model_parallel = True
    else:
        model_parallel = False

    model.to(device)

    params_model = {'device': device, "use_cuda": use_cuda, "batch_size": args.batch_size,
                    "model_parallel": model_parallel}

    # Set up tensor to hold probabilities
    pred_prob = torch.zeros((256, 256, 256, args.num_classes_ax_cor), dtype=torch.float)

    # Axial Prediction
    start = time.time()
    pred_prob = run_network(img_filename,
                            orig_data, pred_prob, "Axial",
                            args.network_axial_path,
                            params_model, model, logger)

    logger.info("Axial View Tested in {:0.4f} seconds".format(time.time() - start))

    # Coronal Prediction
    start = time.time()
    pred_prob = run_network(img_filename,
                            orig_data, pred_prob, "Coronal",
                            args.network_coronal_path,
                            params_model, model, logger)

    logger.info("Coronal View Tested in {:0.4f} seconds".format(time.time() - start))

    # Sagittal Prediction
    start = time.time()
    params_network["num_classes"] = args.num_classes_sag
    params_network["num_channels"] = args.num_channels

    model = FastSurferCNN(params_network)

    if model_parallel:
        model = nn.DataParallel(model)

    model.to(device)

    pred_prob = run_network(img_filename, orig_data, pred_prob, "Sagittal",
                            args.network_sagittal_path,
                            params_model, model, logger)

    logger.info("Sagittal View Tested in {:0.4f} seconds".format(time.time() - start))

    # Get predictions and map to freesurfer label space
    _, pred_prob = torch.max(pred_prob, 3)
    pred_prob = pred_prob.numpy()
    pred_prob = map_label2aparc_aseg(pred_prob)

    # Post processing - Splitting classes
    # Quick Fix for 2026 vs 1026; 2029 vs. 1029; 2025 vs. 1025
    rh_wm = get_largest_cc(pred_prob == 41)
    lh_wm = get_largest_cc(pred_prob == 2)
    rh_wm = regionprops(label(rh_wm, background=0))
    lh_wm = regionprops(label(lh_wm, background=0))
    centroid_rh = np.asarray(rh_wm[0].centroid)
    centroid_lh = np.asarray(lh_wm[0].centroid)

    labels_list = np.array([1003, 1006, 1007, 1008, 1009, 1011,
                            1015, 1018, 1019, 1020, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1034, 1035])

    for label_current in labels_list:

        label_img = label(pred_prob == label_current, connectivity=3, background=0)

        for region in regionprops(label_img):

            if region.label != 0:  # To avoid background

                if np.linalg.norm(np.asarray(region.centroid) - centroid_rh) < np.linalg.norm(
                        np.asarray(region.centroid) - centroid_lh):
                    mask = label_img == region.label
                    pred_prob[mask] = label_current + 1000

    # Quick Fixes for overlapping classes
    aseg_lh = gaussian_filter(1000 * np.asarray(pred_prob == 2, dtype=np.float), sigma=3)
    aseg_rh = gaussian_filter(1000 * np.asarray(pred_prob == 41, dtype=np.float), sigma=3)

    lh_rh_split = np.argmax(np.concatenate((np.expand_dims(aseg_lh, axis=3), np.expand_dims(aseg_rh, axis=3)), axis=3),
                            axis=3)

    # Problematic classes: 1026, 1011, 1029, 1019
    for prob_class_lh in [1011, 1019, 1026, 1029]:
        prob_class_rh = prob_class_lh + 1000
        mask_lh = ((pred_prob == prob_class_lh) | (pred_prob == prob_class_rh)) & (lh_rh_split == 0)
        mask_rh = ((pred_prob == prob_class_lh) | (pred_prob == prob_class_rh)) & (lh_rh_split == 1)

        pred_prob[mask_lh] = prob_class_lh
        pred_prob[mask_rh] = prob_class_rh

    # Clean-Up
    if args.cleanup is True:

        labels = [2, 4, 5, 7, 8, 10, 11, 12, 13, 14,
                  15, 16, 17, 18, 24, 26, 28, 31, 41, 43, 44,
                  46, 47, 49, 50, 51, 52, 53, 54, 58, 60, 63,
                  77, 1026, 2026]

        start = time.time()
        pred_prob_medfilt = median_filter(pred_prob, size=(3, 3, 3))
        mask = np.zeros_like(pred_prob)
        tolerance = 25

        for current_label in labels:
            current_class = (pred_prob == current_label)
            label_image = label(current_class, connectivity=3)

            for region in regionprops(label_image):

                if region.area <= tolerance:
                    mask_label = (label_image == region.label)
                    mask[mask_label] = 1

        pred_prob[mask == 1] = pred_prob_medfilt[mask == 1]
        logger.info("Segmentation Cleaned up in {:0.4f} seconds.".format(time.time() - start))

    # Saving image
    header_info.set_data_dtype(np.int16)
    mapped_aseg_img = nib.MGHImage(pred_prob, affine_info, header_info)
    mapped_aseg_img.to_filename(save_as)
    logger.info("Saving Segmentation to {}".format(save_as))
    logger.info("Total processing time: {:0.4f} seconds.".format(time.time() - start_total))