Ejemplo n.º 1
0
def test_svi_multi():
    args = make_args()
    args.assignment_grad = True
    detections = generate_data(args)

    pyro.clear_param_store()
    pyro.param('noise_scale',
               torch.tensor(args.init_noise_scale),
               constraint=constraints.positive)
    pyro.param('objects_loc', torch.randn(args.max_num_objects, 1))

    # Learn object_loc via Newton and noise_scale via Adam.
    elbo = TraceEnum_ELBO(max_plate_nesting=2)
    adam = Adam({'lr': 0.1})
    newton = Newton(trust_radii={'objects_loc': 1.0})
    optim = MixedMultiOptimizer([(['noise_scale'], adam),
                                 (['objects_loc'], newton)])
    for svi_step in range(50):
        with poutine.trace(param_only=True) as param_capture:
            loss = elbo.differentiable_loss(model, guide, detections, args)
        params = {
            name: pyro.param(name).unconstrained()
            for name in param_capture.trace.nodes.keys()
        }
        optim.step(loss, params)
        logger.debug(
            'step {: >2d}, loss = {:0.6f}, noise_scale = {:0.6f}'.format(
                svi_step, loss.item(),
                pyro.param('noise_scale').item()))
Ejemplo n.º 2
0
def test_multi_optimizer_overlap_error():
    parts = [
        (["x", "y"], pyro.optim.Adam({"lr": 0.1})),
        (["y", "z"], pyro.optim.Adam({"lr": 0.01})),
    ]
    with pytest.raises(ValueError):
        MixedMultiOptimizer(parts)
Ejemplo n.º 3
0
def test_multi_optimizer_overlap_error():
    parts = [(['x', 'y'], pyro.optim.Adam({'lr': 0.1})),
             (['y', 'z'], pyro.optim.Adam({'lr': 0.01}))]
    with pytest.raises(ValueError):
        MixedMultiOptimizer(parts)
Ejemplo n.º 4
0
def test_multi_optimizer_disjoint_ok():
    parts = [(['w', 'x'], pyro.optim.Adam({'lr': 0.1})),
             (['y', 'z'], pyro.optim.Adam({'lr': 0.01}))]
    MixedMultiOptimizer(parts)
Ejemplo n.º 5
0
import torch
from torch.distributions import constraints

import pyro
import pyro.distributions as dist
import pyro.optim
import pyro.poutine as poutine
from pyro.optim.multi import MixedMultiOptimizer, Newton, PyroMultiOptimizer, TorchMultiOptimizer
from tests.common import assert_equal

FACTORIES = [
    lambda: PyroMultiOptimizer(pyro.optim.Adam({'lr': 0.05})),
    lambda: TorchMultiOptimizer(torch.optim.Adam, {'lr': 0.05}),
    lambda: Newton(trust_radii={'z': 0.2}),
    lambda: MixedMultiOptimizer(
        [(['y'], PyroMultiOptimizer(pyro.optim.Adam({'lr': 0.05}))),
         (['x', 'z'], Newton())]),
    lambda: MixedMultiOptimizer([(['y'], pyro.optim.Adam({'lr': 0.05})),
                                 (['x', 'z'], Newton())]),
]


@pytest.mark.parametrize('factory', FACTORIES)
def test_optimizers(factory):
    optim = factory()

    def model(loc, cov):
        x = pyro.param("x", torch.randn(2))
        y = pyro.param("y", torch.randn(3, 2))
        z = pyro.param("z",
                       torch.randn(4, 2).abs(),
Ejemplo n.º 6
0
def test_multi_optimizer_disjoint_ok():
    parts = [
        (["w", "x"], pyro.optim.Adam({"lr": 0.1})),
        (["y", "z"], pyro.optim.Adam({"lr": 0.01})),
    ]
    MixedMultiOptimizer(parts)
Ejemplo n.º 7
0
import pyro.optim
import pyro.poutine as poutine
from pyro.optim.multi import (
    MixedMultiOptimizer,
    Newton,
    PyroMultiOptimizer,
    TorchMultiOptimizer,
)
from tests.common import assert_equal

FACTORIES = [
    lambda: PyroMultiOptimizer(pyro.optim.Adam({"lr": 0.05})),
    lambda: TorchMultiOptimizer(torch.optim.Adam, {"lr": 0.05}),
    lambda: Newton(trust_radii={"z": 0.2}),
    lambda: MixedMultiOptimizer([
        (["y"], PyroMultiOptimizer(pyro.optim.Adam({"lr": 0.05}))),
        (["x", "z"], Newton()),
    ]),
    lambda: MixedMultiOptimizer([(["y"], pyro.optim.Adam({"lr": 0.05})),
                                 (["x", "z"], Newton())]),
]


@pytest.mark.parametrize("factory", FACTORIES)
def test_optimizers(factory):
    optim = factory()

    def model(loc, cov):
        x = pyro.param("x", torch.randn(2))
        y = pyro.param("y", torch.randn(3, 2))
        z = pyro.param("z",
                       torch.randn(4, 2).abs(),