Beispiel #1
0
def test_stardistdata():
    from stardist.models import StarDistData2D
    img, mask = real_image2d()
    s = StarDistData2D([img, img], [mask, mask],
                       batch_size=1, patch_size=(30, 40), n_rays=32, length=1)
    (img,), (prob, dist) = s[0]
    return (img,), (prob, dist), s
Beispiel #2
0
def render_label_pred_example():
    model_path = path_model2d()
    model = StarDist2D(None,
                       name=model_path.name,
                       basedir=str(model_path.parent))
    img, y_gt = real_image2d()
    x = normalize(img, 1, 99.8)
    y, _ = model.predict_instances(x)

    im = render_label_pred(y_gt, y, img=x)
    import matplotlib.pyplot as plt
    plt.figure(1, figsize=(12, 4))
    plt.subplot(1, 4, 1)
    plt.imshow(x)
    plt.title("img")
    plt.subplot(1, 4, 2)
    plt.imshow(render_label(y_gt, img=x))
    plt.title("gt")
    plt.subplot(1, 4, 3)
    plt.imshow(render_label(y, img=x))
    plt.title("pred")
    plt.subplot(1, 4, 4)
    plt.imshow(im)
    plt.title("tp (green) fp (red) fn(blue)")
    plt.tight_layout()
    plt.show()
    return im
Beispiel #3
0
def test_speed(model2d):
    from time import time

    model = model2d
    img, mask = real_image2d()
    x = normalize(img, 1, 99.8)
    x = np.tile(x, (6, 6))
    print(x.shape)

    stats = []

    for mode, n_tiles, sparse in product(("normal", "big"), (None, (2, 2)),
                                         (True, False)):
        t = time()
        if mode == "normal":
            labels, res = model.predict_instances(x,
                                                  n_tiles=n_tiles,
                                                  sparse=sparse)
        else:
            labels, res = model.predict_instances_big(x,
                                                      axes="YX",
                                                      block_size=1024 + 256,
                                                      context=64,
                                                      min_overlap=64,
                                                      n_tiles=n_tiles,
                                                      sparse=sparse)

        t = time() - t
        s = f"mode={mode}\ttiles={n_tiles}\tsparse={sparse}\t{t:.2f}s"
        print(s)
        stats.append(s)

    for s in stats:
        print(s)
Beispiel #4
0
def test_predict2D(model2d, use_channel):
    model = model2d
    img = real_image2d()[0]
    img = normalize(img, 1, 99.8)
    img = repeat(img, 2)
    axes = 'YX'

    if use_channel:
        img = img[...,np.newaxis]
        axes += 'C'

    ref_labels, ref_polys = model.predict_instances(img, axes=axes)
    res_labels, res_polys = model.predict_instances_big(img, axes=axes, block_size=288, min_overlap=32, context=96)

    m = matching(ref_labels, res_labels)
    assert (1.0, 1.0) == (m.accuracy, m.mean_true_score)

    m = matching(render_polygons(ref_polys, shape=img.shape),
                 render_polygons(res_polys, shape=img.shape))
    assert (1.0, 1.0) == (m.accuracy, m.mean_true_score)

    # sort them first lexicographic
    ref_inds = np.lexsort(ref_polys["points"].T)
    res_inds = np.lexsort(res_polys["points"].T)

    assert np.allclose(ref_polys["coord"][ref_inds],
                       res_polys["coord"][res_inds],atol=1e-2)
    assert np.allclose(ref_polys["points"][ref_inds],
                       res_polys["points"][res_inds],atol=1e-2)
    assert np.allclose(ref_polys["prob"][ref_inds],
                       res_polys["prob"][res_inds],atol=1e-2)

    return ref_polys, res_polys
Beispiel #5
0
def test_cover2D(block_size, context, grid):
    lbl = real_image2d()[1]
    lbl = lbl.astype(np.int32)

    max_sizes = tuple(calculate_extents(lbl, func=np.max))
    min_overlap = tuple(1 + v for v in max_sizes)
    lbl = repeat(lbl, 4)
    assert max_sizes == tuple(calculate_extents(lbl, func=np.max))

    reassemble(lbl, 'YX', block_size, min_overlap, context, grid)
Beispiel #6
0
def test_load_and_predict_big():
    model_path = path_model2d()
    model = StarDist2D(None,
                       name=model_path.name,
                       basedir=str(model_path.parent))
    img, _ = real_image2d()
    x = normalize(img, 1, 99.8)
    x = np.tile(x, (8, 8))
    labels, polygons = model.predict_instances(x)
    return labels
Beispiel #7
0
def test_pretrained_integration():
    from stardist.models import StarDist2D
    img = normalize(real_image2d()[0])

    model = StarDist2D.from_pretrained("2D_versatile_fluo")
    prob, dist = model.predict(img)

    y1, res1 = model._instances_from_prediction(img.shape,
                                                prob,
                                                dist,
                                                nms_thresh=.3)
    return y1, res1
