class GBN(Layer): def __init__(self, dim_in, dim_h, posterior=None, conditional=None, prior=None, name='gbn', **kwargs): self.dim_in = dim_in self.dim_h = dim_h self.posterior = posterior self.conditional = conditional self.prior = prior kwargs = init_weights(self, **kwargs) kwargs = init_rngs(self, **kwargs) super(GBN, self).__init__(name=name) @staticmethod def mlp_factory(dim_h, dims, distributions, prototype=None, recognition_net=None, generation_net=None): mlps = {} if recognition_net is not None: t = recognition_net.get('type', None) if t == 'mmmlp': raise NotImplementedError() elif t == 'lfmlp': recognition_net['prototype'] = prototype input_name = recognition_net.get('input_layer') recognition_net['distribution'] = 'gaussian' recognition_net['dim_in'] = dims[input_name] recognition_net['dim_out'] = dim_h posterior = resolve_mlp(t).factory(**recognition_net) mlps['posterior'] = posterior if generation_net is not None: output_name = generation_net['output'] generation_net['dim_in'] = dim_h t = generation_net.get('type', None) if t == 'mmmlp': for out in generation_net['graph']['outputs']: generation_net['graph']['outs'][out] = dict( dim=dims[out], distribution=distributions[out]) conditional = resolve_mlp(t).factory(**generation_net) else: if t == 'lfmlp': generation_net['filter_in'] = False generation_net['prototype'] = prototype generation_net['dim_out'] = dims[output_name] generation_net['distribution'] = distributions[output_name] conditional = resolve_mlp(t).factory(**generation_net) mlps['conditional'] = conditional return mlps def set_params(self): self.params = OrderedDict() if self.prior is None: self.prior = Gaussian(self.dim_h) if self.posterior is None: self.posterior = MLP(self.dim_in, self.dim_h, dim_hs=[], rng=self.rng, trng=self.trng, h_act='T.nnet.sigmoid', distribution='gaussian') if self.conditional is None: self.conditional = MLP(self.dim_h, self.dim_in, dim_hs=[], rng=self.rng, trng=self.trng, h_act='T.nnet.sigmoid', distribution='binomial') self.posterior.name = self.name + '_posterior' self.conditional.name = self.name + '_conditional' def set_tparams(self, excludes=[]): tparams = super(GBN, self).set_tparams() tparams.update(**self.posterior.set_tparams()) tparams.update(**self.conditional.set_tparams()) tparams.update(**self.prior.set_tparams()) tparams = OrderedDict( (k, v) for k, v in tparams.iteritems() if k not in excludes) return tparams def get_params(self): params = self.prior.get_params() + self.conditional.get_params() return params def get_prior_params(self, *params): params = list(params) return params[:self.prior.n_params] # Latent sampling --------------------------------------------------------- def sample_from_prior(self, n_samples=100): h, updates = self.prior.sample(n_samples) py = self.conditional.feed(h) return self.get_center(py), updates def visualize_latents(self): h0 = self.prior.mu y0_mu, y0_logsigma = self.conditional.distribution.split_prob( self.conditional.feed(h0)) sigma = T.nlinalg.AllocDiag()(T.exp(self.prior.log_sigma)) h = 10 * sigma.astype(floatX) + h0[None, :] y_mu, y_logsigma = self.conditional.distribution.split_prob( self.conditional.feed(h)) py = y_mu - y0_mu[None, :] return py # / py.std()#T.exp(y_logsigma) # Misc -------------------------------------------------------------------- def get_center(self, p): return self.conditional.get_center(p) def p_y_given_h(self, h, *params): start = self.prior.n_params stop = start + self.conditional.n_params params = params[start:stop] return self.conditional.step_feed(h, *params) def kl_divergence(self, p, q): dim = self.dim_h mu_p = _slice(p, 0, dim) log_sigma_p = _slice(p, 1, dim) mu_q = _slice(q, 0, dim) log_sigma_q = _slice(q, 1, dim) kl = log_sigma_q - log_sigma_p + 0.5 * ( (T.exp(2 * log_sigma_p) + (mu_p - mu_q)**2) / T.exp(2 * log_sigma_q) - 1) return kl.sum(axis=kl.ndim - 1) def l2_decay(self, rate): rec_l2_cost = self.posterior.get_L2_weight_cost(rate) gen_l2_cost = self.conditional.get_L2_weight_cost(rate) rval = OrderedDict(rec_l2_cost=rec_l2_cost, gen_l2_cost=gen_l2_cost, cost=rec_l2_cost + gen_l2_cost) return rval def init_inference_samples(self, size): return self.posterior.distribution.prototype_samples(size) def __call__(self, x, y, qk=None, n_posterior_samples=10, pass_gradients=False): q0 = self.posterior.feed(x) if qk is None: qk = q0 h, updates = self.posterior.sample(qk, n_samples=n_posterior_samples) py = self.conditional.feed(h) log_py_h = -self.conditional.neg_log_prob(y[None, :, :], py) KL_qk_p = self.prior.kl_divergence(qk) qk_c = qk.copy() if pass_gradients: KL_qk_q0 = T.constant(0.).astype(floatX) else: KL_qk_q0 = self.kl_divergence(qk_c, q0) log_ph = -self.prior.neg_log_prob(h) log_qh = -self.posterior.neg_log_prob(h, qk[None, :, :]) log_p = (log_sum_exp(log_py_h + log_ph - log_qh, axis=0) - T.log(n_posterior_samples)) y_energy = -log_py_h.mean(axis=0) p_entropy = self.prior.entropy() q_entropy = self.posterior.entropy(qk) nll = -log_p lower_bound = -(y_energy + KL_qk_p).mean() if pass_gradients: cost = (y_energy + KL_qk_p).mean(0) else: cost = (y_energy + KL_qk_p + KL_qk_q0).mean(0) mu, log_sigma = self.posterior.distribution.split_prob(qk) results = OrderedDict({ 'mu': mu.mean(), 'log_sigma': log_sigma.mean(), '-log p(x|h)': y_energy.mean(0), '-log p(x)': nll.mean(0), '-log p(h)': log_ph.mean(), '-log q(h)': log_qh.mean(), 'KL(q_k||p)': KL_qk_p.mean(0), 'KL(q_k||q_0)': KL_qk_q0.mean(), 'H(p)': p_entropy, 'H(q)': q_entropy.mean(0), 'lower_bound': lower_bound, 'cost': cost }) samples = OrderedDict(py=py, batch_energies=y_energy) constants = [qk_c] return results, samples, constants