コード例 #1
0
    def _load_most_recent_checkpoint(self):
        block_num = 0
        meshes = {}
        links = {}
        norm_weights = {}
        structural_meshes = {}

        possible_fnames = sorted(
            glob.glob(
                os.path.join(self._checkpoints_dir, 'checkpoint_block_*.pkl')))
        # Find the latest valid checkpoint file
        cp_fname = None
        for fname in reversed(possible_fnames):
            if re.search(r'checkpoint_block_([0-9]+)\.pkl',
                         os.path.basename(fname)):
                cp_fname = fname
                break

        if cp_fname is not None:
            # Found a valid checkpoint file
            logger.report_event(
                "Loading checkpoint block data from {}".format(cp_fname),
                log_level=logging.INFO)
            with open(cp_fname, 'rb') as in_file:
                block_num, meshes, links, norm_weights, structural_meshes = pickle.load(
                    in_file)
            block_num += 1  # increase block num (because the next iteration starts from the next block)
        else:
            logger.report_event("No checkpoint block data found in {}".format(
                self._checkpoints_dir),
                                log_level=logging.INFO)

        return block_num, meshes, links, norm_weights, structural_meshes
コード例 #2
0
 def _compute_features(img, i):
     global _detector
     result = _detector.detect(img)
     logger.report_event("Img {}, found {} features.".format(
         i, len(result[0])),
                         log_level=logging.INFO)
     return result
コード例 #3
0
ファイル: common.py プロジェクト: Gilhirith/mb_aligner
def parse_workflows_folder(cur_fs, workflows_folder):
    '''
    Parses a folder which has at least one workflow folder.
    Each workflow folder will have the pattern [name]_[date]_[time],
    where [name] can be anything, [date] will be of the format YYYMMDD,
    and [time] will be HH-MM-SS.
    '''
    sub_folders = cur_fs.glob("*/")
    all_workflow_folders = []
    dir_to_time = {}
    for folder_glob in sub_folders:
        if folder_glob.info.is_dir:
            folder = folder_glob.path
            m = re.match('.*_([0-9]{8})_([0-9]{2})-([0-9]{2})-([0-9]{2})/$', folder)
            if m is not None:
                dir_to_time[folder] = "{}_{}-{}-{}".format(m.group(1), m.group(2), m.group(3), m.group(4))
                all_workflow_folders.append(folder)

    full_result = {}
    for sub_folder in sorted(all_workflow_folders, key=lambda folder: dir_to_time[folder]):
        if cur_fs.isdir(sub_folder):
            logger.report_event("Parsing sections from subfolder: {}".format(sub_folder), log_level=logging.INFO)
            full_result.update(parse_workflow_folder(cur_fs, sub_folder))

    # For debug
    # first_keys = sorted(list(full_result.keys()))[:20]
    # full_result = {k:full_result[k] for k in first_keys}
    return full_result
コード例 #4
0
def add_transformation(in_file, out_file, transform, deltas):
    # load the current json file
    try:
        with open(in_file, 'r') as f:
            data = json.load(f)
    except:
        logger.report_event("Error when reading {} - Exiting".format(in_file),
                            log_level=logging.ERROR)
        raise

    if deltas[0] != 0.0 and deltas[1] != 0.0:
        for tile in data:
            # Update the transformation
            if "transforms" not in tile.keys():
                tile["transforms"] = []
            tile["transforms"].append(transform)

            # Update the bbox
            if "bbox" in tile.keys():
                bbox = tile["bbox"]
                bbox_new = [
                    bbox[0] - deltas[0], bbox[1] - deltas[0],
                    bbox[2] - deltas[1], bbox[3] - deltas[1]
                ]
                tile["bbox"] = bbox_new

    with open(out_file, 'w') as f:
        json.dump(data, f, indent=4)
コード例 #5
0
    def __init__(self, **kwargs):

        self._matcher_kwargs = kwargs
        self._mesh_spacing = kwargs.get("mesh_spacing", 1500)

        #         self._scaling = kwargs.get("scaling", 0.2)
        #         self._template_size = kwargs.get("template_size", 200)
        #         self._search_window_size = kwargs.get("search_window_size", 8 * template_size)
        #         logger.report_event("Actual template size: {} and window search size: {} (after scaling)".format(template_size * scaling, search_window_size * scaling), log_level=logging.INFO)
        #
        #         # Parameters for PMCC filtering
        #         self._min_corr = kwargs.get("min_correlation", 0.2)
        #         self._max_curvature = kwargs.get("maximal_curvature_ratio", 10)
        #         self._max_rod = kwargs.get("maximal_ROD", 0.9)
        #         self._use_clahe = kwargs.get("use_clahe", False)

        self._debug_dir = kwargs.get("debug_dir", None)
        if self._debug_dir is not None:
            logger.report_event("Debug mode - on", log_level=logging.INFO)
            # Create a debug directory
            import datetime
            self._debug_dir = os.path.join(
                self._debug_dir,
                'debug_matches_{}'.format(datetime.datetime.now().isoformat()))
            os.mkdirs(self._debug_dir)
    def __init__(self, section, mesh_spacing, refined_mesh_spacing):

        self._section = section
        self._mesh_spacing = mesh_spacing
        self._refined_mesh_spacing = refined_mesh_spacing
        self._section_mesh_refiner = SectionMeshRefiner(
            self._section, self._mesh_spacing, self._refined_mesh_spacing)

        # load the mesh
        #self.orig_pts = np.array(points, dtype=FLOAT_TYPE).reshape((-1, 2)).copy()
        self.pts = np.array(
            self._section_mesh_refiner.get_refined_mesh_points(),
            dtype=FLOAT_TYPE).reshape((-1, 2)).copy()
        #        center = self.pts.mean(axis=0)
        #        self.pts -= center
        #        self.pts *= 1.1
        #        self.pts += center
        self.orig_pts = self.pts.copy()

        self.pts_neg_mask = np.zeros((len(self.pts), ), dtype=np.bool)

        logger.report_event("# points in base mesh {}".format(
            self.pts.shape[0]),
                            log_level=logging.DEBUG)

        # for neighbor searching and internal mesh
        self.triangulation = Delaunay(self.pts)
