Ejemplo n.º 1
0
def main():
    import argparse, h5py, os
    import matplotlib.pyplot as plt
    from rapprentice import clouds, plotting_plt
    import registration
    import time
    
    parser = argparse.ArgumentParser()
    parser.add_argument("input_file", type=str)
    parser.add_argument("--output_folder", type=str, default="")
    parser.add_argument("--i_start", type=int, default=0)
    parser.add_argument("--i_end", type=int, default=-1)
    regtype_choices = ['rpm', 'rpm-cheap', 'rpm-bij', 'rpm-bij-cheap']
    parser.add_argument("--regtypes", type=str, nargs='*', choices=regtype_choices, default=regtype_choices)
    parser.add_argument("--plot_color", type=int, default=1)
    parser.add_argument("--proj", type=int, default=1, help="project 3d visualization into 2d")
    parser.add_argument("--visual_prior", type=int, default=1)
    parser.add_argument("--plotting", type=int, default=1)

    args = parser.parse_args()
    
    def plot_cb_gen(output_prefix, args, x_color, y_color):
        def plot_cb(x_nd, y_md, corr_nm, f, iteration):
            if args.plot_color:
                plotting_plt.plot_tps_registration(x_nd, y_md, f, x_color = x_color, y_color = y_color, proj_2d=args.proj)
            else:
                plotting_plt.plot_tps_registration(x_nd, y_md, f, proj_2d=args.proj)
            # save plot to file
            if output_prefix is not None:
                plt.savefig(output_prefix + "_iter" + str(iteration) + '.png')
        return plot_cb

    def plot_cb_bij_gen(output_prefix, args, x_color, y_color):
        def plot_cb_bij(x_nd, y_md, xtarg_nd, corr_nm, wt_n, f):
            if args.plot_color:
                plotting_plt.plot_tps_registration(x_nd, y_md, f, res = (.3, .3, .12), x_color = x_color, y_color = y_color, proj_2d=args.proj)
            else:
                plotting_plt.plot_tps_registration(x_nd, y_md, f, res = (.4, .3, .12), proj_2d=args.proj)
            # save plot to file
            if output_prefix is not None:
                plt.savefig(output_prefix + "_iter" + str(iteration) + '.png')
        return plot_cb_bij

    # preprocess and downsample clouds
    DS_SIZE = 0.025
    infile = h5py.File(args.input_file)
    source_clouds = {}
    target_clouds = {}
    for i in range(args.i_start, len(infile) if args.i_end==-1 else args.i_end):
        source_cloud = clouds.downsample(infile[str(i)]['source_cloud'][()], DS_SIZE)
        source_clouds[i] = source_cloud
        target_clouds[i] = []
        for (cloud_key, target_cloud) in infile[str(i)]['target_clouds'].iteritems():
            target_cloud = clouds.downsample(target_cloud[()], DS_SIZE)
            target_clouds[i].append(target_cloud)
    infile.close()
    
    tps_costs = []
    tps_reg_costs = []
    for regtype in args.regtypes:
        start_time = time.time()
        costs = []
        reg_costs = []
        for i in range(args.i_start, len(source_clouds) if args.i_end==-1 else args.i_end):
            source_cloud = source_clouds[i]
            for target_cloud in target_clouds[i]:
                if args.visual_prior:
                    vis_cost_xy = ab_cost(source_cloud, target_cloud)
                else:
                    vis_cost_xy = None
                if regtype == 'rpm':
                    f, corr_nm = tps_rpm(source_cloud[:,:-3], target_cloud[:,:-3],
                                         vis_cost_xy = vis_cost_xy,
                                         plotting=args.plotting, plot_cb = plot_cb_gen(os.path.join(args.output_folder, str(i) + "_" + cloud_key + "_rpm") if args.output_folder else None,
                                                                                       args,
                                                                                       source_cloud[:,-3:],
                                                                                       target_cloud[:,-3:]))
                elif regtype == 'rpm-cheap':
                    f, corr_nm = tps_rpm(source_cloud[:,:-3], target_cloud[:,:-3],
                                         vis_cost_xy = vis_cost_xy, n_iter = N_ITER_CHEAP, em_iter = EM_ITER_CHEAP, 
                                         plotting=args.plotting, plot_cb = plot_cb_gen(os.path.join(args.output_folder, str(i) + "_" + cloud_key + "_rpm_cheap") if args.output_folder else None,
                                                                                       args,
                                                                                       source_cloud[:,-3:],
                                                                                       target_cloud[:,-3:]))
                elif regtype == 'rpm-bij':
                    x_nd = source_cloud[:,:3]
                    y_md = target_cloud[:,:3]
                    scaled_x_nd, _ = registration.unit_boxify(x_nd)
                    scaled_y_md, _ = registration.unit_boxify(y_md)
                    f,g = registration.tps_rpm_bij(scaled_x_nd, scaled_y_md, rot_reg=np.r_[1e-4, 1e-4, 1e-1], n_iter=50, reg_init=10, reg_final=.1, outlierfrac=1e-2, vis_cost_xy=vis_cost_xy,
                                                   plotting=args.plotting, plot_cb=plot_cb_bij_gen(os.path.join(args.output_folder, str(i) + "_" + cloud_key + "_rpm_bij") if args.output_folder else None,
                                                                                                   args,
                                                                                                   source_cloud[:,-3:],
                                                                                                   target_cloud[:,-3:]))
                elif regtype == 'rpm-bij-cheap':
                    x_nd = source_cloud[:,:3]
                    y_md = target_cloud[:,:3]
                    scaled_x_nd, _ = registration.unit_boxify(x_nd)
                    scaled_y_md, _ = registration.unit_boxify(y_md)
                    f,g = registration.tps_rpm_bij(scaled_x_nd, scaled_y_md, rot_reg=np.r_[1e-4, 1e-4, 1e-1], n_iter=10, outlierfrac=1e-2, vis_cost_xy=vis_cost_xy, # Note registration_cost_cheap in rope_qlearn has a different rot_reg and outlierfrac
                                                   plotting=args.plotting, plot_cb=plot_cb_bij_gen(os.path.join(args.output_folder, str(i) + "_" + cloud_key + "_rpm_bij_cheap") if args.output_folder else None,
                                                                                                   args,
                                                                                                   source_cloud[:,-3:],
                                                                                                   target_cloud[:,-3:]))
                costs.append(f._cost)
                reg_costs.append(registration.tps_reg_cost(f))
        tps_costs.append(costs)
        tps_reg_costs.append(reg_costs)
        print regtype, "time elapsed", time.time() - start_time

    np.set_printoptions(suppress=True)
    
    print ""
    print "tps_costs"
    print args.regtypes
    print np.array(tps_costs).T
    
    print ""
    print "tps_reg_costs"
    print args.regtypes
    print np.array(tps_reg_costs).T