Beispiel #8
0
def test_predict_dense_sparse(model2d):
    model = model2d
    img, mask = real_image2d()
    x = normalize(img, 1, 99.8)
    labels1, res1 = model.predict_instances(x, n_tiles=(2, 2), sparse=False)
    labels2, res2 = model.predict_instances(x, n_tiles=(2, 2), sparse=True)
    assert np.allclose(labels1, labels2)
    assert all(
        np.allclose(res1[k], res2[k])
        for k in set(res1.keys()).union(set(res2.keys()))
        if isinstance(res1[k], np.ndarray))
    return labels2, res1, labels2, res2
Beispiel #9
0
def test_polygon_order_2D(model2d):
    model = model2d
    img = real_image2d()[0]
    img = normalize(img, 1, 99.8)
    labels, polys = model.predict_instances(img, nms_thresh=0)

    for i, coord in enumerate(polys['coord'], start=1):
        # polygon representing object with id i
        p = Polygon(coord, shape_max=labels.shape)
        # mask of object with id i in label image (not occluded since nms_thresh=0)
        mask_i = labels[p.slice] == i
        assert np.all(p.mask == mask_i)
Beispiel #10
0
def render_label_example(model2d):
    model = model2d
    img, y_gt = real_image2d()
    x = normalize(img, 1, 99.8)
    y, _ = model.predict_instances(x)
    # im =  render_label(y,img = x, alpha = 0.3, alpha_boundary=1, cmap = (.3,.4,0))
    im =  render_label(y,img = x, alpha = 0.3, alpha_boundary=1)
    import matplotlib.pyplot as plt
    plt.figure(1)
    plt.imshow(im)
    plt.show()
    return im
Beispiel #11
0
 def _check_single_val(n_classes, classes=1):
     img, y_gt = real_image2d()
     labels_gt = set(np.unique(y_gt[y_gt > 0]))
     p, cls_dict = mask_to_categorical(y_gt,
                                       n_classes=n_classes,
                                       classes=classes,
                                       return_cls_dict=True)
     assert p.shape == y_gt.shape + (n_classes + 1, )
     assert tuple(cls_dict.keys()) == (classes, ) and set(
         cls_dict[classes]) == labels_gt
     assert set(np.where(np.count_nonzero(p, axis=(0, 1)))[0]) == set(
         {0, classes})
     return p, cls_dict
Beispiel #12
0
def test_optimize_thresholds(model2d):
    model = model2d
    img, mask = real_image2d()
    x = normalize(img, 1, 99.8)

    res = model.optimize_thresholds([x], [mask],
                              nms_threshs=[.3, .5],
                              iou_threshs=[.3, .5],
                              optimize_kwargs=dict(tol=1e-1),
                              save_to_json=False)

    np.testing.assert_almost_equal(res["prob"], 0.454617141955, decimal=3)
    np.testing.assert_almost_equal(res["nms"] , 0.3, decimal=3)
Beispiel #13
0
def test_load_and_predict():
    model_path = path_model2d()
    model = StarDist2D(None, name=model_path.name, basedir=str(model_path.parent))
    img, mask = real_image2d()
    x = normalize(img,1,99.8)
    prob, dist = model.predict(x, n_tiles=(2,3))
    assert prob.shape == dist.shape[:2]
    assert model.config.n_rays == dist.shape[-1]
    labels, polygons = model.predict_instances(x)
    assert labels.shape == img.shape[:2]
    assert labels.max() == len(polygons['coord'])
    assert len(polygons['coord']) == len(polygons['points']) == len(polygons['prob'])
    stats = matching(mask, labels, thresh=0.5)
    assert (stats.fp, stats.tp, stats.fn) == (1, 48, 17)
Beispiel #14
0
def render_label_example():
    model_path = path_model2d()
    model = StarDist2D(None,
                       name=model_path.name,
                       basedir=str(model_path.parent))
    img, y_gt = real_image2d()
    x = normalize(img, 1, 99.8)
    y, _ = model.predict_instances(x)
    # im =  render_label(y,img = x, alpha = 0.3, alpha_boundary=1, cmap = (.3,.4,0))
    im = render_label(y, img=x, alpha=0.3, alpha_boundary=1)
    import matplotlib.pyplot as plt
    plt.figure(1)
    plt.imshow(im)
    plt.show()
    return im
Beispiel #15
0
def test_load_and_predict(model2d):
    model = model2d
    img, mask = real_image2d()
    x = normalize(img, 1, 99.8)
    prob, dist = model.predict(x, n_tiles=(2, 3))
    assert prob.shape == dist.shape[:2]
    assert model.config.n_rays == dist.shape[-1]
    labels, polygons = model.predict_instances(x)
    assert labels.shape == img.shape[:2]
    assert labels.max() == len(polygons['coord'])
    assert len(polygons['coord']) == len(
        polygons['points']) == len(polygons['prob'])
    stats = matching(mask, labels, thresh=0.5)
    assert (stats.fp, stats.tp, stats.fn) == (1, 48, 17)
    return labels
