def test_maybe_methods(): model = Linear(5) assert model.maybe_get_dim("nI") is None model.set_dim("nI", 4) assert model.maybe_get_dim("nI") == 4 assert model.maybe_get_ref("boo") is None assert model.maybe_get_param("W") is None model.initialize() assert model.maybe_get_param("W") is not None
def test_linear_dimensions_on_data(): X = MagicMock(shape=(5, 10), spec=numpy.ndarray) X.ndim = 2 X.dtype = "float32" y = MagicMock(shape=(8,), spec=numpy.ndarray) y.ndim = 2 y.dtype = "float32" y.max = MagicMock() model = Linear() model.initialize(X, y) assert model.get_dim("nI") is not None y.max.assert_called_with()
def get_model(W_values, b_values): model = Linear(W_values.shape[0], W_values.shape[1], ops=NumpyOps()) model.initialize() model.set_param("W", W_values) model.set_param("b", b_values) return model