Exemplo n.º 1
0
def test_MergeOmeTiffWriter_mc(simple_transform_affine_nl, tmp_path):
    reg_image1 = np.random.randint(0, 255, (3, 1024, 1024), dtype=np.uint8)
    reg_image2 = np.random.randint(0, 255, (3, 1024, 1024), dtype=np.uint8)
    reg_image3 = np.random.randint(0, 255, (3, 1024, 1024), dtype=np.uint8)

    mreg_image = MergeRegImage(
        [reg_image1, reg_image2, reg_image3],
        [1, 1, 1],
        channel_names=[["1", "2", "3"], ["1", "2", "3"], ["1", "2", "3"]],
    )
    rts = RegTransformSeq(simple_transform_affine_nl)
    merge_ometiffwriter = MergeOmeTiffWriter(
        mreg_image, reg_transform_seqs=[rts, rts, rts]
    )

    by_plane_fp = merge_ometiffwriter.merge_write_image_by_plane(
        "merge_testimage_by_plane",
        ["1", "2", "3"],
        output_dir=str(tmp_path),
    )

    im_plane = imread(by_plane_fp)

    reg_image1_loaded = reg_image_loader(reg_image1, 1)

    ometiffwriter = OmeTiffWriter(reg_image1_loaded, reg_transform_seq=rts)

    by_plane_fp_s1 = ometiffwriter.write_image_by_plane(
        "testimage_by_plane_s1",
        output_dir=str(tmp_path),
    )

    im_plane_s1 = imread(by_plane_fp_s1)

    reg_image2_loaded = reg_image_loader(reg_image2, 1)

    ometiffwriter = OmeTiffWriter(reg_image2_loaded, reg_transform_seq=rts)

    by_plane_fp_s2 = ometiffwriter.write_image_by_plane(
        "testimage_by_plane_s2",
        output_dir=str(tmp_path),
    )

    im_plane_s2 = imread(by_plane_fp_s2)

    reg_image3_loaded = reg_image_loader(reg_image3, 1)

    ometiffwriter = OmeTiffWriter(reg_image3_loaded, reg_transform_seq=rts)

    by_plane_fp_s3 = ometiffwriter.write_image_by_plane(
        "testimage_by_plane_s3",
        output_dir=str(tmp_path),
    )

    im_plane_s3 = imread(by_plane_fp_s3)

    assert im_plane.shape[0] == 9
    assert np.array_equal(im_plane[0:3, :, :], im_plane_s1)
    assert np.array_equal(im_plane[3:6, :, :], im_plane_s2)
    assert np.array_equal(im_plane[6:9, :, :], im_plane_s3)
Exemplo n.º 2
0
def test_OmeTiffWriter_by_tile(complex_transform, tmp_path):
    reg_image = reg_image_loader(np.ones((4096, 4096), dtype=np.uint8), 0.5)
    rts = RegTransformSeq(complex_transform)

    ometiffwriter = OmeTiffTiledWriter(reg_image, reg_transform_seq=rts)

    by_tile_fp = ometiffwriter.write_image_by_tile(
        gen_project_name_str(),
        output_dir=str(tmp_path),
        zarr_temp_dir=tmp_path,
    )
    by_tile_image = reg_image_loader(by_tile_fp, 2)
    assert by_tile_image.shape == (1, 1024, 1024)
Exemplo n.º 3
0
def test_OmeTiffWriter_by_plane(complex_transform, tmp_path):
    reg_image = reg_image_loader(np.ones((1024, 1024), dtype=np.uint8), 1)
    # composite_transform, _, final_transform = prepare_wsireg_transform_data(
    #     complex_transform
    # )
    rts = RegTransformSeq(complex_transform)

    ometiffwriter = OmeTiffWriter(reg_image, reg_transform_seq=rts)
    by_plane_fp = ometiffwriter.write_image_by_plane(
        gen_project_name_str(),
        output_dir=str(tmp_path),
    )
    by_plane_image = reg_image_loader(by_plane_fp, 2)
    assert by_plane_image.shape == (1, 1024, 1024)
