示例#1
0
def test_automate():

    shape = (20, 20)
    cx, cy = shape[1] / 2.0, shape[0] / 2.0
    r = 5

    y, x = np.ogrid[0:shape[0], 0:shape[1]]
    mask = (np.power((y - cy), 2) + np.power((x - cx), 2)) < np.power(r, 2)

    field = np.zeros(mask.shape)
    field[mask] = 1.0
    label = np.zeros_like(field)
    theta = np.zeros_like(field)
    label[:] = 0

    label[0:(r / 4), 0:(r / 4)] = -1
    theta[0:(r / 4), 0:(r / 4)] = 1.0

    label[cy - (r / 4):cy + (r / 4), cx - (r / 4):cx + (r / 4)] = 1
    theta[cy - (r / 4):cy + (r / 4), cx - (r / 4):cx + (r / 4)] = 1.0

    # Plot the image and the label map.
    # import matplotlib.pyplot as plt
    # fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
    # ax1.imshow(field, interpolation='nearest', cmap='gray')
    # ax1.axis('off')
    # ax2.imshow(label, interpolation='nearest', cmap='jet')
    # ax2.axis('off')

    # Automate to update the labels
    for itteration in range(100):
        theta, label = growcut.automate(field, theta, label)

    # fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
    # ax1.imshow(field, interpolation='nearest', cmap='gray')
    # ax1.axis('off')
    # ax2.imshow(label, interpolation='nearest', cmap='jet')
    # ax2.axis('off')
    # plt.show()

    # Assert that the label map looks like the mask
    assert np.allclose(label[mask], 1.0) & np.allclose(label[~mask], -1), \
        "Segmentation did not converge after {} iterations".format(itteration)
示例#2
0
# Form a label grid (0: no label, 1: foreground, 2: background)
label = np.zeros_like(lum, dtype=np.int)
label[:] = -1
label[75:90, 100:110] = 1
label[110:120, 150:160] = 1
label[50:55, 160:165] = 1
label[50:55, 180:185] = 0
label[0:10, 0:10] = 0
label[75:90, 0:10] = 0
label[0:10, 200:210] = 0
label[75:90, 200:210] = 0

# Form a strength grid.
strength = np.zeros_like(lum, dtype=np.float64)
strength[label != -1] = 1.0


t0 = time.time()
coordinates = automata.formSamples(lum.shape, neighbours=automata.CONNECT_4)
strength, label = growcut.numpyAutomate(coordinates, lum, strength, label)
print "Numpy vectorized: " + str(1000 * (time.time() - t0)) + " ms"

t0 = time.time()
strength, label = automate_cy(lum, strength, label, connectivity=4)
print "Cython: " + str(1000 * (time.time() - t0)) + " ms"

t0 = time.time()
strength, label = growcut.automate(lum, strength, label)
print "Pure Python: " + str(1000 * (time.time() - t0)) + " ms"