def run_time(self, Vt, Wt=None, Zt=None, Ht=None, niter=5): self.time_analyzer.V = Vt if Wt is None or Zt is None or Ht is None: initW, initZ, initH = self.time_analyzer.initialize() if Wt is None: Wt = initW if Zt is None: Zt = initZ if Ht is None: Ht = initH for iter in xrange(niter): logprob, WZH = self.time_analyzer.do_estep(Wt, Zt, Ht) logger.info('Iteration t%d: logprob = %f', iter, logprob) Wt, Zt, Ht = self.time_analyzer.do_mstep(iter) assert Wt.ndim == 3 assert Zt.ndim == 1 assert Ht.ndim == 2 # Ht : (rank, rank*F*T) # Ht_sum : (rank, F*T) meta_Ht = self.sum_pieces(Ht) Htt = meta_Ht.reshape(self.rank, self.f_steps, self.t_steps) Hmax = np.max(Htt) plt.clf() plt.imshow(np.rollaxis(Htt[0:3]/Hmax*2, 0, 3), origin='lower', aspect='auto', interpolation='nearest') plt.draw() meta_logprob, meta_WZH = self.meta_time_analyzer.do_estep(Wt, Zt, meta_Ht) logger.info('Meta t%d: logprob = %f', iter, logprob) meta_Wt, meta_Zt, meta_Ht = self.meta_time_analyzer.do_mstep(0) return Wt, Zt, Ht, meta_Ht
def run_freq(self, Vf, Wf=None, Zf=None, Hf=None, niter=5): self.freq_analyzer.V = Vf if Wf is None or Zf is None or Hf is None: initW, initZ, initH = self.freq_analyzer.initialize() if Wf is None: Wf = initW if Zf is None: Zf = initZ if Hf is None: Hf = initH for iter in xrange(niter): logprob, WZH = self.freq_analyzer.do_estep(Wf, Zf, Hf) logger.info('Iteration f%d: logprob = %f', iter, logprob) Wf, Zf, Hf = self.freq_analyzer.do_mstep(iter) plt.clf() plt.plot(Wf[...,0]) plt.draw() # Hf : (rank, F, rank*T) # Hf_sum : (rank, F, T) meta_Hf = self.sum_pieces(Hf) meta_logprob, meta_WZH = self.meta_freq_analyzer.do_estep( Wf, Zf, meta_Hf) logger.info('Meta f%d: logprob = %f', iter, logprob) meta_Wf, meta_Zf, meta_Hf = self.meta_freq_analyzer.do_mstep(0) return Wf, Zf, Hf, meta_Hf
def analyze(cls, V, rank, niter=100, convergence_thresh=1e-9, printiter=50, plotiter=None, plotfilename=None, initWf=None, initWt=None, initZ=None, initH=None, updateW=True, updateZ=True, updateH=True, **kwargs): norm = V.sum() V /= norm params = cls(V, rank, **kwargs) iWf, iWt, iZ, iH = params.initialize() Wf = iWf if initWf is None else initWf.copy() Wt = iWt if initWt is None else initWt.copy() Z = iZ if initZ is None else initZ.copy() H = iH if initH is None else initH.copy() params.Wf = Wf params.Wt = Wt params.Z = Z params.H = H oldlogprob = -np.inf for n in xrange(niter): logprob, WZH = params.do_estep(Wf, Wt, Z, H) if n % printiter == 0: logger.info('Iteration %d: logprob = %f', n, logprob) if plotiter and n % plotiter == 0: params.plot(V, Wf, Wt, Z, H, n) if not plotfilename is None: plt.savefig('%s_%04d.png' % (plotfilename, n)) if logprob < oldlogprob: logger.debug('Warning: logprob decreased from %f to %f at ' 'iteration %d!', oldlogprob, logprob, n) #import pdb; pdb.set_trace() elif n > 0 and logprob - oldlogprob < convergence_thresh: logger.info('Converged at iteration %d', n) break oldlogprob = logprob nWf, nWt, nZ, nH = params.do_mstep(n) if updateW: Wf = nWf Wt = nWt if updateZ: Z = nZ if updateH: H = nH params.Wf = Wf params.Wt = Wt params.Z = Z params.H = H if plotiter: params.plot(V, Wf, Wt, Z, H, n) if not plotfilename is None: plt.savefig('%s_%04d.png' % (plotfilename, n)) logger.info('Iteration %d: final logprob = %f', n, logprob) recon = norm * WZH return Wf, Wt, Z, H, norm, recon, logprob
def _prune_undeeded_bases(self, Wf, Wt, Z, H, curriter): """Discards bases which do not contribute to the decomposition""" threshold = 10 * EPS zidx = np.argwhere(Z > threshold).flatten() if len(zidx) < self.rank and curriter >= self.minpruneiter: logger.info('Rank decreased from %d to %d during iteration %d', self.rank, len(zidx), curriter) self.rank = len(zidx) Z = Z[zidx] Wf = Wf[:,zidx] Wt = Wt[zidx,:] H = H[zidx,:] self.VRWf = self.VRWf[:,zidx] self.VRWt = self.VRWt[zidx,:] self.VRH = self.VRH[:,zidx] if isinstance(self.alphaWf, np.ndarray): self.alphaWf = self.alphaWf[:,zidx] if isinstance(self.alphaWt, np.ndarray): self.alphaWt = self.alphaWt[zidx,:] if isinstance(self.alphaH, np.ndarray): self.alphaH = self.alphaH[zidx,:] if isinstance(self.alphaZ, np.ndarray): self.alphaZ = self.alphaZ[zidx] return Wf, Wt, Z, H