Exemplo n.º 4
0
def test_wsireg_run_reg_with_crop_merge(data_out_dir, disk_im_gry):
    wsi_reg = WsiReg2D(gen_project_name_str(), str(data_out_dir))
    img_fp1 = str(disk_im_gry)

    wsi_reg.add_modality(
        "mod1",
        img_fp1,
        0.65,
        channel_names=["test"],
        channel_colors=["red"],
    )

    wsi_reg.add_modality(
        "mod2",
        img_fp1,
        0.65,
        channel_names=["test"],
        channel_colors=["red"],
    )

    wsi_reg.add_modality(
        "mod3",
        img_fp1,
        0.65,
        channel_names=["test"],
        channel_colors=["red"],
        preprocessing={
            "mask_bbox": [512, 512, 512, 512],
            "crop_to_mask_bbox": True,
        },
    )

    wsi_reg.add_reg_path("mod1", "mod3", reg_params=["rigid_test"])
    wsi_reg.add_reg_path("mod2", "mod3", reg_params=["rigid_test"])
    wsi_reg.register_images()
    wsi_reg.add_merge_modalities("merge", ["mod1", "mod2", "mod3"])
    # not cropped
    im_fps = wsi_reg.transform_images(transform_non_reg=True,
                                      to_original_size=True)
    registered_image_nocrop = reg_image_loader(im_fps[0], 1)

    # crop image
    im_fps = wsi_reg.transform_images(transform_non_reg=True,
                                      to_original_size=False)
    wsi_reg.save_transformations()
    registered_image_crop = reg_image_loader(im_fps[0], 1)
    assert registered_image_nocrop.shape[1:] == (2048, 2048)
    assert registered_image_crop.shape[1:] == (512, 512)
Exemplo n.º 5
0
def test_ometiff_read_rgb():
    image_fp = os.path.join(PRIVATE_DIR, "czi_rgb.ome.tiff")
    ri = reg_image_loader(image_fp, 1)
    assert len(ri.im_dims) == 3
    assert ri.im_dtype == np.uint8
    assert ri.is_rgb is True
    assert ri.is_rgb_interleaved is False
Exemplo n.º 6
0
    def __init__(
        self,
        image_fp: List[Union[Path, str]],
        image_res: List[Union[int, float]],
        channel_names: Optional[List[List[str]]] = None,
        channel_colors: Optional[List[List[str]]] = None,
    ):

        if isinstance(image_fp, list) is False:
            raise ValueError(
                "MergeRegImage requires a list of images to merge")

        if isinstance(image_res, list) is False:
            raise ValueError(
                "MergeRegImage requires a list of image resolutions for each image to merge"
            )

        if channel_names is None:
            channel_names = [None for _ in range(0, len(image_fp))]

        if channel_colors is None:
            channel_colors = [None for _ in range(0, len(image_fp))]

        images = []
        for im_idx, image_data in enumerate(
                zip(image_fp, image_res, channel_names, channel_colors)):
            image, image_res, channel_names, channel_colors = image_data
            imdata = reg_image_loader(
                image,
                image_res,
                channel_names=channel_names,
                channel_colors=channel_colors,
            )
            if (imdata.channel_names is None
                    or len(imdata.channel_names) != imdata.n_ch):
                imdata._channel_names = [
                    f"C{idx}" for idx in range(0, imdata.n_ch)
                ]

            images.append(imdata)

        if all([im.im_dtype == images[0].im_dtype for im in images]) is False:
            warn(
                "MergeRegImage created with mixed data types, writing will cast "
                "to the largest data type")

        if any([im.is_rgb for im in images]) is True:
            warn(
                "MergeRegImage does not support writing merged interleaved RGB "
                "Data will be written as multi-channel")

        self.images = images
        self.image_fps = image_fp
        self.im_dtype = self.images[0].im_dtype

        self.is_rgb = False

        self.n_ch = np.sum([i.n_ch for i in self.images])
        self.channel_names = [i.channel_names for i in self.images]
        self.original_size_transform = None
Exemplo n.º 7
0
def test_wsireg_run_reg_changeres(data_out_dir, disk_im_gry):
    wsi_reg = WsiReg2D(gen_project_name_str(), str(data_out_dir))
    img_fp1 = str(disk_im_gry)

    wsi_reg.add_modality(
        "mod1",
        img_fp1,
        0.65,
        channel_names=["test"],
        channel_colors=["red"],
        output_res=0.325,
    )

    wsi_reg.add_modality(
        "mod2",
        img_fp1,
        0.65,
        channel_names=["test"],
        channel_colors=["red"],
    )

    wsi_reg.add_reg_path("mod1",
                         "mod2",
                         reg_params=["rigid_test", "affine_test"])
    wsi_reg.register_images()

    im_fps = wsi_reg.transform_images(transform_non_reg=False)
    regim = reg_image_loader(im_fps[0], 0.325)

    assert regim.shape[1:] == (4096, 4096)
