示例#1
0
def fuse_labels(t1: np.ndarray, id_: str, acq: Path, nib_obj) -> np.ndarray:
    gt: np.ndarray = np.zeros_like(t1, dtype=np.uint8)
    gt1: np.ndarray = np.zeros_like(t1, dtype=np.uint8)
    assert gt.dtype == np.uint8

    labels: List[Path] = list(acq.glob(f"{id_}_LesionSmooth_*stx.nii.gz"))
    assert len(labels) >= 1, (acq, id_)

    label_path: Path
    label: np.ndarray
    for label_path in labels:
        label_obj = nib.load(str(label_path))
        label = np.asarray(label_obj.dataobj)

        assert sanity_label(label, t1, label_obj.header.get_zooms(), nib_obj.header.get_zooms(), label_path)

        binary_label: np.ndarray = (label > 0).astype(np.uint8)
        binary_label1: np.ndarray = (label > 1).astype(np.uint8)
        assert binary_label.dtype == np.uint8, binary_label.dtype
        assert set(uniq(binary_label)) <= set([0, 1])

        gt |= binary_label  # logical OR if labels overlap
        gt1 |= binary_label1  # logical OR if labels overlap
        # gt += binary_label
    assert set(uniq(gt)) <= set([0, 1])
    assert gt.dtype == np.uint8

    return gt, gt1
    def stat(self, halos, sparticles, sgrp):
        sortind  = np.argsort(sparticles)
        sortgrp  = sgrp[sortind]
        sort     = sortgrp[np.argsort(sortgrp)]
        unique   = sort[np.uniq(sort)]
        ngrps    = len(unique)
        amigaind = unique - 1

        dtype = np.dtype({'names': ('grp','sat', 'contam', 'false', 'npart', 'ngas', 
                                    'nstar', 'ndar', 'mvir', 'rvir', 'gasmass', 'starmass', 
                                    'darkmass', 'vmax', 'rmax', 'sigv', 'xc', 'yc', 'zc', 'vx', 
                                    'vy', 'vz', 'aimgaorigid'),
                          'formats': ('l','a','a','a','l','l','l','l','d','d', 
                                      'd','d','d','d','d','d','d','d','d','d' 
                                      'd','l')})
        stat          = np.empty(ngrps, dtype = dtype) 
        stat['grp']   = unique
        stat['npart'] = halos['npart5'].npart[amigaind]
        stat['mvir']  = halos['mvir'][amigaind]/self.h
        stat['rvir']  = halos['rvir'][amigaind]/self.h
        stat['vmax']  = halos['vmax'][amigaind]
        stat['rmax']  = halos['rmax'][amigaind]/self.h
        stat['sigv']  = halos['sigv'][amigaind]
        stat['x']     = halos['x']   [aimgaind]/self.h
        stat['y']     = halos['y']   [aimgaind]/self.h
        stat['z']     = halos['z']   [aimgaind]/self.h
        stat['vx']    = halos['vx']  [amigaind]
        stat['vy']    = halos['vy']  [aimgaind]
        stat['vz']    = halos['vz']  [aimgaind]

        if self.tipsy.ngas+self.tipsy.nstar > 0:
            gmass = tipsy.gas['mass'][0]
    def stat(self, halos, sparticles, sgrp):
        sortind = np.argsort(sparticles)
        sortgrp = sgrp[sortind]
        sort = sortgrp[np.argsort(sortgrp)]
        unique = sort[np.uniq(sort)]
        ngrps = len(unique)
        amigaind = unique - 1

        dtype = np.dtype({
            'names':
            ('grp', 'sat', 'contam', 'false', 'npart', 'ngas', 'nstar', 'ndar',
             'mvir', 'rvir', 'gasmass', 'starmass', 'darkmass', 'vmax', 'rmax',
             'sigv', 'xc', 'yc', 'zc', 'vx', 'vy', 'vz', 'aimgaorigid'),
            'formats': ('l', 'a', 'a', 'a', 'l', 'l', 'l', 'l', 'd', 'd', 'd',
                        'd', 'd', 'd', 'd', 'd', 'd', 'd', 'd', 'd'
                        'd', 'l')
        })
        stat = np.empty(ngrps, dtype=dtype)
        stat['grp'] = unique
        stat['npart'] = halos['npart5'].npart[amigaind]
        stat['mvir'] = halos['mvir'][amigaind] / self.h
        stat['rvir'] = halos['rvir'][amigaind] / self.h
        stat['vmax'] = halos['vmax'][amigaind]
        stat['rmax'] = halos['rmax'][amigaind] / self.h
        stat['sigv'] = halos['sigv'][amigaind]
        stat['x'] = halos['x'][aimgaind] / self.h
        stat['y'] = halos['y'][aimgaind] / self.h
        stat['z'] = halos['z'][aimgaind] / self.h
        stat['vx'] = halos['vx'][amigaind]
        stat['vy'] = halos['vy'][aimgaind]
        stat['vz'] = halos['vz'][aimgaind]

        if self.tipsy.ngas + self.tipsy.nstar > 0:
            gmass = tipsy.gas['mass'][0]
