예제 #1
0
def create_dLDP_illustration():
    id_trans = transforms.CompositeTransform(2, [transforms.ScalingTransform(2, uniform=True), \
                                          transforms.Rigid2DTransform()])

    skip = 24
    limit = 1
    norm = True
    blur = 3.0
    modelname = 'dLDP'  #Choose from ['alphaAMD', 'MI', 'MSE', 'dLDP']
    modelparams = {'interpolation': 'nearest'}
    optname = 'gridsearch'  #Choose from ['gd', 'adam', 'gridsearch', 'bfgs']
    optparams = {
        'bounds': gridBounds(id_trans, 0, 0, 0),
        'steps': [1, 1, 1, 1]
    }

    for slide, roi_idx, ref_im, flo_im in getNextMPMPair(verbose=True, \
                                                         server=False, \
                                                         norm=norm, \
                                                         blur=blur, \
                                                         rotate=True, \
                                                         quantize=32):
        if skip > 0:
            skip -= 1
            continue
        limit -= 1
        if limit < 0:
            break

        reg = Register(2)
        reg.set_model(modelname, **modelparams)
        reg.set_optimizer(optname, **optparams)
        reg.set_image_data(ref_im, \
                           flo_im, \
                           ref_mask=np.ones(ref_im.shape, 'bool'), \
                           flo_mask=np.ones(flo_im.shape, 'bool'), \
                           ref_weights=None, \
                           flo_weights=None
                           )
        reg.add_pyramid_levels(factors=[
            1,
        ], sigmas=[
            0.0,
        ])
        reg.add_initial_transform(id_trans)
        # Start the pre-processing
        reg.initialize("./tmp/")

        # The initialization sets up the distance measure, now we can use it
        dist = reg.distances[-1]

        shg_dLDP, shg_mask = dist.create_LDP(ref_im)
        tpef_dLDP, tpef_mask = dist.create_LDP(flo_im)

        #show what the dLDPs look like for ref image
        fig = plt.figure(figsize=(15, 10))
        plt.subplot(2, 3, 1)
        plt.imshow(ref_im, cmap='gray', vmin=0, vmax=1)
        plt.title('a)', loc='left')
        plt.axis('off')
        #show (d)LDP images for two pairs of directions
        i = 0  # i=0 #0 degrees for LDP; i=1 #0, 90 for dLDP
        shg_dLDP_im1 = dist.dLDP_as_image(
            shg_dLDP[...,
                     (8 * i):(8 *
                              (i + 1))])  #Turn the next 8 bits into an image
        i = 2  # i=2 #90 degrees for LDP; i=4 #45, 135 for dLDP
        shg_dLDP_im2 = dist.dLDP_as_image(
            shg_dLDP[...,
                     (8 * i):(8 *
                              (i + 1))])  #Turn the next 8 bits into an image
        plt.subplot(2, 3, 2)
        plt.imshow(shg_dLDP_im1, cmap='gray', vmin=0, vmax=1)
        plt.title('b)', loc='left')
        plt.axis('off')
        plt.subplot(2, 3, 3)
        plt.imshow(shg_dLDP_im2, cmap='gray', vmin=0, vmax=1)
        plt.title('c)', loc='left')
        plt.axis('off')

        #same for floating image
        plt.subplot(2, 3, 4)
        plt.imshow(flo_im, cmap='gray', vmin=0, vmax=1)
        plt.title('d)', loc='left')
        plt.axis('off')
        #show dLDP images for two pairs of directions
        i = 0  # i=0 #0 degrees for LDP; i=1 #0, 90 for dLDP
        tpef_dLDP_im1 = dist.dLDP_as_image(
            tpef_dLDP[...,
                      (8 * i):(8 *
                               (i + 1))])  #Turn the next 8 bits into an image
        i = 2  # i=2 #90 degrees for LDP; i=4 #45, 135 for dLDP
        tpef_dLDP_im2 = dist.dLDP_as_image(
            tpef_dLDP[...,
                      (8 * i):(8 *
                               (i + 1))])  #Turn the next 8 bits into an image
        plt.subplot(2, 3, 5)
        plt.imshow(tpef_dLDP_im1, cmap='gray', vmin=0, vmax=1)
        plt.title('e)', loc='left')
        plt.axis('off')
        plt.subplot(2, 3, 6)
        plt.imshow(tpef_dLDP_im2, cmap='gray', vmin=0, vmax=1)
        plt.title('f)', loc='left')
        plt.axis('off')

        fig.subplots_adjust(hspace=0.1, wspace=0.05)
        plt.show()
