コード例 #1
0
ファイル: test_svgd.py プロジェクト: zeta1999/pyro
def test_shapes(shape, stein_kernel):
    pyro.clear_param_store()
    shape1, shape2 = (5, ) + shape, shape + (6, )

    mean_init1 = torch.arange(
        _product(shape1)).double().reshape(shape1) / 100.0
    mean_init2 = torch.arange(_product(shape2)).double().reshape(shape2)

    def model():
        pyro.sample("z1",
                    dist.LogNormal(mean_init1, 1.0e-8).to_event(len(shape1)))
        pyro.sample("scalar", dist.Normal(0.0, 1.0))
        pyro.sample("z2",
                    dist.Normal(mean_init2, 1.0e-8).to_event(len(shape2)))

    num_particles = 7
    svgd = SVGD(model, stein_kernel(), Adam({"lr": 0.0}), num_particles, 0)

    for step in range(2):
        svgd.step()

    particles = svgd.get_named_particles()
    assert particles['z1'].shape == (num_particles, ) + shape1
    assert particles['z2'].shape == (num_particles, ) + shape2

    for particle in range(num_particles):
        assert_equal(particles['z1'][particle, ...],
                     mean_init1.exp(),
                     prec=1.0e-6)
        assert_equal(particles['z2'][particle, ...], mean_init2, prec=1.0e-6)
コード例 #2
0
ファイル: guides.py プロジェクト: jamestwebber/pyro
    def _setup_prototype(self, *args, **kwargs):
        super(AutoContinuous, self)._setup_prototype(*args, **kwargs)
        self._unconstrained_shapes = {}
        self._cond_indep_stacks = {}
        for name, site in self.prototype_trace.iter_stochastic_nodes():
            # Collect the shapes of unconstrained values.
            # These may differ from the shapes of constrained values.
            self._unconstrained_shapes[name] = biject_to(site["fn"].support).inv(site["value"]).shape

            # Collect independence contexts.
            self._cond_indep_stacks[name] = site["cond_indep_stack"]

        self.latent_dim = sum(_product(shape) for shape in self._unconstrained_shapes.values())
        if self.latent_dim == 0:
            raise RuntimeError('{} found no latent variables; Use an empty guide instead'.format(type(self).__name__))
コード例 #3
0
ファイル: guides.py プロジェクト: jamestwebber/pyro
    def _unpack_latent(self, latent):
        """
        Unpacks a packed latent tensor, iterating over tuples of the form::

            (site, unconstrained_value)
        """
        batch_shape = latent.shape[:-1]  # for plates outside of _setup_prototype, e.g. parallel particles
        pos = 0
        for name, site in self.prototype_trace.iter_stochastic_nodes():
            constrained_shape = site["value"].shape
            unconstrained_shape = self._unconstrained_shapes[name]
            size = _product(unconstrained_shape)
            event_dim = site["fn"].event_dim + len(unconstrained_shape) - len(constrained_shape)
            unconstrained_shape = broadcast_shape(unconstrained_shape,
                                                  batch_shape + (1,) * event_dim)
            unconstrained_value = latent[..., pos:pos + size].view(unconstrained_shape)
            yield site, unconstrained_value
            pos += size
        if not torch._C._get_tracing_state():
            assert pos == latent.size(-1)