예제 #1
0
    def initialize_storage(self, output: Path):
        logger.info("Preparing mosaic output")

        ds = xr.Dataset()
        ds.attrs = self._ds.attrs
        ds.attrs["raw_meta"] = [ds.attrs["raw_meta"]]
        ds.to_zarr(output / "mos.zarr", mode="w")
예제 #2
0
def main(
    *,
    root_dir: Path,
    output_dir: Path,
    recipes: List,
    sections: List,
    reset: bool,
    cof_dist: Dict = None,
) -> None:
    """[summary]

    Parameters
    ----------
    root_dir : Path
        [description]
    output_dir : Path
        [description]
    """
    logger.info("Pipeline started")

    output_dir_full = output_dir / root_dir.name
    main.config["output_dir"] = f"{output_dir_full}"
    owl_dev.setup(main.config)

    if cof_dist is not None:
        Settings.cof_dist = cof_dist

    if not root_dir.exists():
        raise FileNotFoundError(f"Directory {root_dir} does not exist.")

    # TODO: compute dark and flat

    mos = STPTMosaic(root_dir)
    if "mosaic" in recipes:
        if reset:
            mos.initialize_storage(output_dir_full)

        for section in mos.sections():
            if sections and (section.name not in sections):
                continue
            section.find_offsets()
            section.stitch(output_dir_full)

            # after the first section, stage size is fixed:
            mos.stage_size = section.stage_size

    if "downsample" in recipes:
        mos.downsample(output_dir_full)

    mos_dis = output_dir_full / "mos.zarr"
    if "beadreg" in recipes:
        find_beads(mos_dis)
        register_slices(mos_dis)

    logger.info("Pipeline completed")
예제 #3
0
    def _compute_final_mosaic(self, z):
        logger.info("Creating final mosaic")
        mos_overlap = []
        mos_raw = []
        mos_err = []

        for name, offset in z[f"mosaic/{self.name}"].groups():
            raw = da.stack([
                da.from_zarr(ch["raw"]) / da.from_zarr(ch["overlap"])
                for name, ch in offset.groups()
            ])

            raw = (raw - Settings.bzero) / Settings.bscale
            overlap = (da.stack(
                [da.from_zarr(ch["overlap"])
                 for name, ch in offset.groups()]) * 100)
            err = (da.stack(
                [da.from_zarr(ch["pos_err"])
                 for name, ch in offset.groups()]) * 100)
            mos_overlap.append(overlap)
            mos_err.append(err)
            mos_raw.append(raw)
        mos_raw = da.stack(mos_raw)
        mos_overlap = da.stack(mos_overlap)
        mos_err = da.stack(mos_err)
        mos = da.stack([mos_raw, mos_overlap, mos_err]).rechunk(
            (1, 1, 10, CHUNK_SIZE, CHUNK_SIZE))
        nt, nz, nch, ny, nx = mos.shape

        raw = xr.DataArray(
            mos.astype("uint16"),
            dims=("type", "z", "channel", "y", "x"),
            coords={
                "type": ["mosaic", "conf", "err"],
                "x": range(nx),
                "y": range(ny),
                "z": range(nz),
                "channel": range(nch + 1)[1:],
            },
        )

        metadata = self._get_metadata()
        raw.attrs.update(metadata)

        return raw
예제 #4
0
    def _create_temporary_mosaic(self, conf, abs_pos, abs_err, output):
        logger.info("Creating temporary mosaic")
        z = zarr.open(f"{output}/{self.name}.zarr", mode="w")

        flat = read_calib(Settings.flat_file)
        dark = read_calib(Settings.dark_file) / Settings.norm_val
        cof_dist = Settings.cof_dist
        y_delta, x_delta = np.min(abs_pos, axis=0)

        for sl in range(self.slices):
            results = []
            for ch in range(self.channels):
                im_t = self.get_img_section(sl, ch)
                g = z.create_group(f"/mosaic/{self.name}/z={sl}/channel={ch}")

                im_dis = [
                    apply_geometric_transform(im.data / Settings.norm_val,
                                              dark[ch], flat[ch], cof_dist)
                    for im in im_t
                ]

                for imgtype in ["raw", "pos_err", "overlap"]:
                    nimg = _get_image(g,
                                      imgtype,
                                      self.stage_size,
                                      dtype="float32")

                    for i in range(len(im_dis)):
                        y0 = int(abs_pos[i, 0] - y_delta)
                        x0 = int(abs_pos[i, 1] - x_delta)
                        res = _mosaic(
                            im_dis[i],
                            ch,
                            conf,
                            (y0, x0),
                            abs_err,
                            imgtype,
                            nimg,
                        )
                        results.append(res)
            dask.compute(results)
            logger.debug("Mosaic %s Slice %d done", self.name, sl)
        return z