示例#4
0
def display_bins(x, y, binNum, velBin):

    if not (x.size == y.size == binNum.size):
        raise ValueError("The vectors (x, y, binNum) must have the same size")

    if np.uniq(binNum).size != velBin.size:
        raise ValueError("velBin size does not match number of bins")

    img = display_pixels(x, y, velBin[binNum])

    return img
示例#5
0
def sanity_label(label, t1, resolution, t1_resolution, label_path) -> bool:
    # assert False
    assert label.shape == t1.shape
    assert resolution == t1_resolution

    assert label.dtype in [np.float64], label.dtype

    # print(str(label_path))
    # if "31898" in str(label_path):
    #     print(label_path, uniq(label))

    # > 0 means disease
    labels_allowed = [[0.0, 0.9999999997671694],
                      [0., 254.9999999406282],
                      [0., 0.9999999997671694, 253.99999994086102, 254.9999999406282],
                      [0.0, 0.9999999997671694, 1.9999999995343387, 252.99999994109385, 253.99999994086102, 254.9999999406282]]

    # assert set(uniq(label)) in set(labels_allowed), (set(uniq(label)), label_path)
    matches: List[bool] = [set(uniq(label)) == set(allowed) for allowed in labels_allowed]
    assert any(matches), (set(uniq(label)), label_path)

    return True
