예제 #1
0
파일: test_module.py 프로젝트: xidulu/pyro
def test_constraints(shape, constraint_):
    module = PyroModule()
    module.x = PyroParam(torch.full(shape, 1e-4), constraint_)

    assert isinstance(module.x, torch.Tensor)
    assert isinstance(module.x_unconstrained, nn.Parameter)
    assert module.x.shape == shape
    assert constraint_.check(module.x).all()

    module.x = torch.randn(shape).exp() * 1e-6
    assert isinstance(module.x_unconstrained, nn.Parameter)
    assert isinstance(module.x, torch.Tensor)
    assert module.x.shape == shape
    assert constraint_.check(module.x).all()

    assert isinstance(module.x_unconstrained, torch.Tensor)
    y = module.x_unconstrained.data.normal_()
    assert_equal(module.x.data, transform_to(constraint_)(y))
    assert constraint_.check(module.x).all()

    del module.x
    assert 'x' not in module._pyro_params
    assert not hasattr(module, 'x')
    assert not hasattr(module, 'x_unconstrained')
예제 #2
0
파일: test_module.py 프로젝트: xidulu/pyro
def test_torch_serialize():
    module = PyroModule()
    module.x = PyroParam(torch.tensor(1.234), constraints.positive)
    module.y = nn.Parameter(torch.randn(3))
    assert isinstance(module.x, torch.Tensor)

    # Work around https://github.com/pytorch/pytorch/issues/27972
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=UserWarning)
        f = io.BytesIO()
        torch.save(module, f)
        pyro.clear_param_store()
        f.seek(0)
        actual = torch.load(f)

    assert_equal(actual.x, module.x)
    actual_names = {name for name, _ in actual.named_parameters()}
    expected_names = {name for name, _ in module.named_parameters()}
    assert actual_names == expected_names