def initialize_storage(self, output: Path): logger.info("Preparing mosaic output") ds = xr.Dataset() ds.attrs = self._ds.attrs ds.attrs["raw_meta"] = [ds.attrs["raw_meta"]] ds.to_zarr(output / "mos.zarr", mode="w")
def main( *, root_dir: Path, output_dir: Path, recipes: List, sections: List, reset: bool, cof_dist: Dict = None, ) -> None: """[summary] Parameters ---------- root_dir : Path [description] output_dir : Path [description] """ logger.info("Pipeline started") output_dir_full = output_dir / root_dir.name main.config["output_dir"] = f"{output_dir_full}" owl_dev.setup(main.config) if cof_dist is not None: Settings.cof_dist = cof_dist if not root_dir.exists(): raise FileNotFoundError(f"Directory {root_dir} does not exist.") # TODO: compute dark and flat mos = STPTMosaic(root_dir) if "mosaic" in recipes: if reset: mos.initialize_storage(output_dir_full) for section in mos.sections(): if sections and (section.name not in sections): continue section.find_offsets() section.stitch(output_dir_full) # after the first section, stage size is fixed: mos.stage_size = section.stage_size if "downsample" in recipes: mos.downsample(output_dir_full) mos_dis = output_dir_full / "mos.zarr" if "beadreg" in recipes: find_beads(mos_dis) register_slices(mos_dis) logger.info("Pipeline completed")
def _compute_final_mosaic(self, z): logger.info("Creating final mosaic") mos_overlap = [] mos_raw = [] mos_err = [] for name, offset in z[f"mosaic/{self.name}"].groups(): raw = da.stack([ da.from_zarr(ch["raw"]) / da.from_zarr(ch["overlap"]) for name, ch in offset.groups() ]) raw = (raw - Settings.bzero) / Settings.bscale overlap = (da.stack( [da.from_zarr(ch["overlap"]) for name, ch in offset.groups()]) * 100) err = (da.stack( [da.from_zarr(ch["pos_err"]) for name, ch in offset.groups()]) * 100) mos_overlap.append(overlap) mos_err.append(err) mos_raw.append(raw) mos_raw = da.stack(mos_raw) mos_overlap = da.stack(mos_overlap) mos_err = da.stack(mos_err) mos = da.stack([mos_raw, mos_overlap, mos_err]).rechunk( (1, 1, 10, CHUNK_SIZE, CHUNK_SIZE)) nt, nz, nch, ny, nx = mos.shape raw = xr.DataArray( mos.astype("uint16"), dims=("type", "z", "channel", "y", "x"), coords={ "type": ["mosaic", "conf", "err"], "x": range(nx), "y": range(ny), "z": range(nz), "channel": range(nch + 1)[1:], }, ) metadata = self._get_metadata() raw.attrs.update(metadata) return raw
def _create_temporary_mosaic(self, conf, abs_pos, abs_err, output): logger.info("Creating temporary mosaic") z = zarr.open(f"{output}/{self.name}.zarr", mode="w") flat = read_calib(Settings.flat_file) dark = read_calib(Settings.dark_file) / Settings.norm_val cof_dist = Settings.cof_dist y_delta, x_delta = np.min(abs_pos, axis=0) for sl in range(self.slices): results = [] for ch in range(self.channels): im_t = self.get_img_section(sl, ch) g = z.create_group(f"/mosaic/{self.name}/z={sl}/channel={ch}") im_dis = [ apply_geometric_transform(im.data / Settings.norm_val, dark[ch], flat[ch], cof_dist) for im in im_t ] for imgtype in ["raw", "pos_err", "overlap"]: nimg = _get_image(g, imgtype, self.stage_size, dtype="float32") for i in range(len(im_dis)): y0 = int(abs_pos[i, 0] - y_delta) x0 = int(abs_pos[i, 1] - x_delta) res = _mosaic( im_dis[i], ch, conf, (y0, x0), abs_err, imgtype, nimg, ) results.append(res) dask.compute(results) logger.debug("Mosaic %s Slice %d done", self.name, sl) return z
def stitch(self, output: Path): """Stitch and save all images""" ds = xr.open_zarr(output / "mos.zarr") if self.name in ds: logger.info("Section %s already done. Skipping.", self.name) return abs_pos, abs_err = self.compute_abspos() conf = self.get_distconf() if self.stage_size is None: self.stage_size = self._stage_size(abs_pos) z = self._create_temporary_mosaic(conf, abs_pos, abs_err, output) arr = self._compute_final_mosaic(z) ds = xr.Dataset({self._section.name: arr}) ds.to_zarr(output / "mos.zarr", mode="a") # clean temporary mosaic shutil.rmtree(f"{output}/{self.name}.zarr", ignore_errors=True) logger.info("Mosaic saved %s", output / "mos.zarr")
def downsample(self, output: Path): """Downsample mosaic. Parameters ---------- output Location of output directory """ logger.info("Downsampling mosaics") store_name = output / "mos.zarr" store = zarr.DirectoryStore(store_name) up = "" for factor in Settings.scales: logger.debug("Downsampling factor %d", factor) ds = xr.open_zarr(f"{store_name}", group=up) nds = xr.Dataset() down = f"l.{factor}" nds.to_zarr(store, mode="w", group=down) slices = list(ds) for s in slices: nds = xr.Dataset() logger.debug("Downsampling mos%d [%s]", factor, s) narr = ops.downsample(ds[s]) nds[s] = narr nds.to_zarr(store, mode="a", group=down) logger.info("Downsampled mosaic saved %s:%s", store_name, down) up = down arr = zarr.open(f"{store_name}", mode="r+") arr.attrs["multiscale"] = { "datasets": [ { "path": "", "level": 1 }, { "path": "l.2", "level": 2 }, { "path": "l.4", "level": 4 }, { "path": "l.8", "level": 8 }, { "path": "l.16", "level": 16 }, { "path": "l.32", "level": 32 }, ], "metadata": { "method": "cv2.pyrDown", "version": cv2.__version__ }, } arr.attrs["bscale"] = Settings.bscale arr.attrs["bzero"] = Settings.bzero
def compute_abspos(self): # noqa: C901 """Compute absolute positions of images in the field of view.""" logger.info("Processing section %s", self._section.name) # img_cube = self.get_img_section(0, Settings.channel_to_use) # img_cube_stack = da.stack(img_cube) # # ref image is the one with the max integrated flux # cube_totals = img_cube_stack.sum(axis=(1, 2)) # cube_totals = cube_totals.compute() # cube_means = cube_totals / 2080.0 ** 2 # absolute_ref_img = cube_means.argmax() absolute_ref_img = self.absolute_ref_img if self.offset_mode == "default": dx = np.array(self["XPos"]) dy = np.array(self["YPos"]) abs_pos = [] abs_err = [] for i in range(len(dx)): abs_pos.append([ (dx[i] - dx[absolute_ref_img]) / Settings.mosaic_scale, (dy[i] - dy[absolute_ref_img]) / Settings.mosaic_scale, ]) abs_err.append([15.0, 15.0]) logger.info("Displacements too large, resorting to default grid") abs_pos = np.array(abs_pos) abs_err = np.array(abs_err) self._section.attrs["abs_pos"] = abs_pos.tolist() self._section.attrs["abs_err"] = abs_err.tolist() self._section.attrs["default_displacements"] = [{ "default_x": 0.0, "dev_x": 0.0, "default_y": 0.0, "dev_y": 0.0 }] self._x_scale = Settings.mosaic_scale self._y_scale = Settings.mosaic_scale return abs_pos, abs_err self.compute_pairs() scale_x, scale_y = self.scale dx0, dy0, delta_x, delta_y = self.find_grid() default_x, dev_x, default_y, dev_y = self.get_default_displacement() # get error threshold scomparing scaled micron displacements # with measurements px_x_temp = [] px_y_temp = [] mu_x_temp = [] mu_y_temp = [] avg_f_temp = [] for this_key in self._px.keys(): px_x_temp.append(self._px[this_key][0]) px_y_temp.append(self._px[this_key][1]) mu_x_temp.append(self._mu[this_key][0] / scale_x) mu_y_temp.append(self._mu[this_key][1] / scale_y) avg_f_temp.append(self._avg[this_key]) px_x_temp = np.array(px_x_temp) px_y_temp = np.array(px_y_temp) mu_x_temp = np.array(mu_x_temp) mu_y_temp = np.array(mu_y_temp) avg_f_temp = np.array(avg_f_temp) ind_temp = np.where((np.abs(px_x_temp) > np.mean(np.abs(px_x_temp))) & (avg_f_temp > 0.05))[0] if len(ind_temp) > 3: elongx = 1.48 * mad(np.sqrt((px_x_temp - mu_x_temp)[ind_temp]**2)) eshorty = 1.48 * mad(np.sqrt((px_y_temp - mu_y_temp)[ind_temp]**2)) else: elongx = 10 eshorty = 10 ind_temp = np.where((np.abs(px_y_temp) > np.mean(np.abs(px_y_temp))) & (avg_f_temp > 0.05))[0] if len(ind_temp) > 3: elongy = 1.48 * mad(np.sqrt((px_y_temp - mu_y_temp)[ind_temp]**2)) eshortx = 1.48 * mad(np.sqrt((px_x_temp - mu_x_temp)[ind_temp]**2)) else: elongy = 10.0 eshortx = 10.0 error_long_threshold = 3.0 * np.max([elongx, elongy]).clip(2, 30) error_short_threshold = 3.0 * np.max([eshortx, eshorty]).clip(2, 30) logger.debug("Scaling error threshold: l:{0:3f} s:{1:.3f}".format( error_long_threshold, error_short_threshold)) accumulated_pos = [] accumulated_qual = [] for this_ref in np.where( self.cube_means > np.median(self.cube_means))[0]: temp = self.compute_abspos_ref( dx0, dy0, default_x, dev_x, default_y, dev_y, scale_x, scale_y, this_ref, error_long_threshold, error_short_threshold, ) # using a common reference accumulated_pos.append( np.array(temp[0]) - np.array(temp[0])[absolute_ref_img]) # adding quality of global reference accumulated_qual.append( np.sqrt( np.array(temp[1])**2 + np.array(temp[1])[absolute_ref_img]**2)) accumulated_pos = np.array(accumulated_pos) abs_pos = np.median(accumulated_pos, 0) abs_err = np.std(accumulated_pos, 0) self._section.attrs["abs_pos"] = abs_pos.tolist() self._section.attrs["abs_err"] = abs_err.tolist() return abs_pos, abs_err
def find_offsets(self): # noqa: C901 """Calculate offsets between all pairs of overlapping images""" client = Client.current() # convert to find_shifts results = [] logger.info("Processing section %s", self._section.name) # debug # img_cube = self.get_img_section(0, Settings.channel_to_use - 1) img_cube = self.get_img_section(0, -1) # Calculate confidence map. Only needs to be done once per section dist_conf = self.get_distconf() flat = read_calib(Settings.flat_file)[Settings.channel_to_use - 1] dark = (read_calib(Settings.dark_file)[Settings.channel_to_use - 1] / Settings.norm_val) flat = flat.persist() dark = dark.persist() dx0, dy0, delta_x, delta_y = self.find_grid() dx_mos, dy_mos = self.get_mos_pos() # We calculate the ref_img here, for if the slice # has no data (very low max. avg flux), we skip calculations too img_cube_stack = da.stack(img_cube) # ref image is the one with the max integrated flux cube_totals = img_cube_stack.sum(axis=(1, 2)) cube_totals = cube_totals.compute() self.cube_means = cube_totals / 2080.0**2 self.absolute_ref_img = self.cube_means.argmax() self.mean_ref_img = self.cube_means.max() logger.info("Max mean: {0:.5f}".format(self.mean_ref_img)) # If the default displacements are too large # to for overlaps, stick with the default scale self.offset_mode = "sampled" if delta_x / Settings.mosaic_scale > 1950.0: logger.info("Displacement in X too large: {0:.1f}".format( delta_x / Settings.mosaic_scale)) self.offset_mode = "default" if delta_y / Settings.mosaic_scale > 1950.0: logger.info("Displacement in Y too large: {0:.1f}".format( delta_y / Settings.mosaic_scale)) self.offset_mode = "default" if self.mean_ref_img < 0.05: logger.info("Avg. flux too low: {0:.3f}<0.05".format( self.mean_ref_img)) self.offset_mode = "default" for i, img in enumerate(img_cube): r = np.sqrt((dx0 - dx0[i])**2 + (dy0 - dy0[i])**2) # including no diagonals i_t = np.where((r <= np.sqrt(1)) & (r > 0))[0].tolist() im_i = img for this_img in i_t: if i > this_img: continue desp = [ (dx_mos[this_img] - dx_mos[i]) / Settings.mosaic_scale, (dy_mos[this_img] - dy_mos[i]) / Settings.mosaic_scale, ] # trying to do always positive displacements if (desp[0] < -100) or (desp[1] < -100): desp[0] *= -1 desp[1] *= -1 obj_img = i im_obj = im_i ref_img = this_img im_ref = img_cube[this_img] sign = -1 else: obj_img = this_img im_obj = img_cube[this_img] ref_img = i im_ref = im_i sign = 1 if self.offset_mode == "sampled": im1 = apply_geometric_transform( im_ref.data / Settings.norm_val, dark, flat, Settings.cof_dist) im2 = apply_geometric_transform( im_obj.data / Settings.norm_val, dark, flat, Settings.cof_dist) res = delayed(find_overlap_conf)(im1, im2, dist_conf, dist_conf, desp) results.append(_sink(ref_img, obj_img, res, sign)) else: logger.debug( "Initial offsets i: %d j: %d dx: %d dy: %d", ref_img, obj_img, desp[0], desp[1], ) results.append([ref_img, obj_img, desp[0], desp[1], sign]) if self.offset_mode == "sampled": futures = client.compute(results) offsets = [] for fut in as_completed(futures): i, j, res, sign = fut.result() dx, dy, mi, avf = res.x, res.y, res.mi, res.avg_flux offsets.append([i, j, res, sign]) logger.debug( "Section %s offsets i: %d j: %d dx: %d dy: %d mi: %f avg_f: %f sign: %d", self._section.name, i, j, dx, dy, mi, avf, sign, ) else: offsets = [] for t in results: i, j, dx, dy, sign = t offsets.append([i, j, [dx, dy, 0.0, 0.0], sign]) logger.debug( "Section %s offsets i: %d j: %d dx: %d dy: %d mi: %f avg_f: %f sign: %d", self._section.name, i, j, dx, dy, 0.0, 0.0, sign, ) self._offsets = offsets
def register_slices(mos_zarr: Path): # noqa: C901 """ Uses all the detected beads in each slice to cross-match and calculates an average displacement so that all slices are matched to the first one """ # this is to store the beads later zarr_store = zarr.open(f"{mos_zarr}", mode="a") mos_full = xr.open_zarr(f"{mos_zarr}", group="") _slices = list(mos_full) # putting all slices on a single list optical_slices = [] physical_slices = [] for this_slice in _slices: for this_optical in mos_full[this_slice].z.values: physical_slices.append(this_slice) optical_slices.append(this_optical) # first pass crossmatch, taking into account all beads dx = [0.0] # these store the slice to slice offset dy = [0.0] logger.info("1st pass slice shifts") for i in range(1, len(physical_slices)): # We compare each slice (_t) with the previous one (_r) _, x_t, y_t, _, e_t = _get_beads(mos_full[physical_slices[i]].attrs, optical_slices[i]) _, x_r, y_r, _, e_r = _get_beads( mos_full[physical_slices[i - 1]].attrs, optical_slices[i - 1]) if (len(x_t) > 0) & (len(x_r) > 0): dxt, dyt, i_rt, i_tr = _match_cats(x_r, y_r, e_r, x_t, y_t, e_t) dx.append(dxt) dy.append(dyt) dr = np.sqrt((y_r[i_rt] - y_t[i_tr] - dyt)**2 + (x_r[i_rt] - x_t[i_tr] - dxt)**2) dr0 = np.sqrt((y_r[i_rt] - y_t[i_tr])**2 + (x_r[i_rt] - x_t[i_tr])**2) else: dx.append(0.0) dy.append(0.0) i_tr = [] dxt = dyt = dr = dr0 = 0 logger.info(physical_slices[i - 1] + "_Z{0:03d}:".format(optical_slices[i - 1]) + physical_slices[i] + "_Z{0:03d}:".format(optical_slices[i]) + " {0:d} ".format(len(i_tr)) + "{0:.1f} {1:.1f} ".format(dxt, dyt) + "{0:.1f} {1:.1f} ".format(np.median(dr), np.median(dr0))) # now that we know the slice to slice offset, we construct the catalogue # of all the beads bb = bead_collection() for i in range(len(physical_slices)): ref_slice = physical_slices[i] + "_Z{0:03d}".format(optical_slices[i]) # because dx,dy are slice to slice, the total displacement # is the sum of all the previous dx_t = np.sum(dx[0:i + 1]) dy_t = np.sum(dy[0:i + 1]) ind_r, x_r, y_r, _, _ = _get_beads(mos_full[physical_slices[i]].attrs, optical_slices[i]) id_str = [] for this_id in ind_r: id_str.append(ref_slice + ":{0:05d}".format(int(this_id))) id_str = np.array(id_str) for j in range(len(x_r)): bb.add_bead(x_r[j] + dx_t, y_r[j] + dy_t, x_r[j], y_r[j], id_str[j]) # at each slice we check matching beads and update avg coords bb.update_coords() # Now the we know wich objects can be seen in more than one slice, # we re-compute the offsets but only using the beads that appear # in at least all the optical slices and 2 physical min_num_dets = np.max(np.array(optical_slices)) + 2 good_beads = [] for i in range(len(bb.x)): if bb.n[i] >= min_num_dets: good_beads.extend(bb.id_list[i]) # cleaning out repeated ids good_beads = list(set(good_beads)) # here we store the new displacements, and a measurement of error dx2 = [0.0] dy2 = [0.0] dd2 = [0.0] logger.info("2nd pass slice shifts") for i in range(1, len(physical_slices)): this_slice = physical_slices[i] + "_Z{0:03d}".format(optical_slices[i]) ref_slice = physical_slices[i - 1] + "_Z{0:03d}".format( optical_slices[i - 1]) # only beads that have high reps _, x_t, y_t, e_t = _get_good_beads(mos_full, physical_slices[i], optical_slices[i], good_beads) if len(x_t) > 0: _, x_r, y_r, e_r = _get_good_beads(mos_full, physical_slices[i - 1], optical_slices[i - 1], good_beads) dxt, dyt, edt, i_rt, i_tr = _match_cats(x_r, y_r, e_r, x_t, y_t, e_t, errors=True) dx2.append(dxt) dy2.append(dyt) dd2.append(edt) dr = np.sqrt((y_r[i_rt] - y_t[i_tr] - dyt)**2 + (x_r[i_rt] - x_t[i_tr] - dxt)**2) dr0 = np.sqrt((y_r[i_rt] - y_t[i_tr])**2 + (x_r[i_rt] - x_t[i_tr])**2) logger.info( ref_slice + ":" + this_slice + " {0:d} ".format(len(i_tr)) + "{0:.1f} {1:.1f} ".format(dxt, dyt) + "{0:.1f} {1:.1f} ".format(np.median(dr), np.median(dr0))) else: # NO beads in this slice dx2.append(0.0) dy2.append(0.0) dd2.append(100.0) logger.info(ref_slice + ":" + this_slice + " {0:d} ".format(len(i_tr)) + "{0:.1f} {1:.1f} ".format(0.0, 0.0) + "{0:.1f} {1:.1f} ".format(0.0, 0.0)) # Now we store all the displacements as attrs cube_reg = { "abs_dx": [], "abs_dy": [], "abs_err": [], "rel_dx": [], "rel_dy": [], "rel_err": [], "slice": [], "opt_z": [], } for i in range(len(physical_slices)): cube_reg["slice"].append(physical_slices[i]) cube_reg["opt_z"].append(float(optical_slices[i])) # because dx,dy are slice to slice, the total displacement # is the sum of all the previous dx_t = np.sum(dx2[0:i + 1]) dy_t = np.sum(dy2[0:i + 1]) de_t = np.sqrt(np.sum(np.array(dd2[0:i + 1])**2)) cube_reg["abs_dx"].append(dx_t) cube_reg["abs_dy"].append(dy_t) cube_reg["abs_err"].append(de_t) cube_reg["rel_dx"].append(dx2[i]) cube_reg["rel_dy"].append(dy2[i]) cube_reg["rel_err"].append(dd2[i]) zarr_store.attrs["cube_reg"] = cube_reg
def find_beads(mos_zarr: Path): # noqa: C901 """ Finds all the beads in all the slices (physical and optical) in the zarr, and fits the bead profile. Attaches all the bead info as attrs to the zarr """ # this is to store the beads later zarr_store = zarr.open(f"{mos_zarr}", mode="a") # conversion from dict headers to more informative # metadata names bead_par_to_attr_name = { "bead_id": "bead_id", "conv": "bead_conv", "corner": "bead_cutout_corner", "err": "bead_centre_err", "fit": "bead_fit_pars", "rad": "bead_rad", "x": "bead_x", "y": "bead_y", "z": "bead_z", } mos_full = xr.open_zarr(f"{mos_zarr}", group="") mos_zoom = xr.open_zarr(f"{mos_zarr}", group=f"l.{Settings.zoom_level}") full_shape = (mos_full.dims["y"], mos_full.dims["x"]) for this_slice in list(mos_zoom): first_bead = True bead_cat = {} for this_optical in list(mos_full.z.values): logger.info("Analysing beads in " + this_slice + " Z{0:03d}".format(this_optical)) im = (mos_zoom[this_slice].sel(z=this_optical, type="mosaic").mean(dim="channel")) conf = mos_zoom[this_slice].sel(z=this_optical, channel=Settings.channel_to_use, type="conf") # Img stats pedestal, im_std = image_stats(im.data, conf.data).compute() # Detection of all features with bead size labels, good_objects, good_cx, good_cy = det_features( im.data, pedestal, im_std) da_labels = da.from_array(labels).persist() logger.debug("Found {0:d} preliminary beads".format( len(good_objects))) logger.debug("Filtering preliminary detections") temp = [] for i in range(len(good_objects)): temp.append( _fit_bead_1stpass( im.data, da_labels, good_cx[i], good_cy[i], pedestal, im_std, good_objects[i], )) beads_1st_iter = dask.compute(temp)[0] logger.debug("First pass completed") # resampling to full arr full_labels = delayed(ndi.zoom)(labels, Settings.zoom_level, order=0) full_labels = da.from_delayed(full_labels, shape=full_shape, dtype="int") full_im = (mos_full[this_slice].sel( z=this_optical, type="mosaic").mean(dim="channel").data) # conf and error are the same across channels full_conf = (mos_full[this_slice].sel( z=this_optical, channel=Settings.channel_to_use, type="conf").data) full_err = (mos_full[this_slice].sel( z=this_optical, channel=Settings.channel_to_use, type="err").data) logger.debug("""" labels: {0:d},{1:d} im: {2:d},{3:d} conf: {4:d},{5:d} err: {6:d},{7:d} """.format( *full_labels.shape, *full_im.shape, *full_conf.shape, *full_err.shape, )) temp = [] logger.debug("Fitting all beads...") for i in range(len(beads_1st_iter)): if beads_1st_iter[i][-1] is False: continue if (beads_1st_iter[i][2] * Settings.zoom_level > Settings.feature_size[1]): continue temp.append( _fit_bead_2ndpass( full_im, full_conf, full_err, full_labels, beads_1st_iter[i], pedestal, im_std, good_objects[i], )) all_beads = dask.compute(temp)[0] # now we store all good beads in a dictionary, removing # duplicates and bad fits # # We'll store beads in the attrs for the physical slice # so we need to add the optical slice # done_x = [0] done_y = [0] for this_bead, bead_err in all_beads: this_bead["err"] = bead_err if _check_bead(this_bead, done_x, done_y, full_shape) is False: continue done_x.append(this_bead["x"]) done_y.append(this_bead["y"]) if first_bead: for this_key in this_bead.keys(): bead_cat[this_key] = [this_bead[this_key]] bead_cat["z"] = [float(this_optical)] else: for this_key in this_bead.keys(): bead_cat[this_key].append(this_bead[this_key]) bead_cat["z"].append(float(this_optical)) first_bead = False # Store results as attrs in the full res slice for this_key in bead_cat.keys(): zarr_store[this_slice].attrs[ bead_par_to_attr_name[this_key]] = bead_cat[this_key]