def save_slices(img_p: Path, gt_p: Path,
                dest_dir: Path, shape: Tuple[int], n_augment: int,
                img_dir: str = "img", gt_dir: str = "gt") -> Tuple[int, int, int]:
    p_id: str = get_p_id(img_p)
    assert "Case" in p_id
    assert p_id == get_p_id(gt_p)

    # Load the data
    img = imread(str(img_p), plugin='simpleitk')
    gt = imread(str(gt_p), plugin='simpleitk')
    # print(img.shape, img.dtype, gt.shape, gt.dtype)
    # print(img.min(), img.max(), len(np.unique(img)))
    # print(np.unique(gt))

    assert img.shape == gt.shape
    assert img.dtype in [np.int16]
    assert gt.dtype in [np.int8]

    # Normalize and check data content
    norm_img = norm_arr(img)  # We need to normalize the whole 3d img, not 2d slices
    assert 0 == norm_img.min() and norm_img.max() == 255, (norm_img.min(), norm_img.max())
    assert norm_img.dtype == np.uint8

    save_dir_img: Path = Path(dest_dir, img_dir)
    save_dir_gt: Path = Path(dest_dir, gt_dir)
    sizes_2d: np.ndarray = np.zeros(img.shape[-1])
    for j in range(len(img)):
        img_s = norm_img[j, :, :]
        gt_s = gt[j, :, :]
        assert img_s.shape == gt_s.shape

        # Resize and check the data are still what we expect
        resize_: Callable = partial(resize, mode="constant", preserve_range=True, anti_aliasing=False)
        r_img: np.ndarray = resize_(img_s, shape).astype(np.uint8)
        r_gt: np.ndarray = resize_(gt_s, shape).astype(np.uint8)
        assert r_img.dtype == r_gt.dtype == np.uint8
        assert 0 <= r_img.min() and r_img.max() <= 255  # The range might be smaller
        assert set(uniq(r_gt)).issubset(set(uniq(gt)))
        sizes_2d[j] = r_gt[r_gt == 1].sum()

        # for save_dir, data in zip([save_dir_img, save_dir_gt], [r_img, r_gt]):
        #     save_dir.mkdir(parents=True, exist_ok=True)

        #     with warnings.catch_warnings():
        #         warnings.filterwarnings("ignore", category=UserWarning)
        #         imsave(str(Path(save_dir, filename)), data)

        for k in range(n_augment + 1):
            if k == 0:
                a_img, a_gt = r_img, r_gt
            else:
                a_img, a_gt = map_(np.asarray, augment(r_img, r_gt))

            for save_dir, data in zip([save_dir_img, save_dir_gt], [a_img, a_gt]):
                filename = f"{p_id}_{k}_{j:02d}.png"
                save_dir.mkdir(parents=True, exist_ok=True)

                with warnings.catch_warnings():
                    warnings.filterwarnings("ignore", category=UserWarning)
                    imsave(str(Path(save_dir, filename)), data)

    return sizes_2d.sum(), sizes_2d[sizes_2d > 0].min(), sizes_2d.max()
