Esempio n. 1
0
    def get_dataset(self):
        if self.config.dataset == DatasetChoice.from_path:
            assert self.dataset_part == DatasetPart.test

            tensor_infos = {
                self.config.pred_name:
                TensorInfo(
                    name=self.config.pred_name,
                    root=self.config.path,
                    location=self.config.pred_glob,
                    transforms=self.transforms_pipeline.sample_precache_trf,
                    datasets_per_file=1,  # todo: remove hard coded
                    samples_per_dataset=1,
                    remove_singleton_axes_at=(-1, ),
                    insert_singleton_axes_at=(0, 0),  # todo: remove hard coded
                    z_slice=None,
                    skip_indices=tuple(),
                    meta=None,
                ),
                self.config.trgt_name:
                TensorInfo(
                    name=self.config.trgt_name,
                    root=self.config.path,
                    location=self.config.trgt_glob,
                    transforms=self.transforms_pipeline.sample_precache_trf,
                    datasets_per_file=1,
                    samples_per_dataset=1,
                    remove_singleton_axes_at=(-1, ),  # todo: remove hard coded
                    insert_singleton_axes_at=(0, 0),  # todo: remove hard coded
                    z_slice=None,
                    skip_indices=tuple(),
                    meta=None,
                ),
            }
            dtst = ZipDataset({
                name: get_dataset_from_info(ti,
                                            cache=True,
                                            filters=[],
                                            indices=None)
                for name, ti in tensor_infos.items()
            })
            return ConcatDataset(
                [dtst],
                transform=self.transforms_pipeline.sample_preprocessing)

        else:
            return get_dataset(
                self.config.dataset,
                self.dataset_part,
                nnum=19,
                z_out=49,
                scale=self.scale,
                shrink=self.shrink,
                interpolation_order=self.config.interpolation_order,
                incl_pred_vol="pred_vol" in self.save_output_to_disk,
                load_lfd_and_care=self.load_lfd_and_care,
            )
Esempio n. 2
0
def get_tensor_info(tag: str, name: str, meta: dict):
    root = "GKRESHUK"
    if tag == "RC_LFD_n156to156_steps4":
        if name == "lfd":
            location = "LF_computed/LenseLeNet_Microscope/dualview_060918_added/RC_LFD_-156to156_steps4/Cam_Right_*.tif"
        elif name == "lf":
            location = "LF_computed/LenseLeNet_Microscope/dualview_060918_added/RC_rectified/Cam_Right_*.tif"
        else:
            raise NotImplementedError(tag, name)
    elif tag == "LC_LFD_n156to156_steps4":
        if name == "lfd":
            location = "LF_computed/LenseLeNet_Microscope/dualview_060918_added/LC_LFD_-156to156_steps4/Cam_Right_*.tif"
        elif name == "lf":
            location = "LF_computed/LenseLeNet_Microscope/dualview_060918_added/LC_rectified/Cam_Right_*.tif"
        else:
            raise NotImplementedError(tag, name)

    else:
        raise NotImplementedError(tag, name)

    return TensorInfo(name=name,
                      root=root,
                      location=location,
                      insert_singleton_axes_at=[0, 0],
                      tag=tag)
Esempio n. 3
0
def get_tensor_info(tag: str, name: str, meta: dict):
    your_tensor_info = TensorInfo(
        name="lf",
        # root="my_data_root", #optional root name for location (as specified in hylfm._settings.local.py)
        location=
        "/path/to/data/with/glob_expr/in_*folders/and/or/files_*.tif",  # or .h5
        # insert_singleton_axes_at: hylfm expects an explicit batch dimension as first axis
        insert_singleton_axes_at=[0, 0
                                  ],  # e.g. [0, 0] for image as xy: xy -> bcxy
        meta=meta,
    )
    raise NotImplementedError(f"tag: {tag}, name: {name}, meta: {meta}")
    return your_tensor_info
