Ejemplo n.º 1
0
    def _setup_prototype(self, *args, **kwargs):
        super()._setup_prototype(*args, **kwargs)

        self._event_dims = {}
        self.locs = PyroModule()
        self.scales = PyroModule()

        # Initialize guide params
        for name, site in self.prototype_trace.iter_stochastic_nodes():
            # Collect unconstrained event_dims, which may differ from constrained event_dims.
            with helpful_support_errors(site):
                init_loc = (biject_to(site["fn"].support).inv(
                    site["value"].detach()).detach())
            event_dim = site["fn"].event_dim + init_loc.dim(
            ) - site["value"].dim()
            self._event_dims[name] = event_dim

            # If subsampling, repeat init_value to full size.
            for frame in site["cond_indep_stack"]:
                full_size = getattr(frame, "full_size", frame.size)
                if full_size != frame.size:
                    dim = frame.dim - event_dim
                    init_loc = periodic_repeat(init_loc, full_size,
                                               dim).contiguous()
            init_scale = torch.full_like(init_loc, self._init_scale)

            deep_setattr(self.locs, name,
                         PyroParam(init_loc, constraints.real, event_dim))
            deep_setattr(
                self.scales,
                name,
                PyroParam(init_scale, self.scale_constraint, event_dim),
            )
Ejemplo n.º 2
0
 def __init__(self):
     super().__init__()
     self.x = nn.Parameter(torch.tensor(0.))
     self.m = torch.nn.Linear(2, 3)
     self.m.weight.data.fill_(1.)
     self.m.bias.data.fill_(2.)
     self.p = PyroModule()
     self.p.x = nn.Parameter(torch.tensor(3.))
Ejemplo n.º 3
0
 def __init__(self):
     super().__init__()
     self.x = nn.Parameter(torch.tensor(0.))
     self.y = PyroParam(torch.tensor(1.), constraint=constraints.positive)
     self.m = nn.Module()
     self.m.u = nn.Parameter(torch.tensor(2.0))
     self.p = PyroModule()
     self.p.v = nn.Parameter(torch.tensor(3.))
     self.p.w = PyroParam(torch.tensor(4.), constraint=constraints.positive)
Ejemplo n.º 4
0
def test_torch_serialize():
    module = PyroModule()
    module.x = PyroParam(torch.tensor(1.234), constraints.positive)
    module.y = nn.Parameter(torch.randn(3))
    assert isinstance(module.x, torch.Tensor)

    # Work around https://github.com/pytorch/pytorch/issues/27972
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=UserWarning)
        f = io.BytesIO()
        torch.save(module, f)
        pyro.clear_param_store()
        f.seek(0)
        actual = torch.load(f)

    assert_equal(actual.x, module.x)
    actual_names = {name for name, _ in actual.named_parameters()}
    expected_names = {name for name, _ in module.named_parameters()}
    assert actual_names == expected_names
Ejemplo n.º 5
0
def test_constraints(shape, constraint_):
    module = PyroModule()
    module.x = PyroParam(torch.full(shape, 1e-4), constraint_)

    assert isinstance(module.x, torch.Tensor)
    assert isinstance(module.x_unconstrained, nn.Parameter)
    assert module.x.shape == shape
    assert constraint_.check(module.x).all()

    module.x = torch.randn(shape).exp() * 1e-6
    assert isinstance(module.x_unconstrained, nn.Parameter)
    assert isinstance(module.x, torch.Tensor)
    assert module.x.shape == shape
    assert constraint_.check(module.x).all()

    assert isinstance(module.x_unconstrained, torch.Tensor)
    y = module.x_unconstrained.data.normal_()
    assert_equal(module.x.data, transform_to(constraint_)(y))
    assert constraint_.check(module.x).all()

    del module.x
    assert 'x' not in module._pyro_params
    assert not hasattr(module, 'x')
    assert not hasattr(module, 'x_unconstrained')