示例#7
0
def save_slices(img_p: Path,
                gt_p: Path,
                dest_dir: Path,
                shape: Tuple[int, int],
                n_augment: int,
                img_dir: str = "img",
                gt_dir: str = "gt") -> Tuple[Any, Any, Any, Any]:
    p_id: str = get_p_id(img_p)
    assert "patient" in p_id
    assert p_id == get_p_id(gt_p)

    f_id: str = get_frame(img_p.name)
    assert f_id == get_frame(gt_p.name)

    # Load the data
    dx, dy, dz = nib.load(str(img_p)).header.get_zooms()
    assert dz in [5, 6.5, 7, 10], dz
    img = np.asarray(nib.load(str(img_p)).dataobj)
    gt = np.asarray(nib.load(str(gt_p)).dataobj)

    nx, ny = shape
    fx = nx / img.shape[0]
    fy = ny / img.shape[1]
    # print(f"Before dx {dx:.04f}, dy {dy:.04f}")
    dx /= fx
    dy /= fy
    # print(f"After dx {dx:.04f}, dy {dy:.04f}")

    # print(dx, dy, dz)
    pixel_surface: float = dx * dy
    voxel_volume: float = dx * dy * dz

    assert img.shape == gt.shape
    # assert img.shape[:-1] == shape
    assert img.dtype in [np.uint8, np.int16, np.float32]

    # Normalize and check data content
    norm_img = norm_arr(
        img)  # We need to normalize the whole 3d img, not 2d slices
    assert 0 == norm_img.min() and norm_img.max() == 255, (norm_img.min(),
                                                           norm_img.max())
    assert gt.dtype == norm_img.dtype == np.uint8

    resize_: Callable = partial(resize,
                                mode="constant",
                                preserve_range=True,
                                anti_aliasing=False)

    save_dir_img: Path = Path(dest_dir, img_dir)
    save_dir_gt: Path = Path(dest_dir, gt_dir)
    sizes_2d: np.ndarray = np.zeros(img.shape[-1])
    for j in range(img.shape[-1]):
        img_s = norm_img[:, :, j]
        gt_s = gt[:, :, j]
        assert img_s.shape == gt_s.shape
        assert gt_s.dtype == np.uint8

        # Resize and check the data are still what we expect
        r_img: np.ndarray = resize_(img_s, shape).astype(np.uint8)
        r_gt: np.ndarray = resize_(gt_s, shape, order=0)
        # r_gt: np.ndarray = np.array(Image.fromarray(gt_s, mode='L').resize(shape))
        assert set(uniq(r_gt)).issubset(set(uniq(gt))), (r_gt.dtype,
                                                         uniq(r_gt))
        r_gt = r_gt.astype(np.uint8)
        assert r_img.dtype == r_gt.dtype == np.uint8
        assert 0 <= r_img.min() and r_img.max(
        ) <= 255  # The range might be smaller
        sizes_2d[j] = (r_gt == 3).astype(np.int64).sum()

        for k in range(n_augment + 1):
            if k == 0:
                a_img, a_gt = r_img, r_gt
            else:
                a_img, a_gt = map_(np.asarray, augment(r_img, r_gt))

            for save_dir, data in zip([save_dir_img, save_dir_gt],
                                      [a_img, a_gt]):
                filename = f"{p_id}_{f_id}_{k}_{j}.png"
                save_dir.mkdir(parents=True, exist_ok=True)

                with warnings.catch_warnings():
                    warnings.filterwarnings("ignore", category=UserWarning)
                    imsave(str(Path(save_dir, filename)), data)

    lv_gt = (gt == 3).astype(np.uint8)
    assert set(np.unique(lv_gt)) <= set([0, 1])
    assert lv_gt.shape == gt.shape

    lv_gt = resize_(lv_gt, (*shape, img.shape[-1]), order=0)
    assert set(np.unique(lv_gt)) <= set([0, 1])

    slices_sizes_px = np.einsum("xyz->z", lv_gt.astype(np.int64))
    assert np.array_equal(slices_sizes_px,
                          sizes_2d), (slices_sizes_px, sizes_2d)
    # slices_sizes_px = sizes_2d[...]
    slices_sizes_px = slices_sizes_px[slices_sizes_px > 0]
    slices_sizes_mm2 = slices_sizes_px * pixel_surface

    # volume_size_px = np.einsum("xyz->", lv_gt)
    volume_size_px = slices_sizes_px.sum()
    volume_size_mm3 = volume_size_px * voxel_volume

    # print(f"{slices_sizes_px.mean():.0f}, {volume_size_px}")

    return slices_sizes_px, slices_sizes_mm2, volume_size_px, volume_size_mm3
