示例#1
0
def beads_dataset(meta) -> N5CachedDatasetFromInfo:
    datasets = OrderedDict()
    for name in ["lf", "ls_reg"]:
        info = get_tensor_info("beads.small1", name, meta=meta)
        datasets[name] = get_dataset_from_info(info=info, cache=True)

    return ZipDataset(datasets)
示例#2
0
def manual_test_this():
    meta = {
        "nnum": 19,
        "z_out": 49,
        "scale": 4,
        "shrink": 8,
        "interpolation_order": 2,
        "z_ls_rescaled": 241,
        "pred_z_min": 0,
        "pred_z_max": 838,
        "crop_names": ["wholeFOV"],
    }  # z_min full: 0, z_max full: 838; 60/209*838=241; 838-10/209*838=798
    ls_info = get_tensor_info("heart_static.beads_ref_wholeFOV",
                              "ls",
                              meta=meta)
    ls_trf_info = get_tensor_info("heart_static.beads_ref_wholeFOV",
                                  "ls_trf",
                                  meta=meta)
    ls_reg_info = get_tensor_info("heart_static.beads_ref_wholeFOV",
                                  "ls_reg",
                                  meta=meta)
    dataset = ZipDataset(
        collections.OrderedDict([
            ("ls", get_dataset_from_info(info=ls_info, cache=True)),
            ("ls_trf", get_dataset_from_info(info=ls_trf_info, cache=True)),
            ("ls_reg", get_dataset_from_info(info=ls_reg_info, cache=True)),
        ]))
    sample = dataset[0]
    compare_slices(
        {
            "ls_reg": sample["ls_reg"].max(2),
            "ls_trf": sample["ls_trf"].max(2)
        }, "lala", "ls_reg", "ls_trf")
示例#3
0
    def get_dataset(self):
        if self.config.dataset == DatasetChoice.from_path:
            assert self.dataset_part == DatasetPart.test

            tensor_infos = {
                self.config.pred_name:
                TensorInfo(
                    name=self.config.pred_name,
                    root=self.config.path,
                    location=self.config.pred_glob,
                    transforms=self.transforms_pipeline.sample_precache_trf,
                    datasets_per_file=1,  # todo: remove hard coded
                    samples_per_dataset=1,
                    remove_singleton_axes_at=(-1, ),
                    insert_singleton_axes_at=(0, 0),  # todo: remove hard coded
                    z_slice=None,
                    skip_indices=tuple(),
                    meta=None,
                ),
                self.config.trgt_name:
                TensorInfo(
                    name=self.config.trgt_name,
                    root=self.config.path,
                    location=self.config.trgt_glob,
                    transforms=self.transforms_pipeline.sample_precache_trf,
                    datasets_per_file=1,
                    samples_per_dataset=1,
                    remove_singleton_axes_at=(-1, ),  # todo: remove hard coded
                    insert_singleton_axes_at=(0, 0),  # todo: remove hard coded
                    z_slice=None,
                    skip_indices=tuple(),
                    meta=None,
                ),
            }
            dtst = ZipDataset({
                name: get_dataset_from_info(ti,
                                            cache=True,
                                            filters=[],
                                            indices=None)
                for name, ti in tensor_infos.items()
            })
            return ConcatDataset(
                [dtst],
                transform=self.transforms_pipeline.sample_preprocessing)

        else:
            return get_dataset(
                self.config.dataset,
                self.dataset_part,
                nnum=19,
                z_out=49,
                scale=self.scale,
                shrink=self.shrink,
                interpolation_order=self.config.interpolation_order,
                incl_pred_vol="pred_vol" in self.save_output_to_disk,
                load_lfd_and_care=self.load_lfd_and_care,
            )
示例#4
0
def try_dynamic():
    import matplotlib.pyplot as plt

    from torch.utils.data import DataLoader
    from hylfm.datasets import ZipDataset
    from hylfm.datasets import N5CachedDatasetFromInfo
    from hylfm.transformations import Normalize01

    m = A04(input_name="lf", prediction_name="pred", z_out=49, nnum=19)
    # n_res2d: [976, 488, u, 244, 244, u, 122, 122]
    # inplanes_3d: 7
    # n_res3d: [[7, 7], [7], [1]]

    # lfds = N5CachedDataset(get_dataset_from_info(ref0_lf))
    # lsds = N5CachedDataset(get_dataset_from_info(ref0_ls))
    normalize = Normalize01(apply_to=["lf", "ls"],
                            min_percentile=0,
                            max_percentile=100)
    ds = ZipDataset({"lf": lfds, "ls": lsds}, transformation=normalize)
    loader = DataLoader(ds, batch_size=1)
    device = torch.device("cuda")
    m = m.to(device)

    sample = next(iter(loader))
    ipt = sample["lf"]
    tgt = sample["ls"]

    z_slice = sample["meta"]["z_slice"]
    ipt, tgt = ipt.to(device), tgt.to(device)
    print("get_scaling", m.get_scaling(ipt.shape[2:]))
    print("get_shrinkage", m.get_shrinkage(ipt.shape[2:]))
    print("get_output_shape()", m.get_output_shape(ipt.shape[2:]))
    print("ipt", ipt.shape, "tgt", tgt.shape, z_slice)
    plt.imshow(tgt[0, 0].detach().cpu().numpy())
    plt.title("tgt")
    plt.show()
    # print("scale", m.get_scaling(ipt.shape[2:]))
    # print("out", m.get_output_shape(ipt.shape[2:]))
    # print("shrink", m.get_shrinkage(ipt.shape[2:]))
    # print("len 3d", len(m.res3d))

    out_sample = m(sample)
    out = out_sample["out"]
    print("out", out.shape)
    plt.imshow(out[0, 0].detach().cpu().numpy())
    plt.title(f"out {z_slice}")
    plt.show()

    print("done")
