def p(self, observed: Optional[Mapping[str, TensorOrData]] = None, n_z: Optional[int] = None) -> tk.BayesianNet: net = tk.BayesianNet(observed=observed) # sample z ~ p(z) z = net.add('z', tk.UnitNormal([1, self.config.z_dim], event_ndims=1), n_samples=n_z) x_logits = self.px_logits(z.tensor) x = net.add('x', tk.Bernoulli(logits=x_logits, event_ndims=1)) return net
def q(self, x: T.Tensor, observed: Optional[Mapping[str, TensorOrData]] = None, n_z: Optional[int] = None) -> tk.BayesianNet: net = tk.BayesianNet(observed=observed) hx = self.hx_for_qz(T.cast(x, dtype=T.float32)) z_mean = self.qz_mean(hx) z_logstd = self.qz_logstd(hx) z = net.add('z', tk.Normal(mean=z_mean, logstd=z_logstd, event_ndims=1), n_samples=n_z) return net
def q(self, x: T.Tensor, observed: Optional[Mapping[str, TensorOrData]] = None, n_z: Optional[int] = None) -> tk.BayesianNet: net = tk.BayesianNet(observed=observed) hx = self.hx_for_qz(T.cast(x, dtype=T.float32)) z_mean = self.qz_mean(hx) z_logstd = self.qz_logstd(hx) z_logstd = T.maybe_clip(z_logstd, min_val=self.config.qz_logstd_min) qz = tk.FlowDistribution( tk.Normal(mean=z_mean, logstd=z_logstd, event_ndims=1), self.posterior_flow, ) z = net.add('z', qz, n_samples=n_z) return net