예제 #2
0
def create_surface(server=False):

    init_t = transforms.CompositeTransform(2, [transforms.ScalingTransform(2, uniform=True), \
                                          transforms.Rigid2DTransform()])
    #    rigT = transforms.Rigid2DTransform()
    #    rigT.set_params([0.35,0.5,0.5])
    #    rigT.set_params([0.,0.,0.])
    #    init_t = transforms.CompositeTransform(2, [transforms.ScalingTransform(2, uniform=True), rigT])

    ##### Running parameters to update each time #####
    modelname = 'dLDP'  #Choose from ['alphaAMD', 'MI', 'MSE', 'dLDP']
    #    modelparams = {}
    modelparams = {'version': 'dLDP_48'}  #dLDP versions: dLDP_8, dLDP_48, LDP
    #    modelparams = {'alpha_levels':7, 'symmetric_measure':True, 'squared_measure':False}
    #    modelparams = {'mutual_info_fun':'norm'}

    optname = 'gridsearch'  #Choose from ['gd', 'adam', 'gridsearch', 'bfgs']
    #    optparams = {'gradient_magnitude_threshold':1e-9, 'epsilon':0.02}
    optparams = {
        'bounds': gridBounds(init_t, 0.05, 5, 0),
        'steps': [41, 41, 1, 1]
    }
    #    optparams = {'gradient_magnitude_threshold':1e-6}

    norm = True
    blur = 0.0
    skip = 24
    limit = 1
    folder = local_sr_folder

    ##### End running parameters #####

    #    for slide, roi_idx, ref_im, flo_im in getNextSRPair(folder, order=True, verbose=True, server=server, norm=norm, blur=blur):
    for slide, roi_idx, ref_im, flo_im in getNextMPMPair(verbose=True,
                                                         server=server,
                                                         norm=norm,
                                                         blur=blur):
        #    slide = 'cilia'
        #    roi_idx = 'none'
        #    ref_im, _ = OpenAndPreProcessImage('./test_images/reference_example.png', norm=norm, blur=blur, copyOrig=False)
        #    flo_im, _ = OpenAndPreProcessImage('./test_images/floating_example.png', norm=norm, blur=blur, copyOrig=False)
        #    if True: #to save re-indenting after temporarily removing above loop
        if skip > 0:
            print("Skipping %s_%s" % (slide, roi_idx))
            skip -= 1
            continue

        limit -= 1
        if limit < 0:
            break

        print("Creating %s surface for sample %s, region %s" %
              (modelname, slide, roi_idx))
        reg = Register(2)
        reg.set_model(modelname, **modelparams)

        # Choose an optimzer, apply basic parameters specified above
        reg.set_optimizer(optname, **optparams)


        reg.set_image_data(ref_im, \
                           flo_im, \
                           ref_mask=np.ones(ref_im.shape, 'bool'), \
                           flo_mask=np.ones(flo_im.shape, 'bool'), \
                           ref_weights=None, \
                           flo_weights=None
                           )

        reg.add_pyramid_levels(factors=[
            1,
        ], sigmas=[
            0.0,
        ])
        reg.set_sampling_fraction(0.25)
        #        #Adam, GD
        #        step_lengths = np.array([[0.5, 0.1]])
        #        reg.set_step_lengths(step_lengths)
        #
        reg.add_initial_transform(init_t,
                                  param_scaling=[1 / 500., 1 / 500., 1., 1.])
        reg.set_report_freq(500)

        # Create output directory
        directory = os.path.dirname("./tmp/")
        if not os.path.exists(directory):
            os.makedirs(directory)

        # Start the pre-processing
        reg.initialize("./tmp/")

        # Start the registration
        reg.run()

        values = np.array(reg.get_value_history(0, 0))

        if True:  #Show 2d surface
            axes = [0, 1]
            values.resize([optparams['steps'][a] for a in axes])
            gridpts = [np.linspace(*optparams['bounds'][i], optparams['steps'][i]) \
                       for i in range(len(optparams['bounds']))]

            #Convert radians to degrees for display
            gridpts[1] *= 180 / np.pi

            #Determine what aspect is needed for a square image
            aspect = (gridpts[axes[1]][-1] - gridpts[axes[1]][0]) / (
                gridpts[axes[0]][-1] - gridpts[axes[0]][0])

            min_loc = np.unravel_index(np.argmin(values), values.shape)
            min_loc = [gridpts[i][min_loc[idx]] for idx, i in enumerate(axes)]
            plt.figure(figsize=(10, 10))
            ax = plt.gca()
            im = ax.imshow(values, extent=[gridpts[axes[1]][0], gridpts[axes[1]][-1], \
                                       gridpts[axes[0]][0], gridpts[axes[0]][-1]], \
                       origin='lower', aspect=aspect, cmap='inferno_r')
            plt.colorbar(im, fraction=0.046, pad=0.04)
            plt.title(f"{modelname} surface as scale and rotation change\n" \
                      +f"(translation fixed at zero, image{'' if blur else ' not'} smoothed)")
            #            plt.title(f"{modelname} surface as translation changes\n" \
            #                      +f"(scale 1.0, rotation 0, image{'' if blur else ' not'} smoothed)")
            min_loc = list(reversed(min_loc))
            plt.annotate('Min=%.4f at (%.1f, %.2f)' %
                         (np.min(values), *min_loc),
                         xy=min_loc,
                         xycoords='data',
                         xytext=(0.6, 0.04),
                         textcoords='figure fraction',
                         arrowprops=dict(arrowstyle="->"))

            axisnames = [
                'Scale (%)', 'Rotation (degrees)', 'x translation (px)',
                'y translation (px)'
            ]
            plt.xlabel(axisnames[axes[1]])
            plt.ylabel(axisnames[axes[0]])
            plt.show()