コード例 #7
0
    def _create_mesh(self, sec_idx, section, meshes, structural_meshes):
        if sec_idx not in meshes:
            logger.report_event("Creating mesh for section: {}".format(section.layer), log_level=logging.DEBUG)
            meshes[sec_idx] = Mesh(utils.generate_hexagonal_grid(section.bbox, self._mesh_spacing))

            # Build internal structural mesh
            # (edge_indices, edge_lengths, face_indices, face_areas)
            structural_meshes[sec_idx] = meshes[sec_idx].internal_structural_mesh()
コード例 #8
0
        def match_sec2_to_sec1_mfov(self, sec2_pts):
            """
            sec2_pts will be in the original space (before scaling)
            """
            valid_matches = [[], [], []]
            invalid_matches = [[], []]
            if len(sec2_pts) == 0:
                return valid_matches, invalid_matches

            # Assume that only sec1 renderer was transformed and not sec2 (and both scaled)
            sec2_pts = np.atleast_2d(sec2_pts)

            #mat = self._sec1_to_sec2_transform.get_matrix()
            #inverse_mat = np.linalg.inv(mat)
 
            sec2_pts_on_sec1 = self._inverse_model.apply(sec2_pts)

            for sec2_pt, sec1_pt_estimated in zip(sec2_pts, sec2_pts_on_sec1):

                # Fetch the template around sec2_pt
                from_x2, from_y2 = sec2_pt - self._template_side
                to_x2, to_y2 = sec2_pt + self._template_side
                sec2_pt_features_kps, sec2_pt_features_descs = FeaturesBlockMatcherDispatcher.FeaturesBlockMatcher._fetch_sec_features(self._sec2, self._sec2_tiles_rtree, self._sec2_cache_features, tuple(np.array([from_x2, to_x2, from_y2, to_y2]) * self._scaling))
            
                if len(sec2_pt_features_kps) <= 1:
                    continue

                # Fetch a large sub-image around sec1_pt_estimated (after transformation, using search_window_scaled_size)
                from_x1, from_y1 = sec1_pt_estimated - self._search_window_side
                to_x1, to_y1 = sec1_pt_estimated + self._search_window_side
                sec1_pt_est_features_kps, sec1_pt_est_features_descs = FeaturesBlockMatcherDispatcher.FeaturesBlockMatcher._fetch_sec_features(self._sec1, self._sec1_tiles_rtree, self._sec1_cache_features, tuple(np.array([from_x1, to_x1, from_y1, to_y1]) * self._scaling))

                if len(sec1_pt_est_features_kps) <= 1:
                    continue

                # apply the inverse transformation on sec2 feature points locations (after upscaling and then downscaling again)
                sec2_pt_features_kps = self._inverse_model.apply(sec2_pt_features_kps / self._scaling) * self._scaling 
                # Match the features
                transform_model, filtered_matches = self._matcher.match_and_filter(sec2_pt_features_kps, sec2_pt_features_descs, sec1_pt_est_features_kps, sec1_pt_est_features_descs)
                if transform_model is None:
                    invalid_matches[0].append(sec2_pt)
                    invalid_matches[1].append(1)
                else:
                    # the transform model need to be scaled
                    transform_matrix = transform_model.get_matrix()
                    transform_model.set(transform_matrix[:2, 2].T / self._scaling)

                    # Compute the location of the matched point on sec2
                    sec1_pt = transform_model.apply(sec1_pt_estimated)# + np.array([from_x1, from_y1]) + self._template_side
                    logger.report_event("{}: match found: {} and {} (orig assumption: {})".format(os.getpid(), sec2_pt, sec1_pt, sec1_pt_estimated), log_level=logging.DEBUG)
#                     if self._debug_save_matches:
#                         # TODO
                    valid_matches[0].append(sec2_pt)
                    valid_matches[1].append(sec1_pt)
                    valid_matches[2].append(len(filtered_matches[0]) / len(sec2_pt_features_kps))


            return valid_matches, invalid_matches
コード例 #9
0
ファイル: common.py プロジェクト: lichtman-lab/mb_aligner
def parse_workflow_folder(workflow_folder):
    '''
    Parses a single folder which has a coordinates text file for the section or multiple coordinates text files for its mfovs.
    The section coordinates filename will be of the format: full_image_coordinates.txt (or full_image_coordinates_corrected.txt),
    and the per-mfov image coordinates file will be: image_coordinates.txt
    Each workflow folder will have the following format:
    [N]_S[S]R1
    Where [N] is a 3-digit number (irrelevant), and [S] is the section number.
    Returns a map between a section number and the coordinates txt filename (or filenames in case no full section coordinates file was found).
    '''
    result = {}
    all_sec_folders = sorted(glob.glob(os.path.join(workflow_folder,
                                                    '*_S*R1')))
    for sec_folder in all_sec_folders:
        m = re.match('([0-9]{3})_S([0-9]+)R1$', os.path.basename(sec_folder))
        if m is not None:
            # make sure the txt file is there
            image_coordinates_files = None
            if os.path.exists(
                    os.path.join(sec_folder,
                                 'full_image_coordinates_corrected.txt')):
                image_coordinates_files = os.path.join(
                    sec_folder, 'full_image_coordinates_corrected.txt')
            elif os.path.exists(
                    os.path.join(sec_folder, 'full_image_coordinates.txt')):
                image_coordinates_files = os.path.join(
                    sec_folder, 'full_image_coordinates.txt')
            else:
                # look for all mfov folders
                mfov_folders = sorted(glob.glob(os.path.join(sec_folder,
                                                             "0*")))
                if len(mfov_folders) == 0:
                    logger.report_event(
                        "Could not detect coordinate/mfov files for sec: {} - skipping"
                        .format(sec_folder),
                        log_level=logging.WARN)
                    continue
                all_mfov_folders_have_coordinates = True
                for mfov_folder in mfov_folders:
                    if not os.path.exists(
                            os.path.join(mfov_folder,
                                         "image_coordinates.txt")):
                        logger.report_event(
                            "Could not detect mfov coordinates file for mfov: {} - skipping"
                            .format(mfov_folder),
                            log_level=logging.WARN)
                        all_mfov_folders_have_coordinates = False
                if not all_mfov_folders_have_coordinates:
                    continue
                # Take the mfovs folders into account
                image_coordinates_files = [
                    os.path.join(mfov_folder, "image_coordinates.txt")
                    for mfov_folder in mfov_folders
                ]
            sec_num = int(m.group(2))
            result[sec_num] = image_coordinates_files
    return result
