Ejemplo n.º 1
0
def gather_file(args):
    if args.momentum:
        file = common.AsNPCopy(common.LoadITKField(args.files[0], ca.MEM_HOST))
    else:
        #file = common.AsNPCopy(common.LoadITKImage(args.files[0], ca.MEM_HOST))
        file = nib.load(args.files[0]).get_data()

    all_size = (len(args.files), ) + (15, 15, 15, 1, 3)  #file.shape;

    data = torch.zeros(all_size)
    for i in range(0, len(args.files)):
        if args.momentum:
            cur_slice = torch.from_numpy(
                common.AsNPCopy(common.LoadITKField(args.files[i],
                                                    ca.MEM_HOST)))
        else:
            cur_slice = nib.load(args.files[i]).get_data()
            if cur_slice.size == 3375:
                cur_slice = np.zeros([15, 15, 15, 1, 3])
            cur_slice = torch.from_numpy(cur_slice)
        data[i] = cur_slice

    if args.momentum:
        # transpose the dataset to fit the training format
        data = data.numpy()
        data = np.transpose(data, [0, 4, 1, 2, 3])
        data = torch.from_numpy(data)

    torch.save(data, args.output)
def predict_each_datapart(args, net, network_config, input_batch, datapart_idx, batch_size, patch_size, predict_transform_space):
    moving_image = torch.load(args.moving_image_dataset[datapart_idx])
    target_image = torch.load(args.target_image_dataset[datapart_idx])
    optimization_momentum = torch.load(args.deformation_parameter[datapart_idx])
    for slice_idx in range(0, moving_image.size()[0]):
        print(slice_idx)
        moving_slice = moving_image[slice_idx].numpy()
        target_slice = target_image[slice_idx].numpy()
        if predict_transform_space:
            moving_slice = util.convert_to_registration_space(moving_slice)
            target_slice = util.convert_to_registration_space(target_slice)

        predicted_momentum = util.predict_momentum(moving_slice, target_slice, input_batch, batch_size, patch_size, net, predict_transform_space);
        m0_reg = common.FieldFromNPArr(predicted_momentum['image_space'], ca.MEM_DEVICE);
            
        moving_image_ca = common.ImFromNPArr(moving_slice, ca.MEM_DEVICE)
        target_image_ca = common.ImFromNPArr(target_slice, ca.MEM_DEVICE)

        registration_result = registration_methods.geodesic_shooting(moving_image_ca, target_image_ca, m0_reg, args.shoot_steps, ca.MEM_DEVICE, network_config)
            
        target_inv = common.AsNPCopy(registration_result['I1_inv'])
        print(target_inv.shape)
        if predict_transform_space:
            target_inv = util.convert_to_predict_space(target_inv)
        print(target_inv.shape)
        target_inv = torch.from_numpy(target_inv)

        target_image[slice_idx] = target_inv

        optimization_momentum[slice_idx] = optimization_momentum[slice_idx] - torch.from_numpy(predicted_momentum['prediction_space'])
        
    torch.save(target_image, args.warped_back_target_output[datapart_idx])
    torch.save(optimization_momentum, args.momentum_residual[datapart_idx])
Ejemplo n.º 3
0
def preprocess_image(image_pyca, histeq):
    image_np = common.AsNPCopy(image_pyca)
    nan_mask = np.isnan(image_np)
    image_np[nan_mask] = 0
    image_np /= np.amax(image_np)

    # perform histogram equalization if needed
    if histeq:
        image_np[image_np != 0] = exposure.equalize_hist(image_np[image_np != 0])

    return image_np
def intensity_normalization_histeq(args):
    for i in range(0, len(args.input_images)):
        image = common.LoadITKImage(args.output_images[i], ca.MEM_HOST)
        grid = image.grid()
        image_np = common.AsNPCopy(image)
        nan_mask = np.isnan(image_np)
        image_np[nan_mask] = 0
        image_np /= np.amax(image_np)

        # perform histogram equalization if needed
        if args.histeq:
            image_np[image_np != 0] = exposure.equalize_hist(
                image_np[image_np != 0])
        image_result = common.ImFromNPArr(image_np, ca.MEM_HOST)
        image_result.setGrid(grid)
        common.SaveITKImage(image_result, args.output_images[i])