示例#8
0
def save_slices(img_p: Path, gt_p: Path,
                dest_dir: Path, shape: Tuple[int],
                img_dir: str = "img", gt_dir: str = "gt") -> Tuple[int, int, int]:
    p_id: str = get_p_id(img_p)
    assert "Case" in p_id
    assert p_id == get_p_id(gt_p)

    # Load the data
    img = imread(str(img_p), plugin='simpleitk')
    gt = imread(str(gt_p), plugin='simpleitk')
    # print(img.shape, img.dtype, gt.shape, gt.dtype)
    # print(img.min(), img.max(), len(np.unique(img)))
    # print(np.unique(gt))

    assert img.shape == gt.shape
    assert img.dtype in [np.int16]
    assert gt.dtype in [np.int8]

    img_nib = sitk.ReadImage(str(img_p))
    dx, dy, dz = img_nib.GetSpacing()
    # print(dx, dy, dz)
    assert np.abs(dx - dy) <= 0.0000041, (dx, dy, dx - dy)
    assert 0.27 <= dx <= 0.75, dx
    assert 2.19994 <= dz <= 4.00001, dz

    x, y, z = img.shape
    assert (y, z) in [(320, 320), (512, 512), (256, 256), (384, 384)], (y, z)
    assert 15 <= x <= 54, x

    # Normalize and check data content
    norm_img = norm_arr(img)  # We need to normalize the whole 3d img, not 2d slices
    assert 0 == norm_img.min() and norm_img.max() == 255, (norm_img.min(), norm_img.max())
    assert norm_img.dtype == np.uint8

    save_dir_img: Path = dest_dir / img_dir
    save_dir_gt: Path = dest_dir / gt_dir
    save_dir_weak: Path = dest_dir / "weak"
    sizes_2d: np.ndarray = np.zeros(img.shape[-1])
    for j in range(len(img)):
        img_s = norm_img[j, :, :]
        gt_s = gt[j, :, :]
        assert img_s.shape == gt_s.shape

        # Resize and check the data are still what we expect
        # from time import time
        # tic = time()
        resize_: Callable = partial(resize, mode="constant", preserve_range=True, anti_aliasing=False)
        r_img: np.ndarray = resize_(img_s, shape).astype(np.uint8)
        r_gt: np.ndarray = resize_(gt_s, shape).astype(np.uint8)
        # print(time() - tic)
        assert r_img.dtype == r_gt.dtype == np.uint8
        assert 0 <= r_img.min() and r_img.max() <= 255  # The range might be smaller
        assert set(uniq(r_gt)).issubset(set(uniq(gt)))
        sizes_2d[j] = r_gt[r_gt == 1].sum()

        r_weak: np.ndarray = random_strat(r_gt, 1)

        r_gt *= 255
        r_weak *= 255

        for save_dir, data in zip([save_dir_img, save_dir_gt, save_dir_weak],
                                  [r_img, r_gt, r_weak]):
            filename = f"{p_id}_{j:02d}.png"
            save_dir.mkdir(parents=True, exist_ok=True)

            with warnings.catch_warnings():
                warnings.filterwarnings("ignore", category=UserWarning)
                imsave(str(Path(save_dir, filename)), data)

    return sizes_2d.sum(), sizes_2d[sizes_2d > 0].min(), sizes_2d.max()
def save_slices(img_p: Path,
                gt_p: Path,
                dest_dir: Path,
                shape: Tuple[int, int],
                img_dir: str = "img",
                gt_dir: str = "gt") -> None:
    p_id: str = get_p_id(img_p)
    assert "patient" in p_id
    assert p_id == get_p_id(gt_p)

    f_id: str = get_frame(img_p.name)
    assert f_id == get_frame(gt_p.name)

    # Load the data
    dx, dy, dz = nib.load(str(img_p)).header.get_zooms()
    assert dz in [5, 6.5, 7, 10], dz
    img = np.asarray(nib.load(str(img_p)).dataobj)
    gt = np.asarray(nib.load(str(gt_p)).dataobj)

    nx, ny = shape
    fx = nx / img.shape[0]
    fy = ny / img.shape[1]
    # print(f"Before dx {dx:.04f}, dy {dy:.04f}")
    dx /= fx
    dy /= fy
    # print(f"After dx {dx:.04f}, dy {dy:.04f}")

    assert img.shape == gt.shape
    # assert img.shape[:-1] == shape
    assert img.dtype in [np.uint8, np.int16, np.float32]

    # Normalize and check data content
    norm_img = norm_arr(
        img)  # We need to normalize the whole 3d img, not 2d slices
    assert 0 == norm_img.min() and norm_img.max() == 255, (norm_img.min(),
                                                           norm_img.max())
    assert gt.dtype == norm_img.dtype == np.uint8

    resize_: Callable = partial(resize,
                                mode="constant",
                                preserve_range=True,
                                anti_aliasing=False)

    save_dir_img: Path = Path(dest_dir, img_dir)
    save_dir_gt: Path = Path(dest_dir, gt_dir)
    save_dir_weak: Path = Path(dest_dir, "weak")
    sizes_2d: np.ndarray = np.zeros(img.shape[-1])
    for j in range(img.shape[-1]):
        img_s = norm_img[:, :, j]
        gt_s = gt[:, :, j]
        assert img_s.shape == gt_s.shape
        assert gt_s.dtype == np.uint8

        # Resize and check the data are still what we expect
        r_img: np.ndarray = resize_(img_s, shape).astype(np.uint8)
        r_gt: np.ndarray = resize_(gt_s, shape, order=0)
        # r_gt: np.ndarray = np.array(Image.fromarray(gt_s, mode='L').resize(shape))
        assert set(uniq(r_gt)).issubset(set(uniq(gt))), (r_gt.dtype,
                                                         uniq(r_gt))
        r_gt = r_gt.astype(np.uint8)
        assert r_img.dtype == r_gt.dtype == np.uint8
        assert 0 <= r_img.min() and r_img.max(
        ) <= 255  # The range might be smaller
        sizes_2d[j] = (r_gt == 3).astype(np.int64).sum()

        # Don't do it for the background
        r_weak: np.ndarray = random_strat(r_gt, [1, 2, 3])

        assert set(np.unique(r_gt)) <= set([0, 1, 2, 3])
        assert set(np.unique(r_weak)) <= set([0, 1, 2, 3])
        r_gt *= 255 // 3
        r_weak *= 255 // 3
        assert set(np.unique(r_gt)) <= set([0, 85, 170, 255])
        assert set(np.unique(r_weak)) <= set([0, 85, 170, 255])

        for save_dir, data in zip([save_dir_img, save_dir_gt, save_dir_weak],
                                  [r_img, r_gt, r_weak]):
            filename = f"{p_id}_{f_id}_{j:02d}.png"
            save_dir.mkdir(parents=True, exist_ok=True)
            save_path = save_dir / filename

            with warnings.catch_warnings():
                warnings.filterwarnings("ignore", category=UserWarning)
                imsave(str(save_path), data)
