Exemple #1
0
def test_fit():
    reg = GPARRegressor(replace=False,
                        impute=False,
                        normalise_y=True,
                        transform_y=squishing_transform)
    x = np.linspace(0, 5, 10)
    y = reg.sample(x, p=2)

    # TODO: Remove this once greedy search is implemented.
    yield raises, NotImplementedError, lambda: reg.fit(x, y, greedy=True)

    # Test that data is correctly transformed if it has an output with zero
    # variance.
    reg.fit(x, y, iters=0)
    yield ok, (~B.isnan(reg.y)).numpy().all()
    y_pathological = y.copy()
    y_pathological[:, 0] = 1
    reg.fit(x, y_pathological, iters=0)
    yield ok, (~B.isnan(reg.y)).numpy().all()

    # Test transformation and normalisation of outputs.
    z = B.linspace(-1, 1, 10, dtype=torch.float64)
    z = B.stack([z, 2 * z], axis=1)
    yield allclose, reg._untransform_y(reg._transform_y(z)), z
    yield allclose, reg._unnormalise_y(reg._normalise_y(z)), z

    # Test that fitting runs without issues.
    vs = reg.vs.detach()
    yield lambda x_, y_: reg.fit(x_, y_, fix=False), x, y
    reg.vs = vs
    yield lambda x_, y_: reg.fit(x, y, fix=True), x, y
Exemple #2
0
def test_condition_and_fit():
    reg = GPARRegressor(replace=False, impute=False,
                        normalise_y=True, transform_y=squishing_transform)
    x = np.linspace(0, 5, 10)
    y = reg.sample(x, p=2)

    # Test that data is correctly normalised.
    reg.condition(x, y)
    approx(B.mean(reg.y, axis=0), B.zeros(reg.p))
    approx(B.std(reg.y, axis=0), B.ones(reg.p))

    # Test that data is correctly normalised if it has an output with zero
    # variance.
    y_pathological = y.copy()
    y_pathological[:, 0] = 1
    reg.condition(x, y_pathological)
    assert (~B.isnan(reg.y)).numpy().all()

    # Test transformation and normalisation of outputs.
    z = torch.linspace(-1, 1, 10, dtype=torch.float64)
    z = B.stack(z, 2 * z, axis=1)
    allclose(reg._untransform_y(reg._transform_y(z)), z)
    allclose(reg._unnormalise_y(reg._normalise_y(z)), z)

    # Test that fitting runs without issues.
    vs = reg.vs.copy(detach=True)
    reg.fit(x, y, fix=False)
    reg.vs = vs
    reg.fit(x, y, fix=True)

    # TODO: Remove this once greedy search is implemented.
    with pytest.raises(NotImplementedError):
        reg.fit(x, y, greedy=True)