Ejemplo n.º 6
0
def test_cache():
    class MyModule(PyroModule):
        def forward(self):
            return [self.gather(), self.gather()]

        def gather(self):
            return {
                "a": self.a,
                "b": self.b,
                "c": self.c,
                "p.d": self.p.d,
                "p.e": self.p.e,
                "p.f": self.p.f,
            }

    module = MyModule()
    module.a = nn.Parameter(torch.tensor(0.))
    module.b = PyroParam(torch.tensor(1.), constraint=constraints.positive)
    module.c = PyroSample(dist.Normal(0, 1))
    module.p = PyroModule()
    module.p.d = nn.Parameter(torch.tensor(3.))
    module.p.e = PyroParam(torch.tensor(4.), constraint=constraints.positive)
    module.p.f = PyroSample(dist.Normal(0, 1))

    assert module._pyro_context is module.p._pyro_context

    # Check that results are cached with an invocation of .__call__().
    result1 = module()
    actual, expected = result1
    for key in ["a", "c", "p.d", "p.f"]:
        assert actual[key] is expected[key], key

    # Check that results are not cached across invocations of .__call__().
    result2 = module()
    for key in ["b", "c", "p.e", "p.f"]:
        assert result1[0] is not result2[0], key
Ejemplo n.º 7
0
def test_submodule_contains_torch_module():
    submodule = PyroModule()
    submodule.linear = nn.Linear(1, 1)
    module = PyroModule()
    module.child = submodule
Ejemplo n.º 8
0
def test_delete():
    m = PyroModule()
    m.a = PyroParam(torch.tensor(1.))
    del m.a
    m.a = PyroParam(torch.tensor(0.1))
    assert_equal(m.a.detach(), torch.tensor(0.1))
Ejemplo n.º 9
0
    def _setup_prototype(self, *args, **kwargs):
        super()._setup_prototype(*args, **kwargs)

        self.locs = PyroModule()
        self.scales = PyroModule()
        self.scale_trils = PyroModule()
        self.conds = PyroModule()
        self.deps = PyroModule()
        self._batch_shapes = {}
        self._unconstrained_event_shapes = {}
        sample_sites = OrderedDict(
            self.prototype_trace.iter_stochastic_nodes())
        self._auto_config(sample_sites, args, kwargs)

        # Collect unconstrained shapes.
        init_locs = {}
        numel = {}
        for name, site in sample_sites.items():
            with helpful_support_errors(site):
                init_loc = (biject_to(site["fn"].support).inv(
                    site["value"].detach()).detach())
            self._batch_shapes[name] = site["fn"].batch_shape
            self._unconstrained_event_shapes[name] = init_loc.shape[
                len(site["fn"].batch_shape):]
            numel[name] = init_loc.numel()
            init_locs[name] = init_loc.reshape(-1)

        # Initialize guide params.
        children = defaultdict(list)
        num_pending = {}
        for name, site in sample_sites.items():
            # Initialize location parameters.
            init_loc = init_locs[name]
            deep_setattr(self.locs, name, PyroParam(init_loc))

            # Initialize parameters of conditional distributions.
            conditional = self.conditionals[name]
            if callable(conditional):
                deep_setattr(self.conds, name, conditional)
            else:
                if conditional not in ("delta", "normal", "mvn"):
                    raise ValueError(
                        f"Unsupported conditional type: {conditional}")
                if conditional in ("normal", "mvn"):
                    init_scale = torch.full_like(init_loc, self._init_scale)
                    deep_setattr(self.scales, name,
                                 PyroParam(init_scale, self.scale_constraint))
                if conditional == "mvn":
                    init_scale_tril = eye_like(init_loc, init_loc.numel())
                    deep_setattr(
                        self.scale_trils,
                        name,
                        PyroParam(init_scale_tril, self.scale_tril_constraint),
                    )

            # Initialize dependencies on upstream variables.
            num_pending[name] = 0
            deps = PyroModule()
            deep_setattr(self.deps, name, deps)
            for upstream, dep in self.dependencies.get(name, {}).items():
                assert upstream in sample_sites
                children[upstream].append(name)
                num_pending[name] += 1
                if isinstance(dep, str) and dep == "linear":
                    dep = torch.nn.Linear(numel[upstream],
                                          numel[name],
                                          bias=False)
                    dep.weight.data.zero_()
                elif not callable(dep):
                    raise ValueError(
                        f"Expected either the string 'linear' or a callable, but got {dep}"
                    )
                deep_setattr(deps, upstream, dep)

        # Topologically sort sites.
        # TODO should we choose a more optimal structure?
        self._sorted_sites = []
        while num_pending:
            name, count = min(num_pending.items(),
                              key=lambda kv: (kv[1], kv[0]))
            assert count == 0, f"cyclic dependency: {name}"
            del num_pending[name]
            for child in children[name]:
                num_pending[child] -= 1
            site = self._compress_site(sample_sites[name])
            self._sorted_sites.append((name, site))

        # Prune non-essential parts of the trace to save memory.
        for name, site in self.prototype_trace.nodes.items():
            site.clear()