예제 #3
0
def register_pairs(server=False):
    #TODO. Docstrings.
    results = []
    id_trans = transforms.CompositeTransform(2, [transforms.ScalingTransform(2, uniform=True), \
                                          transforms.Rigid2DTransform()])

    ##### Running parameters to update each time #####
    modelname = 'dLDP'  #['alphaAMD', 'MI', 'MSE', 'dLDP']
    #    modelparams = {} #mse
    modelparams = {'version': 'dLDP_48'}  #dLDP versions: dLDP_8, dLDP_48, LDP
    #    modelparams = {'alpha_levels':7, 'symmetric_measure':True, 'squared_measure':False}
    #    modelparams = {'mutual_info_fun':'norm'}

    optname = 'bfgs'  #['gd', 'adam', 'gridsearch', 'bfgs']
    optparams = {'gradient_magnitude_threshold': 1e-9, 'epsilon': 0.05}  #bfgs
    #    optparams = {'bounds':gridBounds(id_trans, 0.05, 5, 10), 'steps':11} #gridsearch
    #    optparams = {'gradient_magnitude_threshold':1e-6} #adam, gd

    norm = True
    blur = 3.0
    skip = 0  #5 #manual way to skip pairs that have already been processed
    results_file = 'PartIII_test5.10_48bit.csv'
    limit = 25 - skip
    ##### End running parameters #####

    np.random.seed(
        999)  #For testing, make sure we get the same transforms each time
    rndTransforms = [
        getRandomTransform(maxRot=5, maxScale=1.05, maxTrans=10)
        for _ in range(limit + skip + 1)
    ]
    #Reverse the list as we will pop transforms from the far end. Want these to be the same, even
    #if we change the limit later.
    rndTransforms.reverse()

    folder = local_sr_folder
    if server:
        folder = server_sr_folder
    if server:
        outfile = server_separate_mpm_folder + results_file
    else:
        outfile = local_separate_mpm_folder + results_file

#    id_trans.set_params([1.,0.2,10.,10.]) #Nelder mead doesn't work starting from zeros

