def doStep(w): # Pick random minibatch idx = np.random.randint(0, n_datapoints, size=(n_batch,)) _x = ndict.getColsFromIndices(x, idx) # Evaluate likelihood and its gradient logpx, _, gw, _ = model.dlogpxz_dwz(w, _x, {}) for i in w: gw[i] *= n_datapoints / n_batch # Evalute prior and its gradient logpw = 0 for i in w: logpw -= (.5 * (w[i]**2) / (prior_sd**2)).sum() gw[i] -= w[i] / (prior_sd**2) for i in gw: #print i, np.sqrt(gw_ss[i]).max(), np.sqrt(gw_ss[i]).min() gw_ss[i] += lambd * (gw[i]**2 - gw_ss[i]) if batchi[0] < warmup: continue w[i] += stepsize * gw[i] / np.sqrt(gw_ss[i] + 1e-8) batchi[0] += 1 return logpx + logpw
def doStep(w): LB = 0 #Lower bound idx = np.random.randint(0, n_datapoints, size=(n_minibatch,)) _x = ndict.getColsFromIndices(x, idx) if not stochastic: _x = x # Draw sample _w from posterior q(w;eta1,eta2) eps = {} _w = {} for i in w: eps[i] = np.random.standard_normal(size=w[i].shape) _w[i] = w[i] + np.exp(logsd[i])*eps[i] LB += (0.5 + 0.5 * np.log(2 * np.pi) + logsd[i]).sum() # Compute L = log p(x,w) logpx, logpz, gw, gz = model.dlogpxz_dwz(_w, _x, {}) logpw, gw2 = model.dlogpw_dw(_w) for i in gw: gw[i] = (float(n_datapoints) / float(n_minibatch)) * gw[i] + gw2[i] L = (logpx.sum() + logpz.sum()) * float(n_datapoints) / float(n_minibatch) L += logpw.sum() LB += L # Update params for i in w: # Noisy estimates g' and C' # l = log p(x,w) # w = mean + sigma * eps = - eta1/(2*eta2) - 1/(2*eta2) * eps = - (eta1+eps)/(2*eta2) # dw/deta1 = -1/(2*eta2) # dw/deta2 = (eta1 + eps)/(2*eta2^2) # g1hat = dl/deta1 = dl/dw dw/deta1 = gw[i] * dw/deta1 # g2hat = dl/deta2 = dl/dw dw/deta2 dwdeta1 = -1/(2*eta2[i]) dwdeta2 = (eta1[i] + eps[i]) / (2*eta2[i]**2) g1hat = gw[i] * dwdeta1 g2hat = gw[i] * dwdeta2 # C11hat = dw/dw * dw/deta1 # C12hat = d(w**2)/dw * dw/deta1 # C21hat = dw/dw * dw/deta2 # C22hat = d(w**2)/dw * dw/deta2 C11hat = dwdeta1 C12hat = 2 * _w[i] * dwdeta1 C21hat = dwdeta2 C22hat = 2 * _w[i] * dwdeta2 if i == 'b0': #print g1['b0'][0].T, g1hat[0], g2['b0'][0].T, g2hat[0] #print C11['b0'][0].T, C11hat[0], C22['b0'][0].T, C22hat[0] #print T1[0], T2[0], logsd[i][0] #print iter[0], w[i][0], logsd[i][0], w[i][1], logsd[i][1], w0, L pass # Update running averages of g and C if True: g1[i] = (1-stepsize)*g1[i] + stepsize*g1hat g2[i] = (1-stepsize)*g2[i] + stepsize*g2hat C11[i] = (1-stepsize)*C11[i] + stepsize*C11hat C12[i] = (1-stepsize)*C12[i] + stepsize*C12hat C21[i] = (1-stepsize)*C21[i] + stepsize*C21hat C22[i] = (1-stepsize)*C22[i] + stepsize*C22hat else: g1[i] = (1-stepsize)*g1[i] + g1hat g2[i] = (1-stepsize)*g2[i] + g2hat C11[i] = (1-stepsize)*C11[i] + C11hat C12[i] = (1-stepsize)*C12[i] + C12hat C21[i] = (1-stepsize)*C21[i] + C21hat C22[i] = (1-stepsize)*C22[i] + C22hat if iter[0] > 0.1/stepsize: # Compute parameters given current g and C # eta = C^-1 g # => eta1 = det(C) * (C22[i] * g1[i] - C12[i] * g2[i]) # => eta2 = det(C) * (-C21[i] * g1[i] + C11[i] * g2[i]) det = 1/(C11[i] * C22[i] - C12[i] * C21[i]) eta1[i] = det * (C22[i] * g1[i] - C12[i] * g2[i]) eta2[i] = det * (-C21[i] * g1[i] + C11[i] * g2[i]) eta2[i] = -np.abs(eta2[i]) # Map natural parameters to mean and variance parameters w[i] = - eta1[i]/(2*eta2[i]) logsd[i] = 0.5 * np.log( - 1/(2*eta2[i])) if np.isnan(w[i]).sum() > 0: print 'w', i, np.isnan(w[i]).sum() raise Exception() if np.isnan(logsd[i]).sum() > 0: print 'logsd', i, np.isnan(logsd[i]).sum() raise Exception() iter[0] += 1 return LB
def doStep(w, z=None): if z is not None: raise Exception() L = [0] # Lower bound g_mean = ndict.cloneZeros(w) if var == 'diag' or var == 'row_isotropic': g_logsd = ndict.cloneZeros(w) elif var == 'isotropic': g_logsd = {i:0 for i in w} # Loop over random datapoints for l in range(n_batch): # Pick random datapoint idx = np.random.randint(0, n_datapoints, size=(n_subbatch,)) _x = ndict.getColsFromIndices(x, idx) # Function that adds gradients for given noise eps def add_grad(eps): # Compute noisy weights _w = {i: w[i] + np.exp(logsd[i]) * eps[i] for i in w} # Compute gradients of log p(x|theta) w.r.t. w logpx, logpz, g_w, g_z = model.dlogpxz_dwz(_w, _x, {}) for i in w: cv = (_w[i] - w[i]) / np.exp(2*logsd[i]) #control variate cov_mean[i] += cv_lr * (g_w[i]*cv - cov_mean[i]) var_mean[i] += cv_lr * (cv**2 - var_mean[i]) g_mean[i] += g_w[i] - cov_mean[i]/var_mean[i] * cv if var == 'diag' or var == 'row_isotropic': grad = g_w[i] * eps[i] * np.exp(logsd[i]) cv = cv - 1 # this control variate (c.v.) is really similar to the c.v. for the mean! cov_logsd[i] += cv_lr * (grad*cv - cov_logsd[i]) var_logsd[i] += cv_lr * (cv**2 - var_logsd[i]) g_logsd[i] += grad - cov_logsd[i]/var_logsd[i] * cv elif var == 'isotropic': g_logsd[i] += (g_w[i] * eps[i]).sum() * np.exp(logsd[i]) else: raise Exception() L[0] += logpx.sum() + logpz.sum() # Gradients with generated noise eps = {i: np.random.standard_normal(size=w[i].shape) for i in w} if sgd: eps = {i: np.zeros(w[i].shape) for i in w} add_grad(eps) # Gradient with negative of noise if negNoise: for i in eps: eps[i] *= -1 add_grad(eps) L = L[0] L *= float(n_datapoints) / float(n_subbatch) / float(n_batch) if negNoise: L /= 2 for i in w: c = float(n_datapoints) / (float(n_subbatch) * float(n_batch)) if negNoise: c /= 2 g_mean[i] *= c g_logsd[i] *= c # Prior g_mean[i] += - w[i] / (prior_sd**2) g_logsd[i] += - np.exp(2 * logsd[i]) / (prior_sd**2) L += - (w[i]**2 + np.exp(2 * logsd[i])).sum() / (2 * prior_sd**2) L += - 0.5 * np.log(2 * np.pi * prior_sd**2) * float(w[i].size) # Entropy L += float(w[i].size) * 0.5 * math.log(2 * math.pi * np.pi) if var == 'diag' or var == 'row_isotropic': g_logsd[i] += 1 # dH(q)/d[logsd] = 1 (nice!) L += logsd[i].sum() elif var == 'isotropic': g_logsd[i] += float(w[i].size) # dH(q)/d[logsd] = 1 (nice!) L += logsd[i] * float(w[i].size) else: raise Exception() # Update variational parameters c = 1 if not anneal: c = 1./ (batchi[0] + 1) # For isotropic row variance, sum gradients per row if var == 'row_isotropic': for i in w: g_sum = g_logsd[i].sum(axis=1).reshape(w[i].shape[0], 1) g_logsd[i] = np.dot(g_sum, np.ones((1, w[i].shape[1]))) for i in w: #print i, np.sqrt(gw_ss[i]).max(), np.sqrt(gw_ss[i]).min() g_w_ss[i] += g_mean[i]**2 g_logsd_ss[i] += g_logsd[i]**2 mom_w[i] += (1-momw) * (g_mean[i] - mom_w[i]) mom_logsd[i] += (1-momsd) * (g_logsd[i] - mom_logsd[i]) if batchi[0] < warmup: continue w[i] += stepsize / np.sqrt(g_w_ss[i] * c + 1e-8) * mom_w[i] logsd[i] += stepsize / np.sqrt(g_logsd_ss[i] * c + 1e-8) * mom_logsd[i] batchi[0] += 1 #print cov_mean['b0']/var_mean['b0'] return L