Ejemplo n.º 2
0
def tps_segment_registration(rope_nodes_or_crossing_info0, rope_nodes_or_crossing_info1, cloud0 = None, cloud1 = None, corr_tile_pattern = np.array([[1]]), rev_perm = None,
                             x_weights = None, reg = .1, rot_reg = np.r_[1e-4, 1e-4, 1e-1], plotting = False, plot_cb = None):
    """
    Find a registration by assigning correspondences based on the topology of the rope
    If rope_nodes0 and rope_nodes1 have the same topology (up to a variant of removing the last crossing in open ropes), the correspondences are given by linearly interpolating segments of both rope_nodes. The rope_nodes are segmented based on crossings.
    If rope_nodes0 and rope_nodes1 don't have the same topology, this function returns None for the TPS and the correspondence matrix
    rope_nodes_or_crossing_info is either rope nodes, which is an ordered sequence of points (i.e. it is the back bone of its respective rope), or is a tuple containing the rope nodes and crossings information (the information returned by knot_classifier.calculateCrossings)
    rev_perm is the permutation matrix of how corr_tile_pattern changes when the rope_nodes have been reversed
    """
    if type(rope_nodes_or_crossing_info0) == tuple:
        rope_nodes0, crossings0, crossings_links_inds0, cross_pairs0, rope_closed0 = rope_nodes_or_crossing_info0
    else:
        rope_nodes0 = rope_nodes_or_crossing_info0
        crossings0, crossings_links_inds0, cross_pairs0, rope_closed0 = knot_classifier.calculateCrossings(rope_nodes0)
    if type(rope_nodes_or_crossing_info1) == tuple:
        rope_nodes1, crossings1, crossings_links_inds1, cross_pairs1, rope_closed1 = rope_nodes_or_crossing_info1
    else:
        rope_nodes1 = rope_nodes_or_crossing_info1
        crossings1, crossings_links_inds1, cross_pairs1, rope_closed1 = knot_classifier.calculateCrossings(rope_nodes1)

    n,d = rope_nodes0.shape
    m,_ = rope_nodes1.shape
    
    # Compile all possible (reasonable) registrations and later select the one with the lowest bending cost
    f_variations = []
    corr_nm_variations = []
    
    # Add registrations for the closed versions of any open rope
    if not rope_closed0 or not rope_closed1:
        rope_nodes_crossing_infos0 = []
        rope_nodes_crossing_infos1 = []
        if not rope_closed0:
            for end in [0,-1]:
                rope_nodes_crossing_infos0.append((rope_nodes0,) + knot_classifier.close_rope(crossings0, crossings_links_inds0, cross_pairs0, end) + (True,))
        else:
            rope_nodes_crossing_infos0.append((rope_nodes0, crossings0, crossings_links_inds0, cross_pairs0, True))
        if not rope_closed1:
            for end in [0,-1]:
                rope_nodes_crossing_infos1.append((rope_nodes1,) + knot_classifier.close_rope(crossings1, crossings_links_inds1, cross_pairs1, end) + (True,))
        else:
            rope_nodes_crossing_infos1.append((rope_nodes1, crossings1, crossings_links_inds1, cross_pairs1, True))
        for rope_nodes_crossing_info0 in rope_nodes_crossing_infos0:
            for rope_nodes_crossing_info1 in rope_nodes_crossing_infos1:
                f_var, corr_nm_var = tps_segment_registration(rope_nodes_crossing_info0, rope_nodes_crossing_info1, cloud0 = cloud0, cloud1 = cloud1, corr_tile_pattern = corr_tile_pattern, 
                                                              x_weights = x_weights, reg = reg, rot_reg = rot_reg, plotting = False, plot_cb = None)
                f_variations.append(f_var)
                corr_nm_variations.append(corr_nm_var)
    
    crossings0 = np.array(crossings0)
    crossings1 = np.array(crossings1)
    crossings_links_inds0 = np.array(crossings_links_inds0)
    crossings_links_inds1 = np.array(crossings_links_inds1)
    
    pts_segmentation_inds0 = np.r_[0, crossings_links_inds0 + 1, n]
    pts_segmentation_inds1 = np.r_[0, crossings_links_inds1 + 1, m]

    if cross_pairs0 == cross_pairs1: # same topology
        # need to try the tps registration f for rope_nodes1 and/or the reverse rope_nodes1
        reversed_rope_points1_variations = []
        if np.all(crossings0 == crossings1):
            reversed_rope_points1_variations.append(False)
        # could happen when (1) rope_nodes1 are in a reverse order compared to rope_nodes0, or (2) crossings1 is a palindrome, or (3) both
        if np.all(crossings0 == crossings1[::-1]):
            reversed_rope_points1_variations.append(True)
        
        if len(reversed_rope_points1_variations) > 0:
            for reversed_rope_points1 in reversed_rope_points1_variations:
                if reversed_rope_points1:
                    corr_nm_var = calc_segment_corr(rope_nodes1[::-1], pts_segmentation_inds0, m - pts_segmentation_inds1[::-1])
                    corr_nm_var = corr_nm_var[:,::-1]
                    if rev_perm is None:
                        rev_perm = np.eye(len(corr_tile_pattern))
                        rev_perm = rev_perm[::-1]
                        rev_perm = np.r_[rev_perm[(len(rev_perm)/2)-1:,:], rev_perm[:(len(rev_perm)/2)-1,:]]
                    corr_nm_var_aug = tile(corr_nm_var, rev_perm.dot(corr_tile_pattern))
                else:
                    corr_nm_var = calc_segment_corr(rope_nodes1, pts_segmentation_inds0, pts_segmentation_inds1)
                    corr_nm_var_aug = tile(corr_nm_var, corr_tile_pattern)
                
                cloud0_var = cloud0 if cloud0 is not None else rope_nodes0
                cloud1_var = cloud1 if cloud1 is not None else rope_nodes1
                assert corr_nm_var_aug.shape == (len(cloud0_var), len(cloud1_var))

                f_var = fit_ThinPlateSpline_corr(cloud0_var, cloud1_var, corr_nm_var_aug, reg, rot_reg, x_weights)
        
                f_variations.append(f_var)
                corr_nm_variations.append(corr_nm_var)
    
    # filter out the invalid registrations
    f_variations = [f_var for f_var in f_variations if f_var is not None]
    corr_nm_variations = [corr_nm_var for corr_nm_var in corr_nm_variations if corr_nm_var is not None]

    if not f_variations:
        f = None
        corr_nm = None
    else:
        if len(f_variations) > 1:
            reflected_reg_costs = [(np.linalg.det(f_var.lin_ag) < 0, registration.tps_reg_cost(f_var)) for f_var in f_variations] # first element indicates if the affine part of this transformation is a reflection
            # sort the registrations from non-reflected to reflected transformations first and then from low to high bending cost to break ties
            f_variations, corr_nm_variations = zip(*[(f_var, corr_nm_var) for (reflected_reg_cost, f_var, corr_nm_var) in sorted(zip(reflected_reg_costs, f_variations, corr_nm_variations))])
        f = f_variations[0]
        corr_nm = corr_nm_variations[0]
    
    # TODO plot correct pts_segmemtation_inds
    if plotting:
        corr_nm_aug = tile(corr_nm, corr_tile_pattern) if corr_nm is not None else None
        plot_cb(rope_nodes0, rope_nodes1, cloud0, cloud1, corr_nm, corr_nm_aug, f, pts_segmentation_inds0, pts_segmentation_inds1)

    return f, corr_nm