Exemplo n.º 8
0
def test_czi_read_mc_read_channels():
    image_fp = os.path.join(PRIVATE_DIR, "czi_4ch_16bit.ome.tiff")
    ri = reg_image_loader(image_fp, 1)
    ch0 = ri.read_single_channel(0)
    ch1 = ri.read_single_channel(1)
    ch2 = ri.read_single_channel(2)
    ch3 = ri.read_single_channel(3)

    assert np.squeeze(ch0).shape == ri.shape[1:]
    assert np.squeeze(ch1).shape == ri.shape[1:]
    assert np.squeeze(ch2).shape == ri.shape[1:]
    assert np.squeeze(ch3).shape == ri.shape[1:]
    assert np.ndim(ch0) == 2
    assert np.ndim(ch1) == 2
    assert np.ndim(ch2) == 2
    assert np.ndim(ch3) == 2
    assert np.array_equal(ch0, ch1) is False
    assert np.array_equal(ch0, ch2) is False
    assert np.array_equal(ch0, ch3) is False
    assert np.array_equal(ch1, ch2) is False
    assert np.array_equal(ch1, ch3) is False
    assert np.array_equal(ch2, ch3) is False
    assert ch0.dtype == np.uint16
    assert ch1.dtype == np.uint16
    assert ch2.dtype == np.uint16
    assert ch3.dtype == np.uint16
Exemplo n.º 9
0
def test_huron_read_rgb():
    image_fp = os.path.join(PRIVATE_DIR, "huron_rgb.tif")
    ri = reg_image_loader(image_fp, 1)
    assert len(ri.im_dims) == 3
    assert ri.im_dims[2] == 3
    assert ri.im_dtype == np.uint8
    assert ri.is_rgb is True
Exemplo n.º 10
0
def test_czi_read_mc_fl_preprocess():
    image_fp = os.path.join(PRIVATE_DIR, "czi_4ch_16bit.czi")
    preprocessing = {"image_type": "FL", "as_uint8": True}
    ri = reg_image_loader(image_fp, 1, preprocessing=preprocessing)
    ri.read_reg_image()
    assert ri.reg_image.GetNumberOfComponentsPerPixel() == 1
    assert ri.reg_image.GetPixelID() == 1
Exemplo n.º 11
0
def test_reg_image_loader_zarr_mch(im_rgb_np_uneven):
    reg_image = reg_image_loader(im_rgb_np_uneven, 0.65, mask=GEOJSON_FP)
    reg_image.read_reg_image()
    assert reg_image.reg_image.GetSpacing() == (0.65, 0.65)
    assert reg_image.reg_image.GetNumberOfComponentsPerPixel() == 1
    assert reg_image.mask.GetSize() == reg_image.reg_image.GetSize()
    assert reg_image.mask.GetSpacing() == reg_image.reg_image.GetSpacing()
Exemplo n.º 12
0
def test_wsireg_run_reg_downsampling_m1m2_changeores(data_out_dir,
                                                     disk_im_gry):
    wsi_reg = WsiReg2D(gen_project_name_str(), str(data_out_dir))
    img_fp1 = str(disk_im_gry)

    wsi_reg.add_modality(
        "mod1",
        img_fp1,
        0.65,
        channel_names=["test"],
        channel_colors=["red"],
        preprocessing={"downsampling": 2},
        output_res=(1.3, 1.3),
    )

    wsi_reg.add_modality(
        "mod2",
        img_fp1,
        0.65,
        channel_names=["test"],
        channel_colors=["red"],
        preprocessing={"downsampling": 2},
    )

    wsi_reg.add_reg_path("mod1", "mod2", reg_params=["rigid_test"])
    wsi_reg.register_images()

    im_fps = wsi_reg.transform_images(transform_non_reg=False)
    regim = reg_image_loader(im_fps[0], 0.65)

    assert regim.shape[1:] == (1024, 1024)
