Esempio n. 1
0
def test_mask(batch, seq_starts, expected):
    shape = ()
    var = sequence.input_variable(shape)
    if type(expected) == type(ValueError):
        with pytest.raises(expected):
            s = sanitize_batch(var, batch, seq_starts)
    else:
        s = sanitize_batch(var, batch, seq_starts)
        assert np.allclose(s.mask, expected)
Esempio n. 2
0
def test_mask(batch, seq_starts, expected):
    shape = ()
    var = sequence.input_variable(shape)
    if type(expected) == type(ValueError):
        with pytest.raises(expected):
            s = sanitize_batch(var, batch, seq_starts)
    else:
        s = sanitize_batch(var, batch, seq_starts)
        assert np.allclose(s.mask, expected)
Esempio n. 3
0
def test_sanitize_batch_contiguity():
    a1 = AA([[1,2],[3,4]])
    a2 = AA([[5,6],[7,8]])
    var = sequence.input_variable((2,2), is_sparse=True)

    batch = [a1.T,a2.T]
    with pytest.warns(RuntimeWarning):
        b = sanitize_batch(var, batch)
        assert b.shape == (2,1,2,2)

    batch = [a1,a2]
    b = sanitize_batch(var, batch)
    assert b.shape == (2,1,2,2)
Esempio n. 4
0
def test_sanitize_batch_contiguity():
    a1 = AA([[1, 2], [3, 4]])
    a2 = AA([[5, 6], [7, 8]])
    var = sequence.input_variable((2, 2), is_sparse=True)

    batch = [a1.T, a2.T]
    with pytest.warns(RuntimeWarning):
        b = sanitize_batch(var, batch)
        assert b.shape == (2, 1, 2, 2)

    batch = [a1, a2]
    b = sanitize_batch(var, batch)
    assert b.shape == (2, 1, 2, 2)
Esempio n. 5
0
def test_sanitize_batch_sparse():
    batch = [csr([[1, 0, 2], [2, 3, 0]]), csr([5, 0, 1])]

    var = sequence.input_variable(3, is_sparse=True)
    b = sanitize_batch(var, batch)
    # 2 sequences, with max seq len of 2 and dimension 3
    assert b.shape == (2, 2, 3)
Esempio n. 6
0
def test_sanitize_batch_sparse():
    batch = [csr([[1,0,2],[2,3,0]]),
             csr([5,0,1])]

    var = sequence.input_variable(3, is_sparse=True)
    b = sanitize_batch(var, batch)
    # 2 sequences, with max seq len of 2 and dimension 3
    assert b.shape == (2,2,3)