Esempio n. 4
0
    def get_dataset(self):
        assert self.config.dataset == DatasetChoice.predict_path
        assert self.dataset_part == DatasetPart.predict

        tensor_info = TensorInfo(
            name="lf",
            root=self.config.path,
            location=self.config.glob_lf,
            transforms=self.transforms_pipeline.sample_precache_trf,
            datasets_per_file=1,
            samples_per_dataset=1,
            remove_singleton_axes_at=tuple(),  # (-1,),
            insert_singleton_axes_at=(0, 0),
            z_slice=None,
            skip_indices=tuple(),
            meta=None,
        )

        dtst = get_dataset_from_info(tensor_info,
                                     cache=True,
                                     filters=[],
                                     indices=None)
        return ConcatDataset(
            [dtst], transform=self.transforms_pipeline.sample_preprocessing)
Esempio n. 5
0
from hylfm.datasets import TensorInfo

beads_right = TensorInfo(
    name="lf",
    root="GKRESHUK",
    location=
    "LF_computed/LenseLeNet_Microscope/DualView_comparison_heart_movie/beads/2018-09-06_13.16.46/stack_2_channel_0/RC_rectified/*.tif",
    insert_singleton_axes_at=[0, 0],
    tag="dualbeads_right",
)

beads_left = TensorInfo(
    name="lf",
    root="GKRESHUK",
    location=
    "LF_computed/LenseLeNet_Microscope/DualView_comparison_heart_movie/beads/2018-09-06_13.16.46/stack_2_channel_0/LC_rectified/*.tif",
    insert_singleton_axes_at=[0, 0],
    tag="dualbeads_left",
)

heart_right = TensorInfo(
    name="lf",
    root="GKRESHUK",
    location=
    "LF_computed/LenseLeNet_Microscope/DualView_comparison_heart_movie/heart/Rectified_RC/*.tif",
    insert_singleton_axes_at=[0, 0],
    tag="dualheart_right",
)


