def predict_each_datapart(args, net, network_config, input_batch, datapart_idx, batch_size, patch_size, predict_transform_space):
    moving_image = torch.load(args.moving_image_dataset[datapart_idx])
    target_image = torch.load(args.target_image_dataset[datapart_idx])
    optimization_momentum = torch.load(args.deformation_parameter[datapart_idx])
    for slice_idx in range(0, moving_image.size()[0]):
        print(slice_idx)
        moving_slice = moving_image[slice_idx].numpy()
        target_slice = target_image[slice_idx].numpy()
        if predict_transform_space:
            moving_slice = util.convert_to_registration_space(moving_slice)
            target_slice = util.convert_to_registration_space(target_slice)

        predicted_momentum = util.predict_momentum(moving_slice, target_slice, input_batch, batch_size, patch_size, net, predict_transform_space);
        m0_reg = common.FieldFromNPArr(predicted_momentum['image_space'], ca.MEM_DEVICE);
            
        moving_image_ca = common.ImFromNPArr(moving_slice, ca.MEM_DEVICE)
        target_image_ca = common.ImFromNPArr(target_slice, ca.MEM_DEVICE)

        registration_result = registration_methods.geodesic_shooting(moving_image_ca, target_image_ca, m0_reg, args.shoot_steps, ca.MEM_DEVICE, network_config)
            
        target_inv = common.AsNPCopy(registration_result['I1_inv'])
        print(target_inv.shape)
        if predict_transform_space:
            target_inv = util.convert_to_predict_space(target_inv)
        print(target_inv.shape)
        target_inv = torch.from_numpy(target_inv)

        target_image[slice_idx] = target_inv

        optimization_momentum[slice_idx] = optimization_momentum[slice_idx] - torch.from_numpy(predicted_momentum['prediction_space'])
        
    torch.save(target_image, args.warped_back_target_output[datapart_idx])
    torch.save(optimization_momentum, args.momentum_residual[datapart_idx])
Пример #2
0
 def test_ResampleInterp(self, disp=False):
     # generate small integer-valued image
     initMax = 5
     randArrSmall = (np.random.rand(10, 10) * initMax).astype(int)
     randImSmall = common.ImFromNPArr(randArrSmall)
     imLarge = Image3D(50, 50, 1)
     Resample(imLarge, randImSmall, BACKGROUND_STRATEGY_CLAMP, INTERP_NN)
     nUnique = len(np.unique(imLarge.asnp()))
     self.assertEqual(nUnique, initMax)
Пример #3
0
 def randMaskSetUp(self):
     randArr = np.random.rand(self.sz[0], self.sz[1])
     maskArr = np.zeros(randArr.shape)
     maskArr[randArr > 0.5] = 1.0
     self.hRandMask = common.ImFromNPArr(maskArr,
                                         mType=MEM_HOST,
                                         sp=self.imSp)
     self.dRandMask = self.hRandMask.copy()
     self.dRandMask.toType(MEM_DEVICE)
Пример #4
0
    def __init__(self, methodName='runTest'):
        super(CpuGpuTestCase, self).__init__(methodName)

        self.cudaEnabled = (GetNumberOfCUDADevices() > 0)

        if self.cudaEnabled:
            # allowable average abs. diff
            self.AvEps = 1e-6
            # allowable max abs. diff
            self.MaxEps = 1e-4
            # image size
            self.sz = np.array([127, 119])
            # spacing
            self.sp = np.array([1.5, 2.1])
            # fluid parameters
            self.fluidParams = [1.0, 1.0, 0.0]

            self.vsz = np.append(self.sz, 2)
            self.imSz = Vec3Di(int(self.sz[0]), int(self.sz[1]), 1)
            self.imSp = Vec3Df(float(self.sp[0]), float(self.sp[1]), 1.0)
            # set up grid
            self.grid = GridInfo(self.imSz, self.imSp)

            # set up host / device images
            self.I0Arr = common.DrawEllipse(self.sz, self.sz / 2,
                                            self.sz[0] / 4, self.sz[1] / 3)
            self.I1Arr = common.DrawEllipse(self.sz, self.sz / 2,
                                            self.sz[0] / 3, self.sz[1] / 4)

            self.I0Arr = common.GaussianBlur(self.I0Arr, 1.5)
            self.I1Arr = common.GaussianBlur(self.I1Arr, 1.5)

            self.hI0Orig = common.ImFromNPArr(self.I0Arr,
                                              mType=MEM_HOST,
                                              sp=self.imSp)
            self.hI1Orig = common.ImFromNPArr(self.I1Arr,
                                              mType=MEM_HOST,
                                              sp=self.imSp)
            self.dI0Orig = common.ImFromNPArr(self.I0Arr,
                                              mType=MEM_DEVICE,
                                              sp=self.imSp)
            self.dI1Orig = common.ImFromNPArr(self.I1Arr,
                                              mType=MEM_DEVICE,
                                              sp=self.imSp)