#OPTION: Starting from gridmax already found
#    grid_params = get_MI_gridmax(local_separate_mpm_folder+'PartI_test4.csv')

#    for slide, roi_idx, ref_im, flo_im in getNextSRPair(folder, order=True, verbose=True, server=server, norm=norm, blur=blur):
#    for slide, roi_idx, mpm_path, al_path in getNextPair():
    for slide, roi_idx, ref_im, flo_im in getNextMPMPair(verbose=True,
                                                         server=server,
                                                         norm=norm,
                                                         blur=blur):
        # Take the next random transform
        rndTrans = rndTransforms.pop()

        if skip > 0:
            print("Skipping %s_%s" % (slide, roi_idx))
            skip -= 1
            continue
        limit -= 1
        if limit < 0:
            break
        # Open and prepare the images: if using SRs then don't normalize (or do?)
#        ref_im, ref_im_orig = OpenAndPreProcessImage(al_path, copyOrig=True)
#        flo_im, flo_im_orig = OpenAndPreProcessImage(mpm_path, copyOrig=True)
#If aligning SHG + TPEF, keep a copy of the SHG (reference) image
#as it was before random transform is applied
#        flo_im_orig = flo_im.copy()
        ref_im_orig = ref_im.copy()

        print("Random transform applied: %r" % rndTrans.get_params())

        # Apply the transform to the reference image, increasing the canvas size to avoid cutting off parts
        ref_im = centreAndApplyTransform(
            ref_im, rndTrans,
            np.rint(np.array(ref_im.shape) * 1.5).astype('int'))

        # Show the images we are working with
        print("Aligning images for sample %s, region %s"%(slide, roi_idx) \
#              + ". A transform of %r has been applied to the reference image"%str(rndTrans.get_params())
#               + " from folder %s"%folder)


        )
        if False:
            plt.figure(figsize=(12, 6))
            plt.subplot(121)
            plt.imshow(ref_im, cmap='gray', vmin=0, vmax=1)
            plt.title("Reference image")
            plt.subplot(122)
            plt.imshow(flo_im, cmap='gray', vmin=0, vmax=1)
            plt.title("Floating image")
            plt.show()

        # Choose a model, set basic parameters for that model
        reg = Register(2)
        reg.set_model(modelname, **modelparams)

        # Choose an optimzer, set basic parameters for it
        reg.set_optimizer(optname, **optparams)

        # Since we have warped the original reference image, create a mask so that only the relevant
        # pixels are considered. Use the same warping function as above
        ref_mask = np.ones(ref_im_orig.shape, 'bool')
        ref_mask = centreAndApplyTransform(
            ref_mask, rndTrans,
            np.rint(np.array(ref_im_orig.shape) * 1.5).astype('int'))

        reg.set_image_data(ref_im, \
                           flo_im, \
                           ref_mask=ref_mask, \
                           flo_mask=np.ones(flo_im.shape, 'bool'), \
                           ref_weights=None, \
                           flo_weights=None
                           )

        ## Add pyramid levels
        if modelname.lower() == 'alphaamd':
            # Learning-rate / Step lengths [[start1, end1], [start2, end2] ...] (for each pyramid level)
            step_lengths = np.array([[1., 1.], [1., 0.5], [0.5, 0.1]])
            reg.set_step_lengths(step_lengths)
            reg.add_pyramid_levels(factors=[4, 2, 1], sigmas=[5.0, 3.0, 0.0])
            reg.set_sampling_fraction(
                0.5)  #very patchy with 0.1, also tried 0.25
            reg.set_iterations(5000)

        else:
            # I have seen no evidence so far, that pyramid levels lead the search towards the MI maximum.
            reg.add_pyramid_levels(factors=[
                1,
            ], sigmas=[
                0.0,
            ])
            # Try with a blurred full-resolution image first (or only)
#            reg.add_pyramid_levels(factors=[1,1], sigmas=[5.0,0.0])