예제 #5
0
    def stitch(self, output: Path):
        """Stitch and save all images"""

        ds = xr.open_zarr(output / "mos.zarr")
        if self.name in ds:
            logger.info("Section %s already done. Skipping.", self.name)
            return

        abs_pos, abs_err = self.compute_abspos()
        conf = self.get_distconf()

        if self.stage_size is None:
            self.stage_size = self._stage_size(abs_pos)

        z = self._create_temporary_mosaic(conf, abs_pos, abs_err, output)

        arr = self._compute_final_mosaic(z)
        ds = xr.Dataset({self._section.name: arr})
        ds.to_zarr(output / "mos.zarr", mode="a")

        # clean temporary mosaic
        shutil.rmtree(f"{output}/{self.name}.zarr", ignore_errors=True)
        logger.info("Mosaic saved %s", output / "mos.zarr")
예제 #6
0
    def downsample(self, output: Path):
        """Downsample mosaic.

        Parameters
        ----------
        output
            Location of output directory
        """
        logger.info("Downsampling mosaics")
        store_name = output / "mos.zarr"
        store = zarr.DirectoryStore(store_name)
        up = ""
        for factor in Settings.scales:
            logger.debug("Downsampling factor %d", factor)
            ds = xr.open_zarr(f"{store_name}", group=up)
            nds = xr.Dataset()
            down = f"l.{factor}"
            nds.to_zarr(store, mode="w", group=down)

            slices = list(ds)
            for s in slices:
                nds = xr.Dataset()
                logger.debug("Downsampling mos%d [%s]", factor, s)
                narr = ops.downsample(ds[s])
                nds[s] = narr
                nds.to_zarr(store, mode="a", group=down)
            logger.info("Downsampled mosaic saved %s:%s", store_name, down)
            up = down
        arr = zarr.open(f"{store_name}", mode="r+")
        arr.attrs["multiscale"] = {
            "datasets": [
                {
                    "path": "",
                    "level": 1
                },
                {
                    "path": "l.2",
                    "level": 2
                },
                {
                    "path": "l.4",
                    "level": 4
                },
                {
                    "path": "l.8",
                    "level": 8
                },
                {
                    "path": "l.16",
                    "level": 16
                },
                {
                    "path": "l.32",
                    "level": 32
                },
            ],
            "metadata": {
                "method": "cv2.pyrDown",
                "version": cv2.__version__
            },
        }
        arr.attrs["bscale"] = Settings.bscale
        arr.attrs["bzero"] = Settings.bzero