示例#10
0
def process_patient(img_p: Path,
                    gt_p: Path,
                    dest_dir: Path,
                    shape: Tuple[int, int],
                    cr: int,
                    img_dir: str = "img",
                    gt_dir: str = "gt") -> np.ndarray:
    p_id: str = get_p_id(img_p)
    assert p_id == get_p_id(gt_p)

    # Load the data
    img_nib = nib.load(str(img_p))
    x, y, z = img_nib.dataobj.shape
    dx, dy, dz = img_nib.header.get_zooms()

    # Make sure data is consistent with the description in the lineage
    assert (x, y, z) == (39, 305, 305), (x, y, z)
    assert 1.9 <= dx <= 2, dx
    assert dy == dz, (dy, dz)
    assert 1 <= dy <= 1.25, dy

    img = np.asarray(img_nib.dataobj)
    gt = np.asarray(nib.load(str(gt_p)).dataobj)

    assert img.shape == gt.shape
    assert img.dtype in [np.int16], img.dtype
    assert gt.dtype in [np.uint8], gt.dtype

    # Normalize and check data content
    norm_img = norm_arr(
        img)  # We need to normalize the whole 3d img, not 2d slices
    assert 0 == norm_img.min() and norm_img.max() == 255, (norm_img.min(),
                                                           norm_img.max())
    assert norm_img.dtype == np.uint8

    norm_gt = gt.astype(np.uint8)
    assert set(uniq(gt)) == set(uniq(norm_gt)) == set([0, 1])
    del img  # Keep gt for sanity checks

    crop_img = norm_img[:, cr:-cr, :]
    crop_gt = norm_gt[:, cr:-cr, :]
    assert norm_gt.sum() == crop_gt.sum(
    )  # Make sure we did not discard any part of the object
    del norm_img, norm_gt

    # Pad to get square slices
    _, ny, _ = crop_img.shape
    offset_x: int = (ny - x) // 2

    pad_img = np.zeros((ny, ny, z), dtype=np.uint8)
    pad_img[offset_x:offset_x + x, ...] = crop_img

    pad_gt = np.zeros((ny, ny, z), dtype=np.uint8)
    pad_gt[offset_x:offset_x + x, ...] = crop_gt
    del crop_img, crop_gt

    resize_: Callable = partial(resize,
                                output_shape=(*shape, z),
                                mode="constant",
                                preserve_range=True,
                                anti_aliasing=False)
    # resize_: Callable = lambda x, *_, **_2: x[cr:-cr, cr:-cr, :]

    resized_img = resize_(pad_img).astype(np.uint8)
    resized_gt = resize_(pad_gt, order=0)
    assert set(uniq(resized_gt)).issubset(set(uniq(gt))), (resized_gt.dtype,
                                                           uniq(resized_gt))
    resized_gt = resized_gt.astype(np.uint8)
    del pad_img, pad_gt

    save_dir_img: Path = Path(dest_dir, img_dir)
    save_dir_gt: Path = Path(dest_dir, gt_dir)
    save_slices([resized_img, resized_gt], [save_dir_img, save_dir_gt], p_id)

    sizes = np.einsum("xyz->z", resized_gt, dtype=np.int64)

    return sizes
