def manual_test_this(): meta = { "nnum": 19, "z_out": 49, "scale": 4, "shrink": 8, "interpolation_order": 2, "z_ls_rescaled": 241, "pred_z_min": 0, "pred_z_max": 838, "crop_names": ["wholeFOV"], } # z_min full: 0, z_max full: 838; 60/209*838=241; 838-10/209*838=798 ls_info = get_tensor_info("heart_static.beads_ref_wholeFOV", "ls", meta=meta) ls_trf_info = get_tensor_info("heart_static.beads_ref_wholeFOV", "ls_trf", meta=meta) ls_reg_info = get_tensor_info("heart_static.beads_ref_wholeFOV", "ls_reg", meta=meta) dataset = ZipDataset( collections.OrderedDict([ ("ls", get_dataset_from_info(info=ls_info, cache=True)), ("ls_trf", get_dataset_from_info(info=ls_trf_info, cache=True)), ("ls_reg", get_dataset_from_info(info=ls_reg_info, cache=True)), ])) sample = dataset[0] compare_slices( { "ls_reg": sample["ls_reg"].max(2), "ls_trf": sample["ls_trf"].max(2) }, "lala", "ls_reg", "ls_trf")
def beads_dataset(meta) -> N5CachedDatasetFromInfo: datasets = OrderedDict() for name in ["lf", "ls_reg"]: info = get_tensor_info("beads.small1", name, meta=meta) datasets[name] = get_dataset_from_info(info=info, cache=True) return ZipDataset(datasets)
def get_dataset(self): if self.config.dataset == DatasetChoice.from_path: assert self.dataset_part == DatasetPart.test tensor_infos = { self.config.pred_name: TensorInfo( name=self.config.pred_name, root=self.config.path, location=self.config.pred_glob, transforms=self.transforms_pipeline.sample_precache_trf, datasets_per_file=1, # todo: remove hard coded samples_per_dataset=1, remove_singleton_axes_at=(-1, ), insert_singleton_axes_at=(0, 0), # todo: remove hard coded z_slice=None, skip_indices=tuple(), meta=None, ), self.config.trgt_name: TensorInfo( name=self.config.trgt_name, root=self.config.path, location=self.config.trgt_glob, transforms=self.transforms_pipeline.sample_precache_trf, datasets_per_file=1, samples_per_dataset=1, remove_singleton_axes_at=(-1, ), # todo: remove hard coded insert_singleton_axes_at=(0, 0), # todo: remove hard coded z_slice=None, skip_indices=tuple(), meta=None, ), } dtst = ZipDataset({ name: get_dataset_from_info(ti, cache=True, filters=[], indices=None) for name, ti in tensor_infos.items() }) return ConcatDataset( [dtst], transform=self.transforms_pipeline.sample_preprocessing) else: return get_dataset( self.config.dataset, self.dataset_part, nnum=19, z_out=49, scale=self.scale, shrink=self.shrink, interpolation_order=self.config.interpolation_order, incl_pred_vol="pred_vol" in self.save_output_to_disk, load_lfd_and_care=self.load_lfd_and_care, )
def get_individual_dataset(self, dss: DatasetSetup) -> torch.utils.data.Dataset: return ZipDataset( OrderedDict( [ ( name, get_dataset_from_info( dsinfo, cache=True, indices=dss.indices, filters=dss.filters + self.filters ), ) for name, dsinfo in dss.infos.items() ] ), transformation=self.sample_preprocessing, )
def get_dataset(self): assert self.config.dataset == DatasetChoice.predict_path assert self.dataset_part == DatasetPart.predict tensor_info = TensorInfo( name="lf", root=self.config.path, location=self.config.glob_lf, transforms=self.transforms_pipeline.sample_precache_trf, datasets_per_file=1, samples_per_dataset=1, remove_singleton_axes_at=tuple(), # (-1,), insert_singleton_axes_at=(0, 0), z_slice=None, skip_indices=tuple(), meta=None, ) dtst = get_dataset_from_info(tensor_info, cache=True, filters=[], indices=None) return ConcatDataset( [dtst], transform=self.transforms_pipeline.sample_preprocessing)
sample = PoissonNoise(apply_to={ "ls_slice": "ls_slice_trf", "lf": "lf_trf" }, peak=p, seed=0)(sample) compare_slices(sample, f"{add_to_tag}_{p}", "ls_slice", "ls_slice_trf") compare_slices(sample, f"{add_to_tag}_{p}", "lf", "lf_trf") if __name__ == "__main__": from hylfm.datasets import ZipDataset, get_dataset_from_info, get_tensor_info meta = {"nnum": 19, "z_out": 49, "interpolation_order": 2, "scale": 2} for tag in [ "brain.11_1__2020-03-11_03.22.33__SinglePlane_-330", "brain.11_2__2020-03-11_07.30.39__SinglePlane_-320", "brain.09_3__2020-03-09_06.43.40__SinglePlane_-330", ]: ls_slice_info = get_tensor_info(tag, "ls_slice", meta=meta) lf_info = get_tensor_info(tag, "lf", meta=meta) ls_slice_dataset = ZipDataset( collections.OrderedDict([ ("ls_slice", get_dataset_from_info(info=ls_slice_info, cache=True)), ("lf", get_dataset_from_info(info=lf_info, cache=True)), ])) manual_test_poisson(ls_slice_dataset, tag)
def try_static(backprop: bool = True): from torch.utils.data import DataLoader import matplotlib.pyplot as plt from hylfm.datasets.beads import b4mu_3_lf, b4mu_3_ls from hylfm.datasets import get_dataset_from_info, ZipDataset, N5CachedDatasetFromInfo, get_collate_fn from hylfm.transformations import Normalize01, ComposedTransformation, ChannelFromLightField, Cast, Crop # m = A04(input_name="lf", prediction_name="pred", z_out=51, nnum=19, n_res2d=(488, 488, "u", 244, 244)) m = A04( input_name="lf", prediction_name="pred", z_out=51, nnum=19, # n_res2d=(488, 488, "u", 244, 244), # n_res3d=[[7], [7], [7]], ) # n_res2d: [976, 488, u, 244, 244, u, 122, 122] # inplanes_3d: 7 # n_res3d: [[7, 7], [7], [1]] b4mu_3_ls.transformations += [ { "Resize": { "apply_to": "ls", "shape": [1.0, 121 / 838, 8 / 19, 8 / 19], "order": 2 } }, { "Assert": { "apply_to": "ls", "expected_tensor_shape": [None, 1, 121, None, None] } }, ] lfds = N5CachedDatasetFromInfoSubset( N5CachedDatasetFromInfo( get_dataset_from_info( b4mu_3_lf # TensorInfo( # name="lf", # root="GHUFNAGELLFLenseLeNet_Microscope", # location="20191031_Beads_MixedSizes/Beads_01micron_highConcentration/2019-10-31_04.57.13/stack_0_channel_0/TP_*/RC_rectified/Cam_Right_1_rectified.tif", # insert_singleton_axes_at=[0, 0], # ) ))) lsds = N5CachedDatasetFromInfoSubset( N5CachedDatasetFromInfo( get_dataset_from_info( b4mu_3_ls # TensorInfo( # name="ls", # root="GHUFNAGELLFLenseLeNet_Microscope", # location="20191031_Beads_MixedSizes/Beads_01micron_highConcentration/2019-10-31_04.57.13/stack_1_channel_1/TP_*/LC/Cam_Left_registered.tif", # insert_singleton_axes_at=[0, 0], # transformations=[ # { # "Resize": { # "apply_to": "ls", # "shape": [ # 1.0, # 121, # 0.21052631578947368421052631578947, # 0.21052631578947368421052631578947, # ], # "order": 2, # } # } # ], # ) ))) trf = ComposedTransformation( Crop(apply_to="ls", crop=((0, None), (35, -35), (8, -8), (8, -8))), Normalize01(apply_to=["lf", "ls"], min_percentile=0, max_percentile=100), ChannelFromLightField(apply_to="lf", nnum=19), Cast(apply_to=["lf", "ls"], dtype="float32", device="cuda"), ) ds = ZipDataset(OrderedDict(lf=lfds, ls=lsds), transformation=trf) loader = DataLoader(ds, batch_size=1, collate_fn=get_collate_fn(lambda t: t)) device = torch.device("cuda") m = m.to(device) # state = torch.load(checkpoint, map_location=device) # m.load_state_dict(state, strict=False) sample = next(iter(loader)) ipt = sample["lf"] tgt = sample["ls"] # ipt = torch.rand(1, nnum ** 2, 5, 5) print("get_scaling", m.get_scaling(ipt.shape[2:])) print("get_shrinkage", m.get_shrinkage(ipt.shape[2:])) print("get_output_shape()", m.get_output_shape(ipt.shape[2:])) print("ipt", ipt.shape, "tgt", tgt.shape) out_sample = m(sample) out = out_sample["pred"] if backprop: loss_fn = torch.nn.MSELoss() loss = loss_fn(out, tgt) loss.backward() adam = torch.optim.Adam(m.parameters()) adam.step() tgt_show = tgt[0, 0].detach().cpu().numpy() plt.imshow(tgt_show.max(axis=0)) plt.title("tgt") plt.show() plt.imshow(tgt_show.max(axis=1)) plt.title("tgt") plt.show() plt.imshow(tgt_show.max(axis=2)) plt.title("tgt") plt.show() print("pred", out.shape) plt.imshow(out[0, 0].detach().cpu().numpy().max(axis=0)) plt.title("pred") plt.show() plt.imshow(out[0, 0].detach().cpu().numpy().max(axis=1)) plt.title("pred") plt.show() plt.imshow(out[0, 0].detach().cpu().numpy().max(axis=2)) plt.title("pred") plt.show() print("done")
meta = { "z_out": 49, "nnum": 19, "interpolation_order": 2, "scale": 4, "z_ls_rescaled": 241, "pred_z_min": 0, "pred_z_max": 838, } datasets = OrderedDict() datasets["pred"] = get_dataset_from_info( TensorInfo( name="pred", root=Path("/scratch/beuttenm/lnet/care/results"), location=f"{subpath}/{model_name}/*.tif", insert_singleton_axes_at=[0, 0], z_slice=None, meta={"crop_name": "Heart_tightCrop", **meta}, ), cache=True, ) datasets["ls_slice"] = get_dataset_from_info( get_tensor_info("heart_dynamic.2019-12-09_04.54.38", name="ls_slice", meta=meta), cache=True, filters=[("z_range", {})], ) assert len(datasets["pred"]) == 51 * 241, len(datasets["pred"]) assert len(datasets["ls_slice"]) == 51 * 209, len(datasets["ls_slice"]) # ipt_paths = { # "pred": ,