def get_elbo(pred, targ, weights, logdets, weight, dataset_size, prior=log_normal, lbda=0, output_type='categorical'): """ negative elbo, an upper bound on NLL """ logqw = -logdets """ originally... logqw = - (0.5*(ep**2).sum(1)+0.5*T.log(2*np.pi)*num_params+logdets) --> constants are neglected in this wrapperfrom utils import log_laplace """ logpw = prior(weights, 0., -T.log(lbda)).sum(1) """ using normal prior centered at zero, with lbda being the inverse of the variance """ kl = (logqw - logpw).mean() if output_type == 'categorical': logpyx = -cc(pred, targ).mean() elif output_type == 'real': logpyx = -se(pred, targ).mean() # assume output is a vector ! else: assert False loss = -(logpyx - weight * kl / T.cast(dataset_size, floatX)) return loss, [logpyx, logpw, logqw]
def _get_elbo(self): """ negative elbo, an upper bound on NLL """ # TODO: kldiv_bias = tf.reduce_sum(.5 * self.pvar_bias - .5 * self.logvar_bias + ((tf.exp(self.logvar_bias) + tf.square(self.mu_bias)) / (2 * tf.exp(self.pvar_bias))) - .5) # eqn14 kl_q_w_z_p = 0 for mu, sig, z_T_f in zip(self.mus, self.sigs, self.z_T_fs): kl_q_w_z_p += (sig**2).sum() - T.log( sig**2).sum() + mu**2 * z_T_f**2 # leaving off the -1 kl_q_w_z_p *= 0.5 # eqn15 self.log_r_z_T_f_W = 0 print '\n \n eqn15' for mu, sig, z_T_b, c, b_mu, b_logsig in zip( self.mus, self.sigs, self.z_T_bs, self.cs, self.b_mus, self.b_logsigs ): # we'll compute this seperately for every layer's W print 'eqn15' print[tt.shape for tt in [mu, sig, z_T_b, c, b_mu, b_logsig]] # reparametrization trick for eqn 9/10 cTW_mu = T.dot(c, mu) cTW_sig = T.dot(c, sig**2)**.5 the_scalar = T.tanh( cTW_mu + cTW_sig * self.srng.normal(cTW_sig.shape)).sum( ) # TODO: double check (does the sum belong here??) # scaling b by the_scalar mu_tilde = (b_mu * the_scalar).squeeze() log_sig_tilde = (b_logsig * the_scalar).squeeze() self.log_r_z_T_f_W += (-.5 * T.exp(log_sig_tilde) * (z_T_b - mu_tilde)**2 - .5 * T.log(2 * np.pi) + .5 * log_sig_tilde).sum() self.log_r_z_T_f_W += self.logdets_z_T_b # -eqn13 self.kl = ( -self.logdets + kl_q_w_z_p - self.log_r_z_T_f_W).sum() # TODO: why do I need the mean/sum?? if self.output_type == 'categorical': self.logpyx = -cc(self.y, self.target_var).mean() elif self.output_type == 'real': self.logpyx = -se(self.y, self.target_var).mean() else: assert False # FIXME: not a scalar!? self.loss = - (self.logpyx - \ self.weight * self.kl/T.cast(self.dataset_size,floatX)) # DK - extra monitoring params = self.params ds = self.dataset_size self.monitored = []
def _get_elbo(self): """ negative elbo, an upper bound on NLL """ logdets = self.logdets self.logqw = -logdets """ originally... logqw = - (0.5*(ep**2).sum(1)+0.5*T.log(2*np.pi)*num_params+logdets) --> constants are neglected in this wrapperfrom utils import log_laplace """ self.logpw = self.prior(self.weights, 0., -T.log(self.lbda)).sum(1) """ using normal prior centered at zero, with lbda being the inverse of the variance """ self.kl = (self.logqw - self.logpw).mean() if self.output_type == 'categorical': self.logpyx = -cc(self.y, self.target_var).mean() elif self.output_type == 'real': self.logpyx = -se(self.y, self.target_var).mean() else: assert False self.loss = - (self.logpyx - \ self.weight * self.kl/T.cast(self.dataset_size,floatX)) # DK - extra monitoring params = self.params ds = self.dataset_size self.logpyx_grad = flatten_list( T.grad(-self.logpyx, params, disconnected_inputs='warn')).norm(2) self.logpw_grad = flatten_list( T.grad(-self.logpw.mean() / ds, params, disconnected_inputs='warn')).norm(2) self.logqw_grad = flatten_list( T.grad(self.logqw.mean() / ds, params, disconnected_inputs='warn')).norm(2) self.monitored = [ self.logpyx, self.logpw, self.logqw, self.logpyx_grad, self.logpw_grad, self.logqw_grad ]
def _get_elbo(self): # NTS: is KL waaay too big?? self.kl = KL(self.prior_mean, self.prior_log_var, self.mean, self.log_var).sum(-1).mean() if self.output_type == 'categorical': self.logpyx = -cc(self.y, self.target_var).mean() elif self.output_type == 'real': self.logpyx = -se(self.y, self.target_var).mean() else: assert False self.loss = - (self.logpyx - \ self.weight * self.kl/T.cast(self.dataset_size,floatX)) # DK - extra monitoring params = self.params ds = self.dataset_size self.logpyx_grad = flatten_list( T.grad(-self.logpyx, params, disconnected_inputs='warn')).norm(2) self.monitored = [self.logpyx, self.logpyx_grad, self.kl] #, self.target_var]
def _get_elbo(self): """ negative elbo, an upper bound on NLL """ logdets = self.logdets logqw = -logdets """ originally... logqw = - (0.5*(ep**2).sum(1)+0.5*T.log(2*np.pi)*num_params+logdets) --> constants are neglected in this wrapper """ logpw = self.prior(self.weights, 0., -T.log(self.lbda)).sum(1) """ using normal prior centered at zero, with lbda being the inverse of the variance """ kl = (logqw - logpw).mean() logpyx = -se(self.y, self.target_var).sum(1).mean() self.loss = -(logpyx - kl / T.cast(self.dataset_size, floatX)) self.monitored = [self.loss]
def __init__(self): # inpv -> input variable, i.e. full image # ep -> std normal noise # beta -> border, p(z|beta) # w -> annealing weight self.inpv = T.tensor4('inpv') self.ep = T.matrix('ep') self.beta = T.tensor4('beta') self.sample = T.matrix('sample') self.w = T.scalar('w') self.lr = T.scalar('lr') self.enc_m, self.enc_s = get_encoder1() self.dec, self.dec1_input, self.dec2_input = get_decoder() self.k = np.prod(self.enc_m.output_shape[1:]) self.qm = get_output(self.enc_m, self.inpv) self.qlogs = get_output(self.enc_s, self.inpv) self.qlogv = 2 * self.qlogs self.qs = T.exp(self.qlogs) self.qv = T.exp(self.qlogs * 2) self.z = self.qm + self.qs * self.ep self.rec = get_output(self.dec, { self.dec1_input: self.z, self.dec2_input: self.beta }) self.ancestral = get_output(self.dec, { self.dec1_input: self.sample, self.dec2_input: self.beta }) self.log_px_z = -se(self.rec, self.inpv) self.log_pz = -0.5 * (self.qm**2 + self.qv) self.log_qz_x = -0.5 * (1 + self.qlogv) self.kls = T.sum(self.log_qz_x - self.log_pz, 1) self.rec_errs = T.sum(-self.log_px_z, axis=[1, 2, 3]) self.kl = T.mean(self.kls) self.rec_err = T.mean(self.rec_errs) self.loss = self.w * self.kl + self.rec_err self.params = np.concatenate([ get_all_params(ly) for ly in [self.enc_m, self.enc_s, self.dec] ]).tolist() self.grads = T.grad(self.loss, self.params) self.scaled_grads = tnc(self.grads, max_norm) self.updates = lasagne.updates.adam(self.scaled_grads, self.params, self.lr) self.train_func = theano.function( [self.inpv, self.beta, self.ep, self.w, self.lr], [self.loss, self.rec_err, self.kl], updates=self.updates) self.recons_func = theano.function([self.inpv, self.beta, self.ep], self.rec) self.sample_func = theano.function([self.beta, self.sample], self.ancestral)
def __init__( self, srng=RandomStreams(seed=427), prior_mean=0, prior_log_var=0, n_hiddens=2, n_units=800, n_inputs=784, n_classes=10, output_type='categorical', random_biases=1, #dataset_size=None, opt='adam', #weight=1.,# the weight of the KL term **kargs): self.__dict__.update(locals()) # TODO self.dataset_size = T.scalar('dataset_size') self.weight = T.scalar('weight') self.learning_rate = T.scalar('learning_rate') self.weight_shapes = [] self.weight_shapes = [] if n_hiddens > 0: self.weight_shapes.append((n_inputs, n_units)) #self.params.append((theano.shared())) for i in range(1, n_hiddens): self.weight_shapes.append((n_units, n_units)) self.weight_shapes.append((n_units, n_classes)) else: self.weight_shapes = [(n_inputs, n_classes)] if self.random_biases: self.num_params = sum( (ws[0] + 1) * ws[1] for ws in self.weight_shapes) else: self.num_params = sum((ws[0]) * ws[1] for ws in self.weight_shapes) self.wd1 = 1 self.X = T.matrix() self.y = T.matrix() self.mean = ts(self.num_params) self.log_var = ts(self.num_params, scale=1e-6, bias=-1e8) self.params = [self.mean, self.log_var] self.ep = self.srng.normal(size=(self.num_params, ), dtype=floatX) self.weights = self.mean + (T.exp(self.log_var) + np.float32(.000001)) * self.ep t = 0 acts = self.X for nn, ws in enumerate(self.weight_shapes): if self.random_biases: num_param = (ws[0] + 1) * ws[1] weight_and_bias = self.weights[t:t + num_param] weight = weight_and_bias[:ws[0] * ws[1]].reshape( (ws[0], ws[1])) bias = weight_and_bias[ws[0] * ws[1]:].reshape((ws[1], )) acts = T.dot(acts, weight) + bias else: assert False # TODO if nn < len(self.weight_shapes) - 1: acts = (acts > 0.) * (acts) else: acts = T.nnet.softmax(acts) t += num_param y_hat = acts #y_hat = T.clip(y_hat, 0.001, 0.999) # stability self.y_hat = y_hat self.kl = KL(self.prior_mean, self.prior_log_var, self.mean, self.log_var).sum(-1).mean() self.logpyx = -cc(self.y_hat, self.y).mean() self.logpyx = -se(self.y_hat, self.y).mean() self.loss = -(self.logpyx - self.weight * self.kl / T.cast(self.dataset_size, floatX)) self.loss = se(self.y_hat, self.y).mean() self.logpyx_grad = flatten_list( T.grad(-self.logpyx, self.params, disconnected_inputs='warn')).norm(2) self.monitored = [self.logpyx, self.logpyx_grad, self.kl] #def _get_useful_funcs(self): self.predict_proba = theano.function([self.X], self.y_hat) self.predict = theano.function([self.X], self.y_hat.argmax(1)) self.predict_fixed_mask = theano.function([self.X, self.weights], self.y_hat) self.sample_weights = theano.function([], self.weights) self.monitor_fn = theano.function( [self.X, self.y], self.monitored) #, (self.predict(x) == y).sum() #def _get_grads(self): grads = T.grad(self.loss, self.params) #mgrads = lasagne.updates.total_norm_constraint(grads, max_norm=self.max_norm) #cgrads = [T.clip(g, -self.clip_grad, self.clip_grad) for g in mgrads] cgrads = grads if self.opt == 'adam': self.updates = lasagne.updates.adam( cgrads, self.params, learning_rate=self.learning_rate) elif self.opt == 'momentum': self.updates = lasagne.updates.nesterov_momentum( cgrads, self.params, learning_rate=self.learning_rate) elif self.opt == 'sgd': self.updates = lasagne.updates.sgd( cgrads, self.params, learning_rate=self.learning_rate) #def _get_train_func(self): inputs = [ self.X, self.y, self.dataset_size, self.learning_rate, self.weight ] train = theano.function(inputs, self.loss, updates=self.updates, on_unused_input='warn') self.train_func_ = train # DK - putting this here, because is doesn't get overwritten by subclasses self.monitor_func = theano.function( [self.X, self.y, self.dataset_size, self.learning_rate], self.monitored, on_unused_input='warn')