コード例 #10
0
    def detect_mfov_blobs(blob_detector_args, mfov):
        """
        Receives a tilespec of an mfov (all the tiles in that mfov),
        detects the blobs on each of the thumbnails of the mfov tiles,
        and returns the locations of the blobs (in stitched global coordinates), and their
        descriptors.
        """
        thread_local_store = ThreadLocalStorageLRU()
        if 'blob_detector' not in thread_local_store.keys():
            # Initialize the blob_detector, and store it in the local thread storage
            blob_detector = BlobDetector2D.create_detector(
                **blob_detector_args)
            thread_local_store['blob_detector'] = blob_detector
        else:
            blob_detector = thread_local_store['blob_detector']

#         blob_detector = getattr(threadLocal, 'blob_detector', None)
#         if blob_detector is None:
#             # Initialize the blob_detector, and store it in the local thread storage
#             blob_detector = BlobDetector2D.create_detector(**blob_detector_args)
#             threadLocal.blob_detector = blob_detector

        all_kps_descs = [[], []]
        for tile in mfov.tiles():
            thumb_img_fname = "thumbnail_{}.jpg".format(
                os.path.splitext(os.path.basename(tile.img_fname))[0])
            thumb_img_fname = os.path.join(os.path.dirname(tile.img_fname),
                                           thumb_img_fname)
            # Read the tile
            thumb_img = mb_aligner.dal.common.read_image_file(thumb_img_fname)
            #thumb_img = cv2.imread(thumb_img_fname, 0)
            kps, descs = blob_detector.detectAndCompute(thumb_img)

            if len(kps) == 0:
                continue

            kps_pts = np.empty((len(kps), 2), dtype=np.float64)
            for kp_i, kp in enumerate(kps):
                kps_pts[kp_i][:] = kp.pt
            # upsample the thumbnail coordinates to original tile coordinates
            us_x = tile.width / thumb_img.shape[1]
            us_y = tile.height / thumb_img.shape[0]
            kps_pts[:, 0] *= us_x
            kps_pts[:, 1] *= us_y

            # Apply the transformation to the points
            assert (len(tile.transforms) == 1)
            model = tile.transforms[0]
            kps_pts = model.apply(kps_pts)

            all_kps_descs[0].extend(kps_pts)
            all_kps_descs[1].extend(descs)

        logger.report_event("Found {} blobs in section {}, mfov {}".format(
            len(all_kps_descs[0]), mfov.layer, mfov.mfov_index),
                            log_level=logging.INFO)
        return mfov.mfov_index, all_kps_descs
コード例 #11
0
 def _get_transform_matrix(self, pts1, pts2):
     model = self._assumed_model
     if model == 1:
         return align_rigid(pts1, pts2)
     elif model == 3:
         return Haffine_from_points(pts1, pts2)
     else:
         logger.report_event("Unsupported transformation model type", log_level=logging.ERROR)
         return None
コード例 #12
0
ファイル: aligner.py プロジェクト: Gilhirith/mb_aligner
def update_section_post_optimization_and_save(section, orig_pts, new_pts, mesh_spacing, out_dir):
    logger.report_event("Exporting section {}".format(section.canonical_section_name), log_level=logging.INFO)
    exporter = MeshPointsModelExporter()
    exporter.update_section_points_model_transform(section, orig_pts, new_pts, mesh_spacing)

    # TODO - should also save the mesh as h5s

    # save the output section
    out_fname = os.path.join(out_dir, '{}.json'.format(section.canonical_section_name))
    print('Writing output to: {}'.format(out_fname))
    section.save_as_json(out_fname)
コード例 #13
0
        def match_sec1_to_sec2_mfov(self, sec1_pts):
            # Apply the mfov transformation to compute estimated location on sec2
            sec1_mfov_pts_on_sec2 = self._sec1_to_sec2_transform.apply(np.atleast_2d(sec1_pts)) * self._scaling

            valid_matches = [[], [], []]
            invalid_matches = [[], []]
            for sec1_pt, sec2_pt_estimated in zip(sec1_pts, sec1_mfov_pts_on_sec2):

                # Fetch the template around img1_point (after transformation)
                from_x1, from_y1 = sec2_pt_estimated - self._template_scaled_side
                to_x1, to_y1 = sec2_pt_estimated + self._template_scaled_side
                sec1_template, sec1_template_start_point = self._sec1_scaled_renderer.crop(from_x1, from_y1, to_x1, to_y1)
            
                # Fetch a large sub-image around img2_point (using search_window_scaled_size)
                from_x2, from_y2 = sec2_pt_estimated - self._search_window_scaled_side
                to_x2, to_y2 = sec2_pt_estimated + self._search_window_scaled_side
                sec2_search_window, sec2_search_window_start_point = self._sec2_scaled_renderer.crop(from_x2, from_y2, to_x2, to_y2)
        
                # execute the PMCC match
                # Do template matching
                if np.any(np.array(sec2_search_window.shape) == 0) or np.any(np.array(sec1_template.shape) == 0):
                    continue
                if sec1_template.shape[0] >= sec2_search_window.shape[0] or sec1_template.shape[1] >= sec2_search_window.shape[1]:
                    continue
                if self._use_clahe:
                    sec2_search_window_clahe = self._clahe.apply(sec2_search_window)
                    sec1_template_clahe = self._clahe.apply(sec1_template)
                    pmcc_result, reason, match_val = PMCC_filter.PMCC_match(sec2_search_window_clahe, sec1_template_clahe, min_correlation=self._min_corr, maximal_curvature_ratio=self._max_curvature, maximal_ROD=self._max_rod)
                else:
                    pmcc_result, reason, match_val = PMCC_filter.PMCC_match(sec2_search_window, sec1_template, min_correlation=self._min_corr, maximal_curvature_ratio=self._max_curvature, maximal_ROD=self._max_rod)

                if pmcc_result is None:
                    invalid_matches[0].append(sec1_pt)
                    invalid_matches[1].append(reason)
