Ejemplo n.º 1
0
def test_DataGenerator_workers():
    """Tests probflow.data.DataGenerator w/ multiple worker processes"""

    # Data
    x = np.random.randn(100, 3).astype('float32')
    w = np.random.randn(3, 1).astype('float32')
    y = x @ w

    # Fit a model with 1 worker
    model = pf.LinearRegression(3)
    model.fit(x, y, batch_size=10, epochs=10, num_workers=1)

    # Fit a model with 4 workers
    model = pf.LinearRegression(3)
    model.fit(x, y, batch_size=10, epochs=10, num_workers=4)
Ejemplo n.º 2
0
def test_example_linear_regression_multiple():
    """Tests example_linear_regression multiple linear regression"""

    # TODO: generate data

    # Create and fit the model
    model = pf.LinearRegression()
    model.fit(x, y)
Ejemplo n.º 3
0
def test_dumps_and_loads_before_fitting():
    model1 = pf.LinearRegression(7)
    model2 = pf.loads(model1.dumps())
    assert isinstance(model2, pf.LinearRegression)
    assert model1 is not model2
    post1 = model1.posterior_mean()
    post2 = model2.posterior_mean()
    for k in post1:
        assert isclose(post1[k], post2[k])
Ejemplo n.º 4
0
def test_dumps_and_loads_after_fitting():
    model1 = pf.LinearRegression(7)
    x, y = get_test_data(1024, 7)
    model1.fit(x, y, epochs=2)
    model2 = pf.loads(model1.dumps())
    assert isinstance(model2, pf.LinearRegression)
    assert model1 is not model2
    post1 = model1.posterior_mean()
    post2 = model2.posterior_mean()
    for k in post1:
        assert isclose(post1[k], post2[k])
Ejemplo n.º 5
0
def test_dump_and_load_before_fitting(tmpdir):
    model1 = pf.LinearRegression(7)
    fname = str(tmpdir.join("test_model.pkl"))
    model1.save(fname)
    model2 = pf.load(fname)
    assert isinstance(model2, pf.LinearRegression)
    assert model1 is not model2
    post1 = model1.posterior_mean()
    post2 = model2.posterior_mean()
    for k in post1:
        assert isclose(post1[k], post2[k])
Ejemplo n.º 6
0
def test_dump_and_load_after_fitting(tmpdir):
    model1 = pf.LinearRegression(7)
    x, y = get_test_data(1024, 7)
    model1.fit(x, y, epochs=2)
    fname = str(tmpdir.join("test_model.pkl"))
    model1.save(fname)
    model2 = pf.load(fname)
    assert isinstance(model2, pf.LinearRegression)
    assert model1 is not model2
    post1 = model1.posterior_mean()
    post2 = model2.posterior_mean()
    for k in post1:
        assert isclose(post1[k], post2[k])
Ejemplo n.º 7
0
def test_linear_regression_times():

    times = []

    for N, D, backend, eager in product(ns, ds, backends, eagers):
        pf.set_backend(backend)
        model = pf.LinearRegression(D)
        x, y = get_data(N, D)
        model.fit(x, y, epochs=2,
                  eager=eager)  # don't include compilation time
        t0 = time.time()
        model.fit(x, y, epochs=EPOCHS, eager=eager)
        t1 = time.time()
        times.append({
            "N": N,
            "D": D,
            "backend": backend,
            "eager": eager,
            "seconds": t1 - t0,
        })

    df = pd.DataFrame.from_records(times)
    print(df)
Ejemplo n.º 8
0
def test_linear_regression():
    """Test that a linear regression recovers the true parameters"""

    # Set random seed
    np.random.seed(1234)
    tf.random.set_seed(1234)

    # Generate data
    N = 1000
    D = 5
    x = np.random.randn(N, D).astype('float32')
    w = np.random.randn(D, 1)
    b = np.random.randn()
    std = np.exp(np.random.randn())
    noise = std * np.random.randn(N, 1)
    y = x @ w + b + noise
    y = y.astype('float32')

    # Create and fit model
    model = pf.LinearRegression(D)
    model.fit(x, y, batch_size=100, epochs=1000, learning_rate=1e-2)

    # Compute and check confidence intervals on the weights
    lb, ub = model.posterior_ci('weights')
    assert np.all(lb < w)
    assert np.all(ub > w)

    # Compute and check confidence intervals on the bias
    lb, ub = model.posterior_ci('bias')
    assert lb < b
    assert ub > b

    # Compute and check confidence intervals on the std
    lb, ub = model.posterior_ci('std')
    assert lb < std
    assert ub > std