def test_shapes(shape, stein_kernel): pyro.clear_param_store() shape1, shape2 = (5, ) + shape, shape + (6, ) mean_init1 = torch.arange( _product(shape1)).double().reshape(shape1) / 100.0 mean_init2 = torch.arange(_product(shape2)).double().reshape(shape2) def model(): pyro.sample("z1", dist.LogNormal(mean_init1, 1.0e-8).to_event(len(shape1))) pyro.sample("scalar", dist.Normal(0.0, 1.0)) pyro.sample("z2", dist.Normal(mean_init2, 1.0e-8).to_event(len(shape2))) num_particles = 7 svgd = SVGD(model, stein_kernel(), Adam({"lr": 0.0}), num_particles, 0) for step in range(2): svgd.step() particles = svgd.get_named_particles() assert particles['z1'].shape == (num_particles, ) + shape1 assert particles['z2'].shape == (num_particles, ) + shape2 for particle in range(num_particles): assert_equal(particles['z1'][particle, ...], mean_init1.exp(), prec=1.0e-6) assert_equal(particles['z2'][particle, ...], mean_init2, prec=1.0e-6)
def _setup_prototype(self, *args, **kwargs): super(AutoContinuous, self)._setup_prototype(*args, **kwargs) self._unconstrained_shapes = {} self._cond_indep_stacks = {} for name, site in self.prototype_trace.iter_stochastic_nodes(): # Collect the shapes of unconstrained values. # These may differ from the shapes of constrained values. self._unconstrained_shapes[name] = biject_to(site["fn"].support).inv(site["value"]).shape # Collect independence contexts. self._cond_indep_stacks[name] = site["cond_indep_stack"] self.latent_dim = sum(_product(shape) for shape in self._unconstrained_shapes.values()) if self.latent_dim == 0: raise RuntimeError('{} found no latent variables; Use an empty guide instead'.format(type(self).__name__))
def _unpack_latent(self, latent): """ Unpacks a packed latent tensor, iterating over tuples of the form:: (site, unconstrained_value) """ batch_shape = latent.shape[:-1] # for plates outside of _setup_prototype, e.g. parallel particles pos = 0 for name, site in self.prototype_trace.iter_stochastic_nodes(): constrained_shape = site["value"].shape unconstrained_shape = self._unconstrained_shapes[name] size = _product(unconstrained_shape) event_dim = site["fn"].event_dim + len(unconstrained_shape) - len(constrained_shape) unconstrained_shape = broadcast_shape(unconstrained_shape, batch_shape + (1,) * event_dim) unconstrained_value = latent[..., pos:pos + size].view(unconstrained_shape) yield site, unconstrained_value pos += size if not torch._C._get_tracing_state(): assert pos == latent.size(-1)