Ejemplo n.º 5
0
def predict_image(args, moving_images, target_images, output_prefixes):
    if (args.use_CPU_for_shooting):
        mType = ca.MEM_HOST
    else:
        mType = ca.MEM_DEVICE

    # load the prediction network
    predict_network_config = torch.load(args.prediction_parameter)
    prediction_net = create_net(args, predict_network_config);

    batch_size = args.batch_size
    patch_size = predict_network_config['patch_size']
    input_batch = torch.zeros(batch_size, 2, patch_size, patch_size, patch_size).cuda()

    # use correction network if required
    if args.use_correction:
        correction_network_config = torch.load(args.correction_parameter);
        correction_net = create_net(args, correction_network_config);
    else:
        correction_net = None;

    # start prediction
    for i in range(0, len(moving_images)):

        common.Mkdir_p(os.path.dirname(output_prefixes[i]))
        if (args.affine_align):
            # Perform affine registration to both moving and target image to the ICBM152 atlas space.
            # Registration is done using Niftireg.
            call(["reg_aladin",
                  "-noSym", "-speeeeed", "-ref", args.atlas ,
                  "-flo", moving_images[i],
                  "-res", output_prefixes[i]+"moving_affine.nii",
                  "-aff", output_prefixes[i]+'moving_affine_transform.txt'])

            call(["reg_aladin",
                  "-noSym", "-speeeeed" ,"-ref", args.atlas ,
                  "-flo", target_images[i],
                  "-res", output_prefixes[i]+"target_affine.nii",
                  "-aff", output_prefixes[i]+'target_affine_transform.txt'])

            moving_image = common.LoadITKImage(output_prefixes[i]+"moving_affine.nii", mType)
            target_image = common.LoadITKImage(output_prefixes[i]+"target_affine.nii", mType)
        else:
            moving_image = common.LoadITKImage(moving_images[i], mType)
            target_image = common.LoadITKImage(target_images[i], mType)

        #preprocessing of the image
        moving_image_np = preprocess_image(moving_image, args.histeq);
        target_image_np = preprocess_image(target_image, args.histeq);

        grid = moving_image.grid()
        #moving_image = ca.Image3D(grid, mType)
        #target_image = ca.Image3D(grid, mType)
        moving_image_processed = common.ImFromNPArr(moving_image_np, mType)
        target_image_processed = common.ImFromNPArr(target_image_np, mType)
        moving_image.setGrid(grid)
        target_image.setGrid(grid)

        # Indicating whether we are using the old parameter files for the Neuroimage experiments (use .t7 files from matlab .h5 format)
        predict_transform_space = False
        if 'matlab_t7' in predict_network_config:
            predict_transform_space = True
        # run actual prediction
        prediction_result = util.predict_momentum(moving_image_np, target_image_np, input_batch, batch_size, patch_size, prediction_net, predict_transform_space);
        m0 = prediction_result['image_space']
        #convert to registration space and perform registration
        m0_reg = common.FieldFromNPArr(m0, mType);

        #perform correction
        if (args.use_correction):
            registration_result = registration_methods.geodesic_shooting(moving_image_processed, target_image_processed, m0_reg, args.shoot_steps, mType, predict_network_config)
            target_inv_np = common.AsNPCopy(registration_result['I1_inv'])

            correct_transform_space = False
            if 'matlab_t7' in correction_network_config:
                correct_transform_space = True
            correction_result = util.predict_momentum(moving_image_np, target_inv_np, input_batch, batch_size, patch_size, correction_net, correct_transform_space);
            m0_correct = correction_result['image_space']
            m0 += m0_correct;
            m0_reg = common.FieldFromNPArr(m0, mType);

        registration_result = registration_methods.geodesic_shooting(moving_image, target_image, m0_reg, args.shoot_steps, mType, predict_network_config)

        #endif

        write_result(registration_result, output_prefixes[i]);
