Beispiel #1
0
class GBN(Layer):
    def __init__(self,
                 dim_in,
                 dim_h,
                 posterior=None,
                 conditional=None,
                 prior=None,
                 name='gbn',
                 **kwargs):

        self.dim_in = dim_in
        self.dim_h = dim_h

        self.posterior = posterior
        self.conditional = conditional
        self.prior = prior

        kwargs = init_weights(self, **kwargs)
        kwargs = init_rngs(self, **kwargs)

        super(GBN, self).__init__(name=name)

    @staticmethod
    def mlp_factory(dim_h,
                    dims,
                    distributions,
                    prototype=None,
                    recognition_net=None,
                    generation_net=None):
        mlps = {}

        if recognition_net is not None:
            t = recognition_net.get('type', None)
            if t == 'mmmlp':
                raise NotImplementedError()
            elif t == 'lfmlp':
                recognition_net['prototype'] = prototype

            input_name = recognition_net.get('input_layer')
            recognition_net['distribution'] = 'gaussian'
            recognition_net['dim_in'] = dims[input_name]
            recognition_net['dim_out'] = dim_h
            posterior = resolve_mlp(t).factory(**recognition_net)
            mlps['posterior'] = posterior

        if generation_net is not None:
            output_name = generation_net['output']
            generation_net['dim_in'] = dim_h

            t = generation_net.get('type', None)

            if t == 'mmmlp':
                for out in generation_net['graph']['outputs']:
                    generation_net['graph']['outs'][out] = dict(
                        dim=dims[out], distribution=distributions[out])
                conditional = resolve_mlp(t).factory(**generation_net)
            else:
                if t == 'lfmlp':
                    generation_net['filter_in'] = False
                    generation_net['prototype'] = prototype
                generation_net['dim_out'] = dims[output_name]
                generation_net['distribution'] = distributions[output_name]
                conditional = resolve_mlp(t).factory(**generation_net)
            mlps['conditional'] = conditional

        return mlps

    def set_params(self):
        self.params = OrderedDict()

        if self.prior is None:
            self.prior = Gaussian(self.dim_h)

        if self.posterior is None:
            self.posterior = MLP(self.dim_in,
                                 self.dim_h,
                                 dim_hs=[],
                                 rng=self.rng,
                                 trng=self.trng,
                                 h_act='T.nnet.sigmoid',
                                 distribution='gaussian')
        if self.conditional is None:
            self.conditional = MLP(self.dim_h,
                                   self.dim_in,
                                   dim_hs=[],
                                   rng=self.rng,
                                   trng=self.trng,
                                   h_act='T.nnet.sigmoid',
                                   distribution='binomial')

        self.posterior.name = self.name + '_posterior'
        self.conditional.name = self.name + '_conditional'

    def set_tparams(self, excludes=[]):
        tparams = super(GBN, self).set_tparams()
        tparams.update(**self.posterior.set_tparams())
        tparams.update(**self.conditional.set_tparams())
        tparams.update(**self.prior.set_tparams())
        tparams = OrderedDict(
            (k, v) for k, v in tparams.iteritems() if k not in excludes)

        return tparams

    def get_params(self):
        params = self.prior.get_params() + self.conditional.get_params()
        return params

    def get_prior_params(self, *params):
        params = list(params)
        return params[:self.prior.n_params]

    # Latent sampling ---------------------------------------------------------

    def sample_from_prior(self, n_samples=100):
        h, updates = self.prior.sample(n_samples)
        py = self.conditional.feed(h)
        return self.get_center(py), updates

    def visualize_latents(self):
        h0 = self.prior.mu
        y0_mu, y0_logsigma = self.conditional.distribution.split_prob(
            self.conditional.feed(h0))

        sigma = T.nlinalg.AllocDiag()(T.exp(self.prior.log_sigma))
        h = 10 * sigma.astype(floatX) + h0[None, :]
        y_mu, y_logsigma = self.conditional.distribution.split_prob(
            self.conditional.feed(h))
        py = y_mu - y0_mu[None, :]
        return py  # / py.std()#T.exp(y_logsigma)

    # Misc --------------------------------------------------------------------

    def get_center(self, p):
        return self.conditional.get_center(p)

    def p_y_given_h(self, h, *params):
        start = self.prior.n_params
        stop = start + self.conditional.n_params
        params = params[start:stop]
        return self.conditional.step_feed(h, *params)

    def kl_divergence(self, p, q):
        dim = self.dim_h
        mu_p = _slice(p, 0, dim)
        log_sigma_p = _slice(p, 1, dim)
        mu_q = _slice(q, 0, dim)
        log_sigma_q = _slice(q, 1, dim)

        kl = log_sigma_q - log_sigma_p + 0.5 * (
            (T.exp(2 * log_sigma_p) +
             (mu_p - mu_q)**2) / T.exp(2 * log_sigma_q) - 1)
        return kl.sum(axis=kl.ndim - 1)

    def l2_decay(self, rate):
        rec_l2_cost = self.posterior.get_L2_weight_cost(rate)
        gen_l2_cost = self.conditional.get_L2_weight_cost(rate)

        rval = OrderedDict(rec_l2_cost=rec_l2_cost,
                           gen_l2_cost=gen_l2_cost,
                           cost=rec_l2_cost + gen_l2_cost)

        return rval

    def init_inference_samples(self, size):
        return self.posterior.distribution.prototype_samples(size)

    def __call__(self,
                 x,
                 y,
                 qk=None,
                 n_posterior_samples=10,
                 pass_gradients=False):
        q0 = self.posterior.feed(x)

        if qk is None:
            qk = q0

        h, updates = self.posterior.sample(qk, n_samples=n_posterior_samples)
        py = self.conditional.feed(h)

        log_py_h = -self.conditional.neg_log_prob(y[None, :, :], py)
        KL_qk_p = self.prior.kl_divergence(qk)
        qk_c = qk.copy()
        if pass_gradients:
            KL_qk_q0 = T.constant(0.).astype(floatX)
        else:
            KL_qk_q0 = self.kl_divergence(qk_c, q0)

        log_ph = -self.prior.neg_log_prob(h)
        log_qh = -self.posterior.neg_log_prob(h, qk[None, :, :])

        log_p = (log_sum_exp(log_py_h + log_ph - log_qh, axis=0) -
                 T.log(n_posterior_samples))

        y_energy = -log_py_h.mean(axis=0)
        p_entropy = self.prior.entropy()
        q_entropy = self.posterior.entropy(qk)
        nll = -log_p

        lower_bound = -(y_energy + KL_qk_p).mean()
        if pass_gradients:
            cost = (y_energy + KL_qk_p).mean(0)
        else:
            cost = (y_energy + KL_qk_p + KL_qk_q0).mean(0)

        mu, log_sigma = self.posterior.distribution.split_prob(qk)

        results = OrderedDict({
            'mu': mu.mean(),
            'log_sigma': log_sigma.mean(),
            '-log p(x|h)': y_energy.mean(0),
            '-log p(x)': nll.mean(0),
            '-log p(h)': log_ph.mean(),
            '-log q(h)': log_qh.mean(),
            'KL(q_k||p)': KL_qk_p.mean(0),
            'KL(q_k||q_0)': KL_qk_q0.mean(),
            'H(p)': p_entropy,
            'H(q)': q_entropy.mean(0),
            'lower_bound': lower_bound,
            'cost': cost
        })

        samples = OrderedDict(py=py, batch_energies=y_energy)

        constants = [qk_c]
        return results, samples, constants