示例#5
0
 def get_individual_dataset(self, dss: DatasetSetup) -> torch.utils.data.Dataset:
     return ZipDataset(
         OrderedDict(
             [
                 (
                     name,
                     get_dataset_from_info(
                         dsinfo, cache=True, indices=dss.indices, filters=dss.filters + self.filters
                     ),
                 )
                 for name, dsinfo in dss.infos.items()
             ]
         ),
         transformation=self.sample_preprocessing,
     )
示例#6
0
        sample = PoissonNoise(apply_to={
            "ls_slice": "ls_slice_trf",
            "lf": "lf_trf"
        },
                              peak=p,
                              seed=0)(sample)
        compare_slices(sample, f"{add_to_tag}_{p}", "ls_slice", "ls_slice_trf")
        compare_slices(sample, f"{add_to_tag}_{p}", "lf", "lf_trf")


if __name__ == "__main__":
    from hylfm.datasets import ZipDataset, get_dataset_from_info, get_tensor_info

    meta = {"nnum": 19, "z_out": 49, "interpolation_order": 2, "scale": 2}
    for tag in [
            "brain.11_1__2020-03-11_03.22.33__SinglePlane_-330",
            "brain.11_2__2020-03-11_07.30.39__SinglePlane_-320",
            "brain.09_3__2020-03-09_06.43.40__SinglePlane_-330",
    ]:
        ls_slice_info = get_tensor_info(tag, "ls_slice", meta=meta)
        lf_info = get_tensor_info(tag, "lf", meta=meta)
        ls_slice_dataset = ZipDataset(
            collections.OrderedDict([
                ("ls_slice",
                 get_dataset_from_info(info=ls_slice_info, cache=True)),
                ("lf", get_dataset_from_info(info=lf_info, cache=True)),
            ]))

        manual_test_poisson(ls_slice_dataset, tag)
