Exemplo n.º 1
0
def test_gen_index_combs():
    """Check whether correct number of combinations are generated."""
    N = MRAIDenseNeuralNetwork()
    x = np.arange(10)
    y = np.arange(10)
    combs = N.gen_index_combs([x, y])
    assert combs.shape[0] == 100
    assert combs.shape[1] == 2
Exemplo n.º 2
0
def test_subsample_rows():
    """Correct shape and contents."""
    A = np.arange(24).reshape((12, 2))
    N = MRAIDenseNeuralNetwork()
    a = N.subsample_rows(A, num_draw=5)
    assert a.shape[0] == 5
    assert a.shape[1] == 2
    assert len(np.unique(a) == 5)
Exemplo n.º 3
0
def test_matrix2sparse():
    """Test all pixels are mapped to sparse."""
    A = np.arange(24).reshape((6, 4))
    N = MRAIDenseNeuralNetwork()
    sA = N.matrix2sparse(A)
    assert sA.shape[0] == 24
    assert sA.shape[1] == 3
    assert len(np.setdiff1d(sA[:, 0], np.arange(6))) == 0
    assert len(np.setdiff1d(sA[:, 1], np.arange(4))) == 0
    assert len(np.setdiff1d(sA[:, 2], np.arange(24))) == 0
Exemplo n.º 4
0
def test_sample_pairs():
    """Produces correct output shape."""
    # Network
    N = MRAIDenseNeuralNetwork(patch_size=(9, 9), dense_size=[2], num_draw=2)

    # Source array
    X = rn.randn(32, 32)
    Y = np.round(np.linspace(0, 3, 32**2)).reshape((32, 32))
    y = N.matrix2sparse(Y, edge=(4, 4))

    # Target array
    Z = rn.randn(32, 32)
    U = np.round(np.linspace(0, 3, 32**2)).reshape((32, 32))
    u = N.matrix2sparse(U, edge=(4, 4))

    # Sample pairs
    P, S = N.sample_pairs(X, y, Z, u, num_draw=(2, 1))

    # Extract pairs
    A, B, a, b = P

    # Check for correct shapes
    assert len(A.shape) == 4
    assert A.shape[1] == A.shape[2]
    assert A.shape[3] == 1
    assert A.shape[0] == a.shape[0]

    assert len(B.shape) == 4
    assert B.shape[1] == B.shape[2]
    assert B.shape[3] == 1
    assert B.shape[0] == b.shape[0]

    # Check for correct contents
    assert len(np.setdiff1d(np.unique(a), [0, 1])) == 0
    assert len(np.setdiff1d(np.unique(b), [0, 1])) == 0
    assert len(np.setdiff1d(np.unique(S), [0, 1])) == 0
Exemplo n.º 5
0
def test_init_patch_size():
    """Valid patch size."""
    with pytest.raises(ValueError):
        N = MRAIDenseNeuralNetwork(patch_size=(-1, 1))
Exemplo n.º 6
0
def test_l2_norm():
    """Test non-negative norm."""
    N = MRAIDenseNeuralNetwork()
    with tf.Session().as_default():
        norms = N.l2_norm([rn.randn(100, 1), rn.randn(100, 1)]).eval()
        assert np.all(np.array(norms) >= 0)
Exemplo n.º 7
0
def test_contrastive_loss():
    """Test non-negative contrastive loss."""
    N = MRAIDenseNeuralNetwork()
    with tf.Session().as_default():
        assert N.contrastive_loss(label=1, distance=rn.randn(1)).eval() >= 0
        assert N.contrastive_loss(label=0, distance=rn.randn(1)).eval() >= 0
Exemplo n.º 8
0
def test_init_dense_size():
    """Valid dense size."""
    with pytest.raises(ValueError):
        N = MRAIDenseNeuralNetwork(dense_size=[])
Exemplo n.º 9
0
# Load source MRI-scan and corresponding segmentation
X = np.load('./data/subject01_GE2D_1.5T.npy')
Y = np.load('./data/subject01_segmentation.npy')

# Load target MRI-scan and corresponding segmentation
Z = np.load('./data/subject02_GE2D_3.0T.npy')
U = np.load('./data/subject02_segmentation.npy')

# Note that U is missing a lot of labels
print('Proportion missing labels = ' + str(np.mean(~np.isnan(U.ravel()))))
'''Set up MRAI network'''

# Initialize and compile a small neural network
N = MRAIDenseNeuralNetwork(patch_size=(31, 31),
                           dense_size=[16, 8],
                           batch_size=128,
                           num_epochs=4,
                           num_draw=10,
                           margin=10)
'''Train the net'''

# Call training procedure on source and target data
N.train(X, Y, Z, U, num_targets=1)
'''Map images to MRAI representation.'''

# Extract all source patches and feed them through network.
PX = extract_all_patches(X[0], patch_size=(31, 31), add_4d=True)
HX = N.feedforward(PX, scan_ID=0)

# Map label image to sparse array format
sY = N.matrix2sparse(Y[0])