예제 #7
0
    def compute_abspos(self):  # noqa: C901
        """Compute absolute positions of images in the field of view."""
        logger.info("Processing section %s", self._section.name)

        # img_cube = self.get_img_section(0, Settings.channel_to_use)
        # img_cube_stack = da.stack(img_cube)
        # # ref image is the one with the max integrated flux
        # cube_totals = img_cube_stack.sum(axis=(1, 2))
        # cube_totals = cube_totals.compute()
        # cube_means = cube_totals / 2080.0 ** 2

        # absolute_ref_img = cube_means.argmax()
        absolute_ref_img = self.absolute_ref_img

        if self.offset_mode == "default":
            dx = np.array(self["XPos"])
            dy = np.array(self["YPos"])

            abs_pos = []
            abs_err = []

            for i in range(len(dx)):
                abs_pos.append([
                    (dx[i] - dx[absolute_ref_img]) / Settings.mosaic_scale,
                    (dy[i] - dy[absolute_ref_img]) / Settings.mosaic_scale,
                ])
                abs_err.append([15.0, 15.0])

            logger.info("Displacements too large, resorting to default grid")

            abs_pos = np.array(abs_pos)
            abs_err = np.array(abs_err)

            self._section.attrs["abs_pos"] = abs_pos.tolist()
            self._section.attrs["abs_err"] = abs_err.tolist()
            self._section.attrs["default_displacements"] = [{
                "default_x": 0.0,
                "dev_x": 0.0,
                "default_y": 0.0,
                "dev_y": 0.0
            }]
            self._x_scale = Settings.mosaic_scale
            self._y_scale = Settings.mosaic_scale

            return abs_pos, abs_err

        self.compute_pairs()
        scale_x, scale_y = self.scale
        dx0, dy0, delta_x, delta_y = self.find_grid()
        default_x, dev_x, default_y, dev_y = self.get_default_displacement()

        # get error threshold scomparing scaled micron displacements
        # with measurements
        px_x_temp = []
        px_y_temp = []
        mu_x_temp = []
        mu_y_temp = []
        avg_f_temp = []
        for this_key in self._px.keys():
            px_x_temp.append(self._px[this_key][0])
            px_y_temp.append(self._px[this_key][1])
            mu_x_temp.append(self._mu[this_key][0] / scale_x)
            mu_y_temp.append(self._mu[this_key][1] / scale_y)
            avg_f_temp.append(self._avg[this_key])
        px_x_temp = np.array(px_x_temp)
        px_y_temp = np.array(px_y_temp)
        mu_x_temp = np.array(mu_x_temp)
        mu_y_temp = np.array(mu_y_temp)
        avg_f_temp = np.array(avg_f_temp)

        ind_temp = np.where((np.abs(px_x_temp) > np.mean(np.abs(px_x_temp)))
                            & (avg_f_temp > 0.05))[0]
        if len(ind_temp) > 3:
            elongx = 1.48 * mad(np.sqrt((px_x_temp - mu_x_temp)[ind_temp]**2))
            eshorty = 1.48 * mad(np.sqrt((px_y_temp - mu_y_temp)[ind_temp]**2))
        else:
            elongx = 10
            eshorty = 10

        ind_temp = np.where((np.abs(px_y_temp) > np.mean(np.abs(px_y_temp)))
                            & (avg_f_temp > 0.05))[0]
        if len(ind_temp) > 3:
            elongy = 1.48 * mad(np.sqrt((px_y_temp - mu_y_temp)[ind_temp]**2))
            eshortx = 1.48 * mad(np.sqrt((px_x_temp - mu_x_temp)[ind_temp]**2))
        else:
            elongy = 10.0
            eshortx = 10.0

        error_long_threshold = 3.0 * np.max([elongx, elongy]).clip(2, 30)
        error_short_threshold = 3.0 * np.max([eshortx, eshorty]).clip(2, 30)
        logger.debug("Scaling error threshold: l:{0:3f} s:{1:.3f}".format(
            error_long_threshold, error_short_threshold))

        accumulated_pos = []
        accumulated_qual = []
        for this_ref in np.where(
                self.cube_means > np.median(self.cube_means))[0]:

            temp = self.compute_abspos_ref(
                dx0,
                dy0,
                default_x,
                dev_x,
                default_y,
                dev_y,
                scale_x,
                scale_y,
                this_ref,
                error_long_threshold,
                error_short_threshold,
            )
            # using a common reference
            accumulated_pos.append(
                np.array(temp[0]) - np.array(temp[0])[absolute_ref_img])
            # adding quality of global reference
            accumulated_qual.append(
                np.sqrt(
                    np.array(temp[1])**2 +
                    np.array(temp[1])[absolute_ref_img]**2))

        accumulated_pos = np.array(accumulated_pos)
        abs_pos = np.median(accumulated_pos, 0)
        abs_err = np.std(accumulated_pos, 0)

        self._section.attrs["abs_pos"] = abs_pos.tolist()
        self._section.attrs["abs_err"] = abs_err.tolist()

        return abs_pos, abs_err
