예제 #1
0
 def doStep(w):
     
     # Pick random minibatch
     idx = np.random.randint(0, n_datapoints, size=(n_batch,))
     _x = ndict.getColsFromIndices(x, idx)
     
     # Evaluate likelihood and its gradient
     logpx, _, gw, _ = model.dlogpxz_dwz(w, _x, {})
     
     for i in w:
         gw[i] *= n_datapoints / n_batch
     
     # Evalute prior and its gradient
     logpw = 0
     for i in w:
         logpw -= (.5 * (w[i]**2) / (prior_sd**2)).sum()
         gw[i] -= w[i] / (prior_sd**2)
     
     for i in gw:
         #print i, np.sqrt(gw_ss[i]).max(), np.sqrt(gw_ss[i]).min()
         gw_ss[i] += lambd * (gw[i]**2 - gw_ss[i])
         if batchi[0] < warmup: continue
         w[i] += stepsize * gw[i] / np.sqrt(gw_ss[i] + 1e-8)
     
     batchi[0] += 1
     
     return logpx + logpw
예제 #2
0
 def doStep(w):
     
     LB = 0 #Lower bound
     
     idx = np.random.randint(0, n_datapoints, size=(n_minibatch,))
     _x = ndict.getColsFromIndices(x, idx)
     if not stochastic:
         _x = x
     
     # Draw sample _w from posterior q(w;eta1,eta2)
     eps = {}
     _w = {}
     for i in w:
         eps[i] = np.random.standard_normal(size=w[i].shape)
         _w[i] = w[i] + np.exp(logsd[i])*eps[i]
         LB += (0.5 + 0.5 * np.log(2 * np.pi) + logsd[i]).sum()
     
     # Compute L = log p(x,w)
     logpx, logpz, gw, gz = model.dlogpxz_dwz(_w, _x, {})
     logpw, gw2 = model.dlogpw_dw(_w)
     for i in gw: gw[i] = (float(n_datapoints) / float(n_minibatch)) * gw[i] + gw2[i]
     
     L = (logpx.sum() + logpz.sum()) * float(n_datapoints) / float(n_minibatch)
     L += logpw.sum()
     LB += L
     
     # Update params
     for i in w:
         
         # Noisy estimates g' and C'
         # l = log p(x,w)
         # w = mean + sigma * eps = - eta1/(2*eta2) - 1/(2*eta2) * eps = - (eta1+eps)/(2*eta2)
         # dw/deta1 = -1/(2*eta2)
         # dw/deta2 = (eta1 + eps)/(2*eta2^2)
         # g1hat = dl/deta1 = dl/dw dw/deta1 = gw[i] * dw/deta1
         # g2hat = dl/deta2 = dl/dw dw/deta2
         dwdeta1 = -1/(2*eta2[i])
         dwdeta2 = (eta1[i] + eps[i]) / (2*eta2[i]**2)
         g1hat = gw[i] * dwdeta1
         g2hat = gw[i] * dwdeta2
         
         # C11hat = dw/dw * dw/deta1
         # C12hat = d(w**2)/dw * dw/deta1
         # C21hat = dw/dw * dw/deta2
         # C22hat = d(w**2)/dw * dw/deta2
         C11hat = dwdeta1
         C12hat = 2 * _w[i] * dwdeta1
         C21hat = dwdeta2
         C22hat = 2 * _w[i] * dwdeta2
         
         if i == 'b0':
             #print g1['b0'][0].T, g1hat[0], g2['b0'][0].T, g2hat[0]
             #print C11['b0'][0].T, C11hat[0], C22['b0'][0].T, C22hat[0]
             #print T1[0], T2[0], logsd[i][0]
             #print iter[0], w[i][0], logsd[i][0], w[i][1], logsd[i][1], w0, L
             pass
         
         # Update running averages of g and C
         if True:
             g1[i] = (1-stepsize)*g1[i] + stepsize*g1hat
             g2[i] = (1-stepsize)*g2[i] + stepsize*g2hat
             
             C11[i] = (1-stepsize)*C11[i] + stepsize*C11hat
             C12[i] = (1-stepsize)*C12[i] + stepsize*C12hat
             C21[i] = (1-stepsize)*C21[i] + stepsize*C21hat
             C22[i] = (1-stepsize)*C22[i] + stepsize*C22hat
             
         else:
             g1[i] = (1-stepsize)*g1[i] + g1hat
             g2[i] = (1-stepsize)*g2[i] + g2hat
             
             C11[i] = (1-stepsize)*C11[i] + C11hat
             C12[i] = (1-stepsize)*C12[i] + C12hat
             C21[i] = (1-stepsize)*C21[i] + C21hat
             C22[i] = (1-stepsize)*C22[i] + C22hat
             
         if iter[0] > 0.1/stepsize:
             # Compute parameters given current g and C
             # eta = C^-1 g
             # => eta1 = det(C) * (C22[i] * g1[i] - C12[i] * g2[i])
             # => eta2 = det(C) * (-C21[i] * g1[i] + C11[i] * g2[i])
             det = 1/(C11[i] * C22[i] - C12[i] * C21[i])
             eta1[i] = det * (C22[i] * g1[i] - C12[i] * g2[i])
             eta2[i] = det * (-C21[i] * g1[i] + C11[i] * g2[i])
             
             eta2[i] = -np.abs(eta2[i])
             
             # Map natural parameters to mean and variance parameters
             w[i] = - eta1[i]/(2*eta2[i])
             logsd[i] = 0.5 * np.log( - 1/(2*eta2[i]))
             
             if np.isnan(w[i]).sum() > 0:
                 print 'w', i, np.isnan(w[i]).sum()
                 raise Exception()
             
             if np.isnan(logsd[i]).sum() > 0:
                 print 'logsd', i, np.isnan(logsd[i]).sum()
                 raise Exception()
             
             
     iter[0] += 1
     
     return LB
