def __init__( self, hdf_path, batch_size=32, shuffle=False, augmenter_ref=None, augmenter_mov=None, return_inverse=False, mlflow_log=False, indexes=None, ): self._locals = locals() del self._locals["self"] self._locals[ "augmenter_mov"] = None if augmenter_mov is None else "active" self._locals[ "augmenter_ref"] = None if augmenter_ref is None else "active" if mlflow_log: mlflow.log_params(self._locals) self.hdf_path = str(hdf_path) self.batch_size = batch_size self.shuffle = shuffle self.augmenter_ref = augmenter_ref self.augmenter_mov = augmenter_mov self.return_inverse = return_inverse with h5py.File(self.hdf_path, "r") as f: length = len(f["p"]) if indexes is None: self.indexes = list(np.arange(length)) elif isinstance(indexes, list): self.indexes = indexes elif isinstance(indexes, (str, pathlib.Path)): self.indexes = list(np.load(str(indexes))) else: raise TypeError("Invalid indexes type {}".format(type(indexes))) self.volume = nissl_volume() self.times = [] self.temp = [] self.on_epoch_end()
def __init__(self, sn, mov_imgs, dvfs): # initial checks if not len(sn) == len(mov_imgs) == len(dvfs): raise ValueError("All the input lists need to have the same length") if len(set(sn)) != len(sn): raise ValueError("There are duplicate section numbers.") if not all([0 <= x < 528 for x in sn]): raise ValueError("All section numbers must lie in [0, 528).") self.sn = sn self.mov_imgs = mov_imgs self.dvfs = dvfs # attributes self.ref_imgs = nissl_volume()[self.sn, ..., 0] self.sn_to_ix = [ None if x not in self.sn else self.sn.index(x) for x in range(528) ] self.reg_imgs = self._warp()
def test_load_works(self, monkeypatch): """Test that loading works.""" # Lets do some patching monkeypatch.setattr( "numpy.load", lambda *args, **kwargs: np.zeros( (528, 320, 456), dtype=np.float32), ) monkeypatch.setattr( "atlalign.data.img_as_float32", lambda *args, **kwargs: np.zeros((320, 456), dtype=np.float32), ) x_atlas = nissl_volume() # Final output assert x_atlas.shape == (528, 320, 456, 1) assert np.all(np.isfinite(x_atlas)) assert x_atlas.min() >= 0 assert x_atlas.max() <= 1 assert x_atlas.dtype == np.float32
"--path", help="Path to saving folder", type=str, default="{}/.atlalign/label/".format(str(Path.home())), ) parser.add_argument("-t", "--title", help="Title", type=str, default="") args = parser.parse_args() fixed = args.fixed moving = args.moving path = args.path title = args.title # Get fixed image img_ref = equalize_hist(nissl_volume()[fixed, ..., 0]).astype(np.float32) # Get moving image print(moving) img_mov_ = cv2.imread(moving, 0) img_mov = img_as_float32(cv2.imread(moving, 0)) # Check folder exists path = path if path[-1] == "/" else path + "/" dir_exists = os.path.isdir(path) if dir_exists: warnings.warn( "Directory {} already exists, potentially rewriting data.".format( path)) else:
def main(argv=None): """Run CLI.""" parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( "ref", type=str, help="Either a path to a reference image or a number from [0, 528) " "representing the coronal dimension in the nissl stain volume.", ) parser.add_argument( "mov", type=str, help="Path to a moving image. Needs to be of the same shape as " "reference.", ) parser.add_argument("output_path", type=str, help="Folder where the outputs will be stored.") parser.add_argument( "-s", "--swap", default=False, help="Swap to the moving to reference mode.", action="store_true", ) args = parser.parse_args(argv) ref = args.ref mov = args.mov output_path = args.output_path swap = args.swap # Imports from atlalign.data import nissl_volume from atlalign.label import load_image, run_GUI output_path = pathlib.Path(output_path) output_path.mkdir(exist_ok=True, parents=True) if ref.isdigit(): img_ref = nissl_volume()[int(ref), ..., 0] else: img_ref_path = pathlib.Path(ref) img_ref = load_image(img_ref_path) img_mov_path = pathlib.Path(mov) img_mov = load_image(img_mov_path, output_channels=1, keep_last=False, output_dtype="float32") result_df = run_GUI(img_ref, img_mov, mode="mov2ref" if swap else "ref2mov")[0] img_reg = result_df.warp(img_mov) result_df.save(output_path / "df.npy") plt.imsave(str(output_path / "registered.png"), img_reg)