Ejemplo n.º 10
0
 def _getattr(obj, attr):
     obj_next = getattr(obj, attr, None)
     if obj_next is not None:
         return obj_next
     setattr(obj, attr, PyroModule())
     return getattr(obj, attr)
Ejemplo n.º 11
0
    def _setup_prototype(self, *args, **kwargs) -> None:
        super()._setup_prototype(*args, **kwargs)

        self.locs = PyroModule()
        self.scales = PyroModule()
        self.white_vecs = PyroModule()
        self.prec_sqrts = PyroModule()
        self._factors = OrderedDict()
        self._plates = OrderedDict()
        self._event_numel = OrderedDict()
        self._unconstrained_event_shapes = OrderedDict()

        # Trace model dependencies.
        model = self._original_model[0]
        self._original_model = None
        self.dependencies = poutine.block(get_dependencies)(model, args, kwargs)[
            "prior_dependencies"
        ]

        # Eliminate observations with no upstream latents.
        for d, upstreams in list(self.dependencies.items()):
            if all(self.prototype_trace.nodes[u]["is_observed"] for u in upstreams):
                del self.dependencies[d]
                del self.prototype_trace.nodes[d]

        # Collect factors and plates.
        for d, site in self.prototype_trace.nodes.items():
            # Prune non-essential parts of the trace to save memory.
            pruned_site, site = site, site.copy()
            pruned_site.clear()

            # Collect factors and plates.
            if site["type"] != "sample" or site_is_subsample(site):
                continue
            assert all(f.vectorized for f in site["cond_indep_stack"])
            self._factors[d] = self._compress_site(site)
            plates = frozenset(site["cond_indep_stack"])
            if site["fn"].batch_shape != _plates_to_shape(plates):
                raise ValueError(
                    f"Shape mismatch at site '{d}'. "
                    "Are you missing a pyro.plate() or .to_event()?"
                )
            if site["is_observed"]:
                # Break irrelevant observation plates.
                plates &= frozenset().union(
                    *(self._plates[u] for u in self.dependencies[d] if u != d)
                )
            self._plates[d] = plates

            # Create location-scale parameters, one per latent variable.
            if site["is_observed"]:
                # This may slightly overestimate, e.g. for Multinomial.
                self._event_numel[d] = site["fn"].event_shape.numel()
                # Account for broken irrelevant observation plates.
                for f in set(site["cond_indep_stack"]) - plates:
                    self._event_numel[d] *= f.size
                continue
            with helpful_support_errors(site):
                init_loc = biject_to(site["fn"].support).inv(site["value"]).detach()
            batch_shape = site["fn"].batch_shape
            event_shape = init_loc.shape[len(batch_shape) :]
            self._unconstrained_event_shapes[d] = event_shape
            self._event_numel[d] = event_shape.numel()
            event_dim = len(event_shape)
            deep_setattr(self.locs, d, PyroParam(init_loc, event_dim=event_dim))
            deep_setattr(
                self.scales,
                d,
                PyroParam(
                    torch.full_like(init_loc, self._init_scale),
                    constraint=self.scale_constraint,
                    event_dim=event_dim,
                ),
            )

        # Create parameters for dependencies, one per factor.
        for d, site in self._factors.items():
            u_size = 0
            for u in self.dependencies[d]:
                if not self._factors[u]["is_observed"]:
                    broken_shape = _plates_to_shape(self._plates[u] - self._plates[d])
                    u_size += broken_shape.numel() * self._event_numel[u]
            d_size = self._event_numel[d]
            if site["is_observed"]:
                d_size = min(d_size, u_size)  # just an optimization
            batch_shape = _plates_to_shape(self._plates[d])

            # Create parameters of each Gaussian factor.
            white_vec = init_loc.new_zeros(batch_shape + (d_size,))
            # We initialize with noise to avoid singular gradient.
            prec_sqrt = torch.rand(
                batch_shape + (u_size, d_size),
                dtype=init_loc.dtype,
                device=init_loc.device,
            )
            prec_sqrt.sub_(0.5).mul_(self._init_scale)
            if not site["is_observed"]:
                # Initialize the [d,d] block to the identity matrix.
                prec_sqrt.diagonal(dim1=-2, dim2=-1).fill_(1)
            deep_setattr(self.white_vecs, d, PyroParam(white_vec, event_dim=1))
            deep_setattr(self.prec_sqrts, d, PyroParam(prec_sqrt, event_dim=2))