Example #1
0
    def __init__(
        self,
        hdf_path,
        batch_size=32,
        shuffle=False,
        augmenter_ref=None,
        augmenter_mov=None,
        return_inverse=False,
        mlflow_log=False,
        indexes=None,
    ):

        self._locals = locals()
        del self._locals["self"]
        self._locals[
            "augmenter_mov"] = None if augmenter_mov is None else "active"
        self._locals[
            "augmenter_ref"] = None if augmenter_ref is None else "active"

        if mlflow_log:
            mlflow.log_params(self._locals)

        self.hdf_path = str(hdf_path)
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.augmenter_ref = augmenter_ref
        self.augmenter_mov = augmenter_mov
        self.return_inverse = return_inverse

        with h5py.File(self.hdf_path, "r") as f:
            length = len(f["p"])

        if indexes is None:
            self.indexes = list(np.arange(length))

        elif isinstance(indexes, list):
            self.indexes = indexes

        elif isinstance(indexes, (str, pathlib.Path)):
            self.indexes = list(np.load(str(indexes)))

        else:
            raise TypeError("Invalid indexes type {}".format(type(indexes)))

        self.volume = nissl_volume()

        self.times = []
        self.temp = []

        self.on_epoch_end()
Example #2
0
    def __init__(self, sn, mov_imgs, dvfs):
        # initial checks
        if not len(sn) == len(mov_imgs) == len(dvfs):
            raise ValueError("All the input lists need to have the same length")

        if len(set(sn)) != len(sn):
            raise ValueError("There are duplicate section numbers.")

        if not all([0 <= x < 528 for x in sn]):
            raise ValueError("All section numbers must lie in [0, 528).")

        self.sn = sn
        self.mov_imgs = mov_imgs
        self.dvfs = dvfs

        # attributes
        self.ref_imgs = nissl_volume()[self.sn, ..., 0]
        self.sn_to_ix = [
            None if x not in self.sn else self.sn.index(x) for x in range(528)
        ]
        self.reg_imgs = self._warp()
Example #3
0
    def test_load_works(self, monkeypatch):
        """Test that loading works."""
        # Lets do some patching
        monkeypatch.setattr(
            "numpy.load",
            lambda *args, **kwargs: np.zeros(
                (528, 320, 456), dtype=np.float32),
        )
        monkeypatch.setattr(
            "atlalign.data.img_as_float32",
            lambda *args, **kwargs: np.zeros((320, 456), dtype=np.float32),
        )

        x_atlas = nissl_volume()

        # Final output
        assert x_atlas.shape == (528, 320, 456, 1)
        assert np.all(np.isfinite(x_atlas))
        assert x_atlas.min() >= 0
        assert x_atlas.max() <= 1
        assert x_atlas.dtype == np.float32
Example #4
0
        "--path",
        help="Path to saving folder",
        type=str,
        default="{}/.atlalign/label/".format(str(Path.home())),
    )
    parser.add_argument("-t", "--title", help="Title", type=str, default="")

    args = parser.parse_args()

    fixed = args.fixed
    moving = args.moving
    path = args.path
    title = args.title

    # Get fixed image
    img_ref = equalize_hist(nissl_volume()[fixed, ..., 0]).astype(np.float32)

    # Get moving image
    print(moving)
    img_mov_ = cv2.imread(moving, 0)
    img_mov = img_as_float32(cv2.imread(moving, 0))

    # Check folder exists
    path = path if path[-1] == "/" else path + "/"
    dir_exists = os.path.isdir(path)

    if dir_exists:
        warnings.warn(
            "Directory {} already exists, potentially rewriting data.".format(
                path))
    else:
Example #5
0
def main(argv=None):
    """Run CLI."""
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        "ref",
        type=str,
        help="Either a path to a reference image or a number from [0, 528) "
        "representing the coronal dimension in the nissl stain volume.",
    )
    parser.add_argument(
        "mov",
        type=str,
        help="Path to a moving image. Needs to be of the same shape as "
        "reference.",
    )
    parser.add_argument("output_path",
                        type=str,
                        help="Folder where the outputs will be stored.")
    parser.add_argument(
        "-s",
        "--swap",
        default=False,
        help="Swap to the moving to reference mode.",
        action="store_true",
    )

    args = parser.parse_args(argv)

    ref = args.ref
    mov = args.mov
    output_path = args.output_path
    swap = args.swap

    # Imports
    from atlalign.data import nissl_volume
    from atlalign.label import load_image, run_GUI

    output_path = pathlib.Path(output_path)
    output_path.mkdir(exist_ok=True, parents=True)

    if ref.isdigit():
        img_ref = nissl_volume()[int(ref), ..., 0]
    else:
        img_ref_path = pathlib.Path(ref)
        img_ref = load_image(img_ref_path)

    img_mov_path = pathlib.Path(mov)
    img_mov = load_image(img_mov_path,
                         output_channels=1,
                         keep_last=False,
                         output_dtype="float32")

    result_df = run_GUI(img_ref,
                        img_mov,
                        mode="mov2ref" if swap else "ref2mov")[0]

    img_reg = result_df.warp(img_mov)

    result_df.save(output_path / "df.npy")
    plt.imsave(str(output_path / "registered.png"), img_reg)