Ejemplo n.º 6
0
def predict_image(args):
    if (args.use_CPU_for_shooting):
        mType = ca.MEM_HOST
    else:
        mType = ca.MEM_DEVICE

    # load the prediction network
    predict_network_config = torch.load(args.prediction_parameter)
    prediction_net = create_net(args, predict_network_config)

    batch_size = args.batch_size
    patch_size = predict_network_config['patch_size']
    input_batch = torch.zeros(batch_size, 2, patch_size, patch_size,
                              patch_size).cuda()

    # start prediction
    for i in range(0, len(args.moving_image)):
        common.Mkdir_p(os.path.dirname(args.output_prefix[i]))
        if (args.affine_align):
            # Perform affine registration to both moving and target image to the ICBM152 atlas space.
            # Registration is done using Niftireg.
            call([
                "reg_aladin", "-noSym", "-speeeeed", "-ref", args.atlas,
                "-flo", args.moving_image[i], "-res",
                args.output_prefix[i] + "moving_affine.nii", "-aff",
                args.output_prefix[i] + 'moving_affine_transform.txt'
            ])

            call([
                "reg_aladin", "-noSym", "-speeeeed", "-ref", args.atlas,
                "-flo", args.target_image[i], "-res",
                args.output_prefix[i] + "target_affine.nii", "-aff",
                args.output_prefix[i] + 'target_affine_transform.txt'
            ])

            moving_image = common.LoadITKImage(
                args.output_prefix[i] + "moving_affine.nii", mType)
            target_image = common.LoadITKImage(
                args.output_prefix[i] + "target_affine.nii", mType)
        else:
            moving_image = common.LoadITKImage(args.moving_image[i], mType)
            target_image = common.LoadITKImage(args.target_image[i], mType)

        #preprocessing of the image
        moving_image_np = preprocess_image(moving_image, args.histeq)
        target_image_np = preprocess_image(target_image, args.histeq)

        grid = moving_image.grid()
        moving_image_processed = common.ImFromNPArr(moving_image_np, mType)
        target_image_processed = common.ImFromNPArr(target_image_np, mType)
        moving_image.setGrid(grid)
        target_image.setGrid(grid)

        predict_transform_space = False
        if 'matlab_t7' in predict_network_config:
            predict_transform_space = True
        # run actual prediction
        prediction_result = util.predict_momentum(moving_image_np,
                                                  target_image_np, input_batch,
                                                  batch_size, patch_size,
                                                  prediction_net,
                                                  predict_transform_space)

        m0 = prediction_result['image_space']
        m0_reg = common.FieldFromNPArr(prediction_result['image_space'], mType)
        registration_result = registration_methods.geodesic_shooting(
            moving_image_processed, target_image_processed, m0_reg,
            args.shoot_steps, mType, predict_network_config)
        phi = common.AsNPCopy(registration_result['phiinv'])
        phi_square = np.power(phi, 2)

        for sample_iter in range(1, args.samples):
            print(sample_iter)
            prediction_result = util.predict_momentum(
                moving_image_np, target_image_np, input_batch, batch_size,
                patch_size, prediction_net, predict_transform_space)
            m0 += prediction_result['image_space']
            m0_reg = common.FieldFromNPArr(prediction_result['image_space'],
                                           mType)
            registration_result = registration_methods.geodesic_shooting(
                moving_image_processed, target_image_processed, m0_reg,
                args.shoot_steps, mType, predict_network_config)
            phi += common.AsNPCopy(registration_result['phiinv'])
            phi_square += np.power(
                common.AsNPCopy(registration_result['phiinv']), 2)

        m0_mean = np.divide(m0, args.samples)
        m0_reg = common.FieldFromNPArr(m0_mean, mType)
        registration_result = registration_methods.geodesic_shooting(
            moving_image_processed, target_image_processed, m0_reg,
            args.shoot_steps, mType, predict_network_config)
        phi_mean = registration_result['phiinv']
        phi_var = np.divide(phi_square, args.samples) - np.power(
            np.divide(phi, args.samples), 2)

        #save result
        common.SaveITKImage(registration_result['I1'],
                            args.output_prefix[i] + "I1.mhd")
        common.SaveITKField(phi_mean,
                            args.output_prefix[i] + "phiinv_mean.mhd")
        common.SaveITKField(common.FieldFromNPArr(phi_var, mType),
                            args.output_prefix[i] + "phiinv_var.mhd")