예제 #8
0
    def find_offsets(self):  # noqa: C901
        """Calculate offsets between all pairs of overlapping images"""
        client = Client.current()
        # convert to find_shifts
        results = []
        logger.info("Processing section %s", self._section.name)
        # debug
        # img_cube = self.get_img_section(0, Settings.channel_to_use - 1)
        img_cube = self.get_img_section(0, -1)

        # Calculate confidence map. Only needs to be done once per section
        dist_conf = self.get_distconf()

        flat = read_calib(Settings.flat_file)[Settings.channel_to_use - 1]
        dark = (read_calib(Settings.dark_file)[Settings.channel_to_use - 1] /
                Settings.norm_val)

        flat = flat.persist()
        dark = dark.persist()

        dx0, dy0, delta_x, delta_y = self.find_grid()
        dx_mos, dy_mos = self.get_mos_pos()

        # We calculate the ref_img here, for if the slice
        # has no data (very low max. avg flux), we skip calculations too
        img_cube_stack = da.stack(img_cube)
        # ref image is the one with the max integrated flux
        cube_totals = img_cube_stack.sum(axis=(1, 2))
        cube_totals = cube_totals.compute()
        self.cube_means = cube_totals / 2080.0**2

        self.absolute_ref_img = self.cube_means.argmax()
        self.mean_ref_img = self.cube_means.max()

        logger.info("Max mean: {0:.5f}".format(self.mean_ref_img))

        # If the default displacements are too large
        # to for overlaps, stick with the default scale
        self.offset_mode = "sampled"

        if delta_x / Settings.mosaic_scale > 1950.0:
            logger.info("Displacement in X too large: {0:.1f}".format(
                delta_x / Settings.mosaic_scale))
            self.offset_mode = "default"
        if delta_y / Settings.mosaic_scale > 1950.0:
            logger.info("Displacement in Y too large: {0:.1f}".format(
                delta_y / Settings.mosaic_scale))
            self.offset_mode = "default"

        if self.mean_ref_img < 0.05:
            logger.info("Avg. flux too low: {0:.3f}<0.05".format(
                self.mean_ref_img))
            self.offset_mode = "default"

        for i, img in enumerate(img_cube):
            r = np.sqrt((dx0 - dx0[i])**2 + (dy0 - dy0[i])**2)

            # including no diagonals

            i_t = np.where((r <= np.sqrt(1)) & (r > 0))[0].tolist()

            im_i = img

            for this_img in i_t:
                if i > this_img:
                    continue

                desp = [
                    (dx_mos[this_img] - dx_mos[i]) / Settings.mosaic_scale,
                    (dy_mos[this_img] - dy_mos[i]) / Settings.mosaic_scale,
                ]

                # trying to do always positive displacements

                if (desp[0] < -100) or (desp[1] < -100):
                    desp[0] *= -1
                    desp[1] *= -1
                    obj_img = i
                    im_obj = im_i
                    ref_img = this_img
                    im_ref = img_cube[this_img]
                    sign = -1
                else:
                    obj_img = this_img
                    im_obj = img_cube[this_img]
                    ref_img = i
                    im_ref = im_i
                    sign = 1

                if self.offset_mode == "sampled":
                    im1 = apply_geometric_transform(
                        im_ref.data / Settings.norm_val, dark, flat,
                        Settings.cof_dist)
                    im2 = apply_geometric_transform(
                        im_obj.data / Settings.norm_val, dark, flat,
                        Settings.cof_dist)
                    res = delayed(find_overlap_conf)(im1, im2, dist_conf,
                                                     dist_conf, desp)
                    results.append(_sink(ref_img, obj_img, res, sign))

                else:
                    logger.debug(
                        "Initial offsets i: %d j: %d dx: %d dy: %d",
                        ref_img,
                        obj_img,
                        desp[0],
                        desp[1],
                    )
                    results.append([ref_img, obj_img, desp[0], desp[1], sign])

        if self.offset_mode == "sampled":
            futures = client.compute(results)
            offsets = []
            for fut in as_completed(futures):
                i, j, res, sign = fut.result()
                dx, dy, mi, avf = res.x, res.y, res.mi, res.avg_flux

                offsets.append([i, j, res, sign])

                logger.debug(
                    "Section %s offsets i: %d j: %d dx: %d dy: %d mi: %f avg_f: %f sign: %d",
                    self._section.name,
                    i,
                    j,
                    dx,
                    dy,
                    mi,
                    avf,
                    sign,
                )
        else:
            offsets = []
            for t in results:
                i, j, dx, dy, sign = t

                offsets.append([i, j, [dx, dy, 0.0, 0.0], sign])

                logger.debug(
                    "Section %s offsets i: %d j: %d dx: %d dy: %d mi: %f avg_f: %f sign: %d",
                    self._section.name,
                    i,
                    j,
                    dx,
                    dy,
                    0.0,
                    0.0,
                    sign,
                )
        self._offsets = offsets