Exemplo n.º 13
0
def test_OmeTiffWriter_compare_tile_plane_rgb_nl_large(tmp_path):
    im_array = da.from_array(
        np.random.randint(0, 255, (2**13, 2**13, 3), dtype=np.uint8),
        chunks=(1024, 1024, 3),
    )
    reg_image = reg_image_loader(im_array, 0.5)

    rts = RegTransformSeq(TFORM_FP)
    ometiffwriter = OmeTiffWriter(reg_image, reg_transform_seq=rts)
    ometiletiffwriter = OmeTiffTiledWriter(reg_image, reg_transform_seq=rts)

    by_tile_fp = ometiletiffwriter.write_image_by_tile(
        gen_project_name_str(),
        output_dir=str(tmp_path),
    )

    by_plane_fp = ometiffwriter.write_image_by_plane(
        gen_project_name_str(),
        output_dir=str(tmp_path),
    )

    im_tile = imread(by_tile_fp)
    im_plane = imread(by_plane_fp)

    assert np.array_equal(im_tile, im_plane)
Exemplo n.º 14
0
def test_czi_read_mc_selectch_preprocess_list():
    image_fp = os.path.join(PRIVATE_DIR, "czi_4ch_16bit.czi")
    preprocessing = {"ch_indices": [0]}
    ri = reg_image_loader(image_fp, 1, preprocessing=preprocessing)
    ri.read_reg_image()
    assert ri.reg_image.GetNumberOfComponentsPerPixel() == 1
    assert ri.reg_image.GetPixelID() == 1
Exemplo n.º 15
0
def test_OmeTiffWriter_compare_tile_plane_rgb_nl(
    simple_transform_affine_nl, tmp_path
):
    reg_image = reg_image_loader(
        np.random.randint(0, 255, (1024, 1024, 3), dtype=np.uint8), 1
    )

    rts = RegTransformSeq(simple_transform_affine_nl)
    ometiffwriter = OmeTiffWriter(reg_image, reg_transform_seq=rts)
    ometiletiffwriter = OmeTiffTiledWriter(reg_image, reg_transform_seq=rts)

    by_tile_fp = ometiletiffwriter.write_image_by_tile(
        gen_project_name_str(),
        output_dir=str(tmp_path),
    )

    by_plane_fp = ometiffwriter.write_image_by_plane(
        gen_project_name_str(),
        output_dir=str(tmp_path),
    )

    im_tile = imread(by_tile_fp)
    im_plane = imread(by_plane_fp)

    assert np.array_equal(im_tile, im_plane)
Exemplo n.º 16
0
def test_wsireg_run_reg_wmerge(data_out_dir, disk_im_gry):
    wsi_reg = WsiReg2D(gen_project_name_str(), str(data_out_dir))
    img_fp1 = str(disk_im_gry)

    wsi_reg.add_modality(
        "mod1",
        img_fp1,
        0.65,
        channel_names=["test"],
        channel_colors=["red"],
    )

    wsi_reg.add_modality(
        "mod2",
        img_fp1,
        0.65,
        channel_names=["test"],
        channel_colors=["red"],
    )

    wsi_reg.add_reg_path("mod1",
                         "mod2",
                         reg_params=["rigid_test", "affine_test"])
    wsi_reg.add_merge_modalities("test_merge", ["mod1", "mod2"])
    wsi_reg.register_images()
    wsi_reg.save_transformations()
    im_fps = wsi_reg.transform_images(transform_non_reg=True)
    merged_im = reg_image_loader(im_fps[0], 0.65)
    assert Path(im_fps[0]).exists() is True
    assert merged_im.shape == (2, 2048, 2048)
Exemplo n.º 17
0
def test_wsireg_run_reg_downsampling_m1m2_merge_no_prepro(
        data_out_dir, disk_im_gry):
    wsi_reg = WsiReg2D(gen_project_name_str(), str(data_out_dir))
    img_fp1 = str(disk_im_gry)

    wsi_reg.add_modality(
        "mod1",
        img_fp1,
        0.65,
        channel_names=["test"],
        channel_colors=["red"],
        preprocessing={"downsampling": 2},
    )

    wsi_reg.add_modality(
        "mod2",
        img_fp1,
        0.65,
        channel_names=["test"],
        channel_colors=["red"],
        preprocessing={"downsampling": 2},
    )

    wsi_reg.add_reg_path("mod1", "mod2", reg_params=["rigid_test"])
    wsi_reg.add_merge_modalities("mod12-merge", ["mod1", "mod2"])
    wsi_reg.register_images()

    im_fps = wsi_reg.transform_images(transform_non_reg=False,
                                      remove_merged=True)
    regim = reg_image_loader(im_fps[0], 0.65)

    assert regim.shape == (2, 2048, 2048)