def get_tensor_info(tag: str, name: str, meta: dict):
Esempio n. 6
0
def trace_and_plot(
    tgt_path: Union[str, Path],
    tgt: str,
    roi: Tuple[slice, slice],
    plots: List[Union[Dict[str, Dict[str, Union[str, Path, List]]], Set[str]]],
    output_path: Path,
    nr_traces: int,
    background_threshold: Optional[float] = None,
    overwrite_existing_files: bool = False,
    smooth_diff_sigma: float = 1.3,
    peak_threshold_abs: float = 1.0,
    reduce_peak_area: str = "mean",
    time_range: Optional[Tuple[int, Optional[int]]] = None,
    plot_peaks: bool = False,
    compute_peaks_on: str = "std",  # std, diff, a*std+b*diff
    peaks_min_dist: int = 3,
    trace_radius: int = 3,
    compensate_motion: Optional[dict] = None,
    tag: str = "",  # for plot title only
    peak_path: Optional[Union[str, Path]] = None,
    compensated_peak_path: Optional[Union[str, Path]] = None,
):
    for plot in plots:
        for recon, kwargs in plot.items():
            kwargs["smooth"] = [
                tuple(
                    [
                        None
                        if smoo is None
                        else tuple([json.dumps(sm, sort_keys=True) if isinstance(sm, dict) else sm for sm in smoo])
                        for smoo in smooth
                    ]
                )
                for smooth in kwargs["smooth"]
            ]

    if isinstance(tgt_path, str):
        tgt_path = Path(tgt_path)

    if isinstance(compensated_peak_path, str):
        compensated_peak_path = Path(compensated_peak_path)

    def path2string(obj):
        if isinstance(obj, Path):
            return str(obj)
        elif isinstance(obj, list):
            return [path2string(v) for v in obj]
        elif isinstance(obj, set):
            return {path2string(v) for v in obj}
        elif isinstance(obj, dict):
            return {k: path2string(v) for k, v in obj.items()}
        else:
            return obj

    all_kwargs = {
        "tgt_path": str(tgt_path),
        "tgt": tgt,
        "roi": str(roi),
        "plots": path2string(plots),
        "nr_traces": nr_traces,
        "background_threshold": background_threshold,
        "overwrite_existing_files": overwrite_existing_files,
        "smooth_diff_sigma": smooth_diff_sigma,
        "peak_threshold_abs": peak_threshold_abs,
        "reduce_peak_area": reduce_peak_area,
        "time_range": time_range,
        "plot_peaks": plot_peaks,
        "compute_peaks_on": compute_peaks_on,
        "peaks_min_dist": peaks_min_dist,
        "trace_radius": trace_radius,
        "compensate_motion": compensate_motion,
        "tag": tag,
        "peak_path": None if peak_path is None else str(peak_path),
        "compensated_peak_path": None if compensated_peak_path is None else str(compensated_peak_path),
    }
    descr_hash = sha256()
    descr_hash.update(json.dumps(all_kwargs, sort_keys=True).encode("utf-8"))
    output_path /= descr_hash.hexdigest()
    print("output path:", output_path)
    output_path.mkdir(exist_ok=True, parents=True)
    yaml.dump(all_kwargs, output_path / "kwargs.yaml")
    default_smooth = [(None, ("flat", 3))]
    for i in range(len(plots)):
        if isinstance(plots[i], set):
            plots[i] = {recon: {"path": tgt_path, "smooth": default_smooth} for recon in plots[i]}
        elif isinstance(plots[i], dict):
            full_recons = {}
            for recon, kwargs in plots[i].items():
                if "path" not in kwargs:
                    kwargs["path"] = tgt_path

                if "smooth" not in kwargs:
                    kwargs["smooth"] = default_smooth

                full_recons[recon] = kwargs

            plots[i] = full_recons
        else:
            raise TypeError(type(plots[i]))

    all_recon_paths = {}
    for recons in plots:
        for recon, kwargs in recons.items():
            if recon in all_recon_paths:
                assert kwargs["path"] == all_recon_paths[recon], (kwargs["path"], all_recon_paths[recon])
            else:
                all_recon_paths[recon] = kwargs["path"]

    ds_tgt = TiffDataset(info=TensorInfo(name=tgt, root=tgt_path, location=f"{tgt}/*.tif"))
    length = len(ds_tgt)
    datasets_to_trace = {
        recon: TiffDataset(info=TensorInfo(name=recon, root=path, location=f"{recon}/*.tif"))
        for recon, path in all_recon_paths.items()
    }
    assert all([len(recon_ds) == length for recon_ds in datasets_to_trace.values()]), [length] + [
        len(recon_ds) for recon_ds in datasets_to_trace.values()
    ]
    assert tgt not in datasets_to_trace
    datasets_to_trace[tgt] = ds_tgt
    if time_range is not None:
        for name, ds in datasets_to_trace.items():
            datasets_to_trace[name] = Subset(
                ds, numpy.arange(time_range[0], len(ds) if time_range[1] is None else time_range[1])
            )
            assert len(datasets_to_trace[name]) <= 600, "accross TP??"

    # load data
    for name, ds in datasets_to_trace.items():
        datasets_to_trace[name] = numpy.stack(
            [
                sample[name].squeeze()[roi]
                for sample in DataLoader(
                    dataset=ds,
                    shuffle=False,
                    collate_fn=get_collate_fn(lambda b: b),
                    num_workers=settings.max_workers_for_trace,
                    pin_memory=False,
                )
            ]
        )
        assert numpy.isfinite(datasets_to_trace[name]).all()

        # plt.imshow(datasets_to_trace[name].max(axis=0))
        # plt.title(name)
        # plt.show()

    compensate_motion_of_peaks = compensate_motion is not None and compensate_motion.pop("of_peaks", False)
    if not compensate_motion_of_peaks and compensate_motion is not None:
        compensate_ref_name = compensate_motion.pop("compensate_ref", tgt)
        assert compensate_ref_name == tgt, "not implemented"
        compensate_ref = datasets_to_trace[compensate_ref_name]
        motion = skvideo.motion.blockMotion(compensate_ref, **compensate_motion)
        # print("motion", motion.shape, motion.max())
        assert numpy.isfinite(motion).all()
        for name in set(datasets_to_trace.keys()):
            data = datasets_to_trace[name]
            # print("data", data.shape)

            # compensate the video
            compensate = (
                skvideo.motion.blockComp(data, motion, mbSize=compensate_motion.get("mbSize", 8))
                .squeeze(axis=-1)
                .astype("float32")
            )
            # print("compensate", compensate.shape)
            assert numpy.isfinite(compensate).all()
            # write
            volwrite(output_path / f"{name}_motion_compensated.tif", compensate)
            volwrite(output_path / f"{name}_not_compensated.tif", data)
    else:
        motion = None

    compensated_peaks = None
    figs = {}
    if peak_path is None and compensated_peak_path is None:
        peak_path = output_path / f"{tgt}_peaks_of_{compute_peaks_on}.yml"
        peaks = None
        if peak_path.exists() and not overwrite_existing_files:
            peaks = numpy.asarray(yaml.load(peak_path))

            if peaks.shape[0] != nr_traces:
                peaks = None
    elif peak_path is None and compensated_peak_path is not None:
        assert compensated_peak_path.exists()
        compensated_peaks = numpy.asarray(yaml.load(compensated_peak_path)).T[None, ...]
        assert nr_traces == 1
    else:
        assert peak_path.exists()
        peaks = numpy.asarray(yaml.load(peak_path))
        nr_traces = min(nr_traces, peaks.shape[0])

    if compensated_peaks is None and (peaks is None or plot_peaks):
        all_projections = OrderedDict()
        for tensor_name in {tgt, *all_recon_paths.keys()}:
            min_path = output_path / f"{tensor_name}_min.tif"
            max_path = output_path / f"{tensor_name}_max.tif"
            mean_path = output_path / f"{tensor_name}_mean.tif"
            std_path = output_path / f"{tensor_name}_std.tif"

            if (
                min_path.exists()
                and max_path.exists()
                and mean_path.exists()
                and std_path.exists()
                and not overwrite_existing_files
            ):
                min_tensor = imread(min_path)
                max_tensor = imread(max_path)
                mean_tensor = imread(mean_path)
                std_tensor = imread(std_path)
            else:
                min_tensor, max_tensor, mean_tensor, std_tensor = get_min_max_mean_std(datasets_to_trace[tensor_name])
                imwrite(min_path, min_tensor)
                imwrite(max_path, max_tensor)
                imwrite(mean_path, mean_tensor)
                imwrite(std_path, std_tensor)

            all_projections[tensor_name] = {
                "min": min_tensor,
                "max": max_tensor,
                "mean": mean_tensor,
                "std": std_tensor,
            }

        all_projections.move_to_end(tgt, last=False)
        # diff_tensor = gaussian_filter(max_tensor, sigma=1.3, mode="constant") - gaussian_filter(min_tensor, sigma=1.3, mode="constant")
        # blobs = blob_dog(
        #     diff_tensor, min_sigma=1, max_sigma=16, sigma_ratio=1.6, threshold=.3, overlap=0.5, exclude_border=True
        # )
        # peaks = blob_dog(
        #     diff_tensor, min_sigma=1.0, max_sigma=5, sigma_ratio=1.1, threshold=.1, overlap=0.5, exclude_border=False
        # )
        # smooth_diff_tensor = diff_tensor
        # smooth_diff_tensor = gaussian_filter(diff_tensor, sigma=1.3, mode="constant")

        for tensor_name, projections in all_projections.items():
            diff_tensor = projections["max"] - projections["min"]
            plot_peaks_on = {
                "diff tensor": diff_tensor,
                "min tensor": projections["min"],
                "max tensor": projections["max"],
                "std tensor": projections["std"],
                "mean tensor": projections["mean"],
            }

            def get_peaks_on_tensor(comp_on: str):
                if comp_on in ["min", "max", "std", "mean"]:
                    peaks_on_tensor = projections[comp_on]
                elif comp_on == "diff":
                    return diff_tensor
                elif comp_on == "smooth_diff":
                    peaks_on_tensor = gaussian_filter(diff_tensor, sigma=smooth_diff_sigma, mode="constant")
                    plot_peaks_on["smooth diff tensor"] = peaks_on_tensor
                elif "+" in comp_on:
                    comp_on_parts = comp_on.split("+")
                    peaks_on_tensor = get_peaks_on_tensor(comp_on_parts[0])
                    for part in comp_on_parts[1:]:
                        peaks_on_tensor = numpy.add(peaks_on_tensor, get_peaks_on_tensor(part))
                elif "*" in comp_on:
                    factor, comp_on = comp_on.split("*")
                    factor = float(factor)
                    peaks_on_tensor = factor * get_peaks_on_tensor(comp_on)
                else:
                    raise NotImplementedError(compute_peaks_on)

                return peaks_on_tensor

            peaks_on_tensor = get_peaks_on_tensor(compute_peaks_on)
            background_mask = (
                None
                if background_threshold is None
                else (all_projections[tgt]["min"] > background_threshold).astype(numpy.int)
            )
            imwrite(output_path / "background_mask.tif", background_mask.astype("uint8") * 255)
            if tensor_name == tgt:
                peaks = peak_local_max(
                    peaks_on_tensor,
                    min_distance=peaks_min_dist,
                    threshold_abs=peak_threshold_abs,
                    exclude_border=True,
                    num_peaks=nr_traces,
                    labels=background_mask,
                )
                peaks = numpy.concatenate([peaks, numpy.full((peaks.shape[0], 1), trace_radius)], axis=1)

            # plot peak positions on different projections
            fig, axes = plt.subplots(nrows=math.ceil(len(plot_peaks_on) / 3), ncols=3, squeeze=False, figsize=(15, 10))
            plt.suptitle(tensor_name)
            [ax.set_axis_off() for ax in axes.flatten()]
            for ax, (name, tensor) in zip(axes.flatten(), plot_peaks_on.items()):
                title = f"peaks on {name}"
                ax.set_title(title)
                im = ax.imshow(tensor)
                fig.colorbar(im, ax=ax)
                for i, peak in enumerate(peaks):
                    y, x, r = peak
                    c = plt.Circle((x, y), r, color="r", linewidth=1, fill=False)
                    ax.text(x + 2 * int(r + 0.5), y, str(i))  # todo fix text
                    ax.add_patch(c)

            try:
                plt.tight_layout()
            except Exception as e:
                warnings.warn(e)

            fig_name = f"trace_positions_on_{tensor_name}"
            plt.savefig(output_path / f"{fig_name}.svg")
            plt.savefig(output_path / f"{fig_name}.png")
            figs[fig_name] = fig
            if SHOW_FIGS:
                plt.show()
            else:
                plt.close()

            yaml.dump(peaks.tolist(), peak_path)

    if compensated_peaks is not None:
        peaks = None
        all_compensated_peaks = {tgt: compensated_peaks}
        for name in all_recon_paths.keys():
            all_compensated_peaks[name] = compensated_peaks
    elif compensate_motion_of_peaks:
        only_on_tgt = compensate_motion.pop("only_on_tgt", False)
        print(f"get_motion_compensated_peaks for {tgt}")
        all_compensated_peaks = {
            tgt: get_motion_compensated_peaks(
                tensor=datasets_to_trace[tgt], peaks=peaks, output_path=output_path, **compensate_motion, name=tgt
            )
        }
        for name in {*all_recon_paths.keys()}:
            if only_on_tgt:
                all_compensated_peaks[name] = all_compensated_peaks[tgt]
            else:
                print(f"get_motion_compensated_peaks for {name}")
                all_compensated_peaks[name] = get_motion_compensated_peaks(
                    tensor=datasets_to_trace[name], peaks=peaks, output_path=output_path, **compensate_motion, name=name
                )
    else:
        all_compensated_peaks = None

    all_traces = {}
    if all_compensated_peaks is None:
        for name in {tgt, *all_recon_paths.keys()}:
            traces_path: Path = output_path / f"{name}_traces.npy"
            if traces_path.exists() and not overwrite_existing_files:
                all_traces[name] = numpy.load(str(traces_path))
            else:
                all_traces[name] = trace_straight_peaks(datasets_to_trace[name], peaks, reduce_peak_area, output_path)
                numpy.save(str(traces_path), all_traces[name])
    else:
        for name, compensated_peaks in all_compensated_peaks.items():
            # traces_path: Path = output_path / f"{name}_traces.npy"
            # if traces_path.exists() and not overwrite_existing_files:
            #     all_traces[name] = numpy.load(str(traces_path))
            # else:
            all_traces[name] = trace_tracked_peaks(
                datasets_to_trace[name],
                compensated_peaks,
                reduce_peak_area,
                output_path,
                name=name,
                n_radii=1
                if compensate_motion_of_peaks is None or compensate_motion is None
                else compensate_motion["n_radii"],
            )
            # numpy.save(str(traces_path), all_traces[name])

    all_smooth_traces = {}
    correlations = {}
    trace_scaling = {}
    for recons in plots:
        for recon, kwargs in recons.items():
            for smooth, tgt_smooth in kwargs["smooth"]:
                if (recon, smooth, tgt_smooth, 0) not in correlations:
                    traces, tgt_traces = get_smooth_traces_pair(
                        recon, smooth, tgt, tgt_smooth, all_traces, all_smooth_traces
                    )
                    for t, (trace, tgt_trace) in enumerate(zip(traces, tgt_traces)):
                        try:
                            pr, _ = pearsonr(trace, tgt_trace)
                        except ValueError as e:
                            logger.error(e)
                            pr = numpy.nan

                        try:
                            sr, _ = spearmanr(trace, tgt_trace)
                        except ValueError as e:
                            logger.error(e)
                            sr = numpy.nan

                        correlations[recon, smooth, tgt_smooth, t] = {"pearson": pr, "spearman": sr}
                        trace_scaling[recon, smooth, tgt_smooth, t] = get_trace_scaling(trace, tgt_trace)

    best_correlations = {}
    for (recon, smooth, tgt_smooth, t), corrs in correlations.items():
        best_correlations[(smooth, tgt_smooth)] = {}
        for metric, value in corrs.items():
            if metric not in best_correlations[(smooth, tgt_smooth)]:
                best_correlations[(smooth, tgt_smooth)][metric] = {}

            best = best_correlations[(smooth, tgt_smooth)][metric].get(recon, [-9999, None])[0]
            if value > best:
                best_correlations[(smooth, tgt_smooth)][metric][recon] = [float(value), t]

    print("best correlations:")
    pprint(best_correlations)
    yaml.dump(best_correlations, output_path / "best_correlations.yml")

    trace_plots_output_path = output_path / "trace_plots"
    trace_plots_output_path.mkdir(exist_ok=True)
    trace_figs = plot_traces(
        tgt=tgt,
        plots=plots,
        all_traces=all_traces,
        all_smooth_traces=all_smooth_traces,
        correlations=correlations,
        trace_scaling=trace_scaling,
        output_path=trace_plots_output_path,
        tag=tag,
    )
    figs.update(trace_figs)
    return peaks, all_traces, correlations, figs, motion
Esempio n. 7
0
    meta = {
        "z_out": 49,
        "nnum": 19,
        "interpolation_order": 2,
        "scale": 4,
        "z_ls_rescaled": 241,
        "pred_z_min": 0,
        "pred_z_max": 838,
    }
    datasets = OrderedDict()
    datasets["pred"] = get_dataset_from_info(
        TensorInfo(
            name="pred",
            root=Path("/scratch/beuttenm/lnet/care/results"),
            location=f"{subpath}/{model_name}/*.tif",
            insert_singleton_axes_at=[0, 0],
            z_slice=None,
            meta={"crop_name": "Heart_tightCrop", **meta},
        ),
        cache=True,
    )
    datasets["ls_slice"] = get_dataset_from_info(
        get_tensor_info("heart_dynamic.2019-12-09_04.54.38", name="ls_slice", meta=meta),
        cache=True,
        filters=[("z_range", {})],
    )

    assert len(datasets["pred"]) == 51 * 241, len(datasets["pred"])
    assert len(datasets["ls_slice"]) == 51 * 209, len(datasets["ls_slice"])
    # ipt_paths = {
    #     "pred": ,