示例#11
0
def save_slices(img_p: Path,
                gt_p: Path,
                dest_dir: Path,
                shape: Tuple[int],
                n_augment: int,
                img_dir: str = "img",
                gt_dir: str = "gt") -> Tuple[int, int, int, str, List[int]]:
    p_id: str = get_p_id(img_p)
    assert "patient" in p_id
    assert p_id == get_p_id(gt_p)

    f_id: str = get_frame(img_p.name)
    assert f_id == get_frame(gt_p.name)

    # Load the data
    img = np.asarray(nib.load(str(img_p)).dataobj)
    gt = np.asarray(nib.load(str(gt_p)).dataobj)

    assert img.shape == gt.shape
    assert img.dtype in [np.uint8, np.int16, np.float32]

    # Normalize and check data content
    norm_img = norm_arr(
        img)  # We need to normalize the whole 3d img, not 2d slices
    assert 0 == norm_img.min() and norm_img.max() == 255, (norm_img.min(),
                                                           norm_img.max())
    assert gt.dtype == norm_img.dtype == np.uint8

    save_dir_img: Path = Path(dest_dir, img_dir)
    save_dir_gt: Path = Path(dest_dir, gt_dir)
    sizes_2d: np.ndarray = np.zeros(img.shape[-1])
    for j in range(img.shape[-1]):
        img_s = norm_img[:, :, j]
        gt_s = gt[:, :, j]
        assert img_s.shape == gt_s.shape

        # Resize and check the data are still what we expect
        resize_: Callable = partial(resize,
                                    mode="constant",
                                    preserve_range=True,
                                    anti_aliasing=False)
        r_img: np.ndarray = resize_(img_s, shape).astype(np.uint8)
        r_gt: np.ndarray = resize_(gt_s, shape).astype(np.uint8)
        assert r_img.dtype == r_gt.dtype == np.uint8
        assert 0 <= r_img.min() and r_img.max(
        ) <= 255  # The range might be smaller
        assert set(uniq(r_gt)).issubset(set(uniq(gt)))
        # only calculate the gt ==3, ventrical surface.
        # sizes_2d[j] = r_gt[r_gt == 3].sum()
        sizes_2d[j] = (r_gt == 3).sum()

        for k in range(n_augment + 1):
            if k == 0:
                a_img, a_gt = r_img, r_gt
            else:
                # the data augmentation is only with rotation, flip and mirror
                a_img, a_gt = map_(np.asarray, augment(r_img, r_gt))

            for save_dir, data in zip([save_dir_img, save_dir_gt],
                                      [a_img, a_gt]):
                filename = f"{p_id}_{f_id}_{k}_{j}.png"
                save_dir.mkdir(parents=True, exist_ok=True)

                with warnings.catch_warnings():
                    warnings.filterwarnings("ignore", category=UserWarning)
                    imsave(str(Path(save_dir, filename)), data)
    # return 3D size, minimal size for positive images. maximal size for positive image, f_id, and size_2d list
    return sizes_2d.sum(), sizes_2d[
        sizes_2d > 0].min(), sizes_2d.max(), f_id, sizes_2d