コード例 #1
0
ファイル: test_models.py プロジェクト: uri-granta/trieste
def test_sparse_variational_optimize_with_defaults() -> None:
    x_observed = np.linspace(0, 100, 100).reshape((-1, 1))
    y_observed = _3x_plus_gaussian_noise(x_observed)
    data = x_observed, y_observed
    dataset = Dataset(*data)
    optimizer = create_optimizer(tf.optimizers.Adam(), dict(max_iter=20))
    model = SparseVariational(svgp_model(x_observed, y_observed), optimizer=optimizer)
    loss = model.model.training_loss(data)
    model.optimize(dataset)
    assert model.model.training_loss(data) < loss
コード例 #2
0
ファイル: test_models.py プロジェクト: uri-granta/trieste
def test_sparse_variational_optimize(batcher: DatasetTransformer, compile: bool) -> None:
    x_observed = np.linspace(0, 100, 100).reshape((-1, 1))
    y_observed = _3x_plus_gaussian_noise(x_observed)
    data = x_observed, y_observed
    dataset = Dataset(*data)

    optimizer = create_optimizer(
        tf.optimizers.Adam(),
        dict(max_iter=10, batch_size=10, dataset_builder=batcher, compile=compile),
    )
    model = SparseVariational(svgp_model(x_observed, y_observed), optimizer=optimizer)
    loss = model.model.training_loss(data)
    model.optimize(dataset)
    assert model.model.training_loss(data) < loss
コード例 #3
0
ファイル: test_models.py プロジェクト: uri-granta/trieste
def test_sparse_variational_update_updates_num_data() -> None:
    model = SparseVariational(
        svgp_model(tf.zeros([1, 4]), tf.zeros([1, 1])),
    )
    model.update(Dataset(tf.zeros([5, 4]), tf.zeros([5, 1])))
    assert model.model.num_data == 5
コード例 #4
0
ファイル: test_models.py プロジェクト: uri-granta/trieste
def test_sparse_variational_update_raises_for_invalid_shapes(new_data: Dataset) -> None:
    model = SparseVariational(
        svgp_model(tf.zeros([1, 4]), tf.zeros([1, 1])),
    )
    with pytest.raises(ValueError):
        model.update(new_data)
コード例 #5
0
ファイル: test_models.py プロジェクト: uri-granta/trieste
def test_sparse_variational_model_attribute() -> None:
    model = svgp_model(*mock_data())
    sv = SparseVariational(model)
    assert sv.model is model