def partial_EM(self, data, cond_muh_ijk, indices, weights=None, eps=1e-4, maxiter=10, verbose=0): (i, j, k) = indices converged = False previous_L = utilities.average( self.likelihood(data), weights=weights) / self.N mini_epochs = 0 if verbose: print('Partial EM %s, L = %.3f' % (mini_epochs, previous_L)) while not converged: if self.nature in ['Bernoulli', 'Spin']: f = np.dot(data, self.weights[[i, j, k], :].T) elif self.nature == 'Potts': f = cy_utilities.compute_output_C(data, self.weights[[i, j, k], :, :], np.zeros([ data.shape[0], 3], dtype=curr_float)) tmp = f - self.logZ[np.newaxis, [i, j, k]] tmp -= tmp.max(-1)[:, np.newaxis] cond_muh = np.exp(tmp) * self.muh[np.newaxis, [i, j, k]] cond_muh /= cond_muh.sum(-1)[:, np.newaxis] cond_muh *= cond_muh_ijk[:, np.newaxis] self.muh[[i, j, k]] = utilities.average(cond_muh, weights=weights) self.cum_muh = np.cumsum(self.muh) self.gh[[i, j, k]] = np.log(self.muh[[i, j, k]]) self.gh -= self.gh.mean() if self.nature == 'Bernoulli': self.cond_muv[[i, j, k]] = utilities.average_product( cond_muh, data, mean1=True, weights=weights) / self.muh[[i, j, k], np.newaxis] self.weights[[i, j, k]] = np.log( (self.cond_muv[[i, j, k]] + eps) / (1 - self.cond_muv[[i, j, k]] + eps)) self.logZ[[i, j, k]] = np.logaddexp( 0, self.weights[[i, j, k]]).sum(-1) elif self.nature == 'Spin': self.cond_muv[[i, j, k]] = utilities.average_product( cond_muh, data, mean1=True, weights=weights) / self.muh[[i, j, k], np.newaxis] self.weights[[i, j, k]] = 0.5 * np.log( (1 + self.cond_muv[[i, j, k]] + eps) / (1 - self.cond_muv[[i, j, k]] + eps)) self.logZ[[i, j, k]] = np.logaddexp( self.weights[[i, j, k]], -self.weights[[i, j, k]]).sum(-1) elif self.nature == 'Potts': self.cond_muv[[i, j, k]] = utilities.average_product( cond_muh, data, c2=self.n_c, mean1=True, weights=weights) / self.muh[[i, j, k], np.newaxis, np.newaxis] self.cum_cond_muv[[i, j, k]] = np.cumsum( self.cond_muv[[i, j, k]], axis=-1) self.weights[[i, j, k]] = np.log( self.cond_muv[[i, j, k]] + eps) self.weights[[i, j, k]] -= self.weights[[i, j, k] ].mean(-1)[:, :, np.newaxis] self.logZ[[i, j, k]] = utilities.logsumexp( self.weights[[i, j, k]], axis=-1).sum(-1) current_L = utilities.average( self.likelihood(data), weights=weights) / self.N mini_epochs += 1 converged = (mini_epochs >= maxiter) | ( np.abs(current_L - previous_L) < eps) previous_L = current_L.copy() if verbose: print('Partial EM %s, L = %.3f' % (mini_epochs, current_L)) return current_L
def couplings_gradients_h(W, X1_p, X1_n, X2_p, X2_n, n_c1, n_c2, l1=0, l1b=0, l1c=0, l2=0, l1_custom=None, l1b_custom=None, weights=None, weights_neg=None): update = utilities.average_product(X1_p, X2_p, c1=n_c1, c2=n_c2, mean1=True, mean2=False, weights=weights) - \ utilities.average_product( X1_n, X2_n, c1=n_c1, c2=n_c2, mean1=False, mean2=True, weights=weights_neg) if l2 > 0: update -= l2 * W if l1 > 0: update -= l1 * np.sign(W) if l1b > 0: # NOT SUPPORTED FOR POTTS if n_c2 > 1: # Potts RBM. update -= l1b * \ np.sign(W) * \ np.abs(W).mean(-1).mean(-1)[:, np.newaxis, np.newaxis] else: update -= l1b * np.sign(W) * (np.abs(W).sum(1))[:, np.newaxis] if l1c > 0: # NOT SUPPORTED FOR POTTS update -= l1c * np.sign(W) * ((np.abs(W).sum(1))**2)[:, np.newaxis] if l1_custom is not None: update -= l1_custom * np.sign(W) if l1b_custom is not None: update -= l1b_custom[0] * (l1b_custom[1] * np.abs(W)).mean(-1).mean(-1)[ :, np.newaxis, np.newaxis] * np.sign(W) if weights is not None: update *= weights.sum() / X1_p.shape[0] return update
def minibatch_fit_symKL(self, data_PGM, PGM=None, data_MOI=None, F_PGM_dPGM=None, F_PGM_dMOI=None, F_MOI_dPGM=None, F_MOI_dMOI=None, cond_muh_dPGM=None, cond_muh_dMOI=None, weights=None): if data_MOI is None: data_MOI, _ = self.gen_data(data_PGM.shape[0]) if F_PGM_dPGM is None: F_PGM_dPGM = PGM.free_energy(data_PGM) if F_PGM_dMOI is None: F_PGM_dMOI = PGM.free_energy(data_MOI) if (F_MOI_dPGM is None) | (cond_muh_dPGM is None): F_MOI_dPGM, cond_muh_dPGM = self.likelihood_and_expectation( data_PGM) F_MOI_dPGM *= -1 if (F_MOI_dMOI is None) | (cond_muh_dMOI is None): F_MOI_dMOI, cond_muh_dMOI = self.likelihood_and_expectation( data_MOI) F_MOI_dMOI *= -1 delta_lik = -F_PGM_dMOI + F_MOI_dMOI delta_lik -= delta_lik.mean() self.gradient = {} self.gradient['gh'] = utilities.average( cond_muh_dPGM, weights=weights) - self.muh + (delta_lik[:, np.newaxis] * cond_muh_dMOI).mean(0) if self.nature in ['Bernoulli', 'Spin']: self.gradient['weights'] = utilities.average_product( cond_muh_dPGM, data_PGM, mean1=True, weights=weights) + utilities.average_product(cond_muh_dMOI * delta_lik[:, np.newaxis], data_MOI, mean1=True) self.gradient['weights'] -= self.muh[:, np.newaxis] * self.cond_muv elif self.nature == 'Potts': self.gradient['weights'] = utilities.average_product(cond_muh_dPGM, data_PGM, mean1=True, c2=self.n_c, weights=weights) + utilities.average_product( cond_muh_dMOI * delta_lik[:, np.newaxis], data_MOI, mean1=True, c2=self.n_c) self.gradient['weights'] -= self.muh[:, np.newaxis, np.newaxis] * self.cond_muv self.gh += self.learning_rate * self.gradient['gh'] self.weights += self.learning_rate * self.gradient['weights'] self.muh = np.exp(self.gh) self.muh /= self.muh.sum() self.cum_muh = np.cumsum(self.muh) if self.nature == 'Bernoulli': self.cond_muv = utilities.logistic(self.weights) elif self.nature == 'Spin': self.cond_muv = np.tanh(self.weights) elif self.nature == 'Potts': self.weights -= self.weights.mean(-1)[:, :, np.newaxis] self.cond_muv = np.exp(self.weights) self.cond_muv /= self.cond_muv.sum(-1)[:, :, np.newaxis] self.cum_cond_muv = np.cumsum(self.cond_muv, axis=-1) self.logpartition()
def split_merge_criterion(self, data, Cmax=5, weights=None): likelihood, cond_muh = self.likelihood_and_expectation(data) norm = np.sqrt(utilities.average(cond_muh**2, weights=weights)) J_merge = utilities.average_product( cond_muh, cond_muh, weights=weights) / (1e-10 + norm[np.newaxis, :] * norm[:, np.newaxis]) J_merge = np.triu(J_merge, 1) proposed_merge = np.argsort(J_merge.flatten())[::-1][:Cmax] proposed_merge = [(merge % self.M, merge // self.M) for merge in proposed_merge] tmp = cond_muh / self.muh[np.newaxis, :] if weights is None: J_split = np.array( [utilities.average(likelihood, weights=tmp[:, m]) for m in range(self.M)]) else: J_split = np.array([utilities.average( likelihood, weights=tmp[:, m] * weights) for m in range(self.M)]) proposed_split = np.argsort(J_split)[:3] proposed_merge_split = [] for merge1, merge2 in proposed_merge: if proposed_split[0] in [merge1, merge2]: if proposed_split[1] in [merge1, merge2]: proposed_merge_split.append( (merge1, merge2, proposed_split[2])) else: proposed_merge_split.append( (merge1, merge2, proposed_split[1])) else: proposed_merge_split.append( (merge1, merge2, proposed_split[0])) return proposed_merge_split
def maximization(self, data, cond_muh, weights=None, eps=1e-6): self.muh = utilities.average(cond_muh, weights=weights) self.cum_muh = np.cumsum(self.muh) self.gh = np.log(self.muh) self.gh -= self.gh.mean() if self.nature == 'Bernoulli': self.cond_muv = utilities.average_product( cond_muh, data, mean1=True, weights=weights) / self.muh[:, np.newaxis] self.weights = np.log((self.cond_muv + eps) / (1 - self.cond_muv + eps)) elif self.nature == 'Spin': self.cond_muv = utilities.average_product( cond_muh, data, mean1=True, weights=weights) / self.muh[:, np.newaxis] self.weights = 0.5 * \ np.log((1 + self.cond_muv + eps) / (1 - self.cond_muv + eps)) elif self.nature == 'Potts': self.cond_muv = utilities.average_product( cond_muh, data, c2=self.n_c, mean1=True, weights=weights) / self.muh[:, np.newaxis, np.newaxis] self.cum_cond_muv = np.cumsum(self.cond_muv, axis=-1) self.weights = np.log(self.cond_muv + eps) self.weights -= self.weights.mean(-1)[:, :, np.newaxis] self.logpartition()
def minibatch_fit(self, data, weights=None, eps=1e-5, update=True): h = self.expectation(data) self.muh = self.learning_rate * \ utilities.average(h, weights=weights) + \ (1 - self.learning_rate) * self.muh self.cum_muh = np.cumsum(self.muh) if update: self.gh = np.log(self.muh + eps) self.gh -= self.gh.mean() if self.nature == 'Bernoulli': self.muvh = self.learning_rate * \ utilities.average_product( h, data, weights=weights) + (1 - self.learning_rate) * self.muvh if update: self.cond_muv = self.muvh / (self.muh[:, np.newaxis]) self.weights = np.log( (self.cond_muv + eps) / (1 - self.cond_muv + eps)) elif self.nature == 'Spin': self.muvh = self.learning_rate * \ utilities.average(h, data, weights=weights) + \ (1 - self.learning_rate) * self.muvh if update: self.cond_muv = self.muvh / self.muh[:, np.newaxis] self.weights = 0.5 * \ np.log((1 + self.cond_muv + eps) / (1 - self.cond_muv + eps)) else: self.muvh = self.learning_rate * utilities.average_product( h, data, c2=self.n_c, weights=weights) + (1 - self.learning_rate) * self.muvh if update: self.cond_muv = self.muvh / self.muh[:, np.newaxis, np.newaxis] self.weights = np.log(self.cond_muv + eps) self.weights -= self.weights.mean(-1)[:, :, np.newaxis] if update: self.logpartition()
def calculate_error(RBM, data_tr, N_sequences=800000, Nstep=10, background=None): N = RBM.n_v q = RBM.n_cv # Check how moments are reproduced # Means mudata = RBM.mu_data # empirical averages #datav, datah = RBM.gen_data(Nchains = int(100), Lchains = int(N_sequences/100), Nthermalize=int(500), background= background) datav, datah = RBM.gen_data(Nchains=int(100), Lchains=int(N_sequences / 100), Nthermalize=int(500)) mugen = utilities.average(datav, c=q, weights=None) # Correlations covgen = utilities.average_product( datav, datav, c1=q, c2=q) - mugen[:, np.newaxis, :, np.newaxis] * mugen[np.newaxis, :, np.newaxis, :] covdata = utilities.average_product( data_tr, data_tr, c1=q, c2=q) - mudata[:, np.newaxis, :, np.newaxis] * mudata[np.newaxis, :, np.newaxis, :] fdata = utilities.average_product(data_tr, data_tr, c1=q, c2=q) #put to zero the diagonal elements of the covariance for i in range(N): covdata[i, i, :, :] = np.zeros((q, q)) fdata[i, i, :, :] = np.zeros((q, q)) covgen[i, i, :, :] = np.zeros((q, q)) M = len(data_tr) maxp = float(1) / float(M) pp = 1 ps = 0.00001 # pseudocount for fully conserved sites # error on frequency pp = 1 errm = 0 neffm = 0 for i in range(N): for a in range(q): neffm += 1 if mudata[i, a] < maxp: errm += np.power((mugen[i, a] - mudata[i, a]), 2) / (float(1 - maxp) * float(maxp)) else: if mudata[i, a] != 1.0: errm += np.power( (mugen[i, a] - mudata[i, a]), 2) / (float(1 - mudata[i, a]) * float(mudata[i, a])) else: errm += np.power((mugen[i, a] - mudata[i, a]), 2) / ( float(1 - mudata[i, a] - ps) * float(mudata[i, a])) errmt = np.sqrt(float(1) / (float(neffm) * float(maxp)) * float(errm)) # rigourously, there would be also the regularization term in the difference errm! # error on correlations errc = 0 neffc = 0 for i in range(N): for j in range(i + 1, N): for a in range(q): for b in range(a + 1, q): neffc += 1 if covdata[i, j, a, b] < maxp: den = np.power( np.sqrt(float(1 - maxp) * float(maxp)) + mudata[i, a] * np.sqrt(mudata[j, b] * (1 - mudata[j, b])) + mudata[j, b] * np.sqrt(mudata[i, a] * (1 - mudata[i, a])), 2) errc += np.power( (covgen[i, j, a, b] - covdata[i, j, a, b]), 2) / float(den) else: den = np.power( np.sqrt( float(1 - fdata[i, j, a, b]) * float(fdata[i, j, a, b])) + mudata[i, a] * np.sqrt(mudata[j, b] * (1 - mudata[j, b])) + mudata[j, b] * np.sqrt(mudata[i, a] * (1 - mudata[i, a])), 2) errc += np.power( (covgen[i, j, a, b] - covdata[i, j, a, b]), 2) / float(den) errct = np.sqrt(float(1) / (float(neffc) * float(maxp)) * float(errc)) return (errmt, errct)
def assess_moment_matching(RBM, data, data_gen, datah_gen=None, weights=None, weights_neg=None, with_reg=True, show=True): h_data = RBM.mean_hiddens(data) if datah_gen is not None: h_gen = datah_gen else: h_gen = RBM.mean_hiddens(data_gen) mu = utilities.average(data, c=RBM.n_cv, weights=weights) if datah_gen is not None: condmu_gen = RBM.mean_visibles(datah_gen) mu_gen = utilities.average(condmu_gen, weights=weights_neg) else: mu_gen = utilities.average(data_gen, c=RBM.n_cv, weights=weights_neg) mu_h = utilities.average(h_data, weights=weights) mu_h_gen = utilities.average(h_gen, weights=weights_neg) if RBM.n_cv > 1: cov_vh = utilities.average_product( h_data, data, c2=RBM.n_cv, weights=weights ) - mu[np.newaxis, :, :] * mu_h[:, np.newaxis, np.newaxis] else: cov_vh = utilities.average_product( h_data, data, c2=RBM.n_cv, weights=weights) - mu[np.newaxis, :] * mu_h[:, np.newaxis] if datah_gen is not None: if RBM.n_cv > 1: cov_vh_gen = utilities.average_product( datah_gen, condmu_gen, mean2=True, c2=RBM.n_cv, weights=weights_neg ) - mu_gen[np.newaxis, :, :] * mu_h_gen[:, np.newaxis, np.newaxis] else: cov_vh_gen = utilities.average_product( datah_gen, condmu_gen, mean2=True, c2=RBM.n_cv, weights=weights_neg ) - mu_gen[np.newaxis, :] * mu_h_gen[:, np.newaxis] else: if RBM.n_cv > 1: cov_vh_gen = utilities.average_product( h_gen, data_gen, c2=RBM.n_cv, weights=weights_neg ) - mu_gen[np.newaxis, :, :] * mu_h_gen[:, np.newaxis, np.newaxis] else: cov_vh_gen = utilities.average_product( h_gen, data_gen, c2=RBM.n_cv, weights=weights_neg ) - mu_gen[np.newaxis, :] * mu_h_gen[:, np.newaxis] if RBM.hidden == 'dReLU': I_data = RBM.vlayer.compute_output(data, RBM.weights) I_gen = RBM.vlayer.compute_output(data_gen, RBM.weights) mu_p_pos, mu_n_pos, mu2_p_pos, mu2_n_pos = RBM.hlayer.mean12_pm_from_inputs( I_data) mu_p_pos = utilities.average(mu_p_pos, weights=weights) mu_n_pos = utilities.average(mu_n_pos, weights=weights) mu2_p_pos = utilities.average(mu2_p_pos, weights=weights) mu2_n_pos = utilities.average(mu2_n_pos, weights=weights) mu_p_neg, mu_n_neg, mu2_p_neg, mu2_n_neg = RBM.hlayer.mean12_pm_from_inputs( I_gen) mu_p_neg = utilities.average(mu_p_neg, weights=weights_neg) mu_n_neg = utilities.average(mu_n_neg, weights=weights_neg) mu2_p_neg = utilities.average(mu2_p_neg, weights=weights_neg) mu2_n_neg = utilities.average(mu2_n_neg, weights=weights_neg) a = RBM.hlayer.gamma eta = RBM.hlayer.eta theta = RBM.hlayer.delta moment_theta = -mu_p_pos / np.sqrt(1 + eta) + mu_n_pos / np.sqrt(1 - eta) moment_theta_gen = -mu_p_neg / np.sqrt(1 + eta) + mu_n_neg / np.sqrt( 1 - eta) moment_eta = 0.5 * a / (1 + eta)**2 * mu2_p_pos - 0.5 * a / ( 1 - eta)**2 * mu2_n_pos + theta / ( 2 * np.sqrt(1 + eta)**3) * mu_p_pos - theta / ( 2 * np.sqrt(1 - eta)**3) * mu_n_pos moment_eta_gen = 0.5 * a / (1 + eta)**2 * mu2_p_neg - 0.5 * a / ( 1 - eta)**2 * mu2_n_neg + theta / ( 2 * np.sqrt(1 + eta)**3) * mu_p_neg - theta / ( 2 * np.sqrt(1 - eta)**3) * mu_n_neg moment_theta *= -1 moment_theta_gen *= -1 moment_eta *= -1 moment_eta_gen *= -1 W = RBM.weights if with_reg: l2 = RBM.l2 l1 = RBM.l1 l1b = RBM.l1b l1c = RBM.l1c l1_custom = RBM.l1_custom l1b_custom = RBM.l1b_custom n_c2 = RBM.n_cv if l2 > 0: cov_vh_gen += l2 * W if l1 > 0: cov_vh_gen += l1 * np.sign(W) if l1b > 0: # NOT SUPPORTED FOR POTTS if n_c2 > 1: # Potts RBM. cov_vh_gen += l1b * np.sign(W) * np.abs(W).mean(-1).mean( -1)[:, np.newaxis, np.newaxis] else: cov_vh_gen += l1b * np.sign(W) * (np.abs(W).sum(1))[:, np.newaxis] if l1c > 0: # NOT SUPPORTED FOR POTTS cov_vh_gen += l1c * np.sign(W) * ( (np.abs(W).sum(1))**2)[:, np.newaxis] if any([l1 > 0, l1b > 0, l1c > 0]): mask_cov = np.abs(W) > 1e-3 else: mask_cov = np.ones(W.shape, dtype='bool') else: mask_cov = np.ones(W.shape, dtype='bool') if RBM.n_cv > 1: if RBM.n_cv == 21: list_aa = Proteins_utils.aa else: list_aa = Proteins_utils.aa[:-1] colors_template = np.array([ matplotlib.colors.to_rgba(aa_color_scatter(letter)) for letter in list_aa ]) color = np.repeat(colors_template[np.newaxis, :, :], data.shape[1], axis=0).reshape([data.shape[1] * RBM.n_cv, 4]) else: color = 'C0' s2 = 14 if RBM.hidden == 'dReLU': fig, ax = plt.subplots(3, 2) fig.set_figheight(3 * 5) fig.set_figwidth(2 * 5) else: fig, ax = plt.subplots(2, 2) fig.set_figheight(2 * 5) fig.set_figwidth(2 * 5) clean_ax(ax[1, 1]) ax_ = ax[0, 0] ax_.scatter(mu.flatten(), mu_gen.flatten(), c=color) ax_.plot([mu.min(), mu.max()], [mu.min(), mu.max()]) ax_.set_xlabel(r'$<v_i>_d$', fontsize=s2) ax_.set_ylabel(r'$<v_i>_m$', fontsize=s2) r2_mu = np.corrcoef(mu.flatten(), mu_gen.flatten())[0, 1]**2 error_mu = np.sqrt(((mu - mu_gen)**2 / (mu * (1 - mu) + 1e-4)).mean()) mini = mu.min() maxi = mu.max() ax_.text(0.6 * maxi + 0.4 * mini, 0.25 * maxi + 0.75 * mini, r'$R^2 = %.2f$' % r2_mu, fontsize=s2) ax_.text(0.6 * maxi + 0.4 * mini, 0.35 * maxi + 0.65 * mini, r'$Err = %.2e$' % error_mu, fontsize=s2) ax_.set_title('Mean visibles', fontsize=s2) ax_ = ax[0, 1] ax_.scatter(mu_h, mu_h_gen) ax_.plot([mu_h.min(), mu_h.max()], [mu_h.min(), mu_h.max()]) ax_.set_xlabel(r'$<h_\mu>_d$', fontsize=s2) ax_.set_ylabel(r'$<h_\mu>_m$', fontsize=s2) r2_muh = np.corrcoef(mu_h, mu_h_gen)[0, 1]**2 error_muh = np.sqrt(((mu_h - mu_h_gen)**2).mean()) mini = mu_h.min() maxi = mu_h.max() ax_.text(0.6 * maxi + 0.4 * mini, 0.25 * maxi + 0.75 * mini, r'$R^2 = %.2f$' % r2_muh, fontsize=s2) ax_.text(0.6 * maxi + 0.4 * mini, 0.35 * maxi + 0.65 * mini, r'$Err = %.2e$' % error_muh, fontsize=s2) ax_.set_title('Mean hiddens', fontsize=s2) ax_ = ax[1, 0] if RBM.n_cv > 1: color = np.repeat(np.repeat(colors_template[np.newaxis, np.newaxis, :, :], RBM.n_h, axis=0), data.shape[1], axis=1).reshape([RBM.n_v * RBM.n_h * RBM.n_cv, 4]) color = color[mask_cov.flatten()] else: color = 'C0' cov_vh = cov_vh[mask_cov].flatten() cov_vh_gen = cov_vh_gen[mask_cov].flatten() ax_.scatter(cov_vh, cov_vh_gen, c=color) ax_.plot([cov_vh.min(), cov_vh.max()], [cov_vh.min(), cov_vh.max()]) ax_.set_xlabel(r'Cov$(v_i \;, h_\mu)_d$', fontsize=s2) ax_.set_ylabel(r'Cov$(v_i \;, h_\mu)_m + \nabla_{w_{\mu i}} \mathcal{R}$', fontsize=s2) r2_vh = np.corrcoef(cov_vh, cov_vh_gen)[0, 1]**2 error_vh = np.sqrt(((cov_vh - cov_vh_gen)**2).mean()) mini = cov_vh.min() maxi = cov_vh.max() ax_.text(0.6 * maxi + 0.4 * mini, 0.25 * maxi + 0.75 * mini, r'$R^2 = %.2f$' % r2_vh, fontsize=s2) ax_.text(0.6 * maxi + 0.4 * mini, 0.35 * maxi + 0.65 * mini, r'$Err = %.2e$' % error_vh, fontsize=s2) ax_.set_title('Hiddens-Visibles correlations', fontsize=s2) if RBM.hidden == 'dReLU': ax_ = ax[2, 0] ax_.scatter(moment_theta, moment_theta_gen, c=theta) ax_.plot([moment_theta.min(), moment_theta.max()], [moment_theta.min(), moment_theta.max()]) ax_.set_xlabel(r'$<-\frac{\partial E}{\partial \theta}>_d$', fontsize=s2) ax_.set_ylabel(r'$<-\frac{\partial E}{\partial \theta}>_m$', fontsize=s2) r2_theta = np.corrcoef(moment_theta, moment_theta_gen)[0, 1]**2 error_theta = np.sqrt(((moment_theta - moment_theta_gen)**2).mean()) mini = moment_theta.min() maxi = moment_theta.max() ax_.text(0.6 * maxi + 0.4 * mini, 0.25 * maxi + 0.75 * mini, r'$R^2 = %.2f$' % r2_theta, fontsize=s2) ax_.text(0.6 * maxi + 0.4 * mini, 0.35 * maxi + 0.65 * mini, r'$Err = %.2e$' % error_theta, fontsize=s2) ax_.set_title('Moment theta', fontsize=s2) ax_ = ax[2, 1] ax_.scatter(moment_eta, moment_eta_gen, c=np.abs(eta)) ax_.plot([moment_eta.min(), moment_eta.max()], [moment_eta.min(), moment_eta.max()]) ax_.set_xlabel(r'$<-\frac{\partial E}{\partial \eta}>_d$', fontsize=s2) ax_.set_ylabel(r'$<-\frac{\partial E}{\partial \eta}>_m$', fontsize=s2) r2_eta = np.corrcoef(moment_eta, moment_eta_gen)[0, 1]**2 error_eta = np.sqrt(((moment_eta - moment_eta_gen)**2).mean()) mini = moment_eta.min() maxi = moment_eta.max() ax_.text(0.6 * maxi + 0.4 * mini, 0.25 * maxi + 0.75 * mini, r'$R^2 = %.2f$' % r2_eta, fontsize=s2) ax_.text(0.6 * maxi + 0.4 * mini, 0.35 * maxi + 0.65 * mini, r'$Err = %.2e$' % error_eta, fontsize=s2) ax_.set_title('Moment eta', fontsize=s2) plt.tight_layout() if show: fig.show() if RBM.hidden == 'dReLU': errors = [error_mu, error_muh, error_vh, error_theta, error_eta] r2s = [r2_mu, r2_muh, r2_vh, r2_theta, r2_eta] else: errors = [error_mu, error_muh, error_vh] r2s = [r2_mu, r2_muh, r2_vh] return fig, errors, r2s
def get_cross_derivatives_ReLU(V_pos, psi_pos, hlayer, n_cv, weights=None): db_dw = average(V_pos, c=n_cv, weights=weights) a = hlayer.gamma[np.newaxis, :] theta = hlayer.delta[np.newaxis, :] b = hlayer.theta[np.newaxis, :] psi = psi_pos psi_plus = (-(psi - b) + theta) / np.sqrt(a) psi_minus = ((psi - b) + theta) / np.sqrt(a) Phi_plus = erf_times_gauss(psi_plus) Phi_minus = erf_times_gauss(psi_minus) p_plus = 1 / (1 + Phi_minus / Phi_plus) p_minus = 1 - p_plus e = (psi - b) - theta * (p_plus - p_minus) v = p_plus * p_minus * (2 * theta / np.sqrt(a)) * \ (2 * theta / np.sqrt(a) - 1 / Phi_plus - 1 / Phi_minus) dpsi_plus_dpsi = -1 / np.sqrt(a) dpsi_minus_dpsi = 1 / np.sqrt(a) dpsi_plus_dtheta = 1 / np.sqrt(a) dpsi_minus_dtheta = 1 / np.sqrt(a) dpsi_plus_da = -1.0 / (2 * a) * psi_plus dpsi_minus_da = -1.0 / (2 * a) * psi_minus d2psi_plus_dadpsi = 0.5 / np.sqrt(a**3) d2psi_plus_dthetadpsi = 0 d2psi_minus_dadpsi = -0.5 / np.sqrt(a**3) d2psi_minus_dthetadpsi = 0 dp_plus_dpsi = p_plus * p_minus * \ ((psi_plus - 1 / Phi_plus) * dpsi_plus_dpsi - (psi_minus - 1 / Phi_minus) * dpsi_minus_dpsi) dp_plus_dtheta = p_plus * p_minus * \ ((psi_plus - 1 / Phi_plus) * dpsi_plus_dtheta - (psi_minus - 1 / Phi_minus) * dpsi_minus_dtheta) dp_plus_da = p_plus * p_minus * \ ((psi_plus - 1 / Phi_plus) * dpsi_plus_da - (psi_minus - 1 / Phi_minus) * dpsi_minus_da) d2p_plus_dpsi2 = -(p_plus - p_minus) * p_plus * p_minus * ((psi_plus - 1 / Phi_plus) * dpsi_plus_dpsi - (psi_minus - 1 / Phi_minus) * dpsi_minus_dpsi)**2 \ + p_plus * p_minus * ((dpsi_plus_dpsi)**2 * (1 + (psi_plus - 1 / Phi_plus) / Phi_plus) - ( dpsi_minus_dpsi)**2 * (1 + (psi_minus - 1 / Phi_minus) / Phi_minus)) d2p_plus_dadpsi = -(p_plus - p_minus) * ((psi_plus - 1 / Phi_plus) * dpsi_plus_dpsi - (psi_minus - 1 / Phi_minus) * dpsi_minus_dpsi) * (dp_plus_da)\ + p_plus * p_minus * ((dpsi_plus_dpsi * dpsi_plus_da) * (1 + (psi_plus - 1 / Phi_plus) / Phi_plus) - (dpsi_minus_dpsi * dpsi_minus_da) * (1 + (psi_minus - 1 / Phi_minus) / Phi_minus) + (d2psi_plus_dadpsi) * (psi_plus - 1 / Phi_plus) - (d2psi_minus_dadpsi) * (psi_minus - 1 / Phi_minus)) d2p_plus_dthetadpsi = -(p_plus - p_minus) * ((psi_plus - 1 / Phi_plus) * dpsi_plus_dpsi - (psi_minus - 1 / Phi_minus) * dpsi_minus_dpsi) * (dp_plus_dtheta)\ + p_plus * p_minus * ((dpsi_plus_dpsi * dpsi_plus_dtheta) * (1 + (psi_plus - 1 / Phi_plus) / Phi_plus) - (dpsi_minus_dpsi * dpsi_minus_dtheta) * (1 + (psi_minus - 1 / Phi_minus) / Phi_minus) + (d2psi_plus_dthetadpsi) * (psi_plus - 1 / Phi_plus) - (d2psi_minus_dthetadpsi) * (psi_minus - 1 / Phi_minus)) # dlogZ_dpsi = (p_plus * (psi_plus - 1 / Phi_plus) * dpsi_plus_dpsi + # p_minus * (psi_minus - 1 / Phi_minus) * dpsi_minus_dpsi) # dlogZ_dtheta = (p_plus * (psi_plus - 1 / Phi_plus) * dpsi_plus_dtheta + # p_minus * (psi_minus - 1 / Phi_minus) * dpsi_minus_dtheta) # dlogZ_da = (p_plus * (psi_plus - 1 / Phi_plus) * dpsi_plus_da + # p_minus * (psi_minus - 1 / Phi_minus) * dpsi_minus_da) de_dpsi = (1 + v) de_db = -de_dpsi de_da = 2 * (-theta) * dp_plus_da de_dtheta = -(p_plus - p_minus) + 2 * (-theta) * dp_plus_dtheta dv_dpsi = 2 * (-theta) * d2p_plus_dpsi2 dv_db = -dv_dpsi dv_da = +2 * (-theta) * d2p_plus_dadpsi dv_dtheta = - 2 * dp_plus_dpsi \ + 2 * (- theta) * d2p_plus_dthetadpsi var_e = average(e**2, weights=weights) - average(e, weights=weights)**2 mean_v = average(v, weights=weights) dmean_v_da = average(dv_da, weights=weights) dmean_v_db = average(dv_db, weights=weights) dmean_v_dtheta = average(dv_dtheta, weights=weights) dvar_e_da = 2 * ( average(e * de_da, weights=weights) - average(e, weights=weights) * average(de_da, weights=weights)) dvar_e_db = 2 * ( average(e * de_db, weights=weights) - average(e, weights=weights) * average(de_db, weights=weights)) dvar_e_dtheta = 2 * ( average(e * de_dtheta, weights=weights) - average(e, weights=weights) * average(de_dtheta, weights=weights)) tmp = np.sqrt((1 + mean_v)**2 + 4 * var_e) da_db = (dvar_e_db + 0.5 * dmean_v_db * (1 + mean_v + tmp)) / \ (tmp - dvar_e_da - 0.5 * dmean_v_da * (1 + mean_v + tmp)) da_dtheta = (dvar_e_dtheta + 0.5 * dmean_v_dtheta * (1 + mean_v + tmp)) / (tmp - dvar_e_da - 0.5 * dmean_v_da * (1 + mean_v + tmp)) dmean_v_dw = average_product(dv_dpsi, V_pos, c1=1, c2=n_cv, weights=weights) if n_cv > 1: dvar_e_dw = 2 * ( average_product(e * de_dpsi, V_pos, c1=1, c2=n_cv, weights=weights) - average(e, weights=weights)[:, np.newaxis, np.newaxis] * average_product(de_dpsi, V_pos, c1=1, c2=n_cv, weights=weights)) da_dw = (dvar_e_dw + 0.5 * dmean_v_dw * (1 + mean_v + tmp)[:, np.newaxis, np.newaxis]) / ( tmp - dvar_e_da - 0.5 * dmean_v_da * (1 + mean_v + tmp))[:, np.newaxis, np.newaxis] else: dvar_e_dw = 2 * ( average_product(e * de_dpsi, V_pos, c1=1, c2=1, weights=weights) - average(e, weights=weights)[:, np.newaxis] * average_product(de_dpsi, V_pos, c1=1, c2=1, weights=weights)) da_dw = (dvar_e_dw + 0.5 * dmean_v_dw * (1 + mean_v + tmp)[:, np.newaxis]) / ( tmp - dvar_e_da - 0.5 * dmean_v_da * (1 + mean_v + tmp))[:, np.newaxis] return db_dw, da_db, da_dtheta, da_dw
def get_cross_derivatives_ReLU_plus(V_pos, psi_pos, hlayer, n_cv, weights=None): db_dw = average(V_pos, c=n_cv, weights=weights) a = hlayer.gamma[np.newaxis, :] b = hlayer.theta[np.newaxis, :] psi = psi_pos psi_plus = -(psi - b) / np.sqrt(a) Phi_plus = erf_times_gauss(psi_plus) e = (psi - b) + np.sqrt(a) / Phi_plus v = (psi_plus - 1 / Phi_plus) / Phi_plus dpsi_plus_dpsi = -1 / np.sqrt(a) dpsi_plus_da = -1.0 / (2 * a) * psi_plus de_dpsi = 1 + v de_db = -de_dpsi de_da = np.sqrt(a) * (1.0 / (2 * a * Phi_plus) - (psi_plus - 1 / Phi_plus) / Phi_plus * dpsi_plus_da) dv_dpsi = dpsi_plus_dpsi * \ (1 + psi_plus / Phi_plus - 1 / Phi_plus ** 2 - (psi_plus - 1 / Phi_plus)**2) / Phi_plus dv_db = -dv_dpsi dv_da = dpsi_plus_da * (1 + psi_plus / Phi_plus - 1 / Phi_plus**2 - (psi_plus - 1 / Phi_plus)**2) / Phi_plus var_e = average(e**2, weights=weights) - average(e, weights=weights)**2 mean_v = average(v, weights=weights) dmean_v_da = average(dv_da, weights=weights) dmean_v_db = average(dv_db, weights=weights) dvar_e_da = 2 * ( average(e * de_da, weights=weights) - average(e, weights=weights) * average(de_da, weights=weights)) dvar_e_db = 2 * ( average(e * de_db, weights=weights) - average(e, weights=weights) * average(de_db, weights=weights)) tmp = np.sqrt((1 + mean_v)**2 + 4 * var_e) denominator = (tmp - dvar_e_da - 0.5 * dmean_v_da * (1 + mean_v + tmp)) denominator = np.maximum(denominator, 0.5) # For numerical stability. da_db = (dvar_e_db + 0.5 * dmean_v_db * (1 + mean_v + tmp)) / denominator dmean_v_dw = average_product(dv_dpsi, V_pos, c1=1, c2=n_cv, weights=weights) if n_cv > 1: dvar_e_dw = 2 * ( average_product(e * de_dpsi, V_pos, c1=1, c2=n_cv, weights=weights) - average(e, weights=weights)[:, np.newaxis, np.newaxis] * average_product(de_dpsi, V_pos, c1=1, c2=n_cv, weights=weights)) da_dw = (dvar_e_dw + 0.5 * dmean_v_dw * (1 + mean_v + tmp)[:, np.newaxis, np.newaxis] ) / denominator[:, np.newaxis, np.newaxis] else: dvar_e_dw = 2 * ( average_product(e * de_dpsi, V_pos, c1=1, c2=1, weights=weights) - average(e, weights=weights)[:, np.newaxis] * average_product(de_dpsi, V_pos, c1=1, c2=1, weights=weights)) da_dw = (dvar_e_dw + 0.5 * dmean_v_dw * (1 + mean_v + tmp)[:, np.newaxis]) / denominator[:, np.newaxis] return db_dw, da_db, da_dw
def minibatch_fit(self, X, Y, weights=None): self.count_updates += 1 grad = {} I = self.input_layer.compute_output(X, self.weights, direction='down') prediction = self.output_layer.mean_from_inputs(I) grad['weights'] = self.moments_XY - utilities.average_product( X, prediction, c1=self.n_cin, c2=self.n_cout, mean2=True, weights=weights) if self.nature in ['Bernoulli', 'Spin', 'Potts']: grad['output_layer'] = self.output_layer.internal_gradients( self.moments_Y, prediction, value='moments', value_neg='mean', weights_neg=weights) else: grad['output_layer'] = self.output_layer.internal_gradients( self.moments_Y, I, value='moments', value_neg='input', weights_neg=weights) for regtype, regtarget, regvalue in self.regularizers: if regtarget == 'weights': target_gradient = grad['weights'] target = self.weights else: target_gradient = grad['output_layer'][regtarget] target = self.output_layer.__dict__[regtarget] if regtype == 'l1': target_gradient -= regvalue * np.sign(target) elif regtype == 'l2': target_gradient -= regvalue * target else: print(regtype, 'not supported') for key, gradient in grad['output_layer'].items(): if self.output_layer.do_grad_updates[key]: if self.optimizer == 'SGD': self.output_layer.__dict__[ key] += self.learning_rate * gradient elif self.optimizer == 'ADAM': self.gradient_moment1['output_layer'][key] *= self.beta1 self.gradient_moment1['output_layer'][key] += ( 1 - self.beta1) * gradient self.gradient_moment2['output_layer'][key] *= self.beta2 self.gradient_moment2['output_layer'][key] += ( 1 - self.beta2) * gradient**2 self.output_layer.__dict__[key] += self.learning_rate / ( 1 - self.beta1 ) * (self.gradient_moment1['output_layer'][key] / (1 - self.beta1**self.count_updates)) / ( self.epsilon + np.sqrt( self.gradient_moment2['output_layer'][key] / (1 - self.beta2**self.count_updates))) if self.optimizer == 'SGD': self.weights += self.learning_rate * grad['weights'] elif self.optimizer == 'ADAM': self.gradient_moment1['weights'] *= self.beta1 self.gradient_moment1['weights'] += (1 - self.beta1) * grad['weights'] self.gradient_moment2['weights'] *= self.beta2 self.gradient_moment2['weights'] += ( 1 - self.beta2) * grad['weights']**2 self.weights += self.learning_rate / (1 - self.beta1) * ( self.gradient_moment1['weights'] / (1 - self.beta1**self.count_updates) ) / (self.epsilon + np.sqrt(self.gradient_moment2['weights'] / (1 - self.beta2**self.count_updates))) if self.symmetric: if self.n_cout > 1: self.weights += np.swapaxes(np.swapaxes(self.weights, 0, 1), 2, 3) self.weights /= 2 else: self.weights += self.weights.T self.weights /= 2 if self.zero_diag: self.weights[np.arange(self.Nout), np.arange(self.Nout)] *= 0 return
def fit(self, X, Y, weights=None, batch_size=100, learning_rate=None, lr_final=None, lr_decay=True, decay_after=0.5, extra_params=None, optimizer='ADAM', n_iter=10, verbose=1, regularizers=[]): self.batch_size = batch_size self.optimizer = optimizer self.n_iter = n_iter if self.n_iter <= 1: lr_decay = False if learning_rate is None: if self.optimizer == 'SGD': learning_rate = 0.01 elif self.optimizer == 'ADAM': learning_rate = 5e-4 else: print('Need to specify learning rate for optimizer.') if self.optimizer == 'ADAM': if extra_params is None: extra_params = [0.9, 0.99, 1e-3] self.beta1 = extra_params[0] self.beta2 = extra_params[1] self.epsilon = extra_params[2] if self.n_cout > 1: out0 = np.zeros([1, self.Nout, self.n_cout], dtype=curr_float) else: out0 = np.zeros([1, self.Nout], dtype=curr_float) grad = { 'weights': np.zeros_like(self.weights), 'output_layer': self.output_layer.internal_gradients(out0, out0, value='input', value_neg='input') } for key in grad['output_layer'].keys(): grad['output_layer'][key] *= 0 self.gradient_moment1 = copy.deepcopy(grad) self.gradient_moment2 = copy.deepcopy(grad) self.learning_rate_init = copy.copy(learning_rate) self.learning_rate = learning_rate self.lr_decay = lr_decay if self.lr_decay: self.decay_after = decay_after self.start_decay = int(self.n_iter * self.decay_after) if lr_final is None: self.lr_final = 1e-2 * self.learning_rate else: self.lr_final = lr_final self.decay_gamma = (float(self.lr_final) / float(self.learning_rate))**( 1 / float(self.n_iter * (1 - self.decay_after))) else: self.decay_gamma = 1 self.regularizers = regularizers n_samples = X.shape[0] n_batches = int(np.ceil(float(n_samples) / self.batch_size)) batch_slices = list( utilities.gen_even_slices(n_batches * self.batch_size, n_batches, n_samples)) X = np.asarray(X, dtype=self.input_layer.type, order='c') Y = np.asarray(Y, dtype=self.output_layer.type, order='c') if weights is not None: weights = weights.astype(curr_float) self.moments_Y = self.output_layer.get_moments(Y, weights=weights, value='data') self.moments_XY = utilities.average_product(X, Y, c1=self.n_cin, c2=self.n_cout, mean1=False, mean2=False, weights=weights) self.count_updates = 0 for epoch in range(1, n_iter + 1): if verbose: begin = time.time() if self.lr_decay: if (epoch > self.start_decay): self.learning_rate *= self.decay_gamma permutation = np.argsort(np.random.randn(n_samples)) X = X[permutation, :] Y = Y[permutation, :] if weights is not None: weights = weights[permutation] if verbose: print('Starting epoch %s' % (epoch)) for batch_slice in batch_slices: if weights is not None: self.minibatch_fit(X[batch_slice], Y[batch_slice], weights=weights[batch_slice]) else: self.minibatch_fit(X[batch_slice], Y[batch_slice], weights=None) if verbose: end = time.time() lik = utilities.average(self.likelihood(X, Y), weights=weights) regularization = 0 for regtype, regtarget, regvalue in self.regularizers: if regtarget == 'weights': target = self.weights else: target = self.output_layer.__dict__[regtarget] if regtype == 'l1': regularization += (regvalue * np.abs(target)).sum() elif regtype == 'l2': regularization += 0.5 * (regvalue * target**2).sum() else: print(regtype, 'not supported') regularization /= self.Nout message = "Iteration %d, time = %.2fs, likelihood = %.2f, regularization = %.2e, loss = %.2f" % ( epoch, end - begin, lik, regularization, -lik + regularization) print(message) return 'done'
def fit(self, data, batch_size=100, nchains=100, learning_rate=None, extra_params=None, init='independent', optimizer='SGD', N_PT=1, N_MC=1, n_iter=10, lr_decay=True, lr_final=None, decay_after=0.5, l1=0, l1b=0, l1c=0, l2=0, l2_fields=0, no_fields=False, batch_norm=False, update_betas=None, record_acceptance=None, epsilon=1e-6, verbose=1, record=[], record_interval=100, p=[1, 0, 0], pseudo_count=0, weights=None): self.nchains = nchains self.optimizer = optimizer self.record_swaps = False self.batch_norm = batch_norm self.layer.batch_norm = batch_norm self.n_iter = n_iter if learning_rate is None: if self.nature in ['Bernoulli', 'Spin', 'Potts']: learning_rate = 0.1 else: learning_rate = 0.01 if self.optimizer == 'ADAM': learning_rate *= 0.1 self.learning_rate = learning_rate self.lr_decay = lr_decay if self.lr_decay: self.decay_after = decay_after self.start_decay = self.n_iter * self.decay_after if lr_final is None: self.lr_final = 1e-2 * self.learning_rate else: self.lr_final = lr_final self.decay_gamma = (float(self.lr_final) / float(self.learning_rate))**( 1 / float(self.n_iter * (1 - self.decay_after))) self.gradient = self.initialize_gradient_dictionary() if self.optimizer == 'momentum': if extra_params is None: extra_params = 0.9 self.momentum = extra_params self.previous_update = self.initialize_gradient_dictionary() elif self.optimizer == 'ADAM': if extra_params is None: extra_params = [0.9, 0.999, 1e-8] self.beta1 = extra_params[0] self.beta2 = extra_params[1] self.epsilon = extra_params[2] self.gradient_moment1 = self.initialize_gradient_dictionary() self.gradient_moment2 = self.initialize_gradient_dictionary() if weights is not None: weights = np.asarray(weights, dtype=float) mean = utilities.average(data, c=self.n_c, weights=weights) covariance = utilities.average_product(data, data, c1=self.n_c, c2=self.n_c, weights=weights) if pseudo_count > 0: p = data.shape[0] / float(data.shape[0] + pseudo_count) covariance = p**2 * covariance + p * \ (1 - p) * (mean[np.newaxis, :, np.newaxis, :] * mean[:, np.newaxis, :, np.newaxis]) / self.n_c + (1 - p)**2 / self.n_c**2 mean = p * mean + (1 - p) / self.n_c iter_per_epoch = data.shape[0] // batch_size if init != 'previous': norm_init = 0 self.init_couplings(norm_init) if init == 'independent': self.layer.init_params_from_data(data, eps=epsilon, value='data') self.N_PT = N_PT self.N_MC = N_MC self.l1 = l1 self.l1b = l1b self.l1c = l1c self.l2 = l2 self.tmp_l2_fields = l2_fields self.no_fields = no_fields if self.N_PT > 1: if record_acceptance == None: record_acceptance = True self.record_acceptance = record_acceptance if update_betas == None: update_betas = True self._update_betas = update_betas if self.record_acceptance: self.mavar_gamma = 0.95 self.acceptance_rates = np.zeros(N_PT - 1) self.mav_acceptance_rates = np.zeros(N_PT - 1) self.count_swaps = 0 if self._update_betas: record_acceptance = True self.update_betas_lr = 0.1 self.update_betas_lr_decay = 1 if self._update_betas | (not hasattr(self, 'betas')): self.betas = np.arange(N_PT) / float(N_PT - 1) self.betas = self.betas[::-1] if (len(self.betas) != N_PT): self.betas = np.arange(N_PT) / float(N_PT - 1) self.betas = self.betas[::-1] if self.nature == 'Potts': (self.fantasy_x, self.fantasy_fields_eff) = self.layer.sample_from_inputs(np.zeros( [self.N_PT * self.nchains, self.N, self.n_c]), beta=0) else: (self.fantasy_x, self.fantasy_fields_eff) = self.layer.sample_from_inputs(np.zeros( [self.N_PT * self.nchains, self.N]), beta=0) if self.N_PT > 1: self.fantasy_x = self.fantasy_x.reshape( [self.N_PT, self.nchains, self.N]) if self.nature == 'Potts': self.fantasy_fields_eff = self.fantasy_fields_eff.reshape( [self.N_PT, self.nchains, self.N, self.n_c]) else: self.fantasy_fields_eff = self.fantasy_fields_eff.reshape( [self.N_PT, self.nchains, self.N]) self.fantasy_E = np.zeros([self.N_PT, self.nchains]) self.count_updates = 0 if verbose: if weights is not None: lik = (self.pseudo_likelihood(data) * weights).sum() / weights.sum() else: lik = self.pseudo_likelihood(data).mean() print('Iteration number 0, pseudo-likelihood: %.2f' % lik) result = {} if 'J' in record: result['J'] = [] if 'F' in record: result['F'] = [] count = 0 for epoch in range(1, n_iter + 1): if verbose: begin = time.time() if self.lr_decay: if (epoch > self.start_decay): self.learning_rate *= self.decay_gamma print('Starting epoch %s' % (epoch)) for _ in range(iter_per_epoch): self.minibatch_fit(mean, covariance) if (count % record_interval == 0): if 'J' in record: result['J'].append(self.layer.couplings.copy()) if 'F' in record: result['F'].append(self.layer.fields.copy()) count += 1 if verbose: end = time.time() if weights is not None: lik = (self.pseudo_likelihood(data) * weights).sum() / weights.sum() else: lik = self.pseudo_likelihood(data).mean() print("[%s] Iteration %d, pseudo-likelihood = %.2f," " time = %.2fs" % (type(self).__name__, epoch, lik, end - begin)) return result
def fit(self, data, weights=None, pseudo_count=1e-4, verbose=1, zero_diag=True): fi = utilities.average(data, c=self.n_c, weights=weights) fij = utilities.average_product(data, data, c1=self.n_c, c2=self.n_c, weights=weights) for i in range(self.N): fij[i, i] = np.diag(fi[i]) fi_PC = (1 - pseudo_count) * fi + pseudo_count / float(self.n_c) fij_PC = (1 - pseudo_count) * fij + pseudo_count / float(self.n_c)**2 for i in range(self.N): fij_PC[i, i] = np.diag(fi_PC[i]) Cij = fij_PC - fi_PC[ np.newaxis, :, np.newaxis, :] * fi_PC[:, np.newaxis, :, np.newaxis] D = np.zeros([self.N, self.n_c - 1, self.n_c - 1]) invD = np.zeros([self.N, self.n_c - 1, self.n_c - 1]) for n in range(self.N): D[n] = scipy.linalg.sqrtm(Cij[n, n, :-1, :-1]) invD[n] = np.linalg.inv(D[n]) Gamma = np.zeros([self.N, self.n_c - 1, self.N, self.n_c - 1]) for n1 in range(self.N): for n2 in range(self.N): Gamma[n1, :, n2, :] = np.dot(invD[n1], np.dot(Cij[n1, n2, :-1, :-1], invD[n2])) Gamma_bin = Gamma.reshape( [self.N * (self.n_c - 1), self.N * (self.n_c - 1)]) Gamma_bin = (Gamma_bin + Gamma_bin.T) / 2 lam, v = np.linalg.eigh(Gamma_bin) order = np.argsort(lam)[::-1] v_ordered = np.rollaxis( v.reshape([self.N, self.n_c - 1, self.N * (self.n_c - 1)]), 2, 0)[order, :, :] lam_ordered = lam[order] DeltaL = 0.5 * (lam_ordered - 1 - np.log(lam_ordered)) xi = np.zeros(v_ordered.shape) for n in range(self.N): xi[:, n, :] = np.dot(v_ordered[:, n, :], invD[n]) xi = np.sqrt(np.abs(1 - 1 / lam_ordered))[:, np.newaxis, np.newaxis] * xi xi = np.concatenate( (xi, np.zeros([self.N * (self.n_c - 1), self.N, 1])), axis=2) # Write in zero-sum gauge. xi -= xi.mean(-1)[:, :, np.newaxis] top_M_contrib = np.argsort(DeltaL)[::-1][:self.M] self.xi = xi[top_M_contrib] self.lam = lam_ordered[top_M_contrib] self.DeltaL = DeltaL[top_M_contrib] couplings = np.tensordot( self.xi[self.lam > 1], self.xi[self.lam > 1], axes=[ (0), (0) ]) - np.tensordot( self.xi[self.lam < 1], self.xi[self.lam < 1], axes=[(0), (0)]) couplings = np.asarray(np.swapaxes(couplings, 1, 2), order='c') if zero_diag: # With zero diag is much better; I just check things... for n in range(self.N): couplings[n, n, :, :] *= 0 fields = np.log(fi_PC) - np.tensordot( couplings, fi_PC, axes=[(1, 3), (0, 1)]) fields -= fields.mean(-1)[:, np.newaxis] self.layer.couplings = couplings self.layer.fields = fields if verbose: fig, ax = plt.subplots() ax2 = ax.twinx() ax.plot(self.DeltaL) ax2.semilogy(self.lam, c='red') ax.set_ylabel(r'$\Delta L$', color='blue') ax2.set_ylabel('Mode variance', color='red') for tl in ax.get_yticklabels(): tl.set_color('blue') for tl in ax2.get_yticklabels(): tl.set_color('red')
def get_cross_derivatives_dReLU(V_pos, psi_pos, hlayer, n_cv, weights=None): # a = 2.0/(1.0/hlayer.a_plus + 1.0/hlayer.a_minus) # eta = 0.5* (a/hlayer.a_plus - a/hlayer.a_minus) # theta = (1.-eta**2)/2. * (hlayer.theta_plus+hlayer.theta_minus) # b = (1.+eta)/2. * hlayer.theta_plus - (1.-eta)/2. * hlayer.theta_minus db_dw = average(V_pos, c=n_cv, weights=weights) a = hlayer.a[np.newaxis, :] eta = hlayer.eta[np.newaxis, :] theta = hlayer.theta[np.newaxis, :] b = hlayer.b[np.newaxis, :] psi = psi_pos psi_plus = (-np.sqrt(1 + eta) * (psi - b) + theta / np.sqrt(1 + eta)) / np.sqrt(a) psi_minus = (np.sqrt(1 - eta) * (psi - b) + theta / np.sqrt(1 - eta)) / np.sqrt(a) Phi_plus = erf_times_gauss(psi_plus) Phi_minus = erf_times_gauss(psi_minus) Z = Phi_plus * np.sqrt(1 + eta) + Phi_minus * np.sqrt(1 - eta) p_plus = 1 / (1 + (Phi_minus * np.sqrt(1 - eta)) / (Phi_plus * np.sqrt(1 + eta))) nans = np.isnan(p_plus) p_plus[nans] = 1.0 * (np.abs(psi_plus[nans]) > np.abs(psi_minus[nans])) p_minus = 1 - p_plus e = (psi - b) * (1 + eta * (p_plus - p_minus)) - theta * ( p_plus - p_minus) + 2 * eta * np.sqrt(a) / Z v = eta * (p_plus - p_minus) + p_plus * p_minus * ( 2 * theta / np.sqrt(a) - 2 * eta * (psi - b) / np.sqrt(a)) * ( 2 * theta / np.sqrt(a) - 2 * eta * (psi - b) / np.sqrt(a) - np.sqrt(1 + eta) / Phi_plus - np.sqrt(1 - eta) / Phi_minus) - 2 * eta * e / (np.sqrt(a) * Z) dpsi_plus_dpsi = -np.sqrt((1 + eta) / a) dpsi_minus_dpsi = np.sqrt((1 - eta) / a) dpsi_plus_dtheta = 1 / np.sqrt(a * (1 + eta)) dpsi_minus_dtheta = 1 / np.sqrt(a * (1 - eta)) # dpsi_plus_da = -1.0/(2*a) * psi_plus # dpsi_minus_da = -1.0/(2*a) * psi_minus dpsi_plus_deta = -1.0 / (2 * np.sqrt(a * (1 + eta))) * ((psi - b) + theta / (1 + eta)) dpsi_minus_deta = -1.0 / (2 * np.sqrt(a * (1 - eta))) * ((psi - b) - theta / (1 - eta)) # d2psi_plus_dadpsi = 0.5 * np.sqrt((1+eta)/a**3 ) d2psi_plus_dthetadpsi = 0 d2psi_plus_detadpsi = -0.5 / np.sqrt((1 + eta) * a) # d2psi_minus_dadpsi = -0.5 * np.sqrt((1-eta)/a**3 ) d2psi_minus_dthetadpsi = 0 d2psi_minus_detadpsi = -0.5 / np.sqrt((1 - eta) * a) dp_plus_dpsi = p_plus * p_minus * ( (psi_plus - 1 / Phi_plus) * dpsi_plus_dpsi - (psi_minus - 1 / Phi_minus) * dpsi_minus_dpsi) dp_plus_dtheta = p_plus * p_minus * ( (psi_plus - 1 / Phi_plus) * dpsi_plus_dtheta - (psi_minus - 1 / Phi_minus) * dpsi_minus_dtheta) # dp_plus_da = p_plus * p_minus * ( (psi_plus-1/Phi_plus) * dpsi_plus_da - (psi_minus-1/Phi_minus) * dpsi_minus_da ) dp_plus_deta = p_plus * p_minus * ( (psi_plus - 1 / Phi_plus) * dpsi_plus_deta - (psi_minus - 1 / Phi_minus) * dpsi_minus_deta + 1 / (1 - eta**2)) d2p_plus_dpsi2 = -(p_plus-p_minus) * p_plus * p_minus * ( (psi_plus-1/Phi_plus) * dpsi_plus_dpsi - (psi_minus-1/Phi_minus) * dpsi_minus_dpsi )**2 \ + p_plus * p_minus * ( (dpsi_plus_dpsi)**2 * (1+ (psi_plus-1/Phi_plus)/Phi_plus) - (dpsi_minus_dpsi)**2 * (1+ (psi_minus-1/Phi_minus)/Phi_minus) ) # d2p_plus_dadpsi = -(p_plus-p_minus) * ( (psi_plus-1/Phi_plus) * dpsi_plus_dpsi - (psi_minus-1/Phi_minus) * dpsi_minus_dpsi ) * (dp_plus_da)\ # + p_plus * p_minus * ( (dpsi_plus_dpsi* dpsi_plus_da) * (1+ (psi_plus-1/Phi_plus)/Phi_plus) - (dpsi_minus_dpsi *dpsi_minus_da) * (1+ (psi_minus-1/Phi_minus)/Phi_minus) \ # + (d2psi_plus_dadpsi) * (psi_plus-1/Phi_plus) - (d2psi_minus_dadpsi) * (psi_minus-1/Phi_minus) ) d2p_plus_dthetadpsi = -(p_plus-p_minus) * ( (psi_plus-1/Phi_plus) * dpsi_plus_dpsi - (psi_minus-1/Phi_minus) * dpsi_minus_dpsi ) * (dp_plus_dtheta)\ + p_plus * p_minus * ( (dpsi_plus_dpsi* dpsi_plus_dtheta) * (1+ (psi_plus-1/Phi_plus)/Phi_plus) - (dpsi_minus_dpsi *dpsi_minus_dtheta) * (1+ (psi_minus-1/Phi_minus)/Phi_minus) \ + (d2psi_plus_dthetadpsi) * (psi_plus-1/Phi_plus) - (d2psi_minus_dthetadpsi) * (psi_minus-1/Phi_minus) ) d2p_plus_detadpsi = -(p_plus-p_minus) * ( (psi_plus-1/Phi_plus) * dpsi_plus_dpsi - (psi_minus-1/Phi_minus) * dpsi_minus_dpsi ) * (dp_plus_deta)\ + p_plus * p_minus * ( (dpsi_plus_dpsi* dpsi_plus_deta) * (1+ (psi_plus-1/Phi_plus)/Phi_plus) - (dpsi_minus_dpsi *dpsi_minus_deta) * (1+ (psi_minus-1/Phi_minus)/Phi_minus) \ + (d2psi_plus_detadpsi) * (psi_plus-1/Phi_plus) - (d2psi_minus_detadpsi) * (psi_minus-1/Phi_minus) ) dlogZ_dpsi = (p_plus * (psi_plus - 1 / Phi_plus) * dpsi_plus_dpsi + p_minus * (psi_minus - 1 / Phi_minus) * dpsi_minus_dpsi) dlogZ_dtheta = (p_plus * (psi_plus - 1 / Phi_plus) * dpsi_plus_dtheta + p_minus * (psi_minus - 1 / Phi_minus) * dpsi_minus_dtheta) # dlogZ_da = (p_plus * (psi_plus-1/Phi_plus)* dpsi_plus_da + p_minus * (psi_minus-1/Phi_minus) * dpsi_minus_da ) dlogZ_deta = (p_plus * (psi_plus - 1 / Phi_plus) * dpsi_plus_deta + p_minus * (psi_minus - 1 / Phi_minus) * dpsi_minus_deta + 0.5 * (p_plus / (1 + eta) - p_minus / (1 - eta))) de_dpsi = (1 + v) de_db = -de_dpsi # de_da = 2*((psi-b) * eta - theta) * dp_plus_da + eta/(Z*np.sqrt(a)) - 2*eta*np.sqrt(a)/Z * dlogZ_da de_dtheta = -(p_plus - p_minus) + 2 * ( (psi - b) * eta - theta) * dp_plus_dtheta - 2 * eta * np.sqrt(a) / Z * dlogZ_dtheta de_deta = (psi - b) * (p_plus - p_minus) + 2 * ( (psi - b) * eta - theta) * dp_plus_deta + 2 * np.sqrt( a) / Z - 2 * eta * np.sqrt(a) / Z * dlogZ_deta dv_dpsi = 4 * eta * dp_plus_dpsi\ + 2*( (psi-b)*eta-theta) * d2p_plus_dpsi2 \ - 2* eta/(np.sqrt(a)*Z) * ( de_dpsi - e*dlogZ_dpsi ) dv_db = -dv_dpsi # dv_da = eta * 2 * dp_plus_da \ # + 2 * ((psi-b)*eta - theta) * d2p_plus_dadpsi \ # -2 * eta/(Z * np.sqrt(a)) * ( -e/(2*a) - e*dlogZ_da + de_da ) dv_dtheta = 2 * eta * dp_plus_dtheta \ - 2 * dp_plus_dpsi \ + 2 * ((psi-b)*eta - theta) * d2p_plus_dthetadpsi \ -2 * eta/(Z * np.sqrt(a)) * ( - e*dlogZ_dtheta + de_dtheta ) dv_deta = (p_plus-p_minus) \ + 2 * eta * dp_plus_deta \ + 2 * (psi-b) * dp_plus_dpsi \ + 2 * ((psi-b)*eta - theta) * d2p_plus_detadpsi \ -2 * 1/(Z * np.sqrt(a)) * (e - e*eta*dlogZ_deta + eta*de_deta ) var_e = average(e**2, weights=weights) - average(e, weights=weights)**2 mean_v = average(v, weights=weights) # dmean_v_da = average(dv_da,weights=weights) dmean_v_db = average(dv_db, weights=weights) dmean_v_dtheta = average(dv_dtheta, weights=weights) dmean_v_deta = average(dv_deta, weights=weights) # dvar_e_da = 2* (average(e*de_da,weights=weights) -average(e,weights=weights) * average(de_da,weights=weights) ) dvar_e_db = 2 * ( average(e * de_db, weights=weights) - average(e, weights=weights) * average(de_db, weights=weights)) dvar_e_dtheta = 2 * ( average(e * de_dtheta, weights=weights) - average(e, weights=weights) * average(de_dtheta, weights=weights)) dvar_e_deta = 2 * ( average(e * de_deta, weights=weights) - average(e, weights=weights) * average(de_deta, weights=weights)) tmp = np.sqrt((1 + mean_v)**2 + 4 * var_e) denominator = tmp # denominator = (tmp - dvar_e_da- 0.5 * dmean_v_da * (1+mean_v+tmp)) # denominator = np.maximum( denominator, 0.5) # For numerical stability. da_db = (dvar_e_db + 0.5 * dmean_v_db * (1 + mean_v + tmp)) / denominator da_dtheta = (dvar_e_dtheta + 0.5 * dmean_v_dtheta * (1 + mean_v + tmp)) / denominator da_deta = (dvar_e_deta + 0.5 * dmean_v_deta * (1 + mean_v + tmp)) / denominator dmean_v_dw = average_product(dv_dpsi, V_pos, c1=1, c2=n_cv, weights=weights) if n_cv > 1: dvar_e_dw = 2 * ( average_product(e * de_dpsi, V_pos, c1=1, c2=n_cv, weights=weights) - average(e, weights=weights)[:, np.newaxis, np.newaxis] * average_product(de_dpsi, V_pos, c1=1, c2=n_cv, weights=weights)) da_dw = (dvar_e_dw + 0.5 * dmean_v_dw * (1 + mean_v + tmp)[:, np.newaxis, np.newaxis] ) / denominator[:, np.newaxis, np.newaxis] else: dvar_e_dw = 2 * ( average_product(e * de_dpsi, V_pos, c1=1, c2=1, weights=weights) - average(e, weights=weights)[:, np.newaxis] * average_product(de_dpsi, V_pos, c1=1, c2=1, weights=weights)) da_dw = (dvar_e_dw + 0.5 * dmean_v_dw * (1 + mean_v + tmp)[:, np.newaxis]) / denominator[:, np.newaxis] return db_dw, da_db, da_dtheta, da_deta, da_dw