Beispiel #16
0
def test_optimize_thresholds():
    model_path = path_model2d()
    model = StarDist2D(None,
                       name=model_path.name,
                       basedir=str(model_path.parent))
    img, mask = real_image2d()
    x = normalize(img, 1, 99.8)

    res = model.optimize_thresholds([x], [mask],
                                    nms_threshs=[.3, .5],
                                    iou_threshs=[.3, .5],
                                    optimize_kwargs=dict(tol=1e-1),
                                    save_to_json=False)

    np.testing.assert_almost_equal(res["prob"], 0.454617141955, decimal=3)
    np.testing.assert_almost_equal(res["nms"], 0.3, decimal=3)
Beispiel #17
0
def test_stardistdata(shape_completion, n_classes, classes):
    np.random.seed(42)
    from stardist.models import StarDistData2D
    img, mask = real_image2d()
    s = StarDistData2D([img, img], [mask, mask],
                       grid=(2, 2),
                       n_classes=n_classes,
                       classes=classes,
                       shape_completion=shape_completion,
                       b=8,
                       batch_size=1,
                       patch_size=(30, 40),
                       n_rays=32,
                       length=1)
    a, b = s[0]
    return a, b, s
Beispiel #18
0
def render_label_pred_example(model2d):
    model = model2d
    img, y_gt = real_image2d()
    x = normalize(img, 1, 99.8)
    y, _ = model.predict_instances(x)

    im = render_label_pred(y_gt, y , img = x)
    import matplotlib.pyplot as plt
    plt.figure(1, figsize = (12,4))
    plt.subplot(1,4,1);plt.imshow(x);plt.title("img")
    plt.subplot(1,4,2);plt.imshow(render_label(y_gt, img = x));plt.title("gt")
    plt.subplot(1,4,3);plt.imshow(render_label(y, img = x));plt.title("pred")
    plt.subplot(1,4,4);plt.imshow(im);plt.title("tp (green) fp (red) fn(blue)")
    plt.tight_layout()
    plt.show()
    return im
Beispiel #19
0
def test_edt_prob(anisotropy):
    try:
        import edt
        from stardist.utils import _edt_prob_edt, _edt_prob_scipy

        masks = (np.tile(real_image2d()[1], (2, 2)), np.zeros(
            (117, 92)), np.ones((153, 112)))
        dtypes = (np.uint16, np.int32)
        slices = (slice(None), ) * 2, (slice(1, -1), ) * 2
        for mask, dtype, sl in product(masks, dtypes, slices):
            mask = mask.astype(dtype)[sl]
            print(f"\nEDT {dtype.__name__} {mask.shape} slice {sl} ")
            with Timer("scipy "):
                ed1 = _edt_prob_scipy(mask, anisotropy=anisotropy)
            with Timer("edt:  "):
                ed2 = _edt_prob_edt(mask, anisotropy=anisotropy)
            assert np.percentile(np.abs(ed1 - ed2), 99.9) < 1e-3

    except ImportError:
        print("Install edt to run test")
Beispiel #20
0
def test_pretrained_scales():
    from scipy.ndimage import zoom
    from stardist.matching import matching
    from skimage.measure import regionprops

    model = StarDist2D.from_pretrained("2D_versatile_fluo")
    img, mask = real_image2d()
    x = normalize(img, 1, 99.8)

    def pred_scale(scale=2):
        x2 = zoom(x, scale, order=1)
        labels2, _ = model.predict_instances(x2)
        labels = zoom(labels2, tuple(_s1/_s2 for _s1, _s2 in zip(mask.shape, labels2.shape)), order=0)
        return labels

    scales = np.linspace(.5,5,10)
    accs = tuple(matching(mask, pred_scale(s)).accuracy for s in scales)
    print("scales   ", np.round(scales,2))
    print("accuracy ", np.round(accs,2))

    return accs