def register_slices(mos_zarr: Path):  # noqa: C901
    """
    Uses all the detected beads in each slice to cross-match
    and calculates an average displacement so that all
    slices are matched to the first one
    """
    # this is to store the beads later
    zarr_store = zarr.open(f"{mos_zarr}", mode="a")

    mos_full = xr.open_zarr(f"{mos_zarr}", group="")
    _slices = list(mos_full)

    # putting all slices on a single list
    optical_slices = []
    physical_slices = []
    for this_slice in _slices:
        for this_optical in mos_full[this_slice].z.values:
            physical_slices.append(this_slice)
            optical_slices.append(this_optical)

    # first pass crossmatch, taking into account all beads
    dx = [0.0]  # these store the slice to slice offset
    dy = [0.0]
    logger.info("1st pass slice shifts")
    for i in range(1, len(physical_slices)):

        # We compare each slice (_t) with the previous one (_r)
        _, x_t, y_t, _, e_t = _get_beads(mos_full[physical_slices[i]].attrs,
                                         optical_slices[i])

        _, x_r, y_r, _, e_r = _get_beads(
            mos_full[physical_slices[i - 1]].attrs, optical_slices[i - 1])

        if (len(x_t) > 0) & (len(x_r) > 0):

            dxt, dyt, i_rt, i_tr = _match_cats(x_r, y_r, e_r, x_t, y_t, e_t)
            dx.append(dxt)
            dy.append(dyt)

            dr = np.sqrt((y_r[i_rt] - y_t[i_tr] - dyt)**2 +
                         (x_r[i_rt] - x_t[i_tr] - dxt)**2)
            dr0 = np.sqrt((y_r[i_rt] - y_t[i_tr])**2 +
                          (x_r[i_rt] - x_t[i_tr])**2)

        else:
            dx.append(0.0)
            dy.append(0.0)
            i_tr = []
            dxt = dyt = dr = dr0 = 0

        logger.info(physical_slices[i - 1] +
                    "_Z{0:03d}:".format(optical_slices[i - 1]) +
                    physical_slices[i] +
                    "_Z{0:03d}:".format(optical_slices[i]) +
                    " {0:d} ".format(len(i_tr)) +
                    "{0:.1f} {1:.1f} ".format(dxt, dyt) +
                    "{0:.1f} {1:.1f} ".format(np.median(dr), np.median(dr0)))

    # now that we know the slice to slice offset, we construct the catalogue
    # of all the beads
    bb = bead_collection()

    for i in range(len(physical_slices)):
        ref_slice = physical_slices[i] + "_Z{0:03d}".format(optical_slices[i])

        # because dx,dy are slice to slice, the total displacement
        # is the sum of all the previous
        dx_t = np.sum(dx[0:i + 1])
        dy_t = np.sum(dy[0:i + 1])

        ind_r, x_r, y_r, _, _ = _get_beads(mos_full[physical_slices[i]].attrs,
                                           optical_slices[i])
        id_str = []
        for this_id in ind_r:
            id_str.append(ref_slice + ":{0:05d}".format(int(this_id)))
        id_str = np.array(id_str)

        for j in range(len(x_r)):
            bb.add_bead(x_r[j] + dx_t, y_r[j] + dy_t, x_r[j], y_r[j],
                        id_str[j])

        # at each slice we check matching beads and update avg coords
        bb.update_coords()

    # Now the we know wich objects can be seen in more than one slice,
    # we re-compute the offsets but only using the beads that appear
    # in at least all the optical slices and 2 physical

    min_num_dets = np.max(np.array(optical_slices)) + 2

    good_beads = []
    for i in range(len(bb.x)):
        if bb.n[i] >= min_num_dets:
            good_beads.extend(bb.id_list[i])
    # cleaning out repeated ids
    good_beads = list(set(good_beads))

    # here we store the new displacements, and a measurement of error
    dx2 = [0.0]
    dy2 = [0.0]
    dd2 = [0.0]
    logger.info("2nd pass slice shifts")
    for i in range(1, len(physical_slices)):
        this_slice = physical_slices[i] + "_Z{0:03d}".format(optical_slices[i])
        ref_slice = physical_slices[i - 1] + "_Z{0:03d}".format(
            optical_slices[i - 1])

        # only beads that have high reps
        _, x_t, y_t, e_t = _get_good_beads(mos_full, physical_slices[i],
                                           optical_slices[i], good_beads)

        if len(x_t) > 0:
            _, x_r, y_r, e_r = _get_good_beads(mos_full,
                                               physical_slices[i - 1],
                                               optical_slices[i - 1],
                                               good_beads)

            dxt, dyt, edt, i_rt, i_tr = _match_cats(x_r,
                                                    y_r,
                                                    e_r,
                                                    x_t,
                                                    y_t,
                                                    e_t,
                                                    errors=True)

            dx2.append(dxt)
            dy2.append(dyt)
            dd2.append(edt)

            dr = np.sqrt((y_r[i_rt] - y_t[i_tr] - dyt)**2 +
                         (x_r[i_rt] - x_t[i_tr] - dxt)**2)
            dr0 = np.sqrt((y_r[i_rt] - y_t[i_tr])**2 +
                          (x_r[i_rt] - x_t[i_tr])**2)
            logger.info(
                ref_slice + ":" + this_slice + " {0:d} ".format(len(i_tr)) +
                "{0:.1f} {1:.1f} ".format(dxt, dyt) +
                "{0:.1f} {1:.1f} ".format(np.median(dr), np.median(dr0)))
        else:
            # NO beads in this slice
            dx2.append(0.0)
            dy2.append(0.0)
            dd2.append(100.0)

            logger.info(ref_slice + ":" + this_slice +
                        " {0:d} ".format(len(i_tr)) +
                        "{0:.1f} {1:.1f} ".format(0.0, 0.0) +
                        "{0:.1f} {1:.1f} ".format(0.0, 0.0))

    # Now we store all the displacements as attrs
    cube_reg = {
        "abs_dx": [],
        "abs_dy": [],
        "abs_err": [],
        "rel_dx": [],
        "rel_dy": [],
        "rel_err": [],
        "slice": [],
        "opt_z": [],
    }
    for i in range(len(physical_slices)):
        cube_reg["slice"].append(physical_slices[i])
        cube_reg["opt_z"].append(float(optical_slices[i]))

        # because dx,dy are slice to slice, the total displacement
        # is the sum of all the previous
        dx_t = np.sum(dx2[0:i + 1])
        dy_t = np.sum(dy2[0:i + 1])
        de_t = np.sqrt(np.sum(np.array(dd2[0:i + 1])**2))

        cube_reg["abs_dx"].append(dx_t)
        cube_reg["abs_dy"].append(dy_t)
        cube_reg["abs_err"].append(de_t)

        cube_reg["rel_dx"].append(dx2[i])
        cube_reg["rel_dy"].append(dy2[i])
        cube_reg["rel_err"].append(dd2[i])

    zarr_store.attrs["cube_reg"] = cube_reg
