def test_stardistdata(): from stardist.models import StarDistData2D img, mask = real_image2d() s = StarDistData2D([img, img], [mask, mask], batch_size=1, patch_size=(30, 40), n_rays=32, length=1) (img,), (prob, dist) = s[0] return (img,), (prob, dist), s
def render_label_pred_example(): model_path = path_model2d() model = StarDist2D(None, name=model_path.name, basedir=str(model_path.parent)) img, y_gt = real_image2d() x = normalize(img, 1, 99.8) y, _ = model.predict_instances(x) im = render_label_pred(y_gt, y, img=x) import matplotlib.pyplot as plt plt.figure(1, figsize=(12, 4)) plt.subplot(1, 4, 1) plt.imshow(x) plt.title("img") plt.subplot(1, 4, 2) plt.imshow(render_label(y_gt, img=x)) plt.title("gt") plt.subplot(1, 4, 3) plt.imshow(render_label(y, img=x)) plt.title("pred") plt.subplot(1, 4, 4) plt.imshow(im) plt.title("tp (green) fp (red) fn(blue)") plt.tight_layout() plt.show() return im
def test_speed(model2d): from time import time model = model2d img, mask = real_image2d() x = normalize(img, 1, 99.8) x = np.tile(x, (6, 6)) print(x.shape) stats = [] for mode, n_tiles, sparse in product(("normal", "big"), (None, (2, 2)), (True, False)): t = time() if mode == "normal": labels, res = model.predict_instances(x, n_tiles=n_tiles, sparse=sparse) else: labels, res = model.predict_instances_big(x, axes="YX", block_size=1024 + 256, context=64, min_overlap=64, n_tiles=n_tiles, sparse=sparse) t = time() - t s = f"mode={mode}\ttiles={n_tiles}\tsparse={sparse}\t{t:.2f}s" print(s) stats.append(s) for s in stats: print(s)
def test_predict2D(model2d, use_channel): model = model2d img = real_image2d()[0] img = normalize(img, 1, 99.8) img = repeat(img, 2) axes = 'YX' if use_channel: img = img[...,np.newaxis] axes += 'C' ref_labels, ref_polys = model.predict_instances(img, axes=axes) res_labels, res_polys = model.predict_instances_big(img, axes=axes, block_size=288, min_overlap=32, context=96) m = matching(ref_labels, res_labels) assert (1.0, 1.0) == (m.accuracy, m.mean_true_score) m = matching(render_polygons(ref_polys, shape=img.shape), render_polygons(res_polys, shape=img.shape)) assert (1.0, 1.0) == (m.accuracy, m.mean_true_score) # sort them first lexicographic ref_inds = np.lexsort(ref_polys["points"].T) res_inds = np.lexsort(res_polys["points"].T) assert np.allclose(ref_polys["coord"][ref_inds], res_polys["coord"][res_inds],atol=1e-2) assert np.allclose(ref_polys["points"][ref_inds], res_polys["points"][res_inds],atol=1e-2) assert np.allclose(ref_polys["prob"][ref_inds], res_polys["prob"][res_inds],atol=1e-2) return ref_polys, res_polys
def test_cover2D(block_size, context, grid): lbl = real_image2d()[1] lbl = lbl.astype(np.int32) max_sizes = tuple(calculate_extents(lbl, func=np.max)) min_overlap = tuple(1 + v for v in max_sizes) lbl = repeat(lbl, 4) assert max_sizes == tuple(calculate_extents(lbl, func=np.max)) reassemble(lbl, 'YX', block_size, min_overlap, context, grid)
def test_load_and_predict_big(): model_path = path_model2d() model = StarDist2D(None, name=model_path.name, basedir=str(model_path.parent)) img, _ = real_image2d() x = normalize(img, 1, 99.8) x = np.tile(x, (8, 8)) labels, polygons = model.predict_instances(x) return labels
def test_pretrained_integration(): from stardist.models import StarDist2D img = normalize(real_image2d()[0]) model = StarDist2D.from_pretrained("2D_versatile_fluo") prob, dist = model.predict(img) y1, res1 = model._instances_from_prediction(img.shape, prob, dist, nms_thresh=.3) return y1, res1
def test_predict_dense_sparse(model2d): model = model2d img, mask = real_image2d() x = normalize(img, 1, 99.8) labels1, res1 = model.predict_instances(x, n_tiles=(2, 2), sparse=False) labels2, res2 = model.predict_instances(x, n_tiles=(2, 2), sparse=True) assert np.allclose(labels1, labels2) assert all( np.allclose(res1[k], res2[k]) for k in set(res1.keys()).union(set(res2.keys())) if isinstance(res1[k], np.ndarray)) return labels2, res1, labels2, res2
def test_polygon_order_2D(model2d): model = model2d img = real_image2d()[0] img = normalize(img, 1, 99.8) labels, polys = model.predict_instances(img, nms_thresh=0) for i, coord in enumerate(polys['coord'], start=1): # polygon representing object with id i p = Polygon(coord, shape_max=labels.shape) # mask of object with id i in label image (not occluded since nms_thresh=0) mask_i = labels[p.slice] == i assert np.all(p.mask == mask_i)
def render_label_example(model2d): model = model2d img, y_gt = real_image2d() x = normalize(img, 1, 99.8) y, _ = model.predict_instances(x) # im = render_label(y,img = x, alpha = 0.3, alpha_boundary=1, cmap = (.3,.4,0)) im = render_label(y,img = x, alpha = 0.3, alpha_boundary=1) import matplotlib.pyplot as plt plt.figure(1) plt.imshow(im) plt.show() return im
def _check_single_val(n_classes, classes=1): img, y_gt = real_image2d() labels_gt = set(np.unique(y_gt[y_gt > 0])) p, cls_dict = mask_to_categorical(y_gt, n_classes=n_classes, classes=classes, return_cls_dict=True) assert p.shape == y_gt.shape + (n_classes + 1, ) assert tuple(cls_dict.keys()) == (classes, ) and set( cls_dict[classes]) == labels_gt assert set(np.where(np.count_nonzero(p, axis=(0, 1)))[0]) == set( {0, classes}) return p, cls_dict
def test_optimize_thresholds(model2d): model = model2d img, mask = real_image2d() x = normalize(img, 1, 99.8) res = model.optimize_thresholds([x], [mask], nms_threshs=[.3, .5], iou_threshs=[.3, .5], optimize_kwargs=dict(tol=1e-1), save_to_json=False) np.testing.assert_almost_equal(res["prob"], 0.454617141955, decimal=3) np.testing.assert_almost_equal(res["nms"] , 0.3, decimal=3)
def test_load_and_predict(): model_path = path_model2d() model = StarDist2D(None, name=model_path.name, basedir=str(model_path.parent)) img, mask = real_image2d() x = normalize(img,1,99.8) prob, dist = model.predict(x, n_tiles=(2,3)) assert prob.shape == dist.shape[:2] assert model.config.n_rays == dist.shape[-1] labels, polygons = model.predict_instances(x) assert labels.shape == img.shape[:2] assert labels.max() == len(polygons['coord']) assert len(polygons['coord']) == len(polygons['points']) == len(polygons['prob']) stats = matching(mask, labels, thresh=0.5) assert (stats.fp, stats.tp, stats.fn) == (1, 48, 17)
def render_label_example(): model_path = path_model2d() model = StarDist2D(None, name=model_path.name, basedir=str(model_path.parent)) img, y_gt = real_image2d() x = normalize(img, 1, 99.8) y, _ = model.predict_instances(x) # im = render_label(y,img = x, alpha = 0.3, alpha_boundary=1, cmap = (.3,.4,0)) im = render_label(y, img=x, alpha=0.3, alpha_boundary=1) import matplotlib.pyplot as plt plt.figure(1) plt.imshow(im) plt.show() return im
def test_load_and_predict(model2d): model = model2d img, mask = real_image2d() x = normalize(img, 1, 99.8) prob, dist = model.predict(x, n_tiles=(2, 3)) assert prob.shape == dist.shape[:2] assert model.config.n_rays == dist.shape[-1] labels, polygons = model.predict_instances(x) assert labels.shape == img.shape[:2] assert labels.max() == len(polygons['coord']) assert len(polygons['coord']) == len( polygons['points']) == len(polygons['prob']) stats = matching(mask, labels, thresh=0.5) assert (stats.fp, stats.tp, stats.fn) == (1, 48, 17) return labels
def test_optimize_thresholds(): model_path = path_model2d() model = StarDist2D(None, name=model_path.name, basedir=str(model_path.parent)) img, mask = real_image2d() x = normalize(img, 1, 99.8) res = model.optimize_thresholds([x], [mask], nms_threshs=[.3, .5], iou_threshs=[.3, .5], optimize_kwargs=dict(tol=1e-1), save_to_json=False) np.testing.assert_almost_equal(res["prob"], 0.454617141955, decimal=3) np.testing.assert_almost_equal(res["nms"], 0.3, decimal=3)
def test_stardistdata(shape_completion, n_classes, classes): np.random.seed(42) from stardist.models import StarDistData2D img, mask = real_image2d() s = StarDistData2D([img, img], [mask, mask], grid=(2, 2), n_classes=n_classes, classes=classes, shape_completion=shape_completion, b=8, batch_size=1, patch_size=(30, 40), n_rays=32, length=1) a, b = s[0] return a, b, s
def render_label_pred_example(model2d): model = model2d img, y_gt = real_image2d() x = normalize(img, 1, 99.8) y, _ = model.predict_instances(x) im = render_label_pred(y_gt, y , img = x) import matplotlib.pyplot as plt plt.figure(1, figsize = (12,4)) plt.subplot(1,4,1);plt.imshow(x);plt.title("img") plt.subplot(1,4,2);plt.imshow(render_label(y_gt, img = x));plt.title("gt") plt.subplot(1,4,3);plt.imshow(render_label(y, img = x));plt.title("pred") plt.subplot(1,4,4);plt.imshow(im);plt.title("tp (green) fp (red) fn(blue)") plt.tight_layout() plt.show() return im
def test_edt_prob(anisotropy): try: import edt from stardist.utils import _edt_prob_edt, _edt_prob_scipy masks = (np.tile(real_image2d()[1], (2, 2)), np.zeros( (117, 92)), np.ones((153, 112))) dtypes = (np.uint16, np.int32) slices = (slice(None), ) * 2, (slice(1, -1), ) * 2 for mask, dtype, sl in product(masks, dtypes, slices): mask = mask.astype(dtype)[sl] print(f"\nEDT {dtype.__name__} {mask.shape} slice {sl} ") with Timer("scipy "): ed1 = _edt_prob_scipy(mask, anisotropy=anisotropy) with Timer("edt: "): ed2 = _edt_prob_edt(mask, anisotropy=anisotropy) assert np.percentile(np.abs(ed1 - ed2), 99.9) < 1e-3 except ImportError: print("Install edt to run test")
def test_pretrained_scales(): from scipy.ndimage import zoom from stardist.matching import matching from skimage.measure import regionprops model = StarDist2D.from_pretrained("2D_versatile_fluo") img, mask = real_image2d() x = normalize(img, 1, 99.8) def pred_scale(scale=2): x2 = zoom(x, scale, order=1) labels2, _ = model.predict_instances(x2) labels = zoom(labels2, tuple(_s1/_s2 for _s1, _s2 in zip(mask.shape, labels2.shape)), order=0) return labels scales = np.linspace(.5,5,10) accs = tuple(matching(mask, pred_scale(s)).accuracy for s in scales) print("scales ", np.round(scales,2)) print("accuracy ", np.round(accs,2)) return accs
def test_stardistdata_multithreaded(workers=5): np.random.seed(42) from stardist.models import StarDistData2D from scipy.ndimage import rotate from concurrent.futures import ThreadPoolExecutor from time import sleep def augmenter(x, y): deg = int(np.sum(y) % 117) print(deg) # return x,y x = rotate(x, deg, reshape=False, order=0) y = rotate(y, deg, reshape=False, order=0) # sleep(np.abs(deg)/180) return x, y n_samples = 4 _, mask = real_image2d() Y = np.stack([mask + i for i in range(n_samples)]) s = StarDistData2D(Y.astype(np.float32), Y, grid=(1, 1), n_classes=None, augmenter=augmenter, batch_size=1, patch_size=mask.shape, n_rays=32, length=len(Y)) a1, b1 = tuple(zip(*tuple(s[i] for i in range(n_samples)))) with ThreadPoolExecutor(max_workers=n_samples) as e: a2, b2 = tuple(zip(*tuple(e.map(lambda i: s[i], range(n_samples))))) assert all([np.allclose(_r1[0], _r2[0]) for _r1, _r2 in zip(a1, a2)]) assert all([np.allclose(_r1[0], _r2[0]) for _r1, _r2 in zip(b1, b2)]) assert all([np.allclose(_r1[1], _r2[1]) for _r1, _r2 in zip(b1, b2)]) return a2, b2, s
import numpy as np import pytest from stardist import star_dist, edt_prob, non_maximum_suppression, dist_to_coord, polygons_to_label from stardist.matching import matching from utils import random_image, real_image2d, check_similar @pytest.mark.parametrize('img', (real_image2d()[1], random_image((128, 123)))) def test_bbox_search(img): prob = edt_prob(img) dist = star_dist(img, n_rays=32, mode="cpp") coord = dist_to_coord(dist) nms_a = non_maximum_suppression(coord, prob, prob_thresh=0.4, verbose=False, max_bbox_search=False) nms_b = non_maximum_suppression(coord, prob, prob_thresh=0.4, verbose=False, max_bbox_search=True) check_similar(nms_a, nms_b) @pytest.mark.parametrize('img', (real_image2d()[1], )) def test_acc(img): prob = edt_prob(img) dist = star_dist(img, n_rays=32, mode="cpp") coord = dist_to_coord(dist) points = non_maximum_suppression(coord, prob, prob_thresh=0.4)
def test_imagej_rois_export(tmpdir, model2d): img = normalize(real_image2d()[0], 1, 99.8) labels, polys = model2d.predict_instances(img) export_imagej_rois(str(Path(tmpdir)/'img_rois.zip'), polys['coord'])
def _test_model_multiclass(n_classes=1, classes="auto", n_channel=None, basedir=None): from skimage.measure import regionprops img, mask = real_image2d() img = normalize(img, 1, 99.8) if n_channel is not None: img = np.repeat(img[..., np.newaxis], n_channel, axis=-1) else: n_channel = 1 X, Y = [img, img, img], [mask, mask, mask] conf = Config2D( n_rays=48, grid=(2, 2), n_channel_in=n_channel, n_classes=n_classes, use_gpu=False, train_epochs=1, train_steps_per_epoch=1, train_batch_size=1, train_dist_loss="iou", train_patch_size=(128, 128), ) if n_classes is not None and n_classes > 1 and classes == "auto": regs = regionprops(mask) areas = tuple(r.area for r in regs) inds = np.argsort(areas) ss = tuple( slice(n * len(regs) // n_classes, (n + 1) * len(regs) // n_classes) for n in range(n_classes)) classes = {} for i, s in enumerate(ss): for j in inds[s]: classes[regs[j].label] = i + 1 classes = (classes, ) * len(X) model = StarDist2D(conf, name=None if basedir is None else "stardist", basedir=str(basedir)) val_classes = {k: 1 for k in set(mask[mask > 0])} s = model.train(X, Y, classes=classes, epochs=30, validation_data=(X[:1], Y[:1]) if n_classes is None else (X[:1], Y[:1], (val_classes, ))) img = np.tile(img, (4, 4) if img.ndim == 2 else (4, 4, 1)) kwargs = dict(prob_thresh=.2) labels1, res1 = model.predict_instances(img, **kwargs) labels2, res2 = model.predict_instances(img, sparse=True, **kwargs) labels3, res3 = model.predict_instances_big( img, axes="YX" if img.ndim == 2 else "YXC", block_size=640, min_overlap=32, context=96, **kwargs) assert np.allclose(labels1, labels2) assert all([ np.allclose(res1[k], res2[k]) for k in set(res1.keys()).union(set(res2.keys())) if isinstance(res1[k], np.ndarray) ]) return model, img, res1, res2, res3