Beispiel #21
0
def test_stardistdata_multithreaded(workers=5):
    np.random.seed(42)
    from stardist.models import StarDistData2D
    from scipy.ndimage import rotate
    from concurrent.futures import ThreadPoolExecutor
    from time import sleep

    def augmenter(x, y):
        deg = int(np.sum(y) % 117)
        print(deg)
        # return x,y
        x = rotate(x, deg, reshape=False, order=0)
        y = rotate(y, deg, reshape=False, order=0)
        # sleep(np.abs(deg)/180)
        return x, y

    n_samples = 4

    _, mask = real_image2d()
    Y = np.stack([mask + i for i in range(n_samples)])
    s = StarDistData2D(Y.astype(np.float32),
                       Y,
                       grid=(1, 1),
                       n_classes=None,
                       augmenter=augmenter,
                       batch_size=1,
                       patch_size=mask.shape,
                       n_rays=32,
                       length=len(Y))

    a1, b1 = tuple(zip(*tuple(s[i] for i in range(n_samples))))

    with ThreadPoolExecutor(max_workers=n_samples) as e:
        a2, b2 = tuple(zip(*tuple(e.map(lambda i: s[i], range(n_samples)))))

    assert all([np.allclose(_r1[0], _r2[0]) for _r1, _r2 in zip(a1, a2)])
    assert all([np.allclose(_r1[0], _r2[0]) for _r1, _r2 in zip(b1, b2)])
    assert all([np.allclose(_r1[1], _r2[1]) for _r1, _r2 in zip(b1, b2)])
    return a2, b2, s
Beispiel #22
0
import numpy as np
import pytest
from stardist import star_dist, edt_prob, non_maximum_suppression, dist_to_coord, polygons_to_label
from stardist.matching import matching
from utils import random_image, real_image2d, check_similar


@pytest.mark.parametrize('img', (real_image2d()[1], random_image((128, 123))))
def test_bbox_search(img):
    prob = edt_prob(img)
    dist = star_dist(img, n_rays=32, mode="cpp")
    coord = dist_to_coord(dist)
    nms_a = non_maximum_suppression(coord,
                                    prob,
                                    prob_thresh=0.4,
                                    verbose=False,
                                    max_bbox_search=False)
    nms_b = non_maximum_suppression(coord,
                                    prob,
                                    prob_thresh=0.4,
                                    verbose=False,
                                    max_bbox_search=True)
    check_similar(nms_a, nms_b)


@pytest.mark.parametrize('img', (real_image2d()[1], ))
def test_acc(img):
    prob = edt_prob(img)
    dist = star_dist(img, n_rays=32, mode="cpp")
    coord = dist_to_coord(dist)
    points = non_maximum_suppression(coord, prob, prob_thresh=0.4)
Beispiel #23
0
def test_imagej_rois_export(tmpdir, model2d):
    img = normalize(real_image2d()[0], 1, 99.8)
    labels, polys = model2d.predict_instances(img)
    export_imagej_rois(str(Path(tmpdir)/'img_rois.zip'), polys['coord'])
Beispiel #24
0
def _test_model_multiclass(n_classes=1,
                           classes="auto",
                           n_channel=None,
                           basedir=None):
    from skimage.measure import regionprops
    img, mask = real_image2d()
    img = normalize(img, 1, 99.8)

    if n_channel is not None:
        img = np.repeat(img[..., np.newaxis], n_channel, axis=-1)
    else:
        n_channel = 1

    X, Y = [img, img, img], [mask, mask, mask]

    conf = Config2D(
        n_rays=48,
        grid=(2, 2),
        n_channel_in=n_channel,
        n_classes=n_classes,
        use_gpu=False,
        train_epochs=1,
        train_steps_per_epoch=1,
        train_batch_size=1,
        train_dist_loss="iou",
        train_patch_size=(128, 128),
    )

    if n_classes is not None and n_classes > 1 and classes == "auto":
        regs = regionprops(mask)
        areas = tuple(r.area for r in regs)
        inds = np.argsort(areas)
        ss = tuple(
            slice(n * len(regs) // n_classes, (n + 1) * len(regs) // n_classes)
            for n in range(n_classes))
        classes = {}
        for i, s in enumerate(ss):
            for j in inds[s]:
                classes[regs[j].label] = i + 1
        classes = (classes, ) * len(X)

    model = StarDist2D(conf,
                       name=None if basedir is None else "stardist",
                       basedir=str(basedir))

    val_classes = {k: 1 for k in set(mask[mask > 0])}

    s = model.train(X,
                    Y,
                    classes=classes,
                    epochs=30,
                    validation_data=(X[:1], Y[:1]) if n_classes is None else
                    (X[:1], Y[:1], (val_classes, )))

    img = np.tile(img, (4, 4) if img.ndim == 2 else (4, 4, 1))

    kwargs = dict(prob_thresh=.2)
    labels1, res1 = model.predict_instances(img, **kwargs)
    labels2, res2 = model.predict_instances(img, sparse=True, **kwargs)
    labels3, res3 = model.predict_instances_big(
        img,
        axes="YX" if img.ndim == 2 else "YXC",
        block_size=640,
        min_overlap=32,
        context=96,
        **kwargs)

    assert np.allclose(labels1, labels2)
    assert all([
        np.allclose(res1[k], res2[k])
        for k in set(res1.keys()).union(set(res2.keys()))
        if isinstance(res1[k], np.ndarray)
    ])

    return model, img, res1, res2, res3