#                     debug_out_fname1 = "temp_debug/debug_match_sec1{}-{}_template.png".format(int(sec1_pt[0]), int(sec1_pt[1]), int(sec2_pt_estimated[0]), int(sec2_pt_estimated[1]))
#                     debug_out_fname2 = "temp_debug/debug_match_sec1{}-{}_search_window.png".format(int(sec1_pt[0]), int(sec1_pt[1]), int(sec2_pt_estimated[0]), int(sec2_pt_estimated[1]))
#                     cv2.imwrite(debug_out_fname1, sec1_template)
#                     cv2.imwrite(debug_out_fname2, sec2_search_window)
                else:
                    # Compute the location of the matched point on img2 in non-scaled coordinates
                    matched_location_scaled = np.array([reason[1], reason[0]]) + np.array([from_x2, from_y2]) + self._template_scaled_side
                    sec2_pt = matched_location_scaled / self._scaling 
                    logger.report_event("{}: match found: {} and {} (orig assumption: {})".format(os.getpid(), sec1_pt, sec2_pt, sec2_pt_estimated / self._scaling), log_level=logging.DEBUG)
                    if self._debug_save_matches:
                        debug_out_fname1 = os.path.join(self._debug_dir, "debug_match_sec1_{}-{}_sec2_{}-{}_image1.png".format(int(sec1_pt[0]), int(sec1_pt[1]), int(sec2_pt[0]), int(sec2_pt[1])))
                        debug_out_fname2 = os.path.join(self._debug_dir, "debug_match_sec1_{}-{}_sec2_{}-{}_image2.png".format(int(sec1_pt[0]), int(sec1_pt[1]), int(sec2_pt[0]), int(sec2_pt[1])))
                        cv2.imwrite(debug_out_fname1, sec1_template)
                        sec2_cut_out = sec2_search_window[int(reason[0]):int(reason[0] + 2 * self._template_scaled_side), int(reason[1]):int(reason[1] + 2 * self._template_scaled_side)]
                        cv2.imwrite(debug_out_fname2, sec2_cut_out)
                    valid_matches[0].append(np.array(sec1_pt))
                    valid_matches[1].append(sec2_pt)
                    valid_matches[2].append(match_val)
            return valid_matches, invalid_matches
コード例 #14
0
    def _perform_matching(sec1_mfov_tile_idx, sec1, sec2,
                          sec1_to_sec2_mfov_transform, sec1_mfov_mesh_pts,
                          sec2_mfov_mesh_pts, debug_dir, matcher_args):
        #         fine_matcher_key = "block_matcher_{},{},{}".format(sec1.canonical_section_name, sec2.canonical_section_name, sec1_mfov_tile_idx[0])
        #         fine_matcher = getattr(threadLocal, fine_matcher_key, None)
        #         if fine_matcher is None:
        #             fine_matcher = BlockMatcherPMCCDispatcher.BlockMatcherPMCC(sec1, sec2, sec1_to_sec2_mfov_transform, **matcher_args)
        #             if debug_dir is not None:
        #                 fine_matcher.set_debug_dir(debug_dir)
        #
        #             setattr(threadLocal, fine_matcher_key, fine_matcher)

        fine_matcher = BlockMatcherPMCCDispatcher.BlockMatcherPMCC(
            sec1, sec2, sec1_to_sec2_mfov_transform, **matcher_args)
        if debug_dir is not None:
            fine_matcher.set_debug_dir(debug_dir)

        logger.report_event(
            "Block-Matching+PMCC layers: {} with {} (mfov1 {}) {} mesh points1, {} mesh points2"
            .format(sec1.canonical_section_name, sec2.canonical_section_name,
                    sec1_mfov_tile_idx, len(sec1_mfov_mesh_pts),
                    len(sec2_mfov_mesh_pts)),
            log_level=logging.INFO)
        logger.report_event("Block-Matching+PMCC layers: {} -> {}".format(
            sec1.canonical_section_name, sec2.canonical_section_name),
                            log_level=logging.INFO)
        valid_matches1, invalid_matches1 = fine_matcher.match_sec1_to_sec2_mfov(
            sec1_mfov_mesh_pts)
        logger.report_event(
            "Block-Matching+PMCC layers: {} -> {} valid matches: {}, invalid_matches: {} {}"
            .format(
                sec1.canonical_section_name, sec2.canonical_section_name,
                len(valid_matches1[0]), len(invalid_matches1[0]),
                BlockMatcherPMCCDispatcher.sum_invalid_matches(
                    invalid_matches1)),
            log_level=logging.INFO)

        logger.report_event("Block-Matching+PMCC layers: {} <- {}".format(
            sec1.canonical_section_name, sec2.canonical_section_name),
                            log_level=logging.INFO)
        valid_matches2, invalid_matches2 = fine_matcher.match_sec2_to_sec1_mfov(
            sec2_mfov_mesh_pts)
        logger.report_event(
            "Block-Matching+PMCC layers: {} <- {} valid matches: {}, invalid_matches: {} {}"
            .format(
                sec1.canonical_section_name, sec2.canonical_section_name,
                len(valid_matches2[0]), len(invalid_matches2[0]),
                BlockMatcherPMCCDispatcher.sum_invalid_matches(
                    invalid_matches2)),
            log_level=logging.INFO)

        return sec1_mfov_tile_idx, valid_matches1, valid_matches2
コード例 #15
0
ファイル: aligner.py プロジェクト: Gilhirith/mb_aligner
 def load_conf_from_file(conf_fname):
     '''
     Loads a given configuration file from a yaml file
     '''
     print("Using config file: {}.".format(conf_fname))
     if conf_fname is None:
         return {}
     with open(conf_fname, 'r') as stream:
         conf = yaml.load(stream)
         conf = conf['alignment']
     
     logger.report_event("loaded configuration: {}".format(conf), log_level=logging.INFO)
     return conf
コード例 #16
0
    def __init__(self, points):
        # load the mesh
        #self.orig_pts = np.array(points, dtype=FLOAT_TYPE).reshape((-1, 2)).copy()
        self.pts = np.array(points, dtype=FLOAT_TYPE).reshape((-1, 2)).copy()
#        center = self.pts.mean(axis=0)
#        self.pts -= center
#        self.pts *= 1.1
#        self.pts += center
        self.orig_pts = self.pts.copy()

        logger.report_event("# points in base mesh {}".format(self.pts.shape[0]), log_level=logging.DEBUG)

        # for neighbor searching and internal mesh
        self.triangulation = Delaunay(self.pts)
