def test_allsort(self): D = 2 my_N = 10 N = self.comm.size * my_N my_a = np.random.uniform(size=(my_N, D)) # Check axis=0, all_a = parallel.allsort(my_a, axis=0, kind='quicksort', comm=self.comm) self.assertEqual(all_a.shape, (N, D)) # Chck default axis all_a = parallel.allsort(my_a, comm=self.comm) self.assertEqual(all_a.shape, (my_N, D * self.comm.size))
def _get_sorted_data(self,N, anneal, A_pi_gamma, all_denoms, candidates,logpj_all,my_y): comm = self.comm # if False: if anneal['Ncut_factor'] > 0.0: N_use = int(N * (1 - (1 - A_pi_gamma) * anneal['Ncut_factor'])) cut_denom = parallel.allsort(all_denoms)[-N_use] which = np.array(all_denoms > cut_denom) # which = np.array(all_denoms >= cut_denom) candidates = candidates[which] logpj_all = logpj_all[which] my_y = my_y[which] my_N, D = my_y.shape N_use = comm.allreduce(my_N) # N_use = N else: N_use = N return N_use, my_y,candidates,logpj_all
def M_step(self, anneal, model_params, my_suff_stat, my_data): """ BSC M_step my_data variables used: my_data['y'] Datapoints my_data['candidates'] Candidate H's according to selection func. Annealing variables used: anneal['T'] Temperature for det. annealing anneal['N_cut_factor'] 0.: no truncation; 1. trunc. according to model """ comm = self.comm H, Hprime = self.H, self.Hprime gamma = self.gamma W = model_params['W'].T pies = model_params['pi'] sigma = model_params['sigma'] mu = model_params['mu'] # Read in data: my_y = my_data['y'].copy() candidates = my_data['candidates'] logpj_all = my_suff_stat['logpj'] all_denoms = np.exp(logpj_all).sum(axis=1) my_N, D = my_y.shape N = comm.allreduce(my_N) # Joerg's data noise idea data_noise_scale = anneal['data_noise'] if data_noise_scale > 0: my_y += my_data['data_noise'] SM = self.state_matrix # shape: (no_states, Hprime) # To compute et_loglike: my_ldenom_sum = 0.0 ldenom_sum = 0.0 # Precompute factor for pi update A_pi_gamma = 0 B_pi_gamma = 0 for gamma_p in range(gamma + 1): A_pi_gamma += comb(H, gamma_p) * (pies**gamma_p) * ( (1 - pies)**(H - gamma_p)) B_pi_gamma += gamma_p * comb(H, gamma_p) * (pies**gamma_p) * ( (1 - pies)**(H - gamma_p)) E_pi_gamma = pies * H * A_pi_gamma / B_pi_gamma # Truncate data if anneal['Ncut_factor'] > 0.0: tracing.tracepoint("M_step:truncating") #alpha = 0.9 # alpha from ET paper #N_use = int(alpha * (N * (1 - (1 - A_pi_gamma) * anneal['Ncut_factor']))) N_use = int(N * (1 - (1 - A_pi_gamma) * anneal['Ncut_factor'])) cut_denom = parallel.allsort(all_denoms)[-N_use] which = np.array(all_denoms >= cut_denom) candidates = candidates[which] logpj_all = logpj_all[which] my_y = my_y[which] my_N, D = my_y.shape N_use = comm.allreduce(my_N) else: N_use = N dlog.append('N', N_use) # Calculate truncated Likelihood L = H * np.log(1 - pies) - 0.5 * D * np.log( 2 * pi * sigma**2) - np.log(A_pi_gamma) Fs = np.log(np.exp(logpj_all).sum(axis=1)).sum() L += comm.allreduce(Fs) / N_use dlog.append('L', L) # Precompute pil_bar = np.log(pies / (1. - pies)) corr_all = logpj_all.max(axis=1) # shape: (my_N,) pjb_all = np.exp(logpj_all - corr_all[:, None]) # shape: (my_N, no_states) # Allocate my_Wp = np.zeros_like(W) # shape (H, D) my_Wq = np.zeros((H, H)) # shape (H, H) my_pi = 0.0 # my_sigma = 0.0 # #my_mup = np.zeros_like(W) # shape (H, D) #my_muq = np.zeros((H,H)) # shape (H, H) my_mus = np.zeros(H) # shape D data_sum = my_y.sum(axis=0) # sum over all data points for mu update ## Calculate mu #for n in xrange(my_N): #tracing.tracepoint("Calculationg offset") #y = my_y[n,:] # length D #cand = candidates[n,:] # length Hprime #logpj = logpj_all[n,:] # length no_states #corr = corr_all[n] # scalar #pjb = pjb_all[n, :] ## Zero active hidden cause (do nothing for the W and pi case) ## this_Wp += 0. # nothing to do ## this_Wq += 0. # nothing to do ## this_pi += 0. # nothing to do ## One active hidden cause #this_mup = np.outer(pjb[1:(H+1)],y) #this_muq = pjb[1:(H+1)] * np.identity(H) #this_mus = pjb[1:(H+1)] ## Handle hidden states with more than 1 active cause #this_mup[cand] += np.dot(np.outer(y,pjb[(1+H):]),SM).T #this_muq_tmp = np.zeros_like(my_muq[cand]) #this_muq_tmp[:,cand] = np.dot(pjb[(1+H):] * SM.T,SM) #this_muq[cand] += this_muq_tmp #this_mus[cand] += np.inner(SM.T,pjb[(1+H):]) #denom = pjb.sum() #my_mup += this_mup / denom #my_muq += this_muq / denom #my_mus += this_mus / denom ## Calculate updated mu #if 'mu' in self.to_learn: #tracing.tracepoint("M_step:update mu") #mup = np.empty_like(my_mup) #muq = np.empty_like(my_muq) #mus = np.empty_like(my_mus) #all_data_sum = np.empty_like(data_sum) #comm.Allreduce( [my_mup, MPI.DOUBLE], [mup, MPI.DOUBLE] ) #comm.Allreduce( [my_muq, MPI.DOUBLE], [muq, MPI.DOUBLE] ) #comm.Allreduce( [my_mus, MPI.DOUBLE], [mus, MPI.DOUBLE] ) #comm.Allreduce( [data_sum, MPI.DOUBLE], [all_data_sum, MPI.DOUBLE] ) #mu_numer = all_data_sum - np.dot(mus,np.dot(np.linalg.inv(muq), mup)) #mu_denom = my_N - np.dot(mus,np.dot(np.linalg.inv(muq), mus)) #mu_new = mu_numer/ mu_denom #else: #mu_new = mu # Iterate over all datapoints tracing.tracepoint("M_step:iterating") for n in range(my_N): y = my_y[n, :] - mu # length D cand = candidates[n, :] # length Hprime pjb = pjb_all[n, :] this_Wp = np.zeros_like( my_Wp) # numerator for current datapoint (H, D) this_Wq = np.zeros_like( my_Wq) # denominator for current datapoint (H, H) this_pi = 0.0 # numerator for pi update (current datapoint) # Zero active hidden cause (do nothing for the W and pi case) # this_Wp += 0. # nothing to do # this_Wq += 0. # nothing to do # this_pi += 0. # nothing to do # One active hidden cause this_Wp = np.outer(pjb[1:(H + 1)], y) this_Wq = pjb[1:(H + 1)] * np.identity(H) this_pi = pjb[1:(H + 1)].sum() this_mus = pjb[1:(H + 1)].copy() # Handle hidden states with more than 1 active cause this_Wp[cand] += np.dot(np.outer(y, pjb[(1 + H):]), SM).T this_Wq_tmp = np.zeros_like(my_Wq[cand]) this_Wq_tmp[:, cand] = np.dot(pjb[(1 + H):] * SM.T, SM) this_Wq[cand] += this_Wq_tmp this_pi += np.inner(pjb[(1 + H):], SM.sum(axis=1)) this_mus[cand] += np.inner(SM.T, pjb[(1 + H):]) denom = pjb.sum() my_Wp += this_Wp / denom my_Wq += this_Wq / denom my_pi += this_pi / denom my_mus += this_mus / denom # Calculate updated W if 'W' in self.to_learn: tracing.tracepoint("M_step:update W") Wp = np.empty_like(my_Wp) Wq = np.empty_like(my_Wq) comm.Allreduce([my_Wp, MPI.DOUBLE], [Wp, MPI.DOUBLE]) comm.Allreduce([my_Wq, MPI.DOUBLE], [Wq, MPI.DOUBLE]) #W_new = np.dot(np.linalg.inv(Wq), Wp) #W_new = np.linalg.solve(Wq, Wp) # TODO check and switch to this one rcond = -1 if float(np.__version__[2:]) >= 14.0: rcond = None W_new = np.linalg.lstsq( Wq, Wp, rcond=rcond)[0] # TODO check and switch to this one else: W_new = W # Calculate updated pi if 'pi' in self.to_learn: tracing.tracepoint("M_step:update pi") pi_new = E_pi_gamma * comm.allreduce(my_pi) / H / N_use else: pi_new = pies # Calculate updated sigma if 'sigma' in self.to_learn: tracing.tracepoint("M_step:update sigma") # Loop for sigma update: for n in range(my_N): y = my_y[n, :] - mu # length D cand = candidates[n, :] # length Hprime logpj = logpj_all[n, :] # length no_states corr = logpj.max() # scalar pjb = np.exp(logpj - corr) # Zero active hidden causes this_sigma = pjb[0] * (y**2).sum() # Hidden states with one active cause this_sigma += (pjb[1:(H + 1)] * ((W - y)**2).sum(axis=1)).sum() # Handle hidden states with more than 1 active cause SM = self.state_matrix # is (no_states, Hprime) W_ = W[cand] # is (Hprime x D) Wbar = np.dot(SM, W_) this_sigma += (pjb[(H + 1):] * ((Wbar - y)**2).sum(axis=1)).sum() denom = pjb.sum() my_sigma += this_sigma / denom sigma_new = np.sqrt(comm.allreduce(my_sigma) / D / N_use) else: sigma_new = sigma # Calculate updated mu: if 'mu' in self.to_learn: tracing.tracepoint("M_step:update mu") mus = np.empty_like(my_mus) all_data_sum = np.empty_like(data_sum) comm.Allreduce([my_mus, MPI.DOUBLE], [mus, MPI.DOUBLE]) comm.Allreduce([data_sum, MPI.DOUBLE], [all_data_sum, MPI.DOUBLE]) mu_new = all_data_sum / my_N - np.inner(W_new.T / my_N, mus) else: mu_new = mu for param in anneal.crit_params: exec('this_param = ' + param) anneal.dyn_param(param, this_param) dlog.append('N_use', N_use) return {'W': W_new.T, 'pi': pi_new, 'sigma': sigma_new, 'mu': mu_new}
def M_step(self, anneal, model_params, my_suff_stat, my_data): """ MCA M_step my_data variables used: my_data['y'] Datapoints my_data['candidates'] Candidate H's according to selection func. Annealing variables used: anneal['T'] Temperature for det. annealing AND softmax anneal['N_cut_factor'] 0.: no truncation; 1. trunc. according to model """ comm = self.comm H, Hprime = self.H, self.Hprime gamma = self.gamma W = model_params['W'].T pies = model_params['pi'] sigma = model_params['sigma'] # Read in data: my_y = my_data['y'] my_cand = my_data['candidates'] my_logpj = my_suff_stat['logpj'] my_N, D = my_y.shape N = comm.allreduce(my_N) state_mtx = self.state_matrix # shape: (no_states, Hprime) state_abs = self.state_abs # shape: (no_states,) no_states = len(state_abs) # To compute et_loglike: my_ldenom_sum = 0.0 ldenom_sum = 0.0 # Precompute T = anneal['T'] T_rho = np.maximum(T, self.rho_temp_bound) rho = 1./(1.-1./T_rho) beta = 1./T pre0 = (1.-rho)/rho pre1 = -1./2./sigma/sigma pil_bar = np.log( pies/(1.-pies) ) Wl = np.log(W) Wrho = np.exp(rho * Wl) Wsquared = W*W # Some asserts assert np.isfinite(pil_bar).all() assert np.isfinite(Wl).all() assert np.isfinite(Wrho).all() assert (Wrho > 1e-86).all() my_corr = beta*((my_logpj).max(axis=1)) # shape: (my_N,) my_pjb = np.exp(beta*my_logpj - my_corr[:, None]) # shape: (my_N, no_states) # Precompute factor for pi/gamma update A_pi_gamma = 0.; B_pi_gamma = 0. for gp in range(0, self.gamma+1): a = comb(H, gp, exact=1) * pies**gp * (1.-pies)**(H-gp) A_pi_gamma += a B_pi_gamma += gp * a # Truncate data if anneal['Ncut_factor'] > 0.0: tracing.tracepoint("M_step:truncating") my_denoms = np.log(my_pjb.sum(axis=1)) + my_corr N_use = int(N * (1 - (1 - A_pi_gamma) * anneal['Ncut_factor'])) cut_denom = parallel.allsort(my_denoms)[-N_use] which = np.array(my_denoms >= cut_denom) my_y = my_y[which] my_cand = my_cand[which] my_logpj = my_logpj[which] my_pjb = my_pjb[which] my_corr = my_corr[which] my_N, D = my_y.shape N_use = comm.allreduce(my_N) else: N_use = N dlog.append('N_use', N_use) # Allocate suff-stat arrays my_Wp = np.zeros_like(W) # shape (H, D) my_Wq = np.zeros_like(W) # shape (H, D) my_pi = 0.0 # my_sigma = 0.0 # # Iterate over all datapoints for n in range(my_N): tracing.tracepoint("M_step:iterating") y = my_y[n,:] # shape (D,) cand = my_cand[n,:] # shape (Hprime,) logpj = my_logpj[n,:] # shape (no_states,) pjb = my_pjb[n,:] # shape (no_states,) corr = my_corr[n] # scalar this_Wp = np.zeros_like(W) # numerator for W (current datapoint) (H, D) this_Wq = np.zeros_like(W) # denominator for W (current datapoint) (H, D) this_pi = 0.0 # numerator for pi update (current datapoint) this_sigma = 0.0 # numerator for gamma update (current datapoint) # Zero active hidden causes # this_Wp += 0. # nothing to do # this_Wq += 0. # nothing to do # this_pi += 0. # nothing to do this_sigma += pjb[0] * (y**2).sum() # One active hidden cause this_Wp += (pjb[1:(H+1),None] * Wsquared[:,:]) * y[None, :] this_Wq += (pjb[1:(H+1),None] * Wsquared[:,:]) this_pi += pjb[1:(H+1)].sum() this_sigma += (pjb[1:(H+1)] * ((W-y)**2).sum(axis=1)).sum() # Handle hidden states with more than 1 active cause W_ = W[cand] # is (Hprime, D) Wl_ = Wl[cand] # is ( " ") Wrho_ = Wrho[cand] # is ( " ") Wlrhom1 = (rho-1)*Wl_ # is (Hprime, D) Wlbar = np.log(np.dot(state_mtx,Wrho_)) / rho # is (no_states, D) Wbar = np.exp(Wlbar) # is (no_states, D) blpj = beta*logpj[1+H:] - corr # is (no_states,) Aid = (state_mtx[:,:, None] * np.exp(blpj[:,None,None] + (1-rho)*Wlbar[:, None, :] + Wlrhom1[None, :, :])).sum(axis=0) assert np.isfinite(Wlbar).all() assert np.isfinite(Wbar).all() assert np.isfinite(pjb).all() assert np.isfinite(Aid).all() this_Wp[cand] += Aid * y[None, :] this_Wq[cand] += Aid this_pi += (pjb[1+H:] * state_abs).sum() this_sigma += (pjb[1+H:] * ((Wbar-y)**2).sum(axis=1)).sum() denom = pjb.sum() my_Wp += this_Wp / denom my_Wq += this_Wq / denom my_pi += this_pi / denom my_sigma += this_sigma / denom my_ldenom_sum += np.log(np.sum(np.exp(logpj))) #For loglike computation # Calculate updated W if 'W' in self.to_learn: tracing.tracepoint("M_step:update W") Wp = np.empty_like(my_Wp) Wq = np.empty_like(my_Wq) assert np.isfinite(my_Wp).all() assert np.isfinite(my_Wq).all() comm.Allreduce( [my_Wp, MPI.DOUBLE], [Wp, MPI.DOUBLE] ) comm.Allreduce( [my_Wq, MPI.DOUBLE], [Wq, MPI.DOUBLE] ) # Make sure wo do not devide by zero tiny = np.finfo(Wq.dtype).tiny Wp[Wq < tiny] = 0. Wq[Wq < tiny] = tiny W_new = (Wp / Wq).T else: W_new = W.T # Calculate updated pi if 'pi' in self.to_learn: tracing.tracepoint("M_step:update pi") assert np.isfinite(my_pi).all() pi_new = A_pi_gamma / B_pi_gamma * pies * comm.allreduce(my_pi) / N_use else: pi_new = pies # Calculate updated sigma if 'sigma' in self.to_learn: # TODO: XXX see LinCA XXX (merge!) tracing.tracepoint("M_step:update sigma") assert np.isfinite(my_sigma).all() sigma_new = np.sqrt(comm.allreduce(my_sigma) / D / N_use) else: sigma_new = sigma #Put all together and compute (always) et_approx_likelihood ldenom_sum = comm.allreduce(my_ldenom_sum) lAi = (H * np.log(1. - pi_new)) - ((D/2) * np.log(2*pi)) -( D * np.log(sigma_new)) #For practical and et approx reasons we use: sum of restected respons=1 loglike_et = (lAi * N_use) + ldenom_sum return { 'W': W_new, 'pi': pi_new, 'sigma': sigma_new , 'Q':loglike_et}
def M_step(self, anneal, model_params, my_suff_stat, my_data): """Ternary Sparse Coding M-Step This function is responsible for finding the optimal model parameters given an approximation of the posterior distribution. Parameters ---------- anneal : Annealing object Annealing type obje ct containing training schedule information anneal['T'] : Temperature for det. annealing anneal['N_cut_factor']: 0. no truncation; 1. trunc. according to model model_params : dict dictionary containing model parameters model_params['W']: (H,D) ndarray linear dictionary model_params['pi']: (K,) ndarray prior parameters model_params['sigma']: float standard deviation of noise model my_suff_stat : dict dictionary containing inforamtion about the joint distribution my_suff_stat['logpj']: (my_N,no_states) ndarray logarithm of joint of data and latent variable states my_data : dict data dictionary my_data['y']: (my_N,D) ndarray datapoints my_data['candidates']: (my_n,Hprime) Candidate H's according to selection func. Returns ------- dict dictionary containing updated model parameters dict['W']: (H,D) ndarray linear dictionary dict['pi']: (K,) ndarray prior parameters dict['sigma']: float standard deviation of noise model """ comm = self.comm H = self.H gamma = self.gamma W = model_params['W'].T pi = model_params['pi'] sigma = model_params['sigma'] # Read in data: my_y = my_data['y'].copy() candidates = my_data['candidates'] logpj_all = my_suff_stat['logpj'] all_denoms = np.exp(logpj_all).sum(axis=1) my_N, D = my_y.shape N = comm.allreduce(my_N) SM = self.state_matrix #[SM_bool] # shape: (no_states, Hprime) state_abs = np.abs(SM).sum(axis=1) # Precompute factor for pi update A_pi_gamma = 0.0 B_pi_gamma = 0.0 for gam1 in range(gamma + 1): for gam2 in range(gamma - gam1 + 1): cmb = comb(gam1, gam1) * comb(gam1 + gam2, gam2) * comb( H, H - gam1 - gam2) A_pi_gamma += cmb * ((pi / 2)**(gam1 + gam2)) * ( (1 - pi)**(H - gam1 - gam2)) B_pi_gamma += (gam1 + gam2) * cmb * ( (pi / 2)**(gam1 + gam2)) * ((1 - pi)**(H - gam1 - gam2)) E_pi_gamma = pi * H * A_pi_gamma / B_pi_gamma #Truncate data if anneal['Ncut_factor'] > 0.0: #tracing.tracepoint("M_step:truncating") N_use = int(N * (1 - (1 - A_pi_gamma) * anneal['Ncut_factor'])) cut_denom = parallel.allsort(all_denoms)[-N_use] which = np.array(all_denoms >= cut_denom) candidates = candidates[which] logpj_all = logpj_all[which] my_y = my_y[which] my_N, D = my_y.shape N_use = comm.allreduce(my_N) else: N_use = N #Log-Likelihood: L = -0.5 * D * np.log(2 * np.pi * sigma**2) - np.log(A_pi_gamma) Fs = np.log(np.exp(logpj_all).sum(axis=1)).sum() L += comm.allreduce(Fs) / N_use dlog.append('L', L) # Precompute corr_all = logpj_all.max(axis=1) # shape: (my_N,) pjb_all = np.exp(logpj_all - corr_all[:, None]) # shape: (my_N, no_states) # Allocate my_Wp = np.zeros_like(W) # shape (H, D) my_Wq = np.zeros((H, H)) # shape (H, H) my_pi = 0.0 # my_sigma = 0.0 # # Iterate over all datapoints for n in range(my_N): #tracing.tracepoint("M_step:iterating") y = my_y[n, :] # length D cand = candidates[n, :] # length Hprime pjb = pjb_all[n, :] this_Wp = np.zeros_like( my_Wp) # numerator for current datapoint (H, D) this_Wq = np.zeros_like( my_Wq) # denominator for current datapoint (H, H) this_pi = np.zeros_like( pi) # numerator for pi update (current datapoint) # Handle hidden states with more than 1 active cause this_Wp[cand] += np.dot(np.outer(y, pjb), SM).T this_Wq_tmp = np.zeros_like(my_Wq[cand]) this_Wq_tmp[:, cand] = np.dot(pjb * SM.T, SM) this_Wq[cand] += this_Wq_tmp this_pi += np.inner(pjb, state_abs) denom = pjb.sum() my_Wp += this_Wp / denom my_Wq += this_Wq / denom my_pi += this_pi / denom #Calculate updated W if 'W' in self.to_learn: #tracing.tracepoint("M_step:update W") Wp = np.empty_like(my_Wp) Wq = np.empty_like(my_Wq) comm.Allreduce([my_Wp, MPI.DOUBLE], [Wp, MPI.DOUBLE]) comm.Allreduce([my_Wq, MPI.DOUBLE], [Wq, MPI.DOUBLE]) W_new = np.dot(np.linalg.pinv(Wq), Wp) else: W_new = W # Calculate updated pi pi_new = np.empty_like(pi) if 'pi' in self.to_learn: #tracing.tracepoint("M_step:update pi") pi_new = E_pi_gamma * comm.allreduce(my_pi) / H / N_use else: pi_new = pi # Calculate updated sigma if 'sigma' in self.to_learn: #tracing.tracepoint("M_step:update sigma") # Loop for sigma update: for n in range(my_N): #tracing.tracepoint("M_step:update sigma iteration") y = my_y[n, :] # length D cand = candidates[n, :] # length Hprime logpj = logpj_all[n, :] # length no_states corr = logpj.max() # scalar pjb = np.exp(logpj - corr) # Zero active hidden causes #this_sigma = pjb[0] * (y**2).sum() # Hidden states with one active cause #this_sigma += (pjb[1:(H+1)] * ((W-y)**2).sum(axis=1)).sum() # Handle hidden states with more than 1 active cause #SM = self.state_matrix # is (no_states, Hprime) W_ = W[cand] # is (Hprime x D) Wbar = np.dot(SM, W_) this_sigma = (pjb * ((Wbar - y)**2).sum(axis=1)).sum() denom = pjb.sum() my_sigma += this_sigma / denom sigma_new = np.sqrt(comm.allreduce(my_sigma) / D / N_use) else: sigma_new = sigma for param in anneal.crit_params: exec('this_param = ' + param) anneal.dyn_param(param, this_param) dlog.append('N_use', N_use) return { 'W': W_new.transpose(), 'pi': pi_new, 'sigma': sigma_new, 'Q': 0. }
def M_step(self, anneal, model_params, my_suff_stat, my_data): """ MCA M_step my_data variables used: my_data['y'] Datapoints my_data['candidates'] Candidate H's according to selection func. Annealing variables used: anneal['T'] Temperature for det. annealing AND softmax anneal['N_cut_factor'] 0.: no truncation; 1. trunc. according to model """ comm = self.comm H, Hprime = self.H, self.Hprime gamma = self.gamma W = model_params['W'].T pies = model_params['pi'] sigma = model_params['sigma'] # Read in data: my_y = my_data['y'] my_cand = my_data['candidates'] my_logpj = my_suff_stat['logpj'] my_N, D = my_y.shape N = comm.allreduce(my_N) state_mtx = self.state_matrix # shape: (no_states, Hprime) state_abs = self.state_abs # shape: (no_states,) no_states = len(state_abs) # Disable some warnings old_seterr = np.seterr(divide='ignore', under='ignore') # To compute et_loglike: my_ldenom_sum = 0.0 ldenom_sum = 0.0 # Precompute T = anneal['T'] T_rho = np.maximum(T, self.rho_T_bound) rho = 1./(1.-1./T_rho) rho = np.maximum(np.minimum(rho, self.rho_ubound), self.rho_lbound) beta = 1./T pre1 = -1./2./sigma/sigma pil_bar = np.log( pies/(1.-pies) ) Wl = accel.log(np.abs(W)) Wrho = accel.exp(rho * Wl) Wrhos = np.sign(W) * Wrho Wsquared = W*W # Some asserts assert np.isfinite(pil_bar).all() assert np.isfinite(Wl).all() assert np.isfinite(Wrho).all() assert (Wrho > 1e-86).all() my_corr = beta*((my_logpj).max(axis=1)) # shape: (my_N,) my_logpjb = beta*my_logpj - my_corr[:, None] # shape: (my_N, no_states) my_pj = accel.exp(my_logpj) # shape: (my_N, no_states) my_pjb = accel.exp(my_logpjb) # shape: (my_N, no_states) # Precompute factor for pi update and ET cutting A_pi_gamma = 0.; B_pi_gamma = 0. for gp in range(0, self.gamma+1): a = comb(H, gp) * pies**gp * (1.-pies)**(H-gp) A_pi_gamma += a B_pi_gamma += gp * a # Truncate data if anneal['Ncut_factor'] > 0.0: tracing.tracepoint("M_step:truncating") my_logdenoms = accel.log(my_pjb.sum(axis=1)) + my_corr N_use = int(N * (1 - (1 - A_pi_gamma) * anneal['Ncut_factor'])) cut_denom = parallel.allsort(my_logdenoms)[-N_use] my_sel, = np.where(my_logdenoms >= cut_denom) my_N, = my_sel.shape N_use = comm.allreduce(my_N) else: my_N,_ = my_y.shape my_sel = np.arange(my_N) N_use = N # Allocate suff-stat arrays my_Wp = np.zeros_like(W) # shape (H, D) my_Wq = np.zeros_like(W) # shape (H, D) my_pi = 0.0 # my_sigma = 0.0 # # Iterate over all datapoints tracing.tracepoint("M_step:iterating...") dlog.append('N_use', N_use) for n in my_sel: y = my_y[n,:] # shape (D,) cand = my_cand[n,:] # shape (Hprime,) logpj = my_logpj[n,:] # shape (no_states,) logpjb = my_logpjb[n,:] # shape (no_states,) pj = my_pj[n,:] # shape (no_states,) pjb = my_pjb[n,:] # shape (no_states,) this_Wp = np.zeros_like(W) # numerator for W (current datapoint) (H, D) this_Wq = np.zeros_like(W) # denominator for W (current datapoint) (H, D) this_pi = 0.0 # numerator for pi update (current datapoint) this_sigma = 0.0 # numerator for gamma update (current datapoint) # Zero active hidden causes # this_Wp += 0. # nothing to do # this_Wq += 0. # nothing to do # this_pi += 0. # nothing to do this_sigma += pjb[0] * (y**2).sum() # One active hidden cause this_Wp += (pjb[1:(H+1),None]) * y[None, :] this_Wq += (pjb[1:(H+1),None]) this_pi += pjb[1:(H+1)].sum() this_sigma += (pjb[1:(H+1)] * ((W-y)**2).sum(axis=1)).sum() # Handle hidden states with more than 1 active cause W_ = W[cand] # is (Hprime, D) Wl_ = Wl[cand] # is ( " ") Wrho_ = Wrho[cand] # is ( " ") Wrhos_ = Wrhos[cand] # is ( " ") #Wbar = calc_Wbar(state_mtx, W_) #Wlbar = np.log(np.abs(Wbar)) t0 = np.dot(state_mtx, Wrhos_) Wlbar = accel.log(np.abs(t0)) / rho # is (no_states, D) #Wlbar = np.maximum(Wlbar, -9.21) Wbar = np.sign(t0)*accel.exp(Wlbar) # is (no_states, D) t = Wlbar[:, None, :]-Wl_[None, :, :] t = np.maximum(t, 0.) Aid = state_mtx[:,:, None] * accel.exp(logpjb[H+1:,None,None] - (rho-1)*t) Aid = Aid.sum(axis=0) #Aid = calc_Aid(logpjb[H+1:], W_, Wl_, state_mtx, Wbar, Wlbar, rho) #assert np.isfinite(Wlbar).all() #assert np.isfinite(Wbar).all() #assert np.isfinite(pjb).all() #assert np.isfinite(Aid).all() this_Wp[cand] += Aid * y[None, :] this_Wq[cand] += Aid this_pi += (pjb[1+H:] * state_abs).sum() this_sigma += (pjb[1+H:] * ((Wbar-y)**2).sum(axis=1)).sum() denom = pjb.sum() my_Wp += this_Wp / denom my_Wq += this_Wq / denom my_pi += this_pi / denom my_sigma += this_sigma / denom #self.tbl.append("logpj", logpj) #self.tbl.append("corr", my_corr[n]) #self.tbl.append("denom", denom) #self.tbl.append("cand", cand) #self.tbl.append("Aid", Aid) my_ldenom_sum += accel.log(np.sum(accel.exp(logpj))) #For loglike computation # Calculate updated W if 'W' in self.to_learn: tracing.tracepoint("M_step:update W") Wp = np.empty_like(my_Wp) Wq = np.empty_like(my_Wq) assert np.isfinite(my_Wp).all() assert np.isfinite(my_Wq).all() comm.Allreduce( [my_Wp, MPI.DOUBLE], [Wp, MPI.DOUBLE] ) comm.Allreduce( [my_Wq, MPI.DOUBLE], [Wq, MPI.DOUBLE] ) # Make sure wo do not devide by zero tiny = self.tol Wq[Wq < tiny] = tiny # Calculate updated W W_new = Wp / Wq # Add inertia depending on Wq alpha = 2.5 inertia = np.maximum(1. - accel.exp(-Wq / alpha), 0.2) W_new = inertia*W_new + (1-inertia)*W W_new = W_new.T else: W_new = W.T # Calculate updated pi if 'pi' in self.to_learn: tracing.tracepoint("M_step:update pi") assert np.isfinite(my_pi).all() pi_new = A_pi_gamma / B_pi_gamma * pies * comm.allreduce(my_pi) / N_use else: pi_new = pies # Calculate updated sigma if 'sigma' in self.to_learn: # TODO: XXX see LinCA XXX (merge!) tracing.tracepoint("M_step:update sigma") assert np.isfinite(my_sigma).all() sigma_new = np.sqrt(comm.allreduce(my_sigma) / D / N_use) else: sigma_new = sigma # Put all together and compute (always) et_approx_likelihood ldenom_sum = comm.allreduce(my_ldenom_sum) lAi = (H * np.log(1. - pi_new)) - ((D/2) * np.log(2*pi)) -( D * np.log(sigma_new)) # For practical and et approx reasons we use: sum of restected respons=1 loglike_et = (lAi * N_use) + ldenom_sum # Restore np.seterr np.seterr(**old_seterr) return { 'W': W_new, 'pi': pi_new, 'sigma': sigma_new , 'Q':loglike_et}