예제 #3
0
 def doStep(w, z=None):
     if z is not None: raise Exception()
     
     L = [0] # Lower bound
     g_mean = ndict.cloneZeros(w)
     if var == 'diag' or var == 'row_isotropic':
         g_logsd = ndict.cloneZeros(w)
     elif var == 'isotropic':
         g_logsd = {i:0 for i in w}
     
     # Loop over random datapoints
     for l in range(n_batch):
         
         # Pick random datapoint
         idx = np.random.randint(0, n_datapoints, size=(n_subbatch,))
         _x = ndict.getColsFromIndices(x, idx)
         
         # Function that adds gradients for given noise eps
         def add_grad(eps):
             # Compute noisy weights
             _w = {i: w[i] + np.exp(logsd[i]) * eps[i] for i in w}
             # Compute gradients of log p(x|theta) w.r.t. w
             logpx, logpz, g_w, g_z = model.dlogpxz_dwz(_w, _x, {})        
             for i in w:
                 cv = (_w[i] - w[i]) / np.exp(2*logsd[i])  #control variate
                 cov_mean[i] += cv_lr * (g_w[i]*cv - cov_mean[i])
                 var_mean[i] += cv_lr * (cv**2 - var_mean[i])
                 g_mean[i] += g_w[i] - cov_mean[i]/var_mean[i] * cv
                 
                 if var == 'diag' or var == 'row_isotropic':
                     grad = g_w[i] * eps[i] * np.exp(logsd[i])
                     cv = cv - 1 # this control variate (c.v.) is really similar to the c.v. for the mean!
                     cov_logsd[i] += cv_lr * (grad*cv - cov_logsd[i])
                     var_logsd[i] += cv_lr * (cv**2  - var_logsd[i])
                     g_logsd[i] += grad - cov_logsd[i]/var_logsd[i] * cv
                 elif var == 'isotropic':
                     g_logsd[i] += (g_w[i] * eps[i]).sum() * np.exp(logsd[i])
                 else: raise Exception()
                 
             L[0] += logpx.sum() + logpz.sum()
         
         # Gradients with generated noise
         eps = {i: np.random.standard_normal(size=w[i].shape) for i in w}
         if sgd: eps = {i: np.zeros(w[i].shape) for i in w}
         add_grad(eps)
         
         # Gradient with negative of noise
         if negNoise:
             for i in eps: eps[i] *= -1
             add_grad(eps)
     
     L = L[0]        
     L *= float(n_datapoints) / float(n_subbatch) / float(n_batch)
     if negNoise: L /= 2
     
     for i in w:
         c = float(n_datapoints) / (float(n_subbatch) * float(n_batch))
         if negNoise: c /= 2
         g_mean[i] *= c
         g_logsd[i] *= c
                     
         # Prior
         g_mean[i] += - w[i] / (prior_sd**2)
         g_logsd[i] += - np.exp(2 * logsd[i]) / (prior_sd**2)
         L += - (w[i]**2 + np.exp(2 * logsd[i])).sum() / (2 * prior_sd**2)
         L += - 0.5 * np.log(2 * np.pi * prior_sd**2) * float(w[i].size)
         
         # Entropy
         L += float(w[i].size) * 0.5 * math.log(2 * math.pi * np.pi)
         if var == 'diag' or var == 'row_isotropic':
             g_logsd[i] += 1 # dH(q)/d[logsd] = 1 (nice!)
             L += logsd[i].sum()
         elif var == 'isotropic':
             g_logsd[i] += float(w[i].size) # dH(q)/d[logsd] = 1 (nice!)
             L += logsd[i] * float(w[i].size)
         else: raise Exception()
         
     # Update variational parameters
     c = 1
     if not anneal:
         c = 1./ (batchi[0] + 1)
     
     # For isotropic row variance, sum gradients per row
     if var == 'row_isotropic':
         for i in w:
             g_sum = g_logsd[i].sum(axis=1).reshape(w[i].shape[0], 1)
             g_logsd[i] = np.dot(g_sum, np.ones((1, w[i].shape[1])))
     
     for i in w:
         #print i, np.sqrt(gw_ss[i]).max(), np.sqrt(gw_ss[i]).min()
         g_w_ss[i] += g_mean[i]**2
         g_logsd_ss[i] += g_logsd[i]**2
         
         mom_w[i] += (1-momw) * (g_mean[i] - mom_w[i])
         mom_logsd[i] += (1-momsd) * (g_logsd[i] - mom_logsd[i])
         
         if batchi[0] < warmup: continue
         
         w[i] += stepsize / np.sqrt(g_w_ss[i] * c + 1e-8) * mom_w[i]
         logsd[i] += stepsize / np.sqrt(g_logsd_ss[i] * c + 1e-8) * mom_logsd[i]            
         
     batchi[0] += 1
     
     #print cov_mean['b0']/var_mean['b0']
     
     return L