## Add initial transform(s), with parameter scaling if required
        if optname.lower() == 'gridsearch':
            reg.add_initial_transform(id_trans)
        else:
            #BFGS and AlphaAMD
            # Estimate an appropriate parameter scaling based on the sizes of the images (not used in grid search).
            diag = transforms.image_diagonal(
                ref_im) + transforms.image_diagonal(flo_im)
            diag = 2.0 / diag
            #            p_scaling = np.array([diag*100, diag*100, 5.0, 5.0])
            p_scaling = np.array([diag * 2.0, diag * 2.0, 1.0, 1.0])
            reg.add_initial_transform(id_trans, param_scaling=p_scaling)

            #OPTION: in addition to the ID transform, add a bunch of random starting points
            add_multiple_startpts(reg, count=20, p_scaling=p_scaling)


#            #OPTION: Starting from gridmax already found
#            if not (slide, roi_idx) in grid_params:
#                print(f'Grid results not found for slide {slide}, region {roi_idx}')
#                continue
#            starting_params = grid_params[(slide, roi_idx)]
#            s_trans = transforms.ScalingTransform(2, uniform=True)
#            s_trans.set_params(starting_params[0])
#            r_trans = transforms.Rigid2DTransform()
#            r_trans.set_params(starting_params[1:4])
#            starting_trans = transforms.CompositeTransform(2, [s_trans, r_trans])
#            reg.add_initial_transform(starting_trans, param_scaling=p_scaling)

        reg.set_report_freq(250)

        # Create output directory
        directory = os.path.dirname("./tmp/")
        if not os.path.exists(directory):
            os.makedirs(directory)

        # Start the pre-processing
        reg.initialize("./tmp/", norm=norm)

        # Start the registration
        reg.run(verbose=True)

        # Get the results and find the best one (for the case when there
        # was more than one starting point)
        out_transforms, values = reg.get_outputs()
        transform = out_transforms[np.argmin(values)]
        value = np.min(values)
        successFlag = reg.get_flags()
        if len(successFlag) == 0:
            successFlag = 'N/A'
        else:
            #Use the optimizer flag for the best output transform found. SuccessFlag
            #has one result for each pyramid level, just take the last level.
            successFlag = successFlag[np.argmin(values)][-1]

        ### Warp final image
        c = transforms.make_image_centered_transform(transform, ref_im, flo_im)

        #       # Print out transformation parameters
        #        print('Transformation parameters: %s.' % str(transform.get_params()))

        # Create the output image
        im_warped = np.zeros(ref_im.shape)

        # Transform the floating image into the reference image space by applying transformation 'c'
        c.warp(In=flo_im, Out=im_warped, mode='nearest', bg_value=0.0)

        # Show the images we ended up with
        if False:
            print("Aligned images for sample %s, region %s" % (slide, roi_idx))
            plt.figure(figsize=(12, 6))
            plt.subplot(121)
            plt.imshow(ref_im, cmap='gray', vmin=0, vmax=1)
            plt.title("Reference image")
            plt.subplot(122)
            plt.imshow(im_warped, cmap='gray', vmin=0, vmax=1)
            plt.title("Floating image")
            plt.show()

        centred_gt_trans = transforms.make_image_centered_transform(rndTrans, \
                                                 ref_im, flo_im)

        gtVal = reg.get_value_at(rndTrans)
        err = get_transf_error(c, centred_gt_trans, flo_im.shape)
        print(
            "Estimated transform:\t [",
            ','.join(['%.4f'] * len(c.get_params())) % tuple(c.get_params()) +
            "] with value %.4f" % (value))
        print(
            "True transform:\t\t [",
            ','.join(['%.4f'] * len(rndTrans.get_params())) %
            tuple(rndTrans.get_params()) + "] with value %.4f" % (gtVal))
        print("Average corner error: %5f" % (err / 4))
        print("Value difference: %.5f" % (gtVal - value))
        #        print("Improvement over gridmax: %.5f"%(-value - float(starting_params[-1])))

        resultLine = (slide, roi_idx, *rndTrans.get_params(), \
                      gtVal, \
                      *c.get_params(), \
                      value, \
                      err, successFlag, \
                      time.strftime('%Y-%m-%d %H:%M:%S'))
        results.append(resultLine)
        with open(outfile, 'a') as f:
            writer = csv.writer(f, delimiter=',')
            writer.writerow(resultLine)