def test_constraints(shape, constraint_): module = PyroModule() module.x = PyroParam(torch.full(shape, 1e-4), constraint_) assert isinstance(module.x, torch.Tensor) assert isinstance(module.x_unconstrained, nn.Parameter) assert module.x.shape == shape assert constraint_.check(module.x).all() module.x = torch.randn(shape).exp() * 1e-6 assert isinstance(module.x_unconstrained, nn.Parameter) assert isinstance(module.x, torch.Tensor) assert module.x.shape == shape assert constraint_.check(module.x).all() assert isinstance(module.x_unconstrained, torch.Tensor) y = module.x_unconstrained.data.normal_() assert_equal(module.x.data, transform_to(constraint_)(y)) assert constraint_.check(module.x).all() del module.x assert 'x' not in module._pyro_params assert not hasattr(module, 'x') assert not hasattr(module, 'x_unconstrained')
def test_torch_serialize(): module = PyroModule() module.x = PyroParam(torch.tensor(1.234), constraints.positive) module.y = nn.Parameter(torch.randn(3)) assert isinstance(module.x, torch.Tensor) # Work around https://github.com/pytorch/pytorch/issues/27972 with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) f = io.BytesIO() torch.save(module, f) pyro.clear_param_store() f.seek(0) actual = torch.load(f) assert_equal(actual.x, module.x) actual_names = {name for name, _ in actual.named_parameters()} expected_names = {name for name, _ in module.named_parameters()} assert actual_names == expected_names