def find_beads(mos_zarr: Path):  # noqa: C901
    """
    Finds all the beads in all the slices (physical and optical) in the
    zarr, and fits the bead profile.

    Attaches all the bead info as attrs to the zarr
    """
    # this is to store the beads later
    zarr_store = zarr.open(f"{mos_zarr}", mode="a")

    # conversion from dict headers to more informative
    # metadata names
    bead_par_to_attr_name = {
        "bead_id": "bead_id",
        "conv": "bead_conv",
        "corner": "bead_cutout_corner",
        "err": "bead_centre_err",
        "fit": "bead_fit_pars",
        "rad": "bead_rad",
        "x": "bead_x",
        "y": "bead_y",
        "z": "bead_z",
    }

    mos_full = xr.open_zarr(f"{mos_zarr}", group="")
    mos_zoom = xr.open_zarr(f"{mos_zarr}", group=f"l.{Settings.zoom_level}")

    full_shape = (mos_full.dims["y"], mos_full.dims["x"])

    for this_slice in list(mos_zoom):

        first_bead = True
        bead_cat = {}

        for this_optical in list(mos_full.z.values):

            logger.info("Analysing beads in " + this_slice +
                        " Z{0:03d}".format(this_optical))

            im = (mos_zoom[this_slice].sel(z=this_optical,
                                           type="mosaic").mean(dim="channel"))

            conf = mos_zoom[this_slice].sel(z=this_optical,
                                            channel=Settings.channel_to_use,
                                            type="conf")

            # Img stats
            pedestal, im_std = image_stats(im.data, conf.data).compute()

            # Detection of all features with bead size
            labels, good_objects, good_cx, good_cy = det_features(
                im.data, pedestal, im_std)
            da_labels = da.from_array(labels).persist()

            logger.debug("Found {0:d} preliminary beads".format(
                len(good_objects)))

            logger.debug("Filtering preliminary detections")

            temp = []
            for i in range(len(good_objects)):
                temp.append(
                    _fit_bead_1stpass(
                        im.data,
                        da_labels,
                        good_cx[i],
                        good_cy[i],
                        pedestal,
                        im_std,
                        good_objects[i],
                    ))
            beads_1st_iter = dask.compute(temp)[0]
            logger.debug("First pass completed")

            # resampling to full arr
            full_labels = delayed(ndi.zoom)(labels,
                                            Settings.zoom_level,
                                            order=0)
            full_labels = da.from_delayed(full_labels,
                                          shape=full_shape,
                                          dtype="int")

            full_im = (mos_full[this_slice].sel(
                z=this_optical, type="mosaic").mean(dim="channel").data)

            # conf and error are the same across channels
            full_conf = (mos_full[this_slice].sel(
                z=this_optical, channel=Settings.channel_to_use,
                type="conf").data)

            full_err = (mos_full[this_slice].sel(
                z=this_optical, channel=Settings.channel_to_use,
                type="err").data)
            logger.debug(""""
                    labels: {0:d},{1:d}
                    im: {2:d},{3:d}
                    conf: {4:d},{5:d}
                    err: {6:d},{7:d}
                """.format(
                *full_labels.shape,
                *full_im.shape,
                *full_conf.shape,
                *full_err.shape,
            ))
            temp = []
            logger.debug("Fitting all beads...")
            for i in range(len(beads_1st_iter)):
                if beads_1st_iter[i][-1] is False:
                    continue
                if (beads_1st_iter[i][2] * Settings.zoom_level >
                        Settings.feature_size[1]):
                    continue

                temp.append(
                    _fit_bead_2ndpass(
                        full_im,
                        full_conf,
                        full_err,
                        full_labels,
                        beads_1st_iter[i],
                        pedestal,
                        im_std,
                        good_objects[i],
                    ))
            all_beads = dask.compute(temp)[0]

            # now we store all good beads in a dictionary, removing
            # duplicates and bad fits
            #
            # We'll store beads in the attrs for the physical slice
            # so we need to add the optical slice
            #
            done_x = [0]
            done_y = [0]

            for this_bead, bead_err in all_beads:
                this_bead["err"] = bead_err

                if _check_bead(this_bead, done_x, done_y, full_shape) is False:
                    continue

                done_x.append(this_bead["x"])
                done_y.append(this_bead["y"])

                if first_bead:
                    for this_key in this_bead.keys():
                        bead_cat[this_key] = [this_bead[this_key]]
                    bead_cat["z"] = [float(this_optical)]
                else:
                    for this_key in this_bead.keys():
                        bead_cat[this_key].append(this_bead[this_key])
                    bead_cat["z"].append(float(this_optical))

                first_bead = False

        # Store results as attrs in the full res slice
        for this_key in bead_cat.keys():
            zarr_store[this_slice].attrs[
                bead_par_to_attr_name[this_key]] = bead_cat[this_key]