コード例 #17
0
def run_stitcher(args):

    #common.fs_create_dir(args.output_dir)

    conf = None
    if args.conf_fname is not None:
        conf = Stitcher.load_conf_from_file(args.conf_fname)
    stitcher = Stitcher(conf)

    # read the inpput tilespecs
    in_fs = fs.open_fs(args.ts_dir)
    in_ts_fnames = get_ts_files(
        in_fs,
        args.ts_dir)  #sorted(glob.glob(os.path.join(args.ts_dir, "*.json")))

    out_fs = fs.open_fs(args.output_dir)
    logger.report_event("Stitching {} sections".format(len(in_ts_fnames)),
                        log_level=logging.INFO)
    for in_ts_fname in in_ts_fnames:
        logger.report_event("Stitching {}".format(in_ts_fname),
                            log_level=logging.DEBUG)
        out_ts_fname = args.output_dir + "/" + fs.path.basename(in_ts_fname)
        if out_fs.exists(out_ts_fname):
            continue

        print("Stitching {}".format(in_ts_fname))
        with in_fs.open(fs.path.basename(in_ts_fname), 'rt') as in_f:
            in_ts = ujson.load(in_f)

        wafer_num = int(
            fs.path.basename(in_ts_fname).split('_')[0].split('W')[1])
        sec_num = int(
            fs.path.basename(in_ts_fname).split('.')[0].split('_')[1].split(
                'Sec')[1])
        section = Section.create_from_tilespec(in_ts,
                                               wafer_section=(wafer_num,
                                                              sec_num))
        stitcher.stitch_section(section)

        # Save the tilespec
        section.save_as_json(out_ts_fname)


#         out_tilespec = section.tilespec
#         import json
#         with open(out_ts_fname, 'wt') as out_f:
#             json.dump(out_tilespec, out_f, sort_keys=True, indent=4)

    del stitcher
コード例 #18
0
ファイル: stitcher.py プロジェクト: lichtman-lab/mb_aligner
 def _match_features(features_result1, features_result2, i, j):
     transform_model, filtered_matches = self._matcher.match_and_filter(
         *features_result1, *features_result2)
     assert (transform_model is not None)
     transform_matrix = transform_model.get_matrix()
     logger.report_event(
         "Imgs {} -> {}, found the following transformations\n{}\nAnd the average displacement: {} px"
         .format(
             i, j, transform_matrix,
             np.mean(
                 Stitcher._compute_l2_distance(
                     transform_model.apply(filtered_matches[1]),
                     filtered_matches[0]))),
         log_level=logging.INFO)
     return transform_matrix
コード例 #19
0
    def _save_checkpoint_data(self, block_num, relevant_sec_idxs, meshes, links, norm_weights, structural_meshes):
        cp_fname = os.path.join(self._checkpoints_dir, "checkpoint_block_{}.pkl".format(str(block_num).zfill(5)))
        cp_fname_partial = "{}.partial".format(cp_fname)
        logger.report_event("Saving checkpoint block data to {}".format(cp_fname), log_level=logging.INFO)
        #print("saving tilespecs: {}".format(sorted(list(relevant_sec_idxs))))

        #  only store the relevant_sec_idxs data
        meshes = {sec_idx: m for sec_idx, m in meshes.items() if sec_idx in relevant_sec_idxs}
        structural_meshes = {sec_idx: m for sec_idx, m in structural_meshes.items() if sec_idx in relevant_sec_idxs}
        links = {(sec1_idx, sec2_idx): v for (sec1_idx, sec2_idx), v in links.items() if sec1_idx in relevant_sec_idxs or sec2_idx in relevant_sec_idxs}
        norm_weights = {(sec1_idx, sec2_idx): v for (sec1_idx, sec2_idx), v in norm_weights.items() if sec1_idx in relevant_sec_idxs or sec2_idx in relevant_sec_idxs}

        with open(cp_fname_partial, 'wb') as out:
            #pickle.dump([block_num, meshes, links, norm_weights, structural_meshes], out)
            pickle.dump([block_num, meshes, links, norm_weights, structural_meshes], out, pickle.HIGHEST_PROTOCOL)
        os.rename(cp_fname_partial, cp_fname)
コード例 #20
0
        def __init__(self, sec1, sec2, sec1_to_sec2_transform, **kwargs):
            self._scaling = kwargs.get("scaling", 0.2)
            self._template_size = kwargs.get("template_size", 200)
            self._search_window_size = kwargs.get("search_window_size",
                                                  8 * self._template_size)
            logger.report_event(
                "Actual template size: {} and window search size: {} (after scaling)"
                .format(self._template_size * self._scaling,
                        self._search_window_size * self._scaling),
                log_level=logging.INFO)

            # Parameters for PMCC filtering
            self._min_corr = kwargs.get("min_correlation", 0.2)
            self._max_curvature = kwargs.get("maximal_curvature_ratio", 10)
            self._max_rod = kwargs.get("maximal_ROD", 0.9)
            self._use_clahe = kwargs.get("use_clahe", False)
            if self._use_clahe:
                self._clahe = cv2.createCLAHE(clipLimit=2.0,
                                              tileGridSize=(8, 8))

            #self._debug_dir = kwargs.get("debug_dir", None)
            self._debug_save_matches = None

            self._template_scaled_side = self._template_size * self._scaling / 2
            self._search_window_scaled_side = self._search_window_size * self._scaling / 2

            self._sec1 = sec1
            self._sec2 = sec2
            self._sec1_to_sec2_transform = sec1_to_sec2_transform

            self._scale_transformation = np.array([[self._scaling, 0., 0.],
                                                   [0., self._scaling, 0.]])
            # For section1 there will be a single renderer with transformation and scaling
            self._sec1_scaled_renderer = TilespecAffineRenderer(
                self._sec1.tilespec)
            self._sec1_scaled_renderer.add_transformation(
                self._sec1_to_sec2_transform.get_matrix())
            self._sec1_scaled_renderer.add_transformation(
                self._scale_transformation)

            # for section2 there will only be a single renderer (no need to transform back to sec1)
            self._sec2_scaled_renderer = TilespecAffineRenderer(
                self._sec2.tilespec)
            self._sec2_scaled_renderer.add_transformation(
                self._scale_transformation)
