def default_flow(self): return nux.sequential(nux.Logit(scale=None), nux.OneByOneConv(), nux.reverse_flow(nux.CouplingLogitsticMixtureLogit(n_components=8, network_kwargs=self.network_kwargs, use_condition=True)), nux.OneByOneConv(), nux.reverse_flow(nux.CouplingLogitsticMixtureLogit(n_components=8, network_kwargs=self.network_kwargs, use_condition=True)), nux.UnitGaussianPrior())
def q_ugx(self): if hasattr(self, "_qugx"): return self._qugx # Keep this simple, but a bit more complicated than p(u|z). self._qugx = nux.sequential( nux.reverse_flow( nux.LogisticMixtureLogit(n_components=8, with_affine_coupling=False, coupling=False)), nux.ParametrizedGaussianPrior(network_kwargs=self.network_kwargs, create_network=self.create_network)) return self._qugx
def default_flow(self): def block(): return nux.sequential(nux.RationalQuadraticSpline(K=8, network_kwargs=self.network_kwargs, create_network=self.create_network, use_condition=True, coupling=True, condition_method="nin"), nux.OneByOneConv()) f = nux.repeat(block, n_repeats=3) return nux.sequential(nux.reverse_flow(f), nux.ParametrizedGaussianPrior(network_kwargs=self.network_kwargs, create_network=self.create_network))