def update(self, Y, max_em_itr=20): ''' hmm.update(Y) ''' # --- Index and array for VB if self.vbs is None: self.vbs = zeros(max_em_itr) ibgn = 0 else: ibgn = len(self.vbs) self.vbs = append(self.vbs, zeros(max_em_itr)) iend = ibgn + max_em_itr # --- data size self.data_dim, data_len = Y.shape # --- initialise expectation self.init_expt_s(data_len) # --- Y Y' YY = einsum('dt,et->det', Y, Y) # --- EM iteration logger.info('Update order: %s' % self.update_order) for i in range(ibgn, iend): self.log_info_update_itr(iend, i, interval_digit=1) for j, uo in enumerate(self.update_order): if uo == 'E': self.qs.update(Y, self.theta, YY) elif uo == 'M': self.theta.update(Y, self.qs.expt, self.qs.expt2, YY) else: logger.error('%s is not supported' % uo) do_stop_itr = self._update_vb(i) # --- early stop if do_stop_itr: break self.expt_s = self.qs.expt
def cluster_mdl(self, n_states): cluster = self._cluster_dic.get(n_states, None) if cluster is None: err_msg = 'cluster model %d states has not been trained' % n_states logger.error(err_msg) return None return cluster
def lda_mdl(self, n_states, n_cat): lda = self._lda_dic.get((n_states, n_cat), None) if lda is None: err_msg = 'LDA (%d, %d) states has not been trained' % (n_states, n_cat) logger.error(err_msg) return None return lda
def save_params(self, file_name, by_posterior=True): prm = self.get_params(by_posterior) try: with open(file_name, 'wb') as f: pickle.dump(prm, f) logger.info('Saved: %s' % file_name) return True except Exception as exception: logger.error(exception) return False
def update(self, Y, s, ss, YY=None): ''' theta.update(Y, s, ss, YY=None) ''' for ut in self.update_order: if ut == 'pi': self.qpi.update(s) elif ut == 'A': self.qA.update(ss) elif ut == 'MuR': self.qmur.update(Y, s, YY=YY) else: logger.error('%s is not supported' % ut)
def load_param_dict(self, file_name): ''' mfa.save_param_dict(file_name) @argvs file_name: string ''' ret = False try: with open(file_name, 'r') as f: prm = pickle.load(f) self.set_params(prm) ret = True except IOError as e: logger.error('%s' % e) return ret
def update(self, Y, expt_s, YY=None): ''' theta.update(Y, expt_s, YY=None) @argv Y: data, np.array(data_dim, data_len) expt_s: <S>, np.array(n_states, data_len) YY: YY^T, np.array(data_dim, data_dim, data_len) ''' for ut in self.update_order: if ut == 'MuR': self.qmur.update(Y, expt_s, YY=YY) elif ut == 'pi': self.qpi.update(expt_s) else: logger.error('%s is not supported' % ut)
def update(self, Y, max_em_itr): ''' mfa.update() @argvs Y: np.array(data_dim, data_len) ''' logger.info('update order %s, in Theta %s' % (self.update_order, self.theta.update_order)) ibgn = 0 iend = max_em_itr for i in range(ibgn, iend): for j, uo in enumerate(self.update_order): if i % 10 == 0: logger.info('iteration %3d (%s)' % (i, uo)) if uo == 'E': self.zs.update(Y, self.theta) elif uo == 'M': self.theta.update(Y, self.zs) else: logger.error('%s is not supported' % uo) sys.exit(-1)
def update(self, Y, zs): ''' theta.update(Y, zs) Y: np.array(data_dim, data_len) zs: qZS class object ''' zs.init_expt(Y.shape[-1], Y) s = zs.s.expt sum_szz = einsum('ljkt->ljk', zs.expt_szz) sum_ysz = einsum('dt,lkt->dlk', Y, zs.expt_sz) sum_yys = einsum('dt,dt,kt->dk', Y, Y, s) for uo in self.update_order: if uo == 'Lamb': self.lamb.update(self.r, sum_szz, sum_ysz) elif uo == 'R': self.r.update(self.lamb, s, sum_szz, sum_yys, sum_ysz) elif uo == 'Pi': self.pi.update(s) else: logger.error('%s is not supported' % uo) sys.exit(-1)
def check_vb_increase(cls, vbs, i): dst = False if i < 1: dst = True else: vb_prv = nround(vbs[i - 1], decimals=cls.decimals) vb_cur = nround(vbs[i], decimals=cls.decimals) vb_diff = vb_cur - vb_prv if vb_diff < 0: logger.error(' '.join([ 'vb decreased.', 'diff: %.10f' % vb_diff, 'iter %3d: %.10f' % (i, vb_cur), 'iter %3d: %.10f' % (i - 1, vb_prv), ])) dst = False if vb_cur == vb_prv: dst = True else: dst = True return dst
def estimate(self, data, n_states, n_cat): ''' estimate category of 1 data chunk ''' cluster = self._cluster_dic.get(n_states, None) lda = self._lda_dic.get((n_states, n_cat), None) if cluster is None: err_msg = 'Cluster model %d states has not been trained' % n_states logger.error(err_msg) return None if lda is None: err_msg = 'LDA (%d, %d) states has not been trained' % (n_states, n_cat) logger.error(err_msg) return None else: est_x, est_s, vb = cluster.estimate(data.T) cluster_prms = cluster.get_params() s_list, z_list, lda_prms = lda.predict([est_s]) states = s_list[0] category = z_list[0] return est_x, states, category, cluster_prms, lda_prms