コード例 #21
0
    def _pre_optimize(self, layout, block_lo, block_hi, meshes, links, norm_weights, structural_meshes):
        # Compute the initial affine pre-alignment for each of the relevant sections
        pre_alignment_block_lo = max(1, block_lo)
        # on all blocks after the first, should avoid a pre affine transformation of anything that was already pre-aligned
        if block_lo > 0:
            pre_alignment_block_lo = min(block_lo + (self._block_size - self._block_step), block_hi)
        for active_sec_idx in range(pre_alignment_block_lo, block_hi):
            active_sec_name = layout['sections'][active_sec_idx].canonical_section_name
            logger.report_event("Before affine (sec {}): {}".format(active_sec_name, ts_mean_offsets(meshes, links, active_sec_idx, plot=False)), log_level=logging.INFO)
            rot = 0
            tran = 0
            count = 0

            #all_H = np.zeros((3,3))
            new_active_idx_mesh_pts = np.zeros_like(meshes[active_sec_idx].pts)
            sum_weights = 0
            for neighbor_sec_idx in layout['neighbors'][active_sec_idx]:
                if neighbor_sec_idx < active_sec_idx:
                    # take both (active_sec, neighbor_sec) and (neighbor_sec, active_sec) into account
                    for (sec1_idx, sec2_idx), ((idx1, w1), (idx2, w2)) in [((active_sec_idx, neighbor_sec_idx), links[active_sec_idx, neighbor_sec_idx]), ((neighbor_sec_idx, active_sec_idx), links[neighbor_sec_idx, active_sec_idx])]:
                        #if active_ts in (ts1, ts2) and (layers[ts1] <= layers[active_ts]) and (layers[ts2] <= layers[active_ts]):
                        pts1 = np.einsum('ijk,ij->ik', meshes[sec1_idx].pts[idx1], w1)
                        pts2 = np.einsum('ijk,ij->ik', meshes[sec2_idx].pts[idx2], w2)
                        logger.report_event("Matches # (sections {}->{}): {}.".format(layout['sections'][sec1_idx].canonical_section_name, layout['sections'][sec2_idx].canonical_section_name, pts1.shape[0]), log_level=logging.INFO)#DEBUG)
                        if sec1_idx == active_sec_idx:
                            #cur_rot, cur_tran = self._get_transform_matrix(pts1, pts2)
                            cur_norm_weights = norm_weights[sec1_idx, sec2_idx]
                            new_active_idx_mesh_pts += cur_norm_weights * self._pre_opt_local_align(meshes, active_sec_idx, neighbor_sec_idx, pts1, pts2)
                            sum_weights += cur_norm_weights
                        else: # sec2_idx == active_sec_idx
                            #cur_rot, cur_tran = self._get_transform_matrix(pts2, pts1)
                            cur_norm_weights = norm_weights[sec2_idx, sec1_idx]
                            new_active_idx_mesh_pts += cur_norm_weights * self._pre_opt_local_align(meshes, active_sec_idx, neighbor_sec_idx, pts2, pts1)
                            sum_weights += cur_norm_weights

                        # Average the affine transformation by the number of matches between the two sections
                        #rot += pts1.shape[0] * cur_norm_weights * cur_rot
                        #tran += pts1.shape[0] * cur_norm_weights * cur_tran
                        count += pts1.shape[0] * cur_norm_weights

            if count == 0:
                logger.report_event("Error: no matches found for section {}.".format(active_sec_name), log_level=logging.ERROR)
                sys.exit(1)

#             # normalize the transformation
#             rot = rot * (1.0 / count)
#             tran = tran * (1.0 / count)
#             #print("rot:\n{}\ntran:\n{}".format(rot, tran))
#             # transform the points
#             meshes[active_sec_idx].pts = np.dot(meshes[active_sec_idx].pts, rot) + tran
            # normalize the new mesh points by the weights sum
            meshes[active_sec_idx].pts = new_active_idx_mesh_pts / sum_weights
            logger.report_event("After affine (sec {}): {}".format(active_sec_name, ts_mean_offsets(meshes, links, active_sec_idx, plot=False)), log_level=logging.INFO)
コード例 #22
0
    def fix_missing_matches(self, cur_matches):
        # If there are no missing matches, return an empty map
        if len(self._missing_matches) == 0:
            return {}

        # Add missing matches
        new_matches = {}

        for missing_match_k, missing_match_v in self._missing_matches.items():
            # missing_match_k = (tile1_unique_idx, tile2_unique_idx)
            # missing_match_v = (tile1, tile2)
            tile1, tile2 = missing_match_v


            if not self._intra_mfov_only or tile1.mfov_index == tile2.mfov_index:
            #if filtered_matches is None:
                logger.report_event("Adding fake matches between: {} and {}".format((tile1.mfov_index, tile1.tile_index), (tile2.mfov_index, tile2.tile_index)), log_level=logging.INFO)
                bbox1 = tile1.bbox
                bbox2 = tile2.bbox
                intersection = [max(bbox1[0], bbox2[0]),
                                min(bbox1[1], bbox2[1]),
                                max(bbox1[2], bbox2[2]),
                                min(bbox1[3], bbox2[3])]
                intersection_center = np.array([intersection[0] + intersection[1], intersection[2] + intersection[3]]) * 0.5
                fake_match_points_global = np.array([
                        [intersection_center[0] + intersection[0] - 2, intersection_center[1] + intersection[2] - 2],
                        [intersection_center[0] + intersection[1] + 4, intersection_center[1] + intersection[2] - 4],
                        [intersection_center[0] + intersection[0] + 2, intersection_center[1] + intersection[3] - 2],
                        [intersection_center[0] + intersection[1] - 4, intersection_center[1] + intersection[3] - 6]
                    ]) * 0.5
                fake_new_matches = np.array([
                        fake_match_points_global - np.array([bbox1[0], bbox1[2]]),
                        fake_match_points_global - np.array([bbox2[0], bbox2[2]])
                    ])
                new_matches[missing_match_k] = fake_new_matches

        return new_matches
コード例 #23
0
def run_stitcher(args):

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    conf = None
    if args.conf_fname is not None:
        conf = Stitcher.load_conf_from_file(args.conf_fname)
    stitcher = Stitcher(conf)

    # read the inpput tilespecs
    in_ts_fnames = sorted(glob.glob(os.path.join(args.ts_dir, "*.json")))

    logger.report_event("Stitching {} sections".format(len(in_ts_fnames)),
                        log_level=logging.INFO)
    for in_ts_fname in in_ts_fnames:
        logger.report_event("Stitching {}".format(in_ts_fname),
                            log_level=logging.DEBUG)
        out_ts_fname = os.path.join(args.output_dir,
                                    os.path.basename(in_ts_fname))
        if os.path.exists(out_ts_fname):
            continue

        print("Stitching {}".format(in_ts_fname))
        with open(in_ts_fname, 'rt') as in_f:
            in_ts = ujson.load(in_f)
            section = Section.create_from_tilespec(in_ts)
            stitcher.stitch_section(section)

            # Save the tilespec
            section.save_as_json(out_ts_fname)
    #         out_tilespec = section.tilespec
    #         import json
    #         with open(out_ts_fname, 'wt') as out_f:
    #             json.dump(out_tilespec, out_f, sort_keys=True, indent=4)

    del stitcher
