Ejemplo n.º 1
0
def get_elbo(pred,
             targ,
             weights,
             logdets,
             weight,
             dataset_size,
             prior=log_normal,
             lbda=0,
             output_type='categorical'):
    """
    negative elbo, an upper bound on NLL
    """

    logqw = -logdets
    """
    originally...
    logqw = - (0.5*(ep**2).sum(1)+0.5*T.log(2*np.pi)*num_params+logdets)
        --> constants are neglected in this wrapperfrom utils import log_laplace
    """
    logpw = prior(weights, 0., -T.log(lbda)).sum(1)
    """
    using normal prior centered at zero, with lbda being the inverse 
    of the variance
    """
    kl = (logqw - logpw).mean()
    if output_type == 'categorical':
        logpyx = -cc(pred, targ).mean()
    elif output_type == 'real':
        logpyx = -se(pred, targ).mean()  # assume output is a vector !
    else:
        assert False
    loss = -(logpyx - weight * kl / T.cast(dataset_size, floatX))

    return loss, [logpyx, logpw, logqw]
Ejemplo n.º 2
0
    def _get_elbo(self):
        """
        negative elbo, an upper bound on NLL
        """

        # TODO: kldiv_bias = tf.reduce_sum(.5 * self.pvar_bias - .5 * self.logvar_bias + ((tf.exp(self.logvar_bias) + tf.square(self.mu_bias)) / (2 * tf.exp(self.pvar_bias))) - .5)

        # eqn14
        kl_q_w_z_p = 0
        for mu, sig, z_T_f in zip(self.mus, self.sigs, self.z_T_fs):
            kl_q_w_z_p += (sig**2).sum() - T.log(
                sig**2).sum() + mu**2 * z_T_f**2  # leaving off the -1
        kl_q_w_z_p *= 0.5

        # eqn15
        self.log_r_z_T_f_W = 0
        print '\n \n eqn15'
        for mu, sig, z_T_b, c, b_mu, b_logsig in zip(
                self.mus, self.sigs, self.z_T_bs, self.cs, self.b_mus,
                self.b_logsigs
        ):  # we'll compute this seperately for every layer's W
            print 'eqn15'
            print[tt.shape for tt in [mu, sig, z_T_b, c, b_mu, b_logsig]]
            # reparametrization trick for eqn 9/10
            cTW_mu = T.dot(c, mu)
            cTW_sig = T.dot(c, sig**2)**.5
            the_scalar = T.tanh(
                cTW_mu + cTW_sig * self.srng.normal(cTW_sig.shape)).sum(
                )  # TODO: double check (does the sum belong here??)
            # scaling b by the_scalar
            mu_tilde = (b_mu * the_scalar).squeeze()
            log_sig_tilde = (b_logsig * the_scalar).squeeze()
            self.log_r_z_T_f_W += (-.5 * T.exp(log_sig_tilde) *
                                   (z_T_b - mu_tilde)**2 -
                                   .5 * T.log(2 * np.pi) +
                                   .5 * log_sig_tilde).sum()
        self.log_r_z_T_f_W += self.logdets_z_T_b

        # -eqn13
        self.kl = (
            -self.logdets + kl_q_w_z_p -
            self.log_r_z_T_f_W).sum()  # TODO: why do I need the mean/sum??

        if self.output_type == 'categorical':
            self.logpyx = -cc(self.y, self.target_var).mean()
        elif self.output_type == 'real':
            self.logpyx = -se(self.y, self.target_var).mean()
        else:
            assert False
        # FIXME: not a scalar!?
        self.loss = - (self.logpyx - \
                       self.weight * self.kl/T.cast(self.dataset_size,floatX))

        # DK - extra monitoring
        params = self.params
        ds = self.dataset_size
        self.monitored = []
    def _get_elbo(self):
        """
        negative elbo, an upper bound on NLL
        """

        logdets = self.logdets
        self.logqw = -logdets
        """
        originally...
        logqw = - (0.5*(ep**2).sum(1)+0.5*T.log(2*np.pi)*num_params+logdets)
            --> constants are neglected in this wrapperfrom utils import log_laplace
        """
        self.logpw = self.prior(self.weights, 0., -T.log(self.lbda)).sum(1)
        """
        using normal prior centered at zero, with lbda being the inverse 
        of the variance
        """
        self.kl = (self.logqw - self.logpw).mean()
        if self.output_type == 'categorical':
            self.logpyx = -cc(self.y, self.target_var).mean()
        elif self.output_type == 'real':
            self.logpyx = -se(self.y, self.target_var).mean()
        else:
            assert False
        self.loss = - (self.logpyx - \
                       self.weight * self.kl/T.cast(self.dataset_size,floatX))

        # DK - extra monitoring
        params = self.params
        ds = self.dataset_size
        self.logpyx_grad = flatten_list(
            T.grad(-self.logpyx, params, disconnected_inputs='warn')).norm(2)
        self.logpw_grad = flatten_list(
            T.grad(-self.logpw.mean() / ds, params,
                   disconnected_inputs='warn')).norm(2)
        self.logqw_grad = flatten_list(
            T.grad(self.logqw.mean() / ds, params,
                   disconnected_inputs='warn')).norm(2)
        self.monitored = [
            self.logpyx, self.logpw, self.logqw, self.logpyx_grad,
            self.logpw_grad, self.logqw_grad
        ]
    def _get_elbo(self):
        # NTS: is KL waaay too big??
        self.kl = KL(self.prior_mean, self.prior_log_var, self.mean,
                     self.log_var).sum(-1).mean()

        if self.output_type == 'categorical':
            self.logpyx = -cc(self.y, self.target_var).mean()
        elif self.output_type == 'real':
            self.logpyx = -se(self.y, self.target_var).mean()
        else:
            assert False
        self.loss = - (self.logpyx - \
                       self.weight * self.kl/T.cast(self.dataset_size,floatX))

        # DK - extra monitoring
        params = self.params
        ds = self.dataset_size
        self.logpyx_grad = flatten_list(
            T.grad(-self.logpyx, params, disconnected_inputs='warn')).norm(2)
        self.monitored = [self.logpyx, self.logpyx_grad,
                          self.kl]  #, self.target_var]