Exemplo n.º 18
0
def test_czi_read_mc_std_preprocess():
    image_fp = os.path.join(PRIVATE_DIR, "czi_4ch_16bit.czi")
    preprocessing = ImagePreproParams()
    ri = reg_image_loader(image_fp, 1, preprocessing=preprocessing)
    ri.read_reg_image()
    assert ri.reg_image.GetNumberOfComponentsPerPixel() == 1
    assert ri.reg_image.GetPixelID() == 1
Exemplo n.º 19
0
def test_czi_read_rgb():
    image_fp = os.path.join(PRIVATE_DIR, "czi_rgb.czi")
    ri = reg_image_loader(image_fp, 1)
    assert len(ri.shape) == 3
    assert ri.shape[2] == 3
    assert ri.im_dtype == np.uint8
    assert ri.is_rgb is True
Exemplo n.º 20
0
def test_czi_read_mc():
    image_fp = os.path.join(PRIVATE_DIR, "czi_4ch_16bit.czi")
    ri = reg_image_loader(image_fp, 1)
    assert len(ri.shape) == 3
    assert ri.shape[0] == 4
    assert ri.shape[2] > 3
    assert ri.im_dtype == np.uint16
    assert ri.is_rgb is False
Exemplo n.º 21
0
def test_reg_image_loader_to_itk(im_gry_np, mask_np):
    reg_image = reg_image_loader(im_gry_np, 0.65, mask=mask_np)
    reg_image.read_reg_image()
    reg_image.sitk_to_itk(cast_to_float32=True)
    assert isinstance(reg_image.image, itk.Image) is True
    assert reg_image.image.GetSpacing() == (0.65, 0.65)
    assert isinstance(reg_image.mask, itk.Image) is True
    assert reg_image.mask.GetSpacing() == (0.65, 0.65)
Exemplo n.º 22
0
def test_reg_image_loader_image_np_rgb_std_prepro_fliph(im_rgb_np):
    reg_image = reg_image_loader(im_rgb_np, 0.65, preprocessing={"flip": "h"})
    reg_image.read_reg_image()
    assert reg_image.image.GetSize() == (2048, 2048)
    assert reg_image.image.GetNumberOfComponentsPerPixel() == 1
    assert reg_image.pre_reg_transforms is not None
    assert len(reg_image.pre_reg_transforms) == 1
    assert reg_image.image.GetSpacing() == (0.65, 0.65)
Exemplo n.º 23
0
def test_gj_reg_image_loader_mask_flip(im_gry_np, mask_geojson):
    reg_image = reg_image_loader(
        im_gry_np, 0.65, preprocessing={"flip": "v"}, mask=mask_geojson
    )
    reg_image.read_reg_image()

    assert reg_image.mask is not None
    assert isinstance(reg_image.mask, sitk.Image) is True
    assert reg_image.mask.GetSpacing() == (0.65, 0.65)
Exemplo n.º 24
0
def test_reg_image_loader_image_np_gry_std_prepro_rot(im_gry_np):
    reg_image = reg_image_loader(im_gry_np, 0.65, preprocessing={"rot_cc": 90})
    reg_image.read_reg_image()

    assert reg_image.image.GetSize() == (2048, 2048)
    assert reg_image.image.GetNumberOfComponentsPerPixel() == 1
    assert reg_image.pre_reg_transforms is not None
    assert len(reg_image.pre_reg_transforms) == 1
    assert reg_image.image.GetSpacing() == (0.65, 0.65)
Exemplo n.º 25
0
def test_reg_image_loader_mask_rot(im_gry_np, mask_np):
    reg_image = reg_image_loader(
        im_gry_np, 0.65, preprocessing={"rot_cc": 90}, mask=mask_np
    )
    reg_image.read_reg_image()

    assert reg_image.mask is not None
    assert isinstance(reg_image.mask, sitk.Image) is True
    assert reg_image.mask.GetSpacing() == (0.65, 0.65)
Exemplo n.º 26
0
def test_wsireg_config_full_merge_rgb_mc(config_fp, data_out_dir):
    wsi_reg1 = config_to_WsiReg2D(config_fp, data_out_dir)
    wsi_reg1.add_data_from_config(config_fp)
    wsi_reg1.register_images()
    im_fps = wsi_reg1.transform_images()
    ri = reg_image_loader(im_fps[0], 1)

    assert ri.im_dtype == np.uint16
    assert ri.im_dims == (9, 3993, 3397)