コード例 #24
0
    def _transform_mesh_tps(self, pts_src_lists, pts_dst_lists, weights_list, mesh_pts):
        transformed_mesh_pts = np.zeros_like(mesh_pts)
        # Apply each transformation by the given matched points (weighted), and sum it all into transformed_mesh_pts
        print("1")
        logger.report_event("starting thinplate splines 1", log_level=logging.DEBUG)
        for pts_src, pts_dst, w in zip(pts_src_lists, pts_dst_lists, weights_list):
            #model = models.PointsTransformModel((pts_src, pts_dst))
            #transformed_mesh_pts += model.apply(mesh_pts) * w
            logger.report_event("Creating thinplate splines 1.1: {} control pts".format(len(pts_src)), log_level=logging.DEBUG)
            model = ThinPlateSplines(pts_src, pts_dst)
            print("1.1")
            logger.report_event("Applying thinplate splines 1.2: {} mesh pts".format(len(mesh_pts)), log_level=logging.DEBUG)
            transformed_mesh_pts += model.apply(mesh_pts, 20000) * w
            logger.report_event("Done applying thinplate splines 1.2: {} mesh pts".format(len(mesh_pts)), log_level=logging.DEBUG)
            print("1.2")

        # Normalize the transformed points
        transformed_mesh_pts /= np.sum(weights_list)
        print("2")
        return transformed_mesh_pts
コード例 #25
0
 def remove_unneeded_points(self, pts1, pts2):
     """Removes points that cause query_cross_barycentrics to fail"""
     p1 = pts1.copy()
     p1[p1 < 0] = 0.01
     simplex_indices = self.triangulation.find_simplex(p1)
     if np.any(simplex_indices == -1):
         locs = np.where(simplex_indices == -1)
         logger.report_event("locations: {}".format(locs), log_level=logging.DEBUG)
         logger.report_event("points: {}".format(pts1[locs]), log_level=logging.DEBUG)
         logger.report_event("removing the above points", log_level=logging.DEBUG)
         pts1 = np.delete(pts1, locs, 0)
         pts2 = np.delete(pts2, locs, 0)
     return pts1, pts2
コード例 #26
0
ファイル: stitcher.py プロジェクト: Gilhirith/mb_aligner
    def stitch_section(self, section, match_results_map=None):
        '''
        Receives a single section (assumes no transformations), stitches all its tiles, and updaates the section tiles' transformations.
        '''

        logger.report_event("stitch_section starting.", log_level=logging.INFO)
        # Compute features

        if match_results_map is None:
            match_results_map, missing_matches_map = self._compute_match_features(
                section)
            if match_results_map is None:
                return

        logger.report_event("Starting optimization", log_level=logging.INFO)
        # Generate a map between tile and its original estimated location
        orig_locations = {}
        for tile in section.tiles():
            tile_unique_idx = (tile.layer, tile.mfov_index, tile.tile_index)
            orig_locations[tile_unique_idx] = [tile.bbox[0], tile.bbox[2]]

        optimized_transforms_map = self._optimizer.optimize(
            orig_locations, match_results_map)

        #if self._filter_inter_mfov_matches:
        #    # find a "seam" of tiles between

        logger.report_event("Done optimizing, updating tiles transforms",
                            log_level=logging.INFO)

        for tile in section.tiles():
            tile_unique_idx = (tile.layer, tile.mfov_index, tile.tile_index)
            if tile_unique_idx not in optimized_transforms_map:
                # TODO - should remove the tile
                logger.report_event(
                    "Could not find a transformation for tile {} in the optimization result, skipping tile"
                    .format(tile_unique_idx),
                    log_level=logging.WARNING)
            else:
                tile.set_transform(optimized_transforms_map[tile_unique_idx])
コード例 #27
0
    def _pre_opt_local_align(self, meshes, active_sec_idx, neighbor_sec_idx, active_sec_matches_merged, neighbor_sec_matches_merged):
        new_mesh_pts = np.empty_like(meshes[active_sec_idx].pts)
        groups_masks, outlier_mask = self._affine_grouper.group_matches(active_sec_matches_merged, neighbor_sec_matches_merged)
        if len(groups_masks) == 0:
            # none found.... use all matches
            logger.report_event("No valid grouping found, using a single transform", log_level=logging.DEBUG)
            cur_rot, cur_tran = self._get_transform_matrix(active_sec_matches_merged, neighbor_sec_matches_merged)
            new_mesh_pts = np.dot(meshes[active_sec_idx].pts, cur_rot) + cur_tran
        else:
            logger.report_event("Found {} groups masks".format(len(groups_masks)), log_level=logging.DEBUG)
            logger.report_event("Found {} outlier matches".format(np.sum(outlier_mask)), log_level=logging.DEBUG)
            # TODO - attach a group to each of the mesh points of the active_sec_idx
            matches_affine_mats_idxs = np.zeros((len(active_sec_matches_merged), ), dtype=np.int)
            matches_affine_mats_idxs[outlier_mask] = -1 # set the outliers to idx -1
            non_outlier_active_sec_matches_merged = active_sec_matches_merged[~outlier_mask]
            kdtree = KDTree(non_outlier_active_sec_matches_merged)
            _, closest_matches_idxs = kdtree.query(meshes[active_sec_idx].pts, k=1)
            # TODO -stopped here !
            groups_matrices_rots = np.empty((len(groups_masks), 2, 2), dtype=np.float)
            groups_matrices_trans = np.empty((len(groups_masks), 2), dtype=np.float)
            for gm_idx, gm in enumerate(groups_masks):
                logger.report_event("Group masks {}: size:{}".format(gm_idx, np.sum(gm)), log_level=logging.DEBUG)
                matches_affine_mats_idxs[gm] = gm_idx # set the idx of these points
                # for each of the masks apply an affine transformation to get rid of large displacements/rotations/scale
                cur_rot, cur_tran = self._get_transform_matrix(active_sec_matches_merged[gm], neighbor_sec_matches_merged[gm])
                groups_matrices_rots[gm_idx] = cur_rot
                groups_matrices_trans[gm_idx] = cur_tran
            
            
            mesh_pts_rots = groups_matrices_rots[matches_affine_mats_idxs[closest_matches_idxs]]
            mesh_pts_trans = groups_matrices_trans[matches_affine_mats_idxs[closest_matches_idxs]]
            # Update the mesh points
            new_mesh_pts = np.vstack((np.sum(meshes[active_sec_idx].pts * mesh_pts_rots[:, :, 0], axis=1),
                                      np.sum(meshes[active_sec_idx].pts * mesh_pts_rots[:, :, 1], axis=1))).T + \
                           mesh_pts_trans



        return new_mesh_pts