Ejemplo n.º 5
0
    def _get_elbo(self):
        """
        negative elbo, an upper bound on NLL
        """

        logdets = self.logdets
        logqw = -logdets
        """
        originally...
        logqw = - (0.5*(ep**2).sum(1)+0.5*T.log(2*np.pi)*num_params+logdets)
            --> constants are neglected in this wrapper
        """
        logpw = self.prior(self.weights, 0., -T.log(self.lbda)).sum(1)
        """
        using normal prior centered at zero, with lbda being the inverse 
        of the variance
        """
        kl = (logqw - logpw).mean()

        logpyx = -se(self.y, self.target_var).sum(1).mean()
        self.loss = -(logpyx - kl / T.cast(self.dataset_size, floatX))
        self.monitored = [self.loss]
Ejemplo n.º 6
0
    def __init__(self):

        # inpv -> input variable, i.e. full image
        # ep -> std normal noise
        # beta -> border, p(z|beta)
        # w -> annealing weight
        self.inpv = T.tensor4('inpv')
        self.ep = T.matrix('ep')
        self.beta = T.tensor4('beta')
        self.sample = T.matrix('sample')
        self.w = T.scalar('w')
        self.lr = T.scalar('lr')

        self.enc_m, self.enc_s = get_encoder1()
        self.dec, self.dec1_input, self.dec2_input = get_decoder()

        self.k = np.prod(self.enc_m.output_shape[1:])

        self.qm = get_output(self.enc_m, self.inpv)
        self.qlogs = get_output(self.enc_s, self.inpv)
        self.qlogv = 2 * self.qlogs
        self.qs = T.exp(self.qlogs)
        self.qv = T.exp(self.qlogs * 2)

        self.z = self.qm + self.qs * self.ep
        self.rec = get_output(self.dec, {
            self.dec1_input: self.z,
            self.dec2_input: self.beta
        })
        self.ancestral = get_output(self.dec, {
            self.dec1_input: self.sample,
            self.dec2_input: self.beta
        })

        self.log_px_z = -se(self.rec, self.inpv)
        self.log_pz = -0.5 * (self.qm**2 + self.qv)
        self.log_qz_x = -0.5 * (1 + self.qlogv)

        self.kls = T.sum(self.log_qz_x - self.log_pz, 1)
        self.rec_errs = T.sum(-self.log_px_z, axis=[1, 2, 3])

        self.kl = T.mean(self.kls)
        self.rec_err = T.mean(self.rec_errs)
        self.loss = self.w * self.kl + self.rec_err

        self.params = np.concatenate([
            get_all_params(ly) for ly in [self.enc_m, self.enc_s, self.dec]
        ]).tolist()

        self.grads = T.grad(self.loss, self.params)
        self.scaled_grads = tnc(self.grads, max_norm)
        self.updates = lasagne.updates.adam(self.scaled_grads, self.params,
                                            self.lr)
        self.train_func = theano.function(
            [self.inpv, self.beta, self.ep, self.w, self.lr],
            [self.loss, self.rec_err, self.kl],
            updates=self.updates)
        self.recons_func = theano.function([self.inpv, self.beta, self.ep],
                                           self.rec)
        self.sample_func = theano.function([self.beta, self.sample],
                                           self.ancestral)
    def __init__(
            self,
            srng=RandomStreams(seed=427),
            prior_mean=0,
            prior_log_var=0,
            n_hiddens=2,
            n_units=800,
            n_inputs=784,
            n_classes=10,
            output_type='categorical',
            random_biases=1,
            #dataset_size=None,
            opt='adam',
            #weight=1.,# the weight of the KL term
            **kargs):

        self.__dict__.update(locals())

        # TODO
        self.dataset_size = T.scalar('dataset_size')
        self.weight = T.scalar('weight')
        self.learning_rate = T.scalar('learning_rate')

        self.weight_shapes = []
        self.weight_shapes = []
        if n_hiddens > 0:
            self.weight_shapes.append((n_inputs, n_units))
            #self.params.append((theano.shared()))
            for i in range(1, n_hiddens):
                self.weight_shapes.append((n_units, n_units))
            self.weight_shapes.append((n_units, n_classes))
        else:
            self.weight_shapes = [(n_inputs, n_classes)]

        if self.random_biases:
            self.num_params = sum(
                (ws[0] + 1) * ws[1] for ws in self.weight_shapes)
        else:
            self.num_params = sum((ws[0]) * ws[1] for ws in self.weight_shapes)

        self.wd1 = 1
        self.X = T.matrix()
        self.y = T.matrix()
        self.mean = ts(self.num_params)
        self.log_var = ts(self.num_params, scale=1e-6, bias=-1e8)
        self.params = [self.mean, self.log_var]
        self.ep = self.srng.normal(size=(self.num_params, ), dtype=floatX)
        self.weights = self.mean + (T.exp(self.log_var) +
                                    np.float32(.000001)) * self.ep

        t = 0
        acts = self.X
        for nn, ws in enumerate(self.weight_shapes):
            if self.random_biases:
                num_param = (ws[0] + 1) * ws[1]
                weight_and_bias = self.weights[t:t + num_param]
                weight = weight_and_bias[:ws[0] * ws[1]].reshape(
                    (ws[0], ws[1]))
                bias = weight_and_bias[ws[0] * ws[1]:].reshape((ws[1], ))
                acts = T.dot(acts, weight) + bias
            else:
                assert False  # TODO
            if nn < len(self.weight_shapes) - 1:
                acts = (acts > 0.) * (acts)
            else:
                acts = T.nnet.softmax(acts)

            t += num_param

        y_hat = acts
        #y_hat = T.clip(y_hat, 0.001, 0.999) # stability
        self.y_hat = y_hat

        self.kl = KL(self.prior_mean, self.prior_log_var, self.mean,
                     self.log_var).sum(-1).mean()
        self.logpyx = -cc(self.y_hat, self.y).mean()
        self.logpyx = -se(self.y_hat, self.y).mean()
        self.loss = -(self.logpyx - self.weight * self.kl /
                      T.cast(self.dataset_size, floatX))
        self.loss = se(self.y_hat, self.y).mean()
        self.logpyx_grad = flatten_list(
            T.grad(-self.logpyx, self.params,
                   disconnected_inputs='warn')).norm(2)
        self.monitored = [self.logpyx, self.logpyx_grad, self.kl]

        #def _get_useful_funcs(self):
        self.predict_proba = theano.function([self.X], self.y_hat)
        self.predict = theano.function([self.X], self.y_hat.argmax(1))
        self.predict_fixed_mask = theano.function([self.X, self.weights],
                                                  self.y_hat)
        self.sample_weights = theano.function([], self.weights)
        self.monitor_fn = theano.function(
            [self.X, self.y], self.monitored)  #, (self.predict(x) == y).sum()

        #def _get_grads(self):
        grads = T.grad(self.loss, self.params)
        #mgrads = lasagne.updates.total_norm_constraint(grads, max_norm=self.max_norm)
        #cgrads = [T.clip(g, -self.clip_grad, self.clip_grad) for g in mgrads]
        cgrads = grads
        if self.opt == 'adam':
            self.updates = lasagne.updates.adam(
                cgrads, self.params, learning_rate=self.learning_rate)
        elif self.opt == 'momentum':
            self.updates = lasagne.updates.nesterov_momentum(
                cgrads, self.params, learning_rate=self.learning_rate)
        elif self.opt == 'sgd':
            self.updates = lasagne.updates.sgd(
                cgrads, self.params, learning_rate=self.learning_rate)

        #def _get_train_func(self):
        inputs = [
            self.X, self.y, self.dataset_size, self.learning_rate, self.weight
        ]
        train = theano.function(inputs,
                                self.loss,
                                updates=self.updates,
                                on_unused_input='warn')
        self.train_func_ = train
        # DK - putting this here, because is doesn't get overwritten by subclasses
        self.monitor_func = theano.function(
            [self.X, self.y, self.dataset_size, self.learning_rate],
            self.monitored,
            on_unused_input='warn')