Exemplo n.º 27
0
def test_reg_image_loader_dask_rgb(dask_im_rgb_np):
    reg_image = reg_image_loader(dask_im_rgb_np, 0.65)
    reg_image.read_reg_image()
    assert len(reg_image.shape) == 3
    assert reg_image.shape[-1] == 3
    assert reg_image.is_rgb
    assert reg_image.n_ch == 3
    assert reg_image.reg_image.GetSpacing() == (0.65, 0.65)
    assert reg_image.reg_image.GetNumberOfComponentsPerPixel() == 1
Exemplo n.º 28
0
def test_reg_image_loader_zarr_mch(zarr_im_mch_np):
    reg_image = reg_image_loader(zarr_im_mch_np, 0.65)
    reg_image.read_reg_image()
    assert len(reg_image.shape) == 3
    assert reg_image.shape[0] == 3
    assert reg_image.is_rgb is False
    assert reg_image.n_ch == 3
    assert reg_image.reg_image.GetSpacing() == (0.65, 0.65)
    assert reg_image.reg_image.GetNumberOfComponentsPerPixel() == 1
Exemplo n.º 29
0
def test_wsireg_run_reg_wattachment_ds2(data_out_dir, disk_im_gry):
    wsi_reg = WsiReg2D(gen_project_name_str(), str(data_out_dir))
    im1 = np.random.randint(0, 255, (2048, 2048), dtype=np.uint16)
    im2 = np.random.randint(0, 255, (2048, 2048), dtype=np.uint16)

    wsi_reg.add_modality(
        "mod1",
        im1,
        0.65,
        channel_names=["test"],
        channel_colors=["red"],
        preprocessing={"downsampling": 2},
    )

    wsi_reg.add_modality(
        "mod2",
        im2,
        0.65,
        channel_names=["test"],
        channel_colors=["red"],
        preprocessing={"downsampling": 2},
    )
    wsi_reg.add_attachment_images("mod2", "attached", im2, image_res=0.65)
    wsi_reg.add_attachment_images("mod1", "attached2", im1, image_res=0.65)

    wsi_reg.add_reg_path("mod2",
                         "mod1",
                         reg_params=["rigid_test", "affine_test"])

    wsi_reg.register_images()
    im_fps = wsi_reg.transform_images(transform_non_reg=False)

    wsi_reg.save_transformations()

    regim = reg_image_loader(im_fps[0], 0.65)
    attachim = reg_image_loader(im_fps[1], 0.65)
    attachim2 = reg_image_loader(im_fps[2], 0.65)

    assert np.array_equal(
        np.squeeze(regim.dask_image.compute()),
        np.squeeze(attachim.dask_image.compute()),
    )
    assert np.array_equal(np.squeeze(im1),
                          np.squeeze(attachim2.dask_image.compute()))
Exemplo n.º 30
0
def test_wsireg_run_reg_with_flip_crop(data_out_dir, disk_im_gry):
    wsi_reg = WsiReg2D("test_proj8", str(data_out_dir))
    img_fp1 = str(disk_im_gry)

    wsi_reg.add_modality(
        "mod1",
        img_fp1,
        0.65,
        channel_names=["test"],
        channel_colors=["red"],
    )

    wsi_reg.add_modality(
        "mod2",
        img_fp1,
        0.65,
        channel_names=["test"],
        channel_colors=["red"],
        prepro_dict={"mask_bbox": [512, 512, 512, 512], "flip": "h"},
    )

    wsi_reg.add_reg_path(
        "mod1", "mod2", reg_params=["rigid_test", "affine_test"]
    )
    wsi_reg.register_images()
    # not cropped
    im_fps = wsi_reg.transform_images(
        transform_non_reg=True, to_original_size=True
    )
    registered_image_nocrop = reg_image_loader(im_fps[0], 1)
    unregistered_image_nocrop = reg_image_loader(im_fps[1], 1)

    # crop image
    im_fps = wsi_reg.transform_images(
        transform_non_reg=True, to_original_size=False
    )
    registered_image_crop = reg_image_loader(im_fps[0], 1)
    unregistered_image_crop = reg_image_loader(im_fps[1], 1)

    assert registered_image_nocrop.im_dims[1:] == (2048, 2048)
    assert unregistered_image_nocrop.im_dims[1:] == (2048, 2048)
    assert registered_image_crop.im_dims[1:] == (512, 512)
    assert unregistered_image_crop.im_dims[1:] == (512, 512)