def _update_log_varpar_assignment(self, children_label_np): digamma_gamma1plus2 = sps.digamma(np.sum(self.varpar_stick, axis=1)) log_varpar_assignment = ( np.append( sps.digamma(self.varpar_stick[:, 0]) - digamma_gamma1plus2, 0)[np.newaxis, :] + np.append( 0, np.cumsum( sps.digamma(self.varpar_stick[:, 1]) - digamma_gamma1plus2))[np.newaxis, :] + np.matmul( np.exp(children_label_np, ), np.transpose( np.matmul(np.exp(self.log_varpar_label), error_mat)))) # denom=np.sum(self.log_varpar_assignment, axis=-1)[:,np.newaxis] # self.log_varpar_assignment /= denom+(denom==0).astype(np.float64) # self.log_varpar_assignment = self.log_varpar_assignment.reshape(self.num_children, self.num_tables_child, self.T) # self.log_varpar_assignment-=spm.logsumexp(self.log_varpar_assignment, axis=-1)[:,np.newaxis] log_varpar_assignment -= spm.logsumexp(log_varpar_assignment, axis=-1)[:, np.newaxis] self.log_varpar_assignment = (log_varpar_assignment).reshape( self.num_children, self.num_tables_child, self.T) self.phi_x_tau_x_error = np.matmul( np.exp(lla.logmatmul(log_varpar_assignment, self.log_varpar_label)), error_mat).reshape(self.num_children, self.num_tables_child, num_symbols)
def _update_log_varpar_label(self, children_label_np): log_varpar_label = ( # np.matmul( # self.mother.log_varpar_assignment[self.id], # Txmother_T # np.matmul( # self.mother.log_varpar_label, # mother_T x |∑| # error_mat # |∑|x|∑| # ) # ) self.mother.phi_x_tau_x_error[self.id] + np.matmul( np.exp( lla.logmatmul( np.transpose( self.log_varpar_assignment.reshape( self.num_children * self.num_tables_child, self.T)), # Tx(num_children*num_tables_child) children_label_np # (num_children*num_tables_child)x|∑| )), error_mat # |∑|x|∑| )) # denom=np.sum(self.log_varpar_label, axis=-1)[:, np.newaxis] # self.log_varpar_label /= denom+(denom==0).astype(np.float64) self.log_varpar_label = ( log_varpar_label - spm.logsumexp(log_varpar_label, axis=-1)[:, np.newaxis])
def set_log_varpar_assignment(self): self.log_varpar_assignment = np.log( np.random.dirichlet(np.ones(self.T), len( self.customers))) # phi in Blei and Jordan. self.phi_x_tau_x_error = np.matmul( np.exp( lla.logmatmul(self.log_varpar_assignment, self.log_varpar_label)), error_mat) if not self.log_varpar_assignment.size: self.log_varpar_label = np.zeros(self.log_varpar_label.shape)
def update_varpars(self): if self.log_varpar_assignment.size: self._update_varpar_stick() assert not np.any(np.isnan(self.varpar_stick)), ('stick\n', self.varpar_stick) self._update_log_varpar_label() assert not np.any(np.isnan( self.log_varpar_label)), ('label\n', self.log_varpar_label) self._update_log_varpar_assignment() assert not np.any(np.isnan( self.log_varpar_assignment)), ('assignment\n', self.log_varpar_assignment) self.phi_x_tau_x_error = np.matmul( np.exp( lla.logmatmul(self.log_varpar_assignment, self.log_varpar_label)), error_mat) assert not np.any(np.isnan( self.phi_x_tau_x_error)), ('phi_x_tau_x_error\n', self.phi_x_tau_x_error)
def set_log_varpar_assignment(self): # self.children = children # Represented by a list of children restaurants. # self.children_label = [child.log_varpar_label for child in self.children] self.num_children = len(self.children) self.num_tables_child = self.children[0].log_varpar_label.shape[ 0] # T for children. self.log_varpar_assignment = np.log( np.random.dirichlet(np.ones( self.T), (self.num_children, self.num_tables_child))) # phi in Blei and Jordan. # self.tau_x_error = np.matmul( # self.log_varpar_label, # error_mat # ) self.phi_x_tau_x_error = np.matmul( np.exp( lla.logmatmul( self.log_varpar_assignment.reshape( self.num_children * self.num_tables_child, self.T), self.log_varpar_label)), error_mat).reshape(self.num_children, self.num_tables_child, num_symbols)