def intensity_normalization_histeq(args):
    for i in range(0, len(args.input_images)):
        image = common.LoadITKImage(args.output_images[i], ca.MEM_HOST)
        grid = image.grid()
        image_np = common.AsNPCopy(image)
        nan_mask = np.isnan(image_np)
        image_np[nan_mask] = 0
        image_np /= np.amax(image_np)

        # perform histogram equalization if needed
        if args.histeq:
            image_np[image_np != 0] = exposure.equalize_hist(
                image_np[image_np != 0])
        image_result = common.ImFromNPArr(image_np, ca.MEM_HOST)
        image_result.setGrid(grid)
        common.SaveITKImage(image_result, args.output_images[i])
def main():
    secNum = sys.argv[1]
    mkyNum = sys.argv[2]
    region = str(sys.argv[3])
    # channel = sys.argv[3]
    ext = 'M{0}/section_{1}/{2}/'.format(mkyNum, secNum, region)
    ss_dir = '/home/sci/blakez/korenbergNAS/3D_database/Working/Microscopic/side_light_microscope/'
    conf_dir = '/home/sci/blakez/korenbergNAS/3D_database/Working/Microscopic/confocal/'
    memT = ca.MEM_DEVICE

    try:
        with open(
                ss_dir +
                'src_registration/M{0}/section_{1}/M{0}_01_section_{1}_regions.txt'
                .format(mkyNum, secNum), 'r') as f:
            region_dict = json.load(f)
            f.close()
    except IOError:
        region_dict = {}
        region_dict[region] = {}
        region_dict['size'] = map(
            int,
            raw_input("What is the size of the full resolution image x,y? ").
            split(','))
        region_dict[region]['bbx'] = map(
            int,
            raw_input(
                "What are the x indicies of the bounding box (Matlab Format x_start,x_stop? "
            ).split(','))
        region_dict[region]['bby'] = map(
            int,
            raw_input(
                "What are the y indicies of the bounding box (Matlab Format y_start,y_stop? "
            ).split(','))

    if region not in region_dict:
        region_dict[region] = {}
        region_dict[region]['bbx'] = map(
            int,
            raw_input(
                "What are the x indicies of the bounding box (Matlab Format x_start,x_stop? "
            ).split(','))
        region_dict[region]['bby'] = map(
            int,
            raw_input(
                "What are the y indicies of the bounding box (Matlab Format y_start,y_stop? "
            ).split(','))

    img_region = common.LoadITKImage(
        ss_dir +
        'src_registration/M{0}/section_{1}/M{0}_01_section_{1}_{2}.tiff'.
        format(mkyNum, secNum, region), ca.MEM_HOST)
    ssiSrc = common.LoadITKImage(
        ss_dir +
        'src_registration/M{0}/section_{1}/frag0/M{0}_01_ssi_section_{1}_frag0.nrrd'
        .format(mkyNum, secNum), ca.MEM_HOST)
    bfi_df = common.LoadITKField(
        ss_dir +
        'Blockface_registered/M{0}/section_{1}/frag0/M{0}_01_ssi_section_{1}_frag0_to_bfi_real.mha'
        .format(mkyNum, secNum), ca.MEM_DEVICE)

    # Figure out the same region in the low resolution image: There is a transpose from here to matlab so dimensions are flipped
    low_sz = ssiSrc.size().tolist()
    yrng_raw = [(low_sz[1] * region_dict[region]['bbx'][0]) /
                np.float(region_dict['size'][0]),
                (low_sz[1] * region_dict[region]['bbx'][1]) /
                np.float(region_dict['size'][0])]
    xrng_raw = [(low_sz[0] * region_dict[region]['bby'][0]) /
                np.float(region_dict['size'][1]),
                (low_sz[0] * region_dict[region]['bby'][1]) /
                np.float(region_dict['size'][1])]
    yrng = [np.int(np.floor(yrng_raw[0])), np.int(np.ceil(yrng_raw[1]))]
    xrng = [np.int(np.floor(xrng_raw[0])), np.int(np.ceil(xrng_raw[1]))]
    low_sub = cc.SubVol(ssiSrc, xrng, yrng)

    # Figure out the grid for the sub region in relation to the sidescape
    originout = [
        ssiSrc.origin().x + ssiSrc.spacing().x * xrng[0],
        ssiSrc.origin().y + ssiSrc.spacing().y * yrng[0], 0
    ]
    spacingout = [
        (low_sub.size().x * ssiSrc.spacing().x) / (img_region.size().x),
        (low_sub.size().y * ssiSrc.spacing().y) / (img_region.size().y), 1
    ]

    gridout = cc.MakeGrid(img_region.size().tolist(), spacingout, originout)
    img_region.setGrid(gridout)

    only_sub = np.zeros(ssiSrc.size().tolist()[0:2])
    only_sub[xrng[0]:xrng[1], yrng[0]:yrng[1]] = np.squeeze(low_sub.asnp())
    only_sub = common.ImFromNPArr(only_sub)
    only_sub.setGrid(ssiSrc.grid())

    # Deform the only sub region to
    only_sub.toType(ca.MEM_DEVICE)
    def_sub = ca.Image3D(bfi_df.grid(), bfi_df.memType())
    cc.ApplyHReal(def_sub, only_sub, bfi_df)
    def_sub.toType(ca.MEM_HOST)

    # Now have to find the bounding box in the deformation space (bfi space)
    if 'deformation_bbx' not in region_dict[region]:
        bb_def = np.squeeze(pp.LandmarkPicker([np.squeeze(def_sub.asnp())]))
        bb_def_y = [bb_def[0][0], bb_def[1][0]]
        bb_def_x = [bb_def[0][1], bb_def[1][1]]
        region_dict[region]['deformation_bbx'] = bb_def_x
        region_dict[region]['deformation_bby'] = bb_def_y

    with open(
            ss_dir +
            'src_registration/M{0}/section_{1}/M{0}_01_section_{1}_regions.txt'
            .format(mkyNum, secNum), 'w') as f:
        json.dump(region_dict, f)
        f.close()

    # Now need to extract the region and create a deformation and image that have the same resolution as the img_region
    deform_sub = cc.SubVol(bfi_df, region_dict[region]['deformation_bbx'],
                           region_dict[region]['deformation_bby'])

    common.DebugHere()
    sizeout = [
        int(
            np.ceil((deform_sub.size().x * deform_sub.spacing().x) /
                    img_region.spacing().x)),
        int(
            np.ceil((deform_sub.size().y * deform_sub.spacing().y) /
                    img_region.spacing().y)), 1
    ]

    region_grid = cc.MakeGrid(sizeout,
                              img_region.spacing().tolist(),
                              deform_sub.origin().tolist())

    def_im_region = ca.Image3D(region_grid, deform_sub.memType())
    up_deformation = ca.Field3D(region_grid, deform_sub.memType())

    img_region.toType(ca.MEM_DEVICE)
    cc.ResampleWorld(up_deformation, deform_sub,
                     ca.BACKGROUND_STRATEGY_PARTIAL_ZERO)
    cc.ApplyHReal(def_im_region, img_region, up_deformation)

    ss_out = ss_dir + 'Blockface_registered/M{0}/section_{1}/{2}/'.format(
        mkyNum, secNum, region)

    if not pth.exists(pth.expanduser(ss_out)):
        os.mkdir(pth.expanduser(ss_out))

    common.SaveITKImage(
        def_im_region,
        pth.expanduser(ss_out) +
        'M{0}_01_section_{1}_{2}_def_to_bfi.nrrd'.format(
            mkyNum, secNum, region))
    common.SaveITKImage(
        def_im_region,
        pth.expanduser(ss_out) +
        'M{0}_01_section_{1}_{2}_def_to_bfi.tiff'.format(
            mkyNum, secNum, region))
    del img_region, def_im_region, ssiSrc, deform_sub

    # Now apply the same deformation to the confocal images
    conf_grid = cc.LoadGrid(
        conf_dir +
        'sidelight_registered/M{0}/section_{1}/{2}/affine_registration_grid.txt'
        .format(mkyNum, secNum, region))
    cf_out = conf_dir + 'blockface_registered/M{0}/section_{1}/{2}/'.format(
        mkyNum, secNum, region)
    # confocal.toType(ca.MEM_DEVICE)
    # def_conf = ca.Image3D(region_grid, deform_sub.memType())
    # cc.ApplyHReal(def_conf, confocal, up_deformation)

    for channel in range(0, 4):
        z_stack = []
        num_slices = len(
            glob.glob(conf_dir +
                      'sidelight_registered/M{0}/section_{1}/{3}/Ch{2}/*.tiff'.
                      format(mkyNum, secNum, channel, region)))
        for z in range(0, num_slices):
            src_im = common.LoadITKImage(
                conf_dir +
                'sidelight_registered/M{0}/section_{1}/{3}/Ch{2}/M{0}_01_section_{1}_LGN_RHS_Ch{2}_conf_aff_sidelight_z{4}.tiff'
                .format(mkyNum, secNum, channel, region,
                        str(z).zfill(2)))
            src_im.setGrid(
                cc.MakeGrid(
                    ca.Vec3Di(conf_grid.size().x,
                              conf_grid.size().y, 1), conf_grid.spacing(),
                    conf_grid.origin()))
            src_im.toType(ca.MEM_DEVICE)
            def_im = ca.Image3D(region_grid, ca.MEM_DEVICE)
            cc.ApplyHReal(def_im, src_im, up_deformation)
            def_im.toType(ca.MEM_HOST)
            common.SaveITKImage(
                def_im, cf_out +
                'Ch{2}/M{0}_01_section_{1}_{3}_Ch{2}_conf_def_blockface_z{4}.tiff'
                .format(mkyNum, secNum, channel, region,
                        str(z).zfill(2)))
            if z == 0:
                common.SaveITKImage(
                    def_im, cf_out +
                    'Ch{2}/M{0}_01_section_{1}_{3}_Ch{2}_conf_def_blockface_z{4}.nrrd'
                    .format(mkyNum, secNum, channel, region,
                            str(z).zfill(2)))
            z_stack.append(def_im)
            print('==> Done with Ch {0}: {1}/{2}'.format(
                channel, z, num_slices - 1))
        stacked = cc.Imlist_to_Im(z_stack)
        stacked.setSpacing(
            ca.Vec3Df(region_grid.spacing().x,
                      region_grid.spacing().y,
                      conf_grid.spacing().z))
        common.SaveITKImage(
            stacked, cf_out +
            'Ch{2}/M{0}_01_section_{1}_{3}_Ch{2}_conf_def_blockface_stack.nrrd'
            .format(mkyNum, secNum, channel, region))
        if channel == 0:
            cc.WriteGrid(
                stacked.grid(),
                cf_out + 'deformed_registration_grid.txt'.format(
                    mkyNum, secNum, region))