コード例 #28
0
def run_aligner(args):

    # Make a list of all the relevant sections
    with open(args.sections_list_file, 'rt') as in_f:
        secs_ts_fnames = in_f.readlines()
    secs_ts_fnames = [fname.strip() for fname in secs_ts_fnames]

    # Make sure the tilespecs exist
    all_files_exist = True
    for sec_ts_fname in secs_ts_fnames:
        if not os.path.exists(sec_ts_fname):
            print("Cannot find tilespec file: {}".format(sec_ts_fname))
            all_files_exist = False

    if not all_files_exist:
        print("One or more tilespecs could not be found, exiting!")
        return

    out_folder = './output_aligned_ECS_test9_cropped'
    conf_fname = '../../conf/conf_example.yaml'

    conf = StackAligner.load_conf_from_file(args.conf_fname)
    logger.report_event("Loading sections", log_level=logging.INFO)
    sections = []
    # TODO - Should be done in a parallel fashion
    for ts_fname in secs_ts_fnames:
        with open(ts_fname, 'rt') as in_f:
            tilespec = ujson.load(in_f)

        wafer_num = int(os.path.basename(ts_fname).split('_')[0].split('W')[1])
        sec_num = int(
            os.path.basename(ts_fname).split('.')[0].split('_')[1].split('Sec')
            [1])
        sections.append(
            Section.create_from_tilespec(tilespec,
                                         wafer_section=(wafer_num, sec_num)))

    logger.report_event("Initializing aligner", log_level=logging.INFO)
    aligner = StackAligner(conf)
    logger.report_event("Aligning sections", log_level=logging.INFO)
    aligner.align_sections(
        sections)  # will align and update the section tiles' transformations

    del aligner

    logger.end_process('main ending', rh_logger.ExitCode(0))
コード例 #29
0
    def _gradient_descent(self, optimize_func, p0, grad_F_huber, args=None):
        
        def compute_cost_huber(optimize_func, cur_p, params, huber_delta):
            residuals = optimize_func(cur_p, *params)
            cost = np.empty_like(residuals)
            residuals_huber_mask = residuals <= huber_delta
            cost[residuals_huber_mask] = 0.5 * residuals[residuals_huber_mask]**2
            cost[~residuals_huber_mask] = huber_delta * residuals[~residuals_huber_mask] - (0.5 * huber_delta**2)
            return np.sum(cost)

        cur_p = p0
        #cur_cost = np.sum(optimize_func(cur_p, *args))
        cur_cost = compute_cost_huber(optimize_func, cur_p, args, self._huber_delta)
        logger.report_event("Initial cost: {}".format(cur_cost), log_level=logging.INFO)
        gamma = self._init_gamma

        for it in range(self._max_iterations):
            #print("Iteration {}".format(it))
            prev_p = cur_p
            prev_cost = cur_cost
            cur_p = prev_p - gamma * grad_F_huber(self._huber_delta, prev_p, *args)
            #print("New params: {}".format(cur_p))
            #cur_cost = np.sum(optimize_func(cur_p, *args))
            cur_cost = compute_cost_huber(optimize_func, cur_p, args, self._huber_delta)
            #print("New cost: {}".format(cur_cost))
            if it % 100 == 0:
                logger.report_event("iter {}: C: {}".format(it, cur_cost), log_level=logging.INFO)
            if cur_cost > prev_cost: # we took a bad step: undo it, scale down gamma, and start over
                #print("Backtracking step")
                cur_p = prev_p
                cur_cost = prev_cost
                gamma *= 0.5
            elif np.all(np.abs(cur_p - prev_p) <= self._eps): # We took a good step, but the change to the parameters vector is negligible
                break
            else: # We took a good step, try to increase the step size a bit
                gamma *= 1.1
            if gamma < self._min_gamma:
                break

        #print("The local minimum occurs at", cur_p)
        logger.report_event("Post-opt cost: {}".format(cur_cost), log_level=logging.INFO)
        return cur_p
コード例 #30
0
    def _gradient_descent(self):

        cur_p = self._cur_params_vector
        #cur_cost = np.sum(optimize_func(cur_p, *args))
        cur_cost = self._compute_cost_huber(cur_p, self._huber_delta)
        logger.report_event("Initial cost: {}".format(cur_cost),
                            log_level=logging.INFO)
        gamma = self._init_gamma

        for it in range(self._max_iterations):
            #print("Iteration {}".format(it))
            prev_p = cur_p
            prev_cost = cur_cost
            cur_p = prev_p - gamma * self._grad_F_huber(
                prev_p, self._huber_delta)
            #print("New params: {}".format(cur_p))
            #cur_cost = np.sum(optimize_func(cur_p, *args))
            cur_cost = self._compute_cost_huber(cur_p, self._huber_delta)
            #print("New cost: {}".format(cur_cost))
            if it % 100 == 0:
                logger.report_event("iter {}: C: {}".format(it, cur_cost),
                                    log_level=logging.INFO)
            if cur_cost > prev_cost:  # we took a bad step: undo it, scale down gamma, and start over
                #print("Backtracking step")
                cur_p = prev_p
                cur_cost = prev_cost
                gamma *= 0.5
            elif np.all(
                    np.abs(cur_p - prev_p) <= self._eps
            ):  # We took a good step, but the change to the parameters vector is negligible
                break
            else:  # We took a good step, try to increase the step size a bit
                gamma *= 1.1
            if gamma < self._min_gamma:
                break

        #print("The local minimum occurs at", cur_p)
        logger.report_event("Post-opt cost: {}".format(cur_cost),
                            log_level=logging.INFO)
        return cur_p