def _update_log_varpar_assignment(self, children_label_np):
     digamma_gamma1plus2 = sps.digamma(np.sum(self.varpar_stick, axis=1))
     log_varpar_assignment = (
         np.append(
             sps.digamma(self.varpar_stick[:, 0]) - digamma_gamma1plus2,
             0)[np.newaxis, :] + np.append(
                 0,
                 np.cumsum(
                     sps.digamma(self.varpar_stick[:, 1]) -
                     digamma_gamma1plus2))[np.newaxis, :] +
         np.matmul(
             np.exp(children_label_np, ),
             np.transpose(
                 np.matmul(np.exp(self.log_varpar_label), error_mat))))
     # 		denom=np.sum(self.log_varpar_assignment, axis=-1)[:,np.newaxis]
     # 		self.log_varpar_assignment /= denom+(denom==0).astype(np.float64)
     # 		self.log_varpar_assignment = self.log_varpar_assignment.reshape(self.num_children, self.num_tables_child, self.T)
     # 		self.log_varpar_assignment-=spm.logsumexp(self.log_varpar_assignment, axis=-1)[:,np.newaxis]
     log_varpar_assignment -= spm.logsumexp(log_varpar_assignment,
                                            axis=-1)[:, np.newaxis]
     self.log_varpar_assignment = (log_varpar_assignment).reshape(
         self.num_children, self.num_tables_child, self.T)
     self.phi_x_tau_x_error = np.matmul(
         np.exp(lla.logmatmul(log_varpar_assignment,
                              self.log_varpar_label)),
         error_mat).reshape(self.num_children, self.num_tables_child,
                            num_symbols)
 def _update_log_varpar_label(self, children_label_np):
     log_varpar_label = (
         # 								np.matmul(
         # 									self.mother.log_varpar_assignment[self.id], # Txmother_T
         # 									np.matmul(
         # 										self.mother.log_varpar_label, # mother_T x |∑|
         # 										error_mat # |∑|x|∑|
         # 										)
         # 									)
         self.mother.phi_x_tau_x_error[self.id] + np.matmul(
             np.exp(
                 lla.logmatmul(
                     np.transpose(
                         self.log_varpar_assignment.reshape(
                             self.num_children * self.num_tables_child,
                             self.T)),  # Tx(num_children*num_tables_child)
                     children_label_np  # (num_children*num_tables_child)x|∑|
                 )),
             error_mat  # |∑|x|∑|
         ))
     # 		denom=np.sum(self.log_varpar_label, axis=-1)[:, np.newaxis]
     # 		self.log_varpar_label /= denom+(denom==0).astype(np.float64)
     self.log_varpar_label = (
         log_varpar_label -
         spm.logsumexp(log_varpar_label, axis=-1)[:, np.newaxis])
 def set_log_varpar_assignment(self):
     self.log_varpar_assignment = np.log(
         np.random.dirichlet(np.ones(self.T), len(
             self.customers)))  # phi in Blei and Jordan.
     self.phi_x_tau_x_error = np.matmul(
         np.exp(
             lla.logmatmul(self.log_varpar_assignment,
                           self.log_varpar_label)), error_mat)
     if not self.log_varpar_assignment.size:
         self.log_varpar_label = np.zeros(self.log_varpar_label.shape)
 def update_varpars(self):
     if self.log_varpar_assignment.size:
         self._update_varpar_stick()
         assert not np.any(np.isnan(self.varpar_stick)), ('stick\n',
                                                          self.varpar_stick)
         self._update_log_varpar_label()
         assert not np.any(np.isnan(
             self.log_varpar_label)), ('label\n', self.log_varpar_label)
         self._update_log_varpar_assignment()
         assert not np.any(np.isnan(
             self.log_varpar_assignment)), ('assignment\n',
                                            self.log_varpar_assignment)
         self.phi_x_tau_x_error = np.matmul(
             np.exp(
                 lla.logmatmul(self.log_varpar_assignment,
                               self.log_varpar_label)), error_mat)
         assert not np.any(np.isnan(
             self.phi_x_tau_x_error)), ('phi_x_tau_x_error\n',
                                        self.phi_x_tau_x_error)
 def set_log_varpar_assignment(self):
     # 		self.children = children # Represented by a list of children restaurants.
     # 		self.children_label = [child.log_varpar_label for child in self.children]
     self.num_children = len(self.children)
     self.num_tables_child = self.children[0].log_varpar_label.shape[
         0]  # T for children.
     self.log_varpar_assignment = np.log(
         np.random.dirichlet(np.ones(
             self.T), (self.num_children,
                       self.num_tables_child)))  # phi in Blei and Jordan.
     # 		self.tau_x_error = np.matmul(
     # 								self.log_varpar_label,
     # 								error_mat
     # 								)
     self.phi_x_tau_x_error = np.matmul(
         np.exp(
             lla.logmatmul(
                 self.log_varpar_assignment.reshape(
                     self.num_children * self.num_tables_child, self.T),
                 self.log_varpar_label)),
         error_mat).reshape(self.num_children, self.num_tables_child,
                            num_symbols)