Пример #7
0
# curDir = [x for x in folderList if '__E'+str(test) in x]

# for a in folderList:
#     print a

#Load in the T2 image and downsample it to the resolution of the DW images
####
T2_list = sorted(glob.glob(indir + 'T2DICOM_scan26/*'))
refIm = dicom.read_file(T2_list[0])
PixelDims = (int(refIm.Rows), int(refIm.Columns), len(T2_list))
# PixelSpacing = (0.5,0.5,0.5)
T2Array = np.zeros(PixelDims, dtype=refIm.pixel_array.dtype)
for filename in T2_list:
    ds = dicom.read_file(filename)
    T2Array[:, :, T2_list.index(filename)] = ds.pixel_array
T2MRI = common.ImFromNPArr(T2Array)
T2MRI.setGrid(cc.MakeGrid(T2MRI.grid().size(), 0.5))
T2MRI.toType(ca.MEM_DEVICE)

#Swap the axis of the images so they align with the gradient directions
T2MRI = cc.SwapAxes(T2MRI, 0, 1)
T2MRI = cc.SwapAxes(T2MRI, 0, 2)
T2MRI = cc.FlipDim(T2MRI, 2)
# T2MRI = cc.FlipDim(T2MRI,2)

DWIgrid = cc.MakeGrid([120, 144, 120], 0.5, [0, 0, 0])
down_T2 = ca.Image3D(DWIgrid, ca.MEM_DEVICE)
ca.Resample(down_T2, T2MRI)
####
#Display the list
Пример #8
0
    Run a test showing different interpolation methods used for
    upsampling and deformation.
    """

    import PyCA.Core as ca
    import PyCA.Common as common
    import PyCA.Display as display

    import numpy as np
    import matplotlib.pyplot as plt

    plt.ion()

    initMax = 5
    randArrSmall = (np.random.rand(10, 10) * initMax).astype(int)
    imSmall = common.ImFromNPArr(randArrSmall)
    imLargeNN = ca.Image3D(50, 50, 1)
    imLargeLinear = ca.Image3D(50, 50, 1)
    imLargeCubic = ca.Image3D(50, 50, 1)
    ca.Resample(imLargeNN, imSmall, ca.BACKGROUND_STRATEGY_CLAMP, ca.INTERP_NN)
    ca.Resample(imLargeLinear, imSmall, ca.BACKGROUND_STRATEGY_CLAMP,
                ca.INTERP_LINEAR)
    ca.Resample(imLargeCubic, imSmall, ca.BACKGROUND_STRATEGY_CLAMP,
                ca.INTERP_CUBIC)
    plt.figure('interp test')
    plt.subplot(2, 3, 1)
    display.DispImage(imLargeNN, 'NN', newFig=False)
    plt.subplot(2, 3, 2)
    display.DispImage(imLargeLinear, 'Linear', newFig=False)
    plt.subplot(2, 3, 3)
    display.DispImage(imLargeCubic, 'Cubic', newFig=False)
Пример #9
0
def RunTest():

    # number of iterations
    nIters = 2000
    #nIters = 0

    disp = True
    dispEvery = 1000

    if GetNumberOfCUDADevices() > 0:
        mType = MEM_DEVICE
    else:
        print "No CUDA devices found, running on CPU"
        mType = MEM_HOST

    # data fidelity modifier
    DataFidC = 20.0
    # TV modifier
    TVC = 1.0

    TVPow = 1.0

    UseMask = True

    # regularization term to avoid zero denominator
    Beta = 1e-5

    stepI = 0.001

    imagedir = './Images/'

    #
    # Run lena images
    #

    Data = common.LoadPNGImage(imagedir + 'lena_orig.png', mType)

    imSz = Data.size()
    sz = imSz.tolist()[0:2]

    if True:
        I0 = Data.copy()
    else:
        I0 = common.RandImage(nSig=1.0, gSig=5.0, mType=mType)

    Mask = None
    if UseMask:
        bdr = 10
        MaskArr = np.zeros(sz)
        MaskArr[bdr:-bdr, bdr:-bdr] = 1.0
        Mask = common.ImFromNPArr(MaskArr, mType)

    (I, energy) = \
        RunROFTV(Data=Data, \
                     I0 = I0, \
                     DataFidC=DataFidC, \
                     TVC=TVC, \
                     TVPow=TVPow, \
                     stepI=stepI, \
                     Beta=Beta, \
                     nIters=nIters, \
                     dispEvery=dispEvery, \
                     disp=disp, \
                     Mask=Mask)

    print 'final energy: {ttl:n} = {im:n} + {tv:n}'\
        .format(ttl=energy[2][-1], \
                    im=energy[0][-1], \
                    tv=energy[1][-1])
Пример #10
0
def predict_image(args, moving_images, target_images, output_prefixes):
    if (args.use_CPU_for_shooting):
        mType = ca.MEM_HOST
    else:
        mType = ca.MEM_DEVICE

    # load the prediction network
    predict_network_config = torch.load(args.prediction_parameter)
    prediction_net = create_net(args, predict_network_config);

    batch_size = args.batch_size
    patch_size = predict_network_config['patch_size']
    input_batch = torch.zeros(batch_size, 2, patch_size, patch_size, patch_size).cuda()

    # use correction network if required
    if args.use_correction:
        correction_network_config = torch.load(args.correction_parameter);
        correction_net = create_net(args, correction_network_config);
    else:
        correction_net = None;

    # start prediction
    for i in range(0, len(moving_images)):

        common.Mkdir_p(os.path.dirname(output_prefixes[i]))
        if (args.affine_align):
            # Perform affine registration to both moving and target image to the ICBM152 atlas space.
            # Registration is done using Niftireg.
            call(["reg_aladin",
                  "-noSym", "-speeeeed", "-ref", args.atlas ,
                  "-flo", moving_images[i],
                  "-res", output_prefixes[i]+"moving_affine.nii",
                  "-aff", output_prefixes[i]+'moving_affine_transform.txt'])

            call(["reg_aladin",
                  "-noSym", "-speeeeed" ,"-ref", args.atlas ,
                  "-flo", target_images[i],
                  "-res", output_prefixes[i]+"target_affine.nii",
                  "-aff", output_prefixes[i]+'target_affine_transform.txt'])

            moving_image = common.LoadITKImage(output_prefixes[i]+"moving_affine.nii", mType)
            target_image = common.LoadITKImage(output_prefixes[i]+"target_affine.nii", mType)
        else:
            moving_image = common.LoadITKImage(moving_images[i], mType)
            target_image = common.LoadITKImage(target_images[i], mType)

        #preprocessing of the image
        moving_image_np = preprocess_image(moving_image, args.histeq);
        target_image_np = preprocess_image(target_image, args.histeq);

        grid = moving_image.grid()
        #moving_image = ca.Image3D(grid, mType)
        #target_image = ca.Image3D(grid, mType)
        moving_image_processed = common.ImFromNPArr(moving_image_np, mType)
        target_image_processed = common.ImFromNPArr(target_image_np, mType)
        moving_image.setGrid(grid)
        target_image.setGrid(grid)

        # Indicating whether we are using the old parameter files for the Neuroimage experiments (use .t7 files from matlab .h5 format)
        predict_transform_space = False
        if 'matlab_t7' in predict_network_config:
            predict_transform_space = True
        # run actual prediction
        prediction_result = util.predict_momentum(moving_image_np, target_image_np, input_batch, batch_size, patch_size, prediction_net, predict_transform_space);
        m0 = prediction_result['image_space']
        #convert to registration space and perform registration
        m0_reg = common.FieldFromNPArr(m0, mType);

        #perform correction
        if (args.use_correction):
            registration_result = registration_methods.geodesic_shooting(moving_image_processed, target_image_processed, m0_reg, args.shoot_steps, mType, predict_network_config)
            target_inv_np = common.AsNPCopy(registration_result['I1_inv'])

            correct_transform_space = False
            if 'matlab_t7' in correction_network_config:
                correct_transform_space = True
            correction_result = util.predict_momentum(moving_image_np, target_inv_np, input_batch, batch_size, patch_size, correction_net, correct_transform_space);
            m0_correct = correction_result['image_space']
            m0 += m0_correct;
            m0_reg = common.FieldFromNPArr(m0, mType);

        registration_result = registration_methods.geodesic_shooting(moving_image, target_image, m0_reg, args.shoot_steps, mType, predict_network_config)

        #endif

        write_result(registration_result, output_prefixes[i]);
Пример #11
0
def predict_image(args):
    if (args.use_CPU_for_shooting):
        mType = ca.MEM_HOST
    else:
        mType = ca.MEM_DEVICE

    # load the prediction network
    predict_network_config = torch.load(args.prediction_parameter)
    prediction_net = create_net(args, predict_network_config)

    batch_size = args.batch_size
    patch_size = predict_network_config['patch_size']
    input_batch = torch.zeros(batch_size, 2, patch_size, patch_size,
                              patch_size).cuda()

    # start prediction
    for i in range(0, len(args.moving_image)):
        common.Mkdir_p(os.path.dirname(args.output_prefix[i]))
        if (args.affine_align):
            # Perform affine registration to both moving and target image to the ICBM152 atlas space.
            # Registration is done using Niftireg.
            call([
                "reg_aladin", "-noSym", "-speeeeed", "-ref", args.atlas,
                "-flo", args.moving_image[i], "-res",
                args.output_prefix[i] + "moving_affine.nii", "-aff",
                args.output_prefix[i] + 'moving_affine_transform.txt'
            ])

            call([
                "reg_aladin", "-noSym", "-speeeeed", "-ref", args.atlas,
                "-flo", args.target_image[i], "-res",
                args.output_prefix[i] + "target_affine.nii", "-aff",
                args.output_prefix[i] + 'target_affine_transform.txt'
            ])

            moving_image = common.LoadITKImage(
                args.output_prefix[i] + "moving_affine.nii", mType)
            target_image = common.LoadITKImage(
                args.output_prefix[i] + "target_affine.nii", mType)
        else:
            moving_image = common.LoadITKImage(args.moving_image[i], mType)
            target_image = common.LoadITKImage(args.target_image[i], mType)

        #preprocessing of the image
        moving_image_np = preprocess_image(moving_image, args.histeq)
        target_image_np = preprocess_image(target_image, args.histeq)

        grid = moving_image.grid()
        moving_image_processed = common.ImFromNPArr(moving_image_np, mType)
        target_image_processed = common.ImFromNPArr(target_image_np, mType)
        moving_image.setGrid(grid)
        target_image.setGrid(grid)

        predict_transform_space = False
        if 'matlab_t7' in predict_network_config:
            predict_transform_space = True
        # run actual prediction
        prediction_result = util.predict_momentum(moving_image_np,
                                                  target_image_np, input_batch,
                                                  batch_size, patch_size,
                                                  prediction_net,
                                                  predict_transform_space)

        m0 = prediction_result['image_space']
        m0_reg = common.FieldFromNPArr(prediction_result['image_space'], mType)
        registration_result = registration_methods.geodesic_shooting(
            moving_image_processed, target_image_processed, m0_reg,
            args.shoot_steps, mType, predict_network_config)
        phi = common.AsNPCopy(registration_result['phiinv'])
        phi_square = np.power(phi, 2)

        for sample_iter in range(1, args.samples):
            print(sample_iter)
            prediction_result = util.predict_momentum(
                moving_image_np, target_image_np, input_batch, batch_size,
                patch_size, prediction_net, predict_transform_space)
            m0 += prediction_result['image_space']
            m0_reg = common.FieldFromNPArr(prediction_result['image_space'],
                                           mType)
            registration_result = registration_methods.geodesic_shooting(
                moving_image_processed, target_image_processed, m0_reg,
                args.shoot_steps, mType, predict_network_config)
            phi += common.AsNPCopy(registration_result['phiinv'])
            phi_square += np.power(
                common.AsNPCopy(registration_result['phiinv']), 2)

        m0_mean = np.divide(m0, args.samples)
        m0_reg = common.FieldFromNPArr(m0_mean, mType)
        registration_result = registration_methods.geodesic_shooting(
            moving_image_processed, target_image_processed, m0_reg,
            args.shoot_steps, mType, predict_network_config)
        phi_mean = registration_result['phiinv']
        phi_var = np.divide(phi_square, args.samples) - np.power(
            np.divide(phi, args.samples), 2)

        #save result
        common.SaveITKImage(registration_result['I1'],
                            args.output_prefix[i] + "I1.mhd")
        common.SaveITKField(phi_mean,
                            args.output_prefix[i] + "phiinv_mean.mhd")
        common.SaveITKField(common.FieldFromNPArr(phi_var, mType),
                            args.output_prefix[i] + "phiinv_var.mhd")
def Fragmenter():
    tmpOb = Config.Load(
        frgSpec,
        pth.expanduser(
            '~/korenbergNAS/3D_database/Working/configuration_files/SidescapeRelateBlockface/M{0}/section_{1}/section_{1}_frag0.yaml'
            .format(secOb.mkyNum, secOb.secNum)))
    dictBuild = {}
    #Load in the whole image so that the fragment can cropped out
    ssiSrc, bfiSrc, ssiMsk, bfiMsk = Loader(tmpOb, ca.MEM_HOST)

    #Because some of the functions only woth with gray images
    bfiGry = ca.Image3D(bfiSrc.grid(), bfiSrc.memType())
    ca.Copy(bfiGry, bfiSrc, 1)

    lblSsi, _ = ndimage.label(np.squeeze(ssiMsk.asnp()) > 0)
    lblBfi, _ = ndimage.label(np.squeeze(bfiMsk.asnp()) > 0)

    seedPt = np.squeeze(pp.LandmarkPicker([lblBfi, lblSsi]))
    subMskBfi = common.ImFromNPArr(lblBfi == lblBfi[seedPt[0, 0],
                                                    seedPt[0,
                                                           1]].astype('int8'),
                                   sp=bfiSrc.spacing(),
                                   orig=bfiSrc.origin())
    subMskSsi = common.ImFromNPArr(lblSsi == lblSsi[seedPt[1, 0],
                                                    seedPt[1,
                                                           1]].astype('int8'),
                                   sp=ssiSrc.spacing(),
                                   orig=ssiSrc.origin())

    bfiGry *= subMskBfi
    bfiSrc *= subMskBfi
    ssiSrc *= subMskSsi
    #Pick points that are the bounding box of the desired subvolume
    corners = np.array(
        pp.LandmarkPicker(
            [np.squeeze(bfiGry.asnp()),
             np.squeeze(ssiSrc.asnp())]))
    bfiCds = corners[:, 0]
    ssiCds = corners[:, 1]

    #Extract the region from the source images
    bfiRgn = cc.SubVol(bfiSrc,
                       xrng=[bfiCds[0, 0], bfiCds[1, 0]],
                       yrng=[bfiCds[0, 1], bfiCds[1, 1]])
    ssiRgn = cc.SubVol(ssiSrc,
                       xrng=[ssiCds[0, 0], ssiCds[1, 0]],
                       yrng=[ssiCds[0, 1], ssiCds[1, 1]])

    #Extract the region from the mask images
    rgnMskSsi = cc.SubVol(subMskSsi,
                          xrng=[ssiCds[0, 0], ssiCds[1, 0]],
                          yrng=[ssiCds[0, 1], ssiCds[1, 1]])
    rgnMskBfi = cc.SubVol(subMskBfi,
                          xrng=[bfiCds[0, 0], bfiCds[1, 0]],
                          yrng=[bfiCds[0, 1], bfiCds[1, 1]])

    dictBuild['rgnBfi'] = np.divide(
        bfiCds, np.array(bfiSrc.size().tolist()[0:2], 'float')).tolist()
    dictBuild['rgnSsi'] = np.divide(
        ssiCds, np.array(ssiSrc.size().tolist()[0:2], 'float')).tolist()

    #Check the output directory for the source files of the fragment
    if not pth.exists(
            pth.expanduser(secOb.ssiSrcPath + 'frag{0}'.format(frgNum))):
        os.mkdir(pth.expanduser(secOb.ssiSrcPath + 'frag{0}'.format(frgNum)))
    if not pth.exists(
            pth.expanduser(secOb.bfiSrcPath + 'frag{0}'.format(frgNum))):
        os.mkdir(pth.expanduser(secOb.bfiSrcPath + 'frag{0}'.format(frgNum)))
    #Check the output directory for the mask files of the fragment
    if not pth.exists(
            pth.expanduser(secOb.ssiMskPath + 'frag{0}'.format(frgNum))):
        os.mkdir(pth.expanduser(secOb.ssiMskPath + 'frag{0}'.format(frgNum)))
    if not pth.exists(
            pth.expanduser(secOb.bfiMskPath + 'frag{0}'.format(frgNum))):
        os.mkdir(pth.expanduser(secOb.bfiMskPath + 'frag{0}'.format(frgNum)))

    dictBuild[
        'ssiSrcName'] = 'frag{0}/M{1}_01_ssi_section_{2}_frag1.tif'.format(
            frgNum, secOb.mkyNum, secOb.secNum)
    dictBuild[
        'bfiSrcName'] = 'frag{0}/M{1}_01_bfi_section_{2}_frag1.mha'.format(
            frgNum, secOb.mkyNum, secOb.secNum)
    dictBuild[
        'ssiMskName'] = 'frag{0}/M{1}_01_ssi_section_{2}_frag1_mask.tif'.format(
            frgNum, secOb.mkyNum, secOb.secNum)
    dictBuild[
        'bfiMskName'] = 'frag{0}/M{1}_01_bfi_section_{2}_frag1_mask.tif'.format(
            frgNum, secOb.mkyNum, secOb.secNum)

    #Write out the masked and cropped images so that they can be loaded from the YAML file
    #The BFI region needs to be saved as color and mha format so that the grid information is carried over.
    common.SaveITKImage(
        ssiRgn, pth.expanduser(secOb.ssiSrcPath + dictBuild['ssiSrcName']))
    cc.WriteColorMHA(
        bfiRgn, pth.expanduser(secOb.bfiSrcPath + dictBuild['bfiSrcName']))
    common.SaveITKImage(
        rgnMskSsi, pth.expanduser(secOb.ssiMskPath + dictBuild['ssiMskName']))
    common.SaveITKImage(
        rgnMskBfi, pth.expanduser(secOb.bfiMskPath + dictBuild['bfiMskName']))

    frgOb = Config.MkConfig(dictBuild, frgSpec)
    updateFragOb(frgOb)

    return None