Exemple #1
0
def test_maybe_methods():
    model = Linear(5)
    assert model.maybe_get_dim("nI") is None
    model.set_dim("nI", 4)
    assert model.maybe_get_dim("nI") == 4
    assert model.maybe_get_ref("boo") is None
    assert model.maybe_get_param("W") is None
    model.initialize()
    assert model.maybe_get_param("W") is not None
def test_linear_dimensions_on_data():
    X = MagicMock(shape=(5, 10), spec=numpy.ndarray)
    X.ndim = 2
    X.dtype = "float32"
    y = MagicMock(shape=(8,), spec=numpy.ndarray)
    y.ndim = 2
    y.dtype = "float32"
    y.max = MagicMock()
    model = Linear()
    model.initialize(X, y)
    assert model.get_dim("nI") is not None
    y.max.assert_called_with()
Exemple #3
0
def get_model(W_values, b_values):
    model = Linear(W_values.shape[0], W_values.shape[1], ops=NumpyOps())
    model.initialize()
    model.set_param("W", W_values)
    model.set_param("b", b_values)
    return model