示例#7
0
def try_static(backprop: bool = True):
    from torch.utils.data import DataLoader
    import matplotlib.pyplot as plt

    from hylfm.datasets.beads import b4mu_3_lf, b4mu_3_ls
    from hylfm.datasets import get_dataset_from_info, ZipDataset, N5CachedDatasetFromInfo, get_collate_fn
    from hylfm.transformations import Normalize01, ComposedTransformation, ChannelFromLightField, Cast, Crop

    # m = A04(input_name="lf", prediction_name="pred", z_out=51, nnum=19, n_res2d=(488, 488, "u", 244, 244))
    m = A04(
        input_name="lf",
        prediction_name="pred",
        z_out=51,
        nnum=19,
        # n_res2d=(488, 488, "u", 244, 244),
        # n_res3d=[[7], [7], [7]],
    )
    # n_res2d: [976, 488, u, 244, 244, u, 122, 122]
    # inplanes_3d: 7
    # n_res3d: [[7, 7], [7], [1]]

    b4mu_3_ls.transformations += [
        {
            "Resize": {
                "apply_to": "ls",
                "shape": [1.0, 121 / 838, 8 / 19, 8 / 19],
                "order": 2
            }
        },
        {
            "Assert": {
                "apply_to": "ls",
                "expected_tensor_shape": [None, 1, 121, None, None]
            }
        },
    ]

    lfds = N5CachedDatasetFromInfoSubset(
        N5CachedDatasetFromInfo(
            get_dataset_from_info(
                b4mu_3_lf
                # TensorInfo(
                #     name="lf",
                #     root="GHUFNAGELLFLenseLeNet_Microscope",
                #     location="20191031_Beads_MixedSizes/Beads_01micron_highConcentration/2019-10-31_04.57.13/stack_0_channel_0/TP_*/RC_rectified/Cam_Right_1_rectified.tif",
                #     insert_singleton_axes_at=[0, 0],
                # )
            )))
    lsds = N5CachedDatasetFromInfoSubset(
        N5CachedDatasetFromInfo(
            get_dataset_from_info(
                b4mu_3_ls
                # TensorInfo(
                #     name="ls",
                #     root="GHUFNAGELLFLenseLeNet_Microscope",
                #     location="20191031_Beads_MixedSizes/Beads_01micron_highConcentration/2019-10-31_04.57.13/stack_1_channel_1/TP_*/LC/Cam_Left_registered.tif",
                #     insert_singleton_axes_at=[0, 0],
                #     transformations=[
                #         {
                #             "Resize": {
                #                 "apply_to": "ls",
                #                 "shape": [
                #                     1.0,
                #                     121,
                #                     0.21052631578947368421052631578947,
                #                     0.21052631578947368421052631578947,
                #                 ],
                #                 "order": 2,
                #             }
                #         }
                #     ],
                # )
            )))
    trf = ComposedTransformation(
        Crop(apply_to="ls", crop=((0, None), (35, -35), (8, -8), (8, -8))),
        Normalize01(apply_to=["lf", "ls"],
                    min_percentile=0,
                    max_percentile=100),
        ChannelFromLightField(apply_to="lf", nnum=19),
        Cast(apply_to=["lf", "ls"], dtype="float32", device="cuda"),
    )
    ds = ZipDataset(OrderedDict(lf=lfds, ls=lsds), transformation=trf)
    loader = DataLoader(ds,
                        batch_size=1,
                        collate_fn=get_collate_fn(lambda t: t))

    device = torch.device("cuda")
    m = m.to(device)

    # state = torch.load(checkpoint, map_location=device)
    # m.load_state_dict(state, strict=False)

    sample = next(iter(loader))
    ipt = sample["lf"]
    tgt = sample["ls"]
    # ipt = torch.rand(1, nnum ** 2, 5, 5)
    print("get_scaling", m.get_scaling(ipt.shape[2:]))
    print("get_shrinkage", m.get_shrinkage(ipt.shape[2:]))
    print("get_output_shape()", m.get_output_shape(ipt.shape[2:]))
    print("ipt", ipt.shape, "tgt", tgt.shape)
    out_sample = m(sample)
    out = out_sample["pred"]
    if backprop:
        loss_fn = torch.nn.MSELoss()
        loss = loss_fn(out, tgt)
        loss.backward()
        adam = torch.optim.Adam(m.parameters())
        adam.step()

    tgt_show = tgt[0, 0].detach().cpu().numpy()
    plt.imshow(tgt_show.max(axis=0))
    plt.title("tgt")
    plt.show()
    plt.imshow(tgt_show.max(axis=1))
    plt.title("tgt")
    plt.show()
    plt.imshow(tgt_show.max(axis=2))
    plt.title("tgt")
    plt.show()

    print("pred", out.shape)
    plt.imshow(out[0, 0].detach().cpu().numpy().max(axis=0))
    plt.title("pred")
    plt.show()
    plt.imshow(out[0, 0].detach().cpu().numpy().max(axis=1))
    plt.title("pred")
    plt.show()
    plt.imshow(out[0, 0].detach().cpu().numpy().max(axis=2))
    plt.title("pred")
    plt.show()

    print("done")
示例#8
0
    )
    datasets["ls_slice"] = get_dataset_from_info(
        get_tensor_info("heart_dynamic.2019-12-09_04.54.38", name="ls_slice", meta=meta),
        cache=True,
        filters=[("z_range", {})],
    )

    assert len(datasets["pred"]) == 51 * 241, len(datasets["pred"])
    assert len(datasets["ls_slice"]) == 51 * 209, len(datasets["ls_slice"])
    # ipt_paths = {
    #     "pred": ,
    #     "ls_slice": Path(
    #         "/g/kreshuk/LF_computed/lnet/logs/heart2/test_z_out49/lr_f4/heart_dynamic.2019-12-09_04.54.38/run000/ds0-0/ls_slice"
    #     ),
    # }
    ds = ZipDataset(datasets)
    assert len(ds) == 51 * 209, len(ds)

    out_paths = {
        "pred_slice": Path("/g/kreshuk/LF_computed/lnet/care/results") / subpath / model_name,
        # "pred_vol": Path("/g/kreshuk/LF_computed/lnet/care/results") / subpath / model_name / "vol",
    }
    for name, p in out_paths.items():
        p.mkdir(parents=True, exist_ok=True)
        print(name, p)

    trf = get_composed_transformation_from_config(yaml.load(args.trf_config_path))

    def do_work(i):
        out_file_paths = {name: p / f"{i:05}.tif" for name, p in out_paths.items()}
        # if all(p.exists() for p in out_file_paths.values()):