def vbis_update(self, measurement, likelihood, prior, init_mean=0, init_var=1, init_alpha=0.5, init_xi=1, num_samples=None, use_LWIS=False): """VB update with importance sampling for Gaussian and Softmax. """ if num_samples is None: num_samples = self.num_importance_samples if use_LWIS: q_mu = np.asarray(prior.means[0]) log_c_hat = np.nan else: # Use VB update q_mu, var_VB, log_c_hat = self.vb_update(measurement, likelihood, prior, init_mean, init_var, init_alpha, init_xi) q_var = np.asarray(prior.covariances[0]) # Importance distribution q = GaussianMixture(1, q_mu, q_var) # Importance sampling correction w = np.zeros(num_samples) # Importance weights x = q.rvs(size=num_samples) # Sampled points x = np.asarray(x) if hasattr(likelihood, 'subclasses'): measurement_class = likelihood.subclasses[measurement] else: measurement_class = likelihood.classes[measurement] # Compute parameters using samples w = prior.pdf(x) * measurement_class.probability(state=x) / q.pdf(x) w /= np.sum(w) # Normalize weights mu_hat = np.sum(x.T * w, axis=-1) # <>TODO: optimize this var_hat = np.zeros_like(np.asarray(q_var)) for i in range(num_samples): x_i = np.asarray(x[i]) var_hat = var_hat + w[i] * np.outer(x_i, x_i) var_hat -= np.outer(mu_hat, mu_hat) # Ensure properly formatted output if mu_hat.size == 1 and mu_hat.ndim > 0: mu_post_vbis = mu_hat[0] else: mu_post_vbis = mu_hat if var_hat.size == 1: var_post_vbis = var_hat[0][0] else: var_post_vbis = var_hat logging.debug('VBIS update found mean of {} and variance of {}.' .format(mu_post_vbis, var_post_vbis)) return mu_post_vbis, var_post_vbis, log_c_hat
def lwis_update(self, prior): """ clustering: pairwise greedy merging - compare means, weights & variances salmond's method and runnals' method (better) """ prior_mean = np.asarray(prior.means[0]) prior_var = np.asarray(prior.covariances[0]) # Importance distribution q = GaussianMixture(1, prior_mean, prior_var) # Importance sampling correction w = np.zeros(num_samples) # Importance weights x = q.rvs(size=num_samples) # Sampled points x = np.asarray(x) if hasattr(likelihood, 'subclasses'): measurement_class = likelihood.subclasses[measurement] else: measurement_class = likelihood.classes[measurement] for i in range(num_samples): w[i] = prior.pdf(x[i]) \ * measurement_class.probability(state=x[i])\ / q.pdf(x[i]) w /= np.sum(w) # Normalize weights mu_hat = np.zeros_like(np.asarray(mu_VB)) for i in range(num_samples): x_i = np.asarray(x[i]) mu_hat = mu_hat + x_i .dot (w[i]) var_hat = np.zeros_like(np.asarray(var_VB)) for i in range(num_samples): x_i = np.asarray(x[i]) var_hat = var_hat + w[i] * np.outer(x_i, x_i) var_hat -= np.outer(mu_hat, mu_hat) if mu_hat.size == 1 and mu_hat.ndim > 0: mu_lwis = mu_hat[0] else: mu_lwis = mu_hat if var_hat.size == 1: var_lwis = var_hat[0][0] else: var_lwis = var_hat logging.debug('LWIS update found mean of {} and variance of {}.' .format(mu_lwis, var_lwis)) return mu_lwis, var_lwis, log_c_hat
def update(self, measurement, likelihood, prior, use_LWIS=False, poly=None, num_std=1): """VB update using Gaussian mixtures and multimodal softmax. This uses Variational Bayes with Importance Sampling (VBIS) for each mixand-softmax pair available. """ # If we have a polygon, update only the mixands intersecting with it if poly is None: update_intersections_only = False else: update_intersections_only = True h = 0 relevant_subclasses = likelihood.classes[measurement].subclasses num_relevant_subclasses = len(relevant_subclasses) # Use intersecting priors only if update_intersections_only: other_priors = prior.copy() weights = [] means = [] covariances = [] mixand_ids = [] ellipses = prior.std_ellipses(num_std) any_intersection = False for i, ellipse in enumerate(ellipses): try: has_intersection = poly.intersects(ellipse) except ValueError: logging.warn('Null geometry error! Defaulting to true.') has_intersection = True if has_intersection: # Get parameters for intersecting priors mixand_ids.append(i) weights.append(prior.weights[i]) means.append(prior.means[i]) covariances.append(prior.covariances[i]) any_intersection = True if not any_intersection: logging.debug('No intersection with any ellipse.') mu_hat = other_priors.means var_hat = other_priors.covariances beta_hat = other_priors.weights return mu_hat, var_hat, beta_hat # Remove these from the other priors other_priors.weights = \ np.delete(other_priors.weights, mixand_ids, axis=0) other_priors.means = \ np.delete(other_priors.means, mixand_ids, axis=0) other_priors.covariances = \ np.delete(other_priors.covariances, mixand_ids, axis=0) # Retain total weight of intersection weights for renormalization max_intersecion_weight = sum(weights) # Create new prior prior = GaussianMixture(weights, means, covariances) logging.debug('Using only mixands {} for VBIS fusion. Total weight {}' .format(mixand_ids, max_intersecion_weight)) # Parameters for all new mixands K = num_relevant_subclasses * prior.weights.size mu_hat = np.zeros((K, prior.means.shape[1])) var_hat = np.zeros((K, prior.covariances.shape[1], prior.covariances.shape[2])) log_beta_hat = np.zeros(K) # Weight estimates for u, mixand_weight in enumerate(prior.weights): mix_sm_corr = 0 # Check to see if the mixand is completely contained within # the softmax class (i.e. doesn't need an update) mixand = GaussianMixture(1, prior.means[u], prior.covariances[u]) mixand_samples = mixand.rvs(self.num_mixand_samples) p_hat_ru_samples = likelihood.classes[measurement].probability(state=mixand_samples) mix_sm_corr = np.sum(p_hat_ru_samples) / self.num_mixand_samples if mix_sm_corr > self.mix_sm_corr_thresh: logging.debug('Mixand {}\'s correspondence with {} was {},' 'above the threshold of {}, so VBIS was skipped.' .format(u, measurement, mix_sm_corr, self.mix_sm_corr_thresh)) # Append the prior's parameters to the mixand parameter lists mu_hat[h, :] = prior.means[u] var_hat[h, :] = prior.covariances[u] log_beta_hat[h] = np.log(mixand_weight) h +=1 continue # Otherwise complete the full VBIS update ordered_subclasses = iter(sorted(relevant_subclasses.iteritems())) for label, subclass in ordered_subclasses: # Compute \hat{P}_s(r|u) mixand_samples = mixand.rvs(self.num_mixand_samples) p_hat_ru_samples = subclass.probability(state=mixand_samples) p_hat_ru_sampled = np.sum(p_hat_ru_samples) / self.num_mixand_samples mu_vbis, var_vbis, log_c_hat = \ self.vbis_update(label, subclass.softmax_collection, mixand, use_LWIS=use_LWIS) # Compute log odds of r given u if np.isnan(log_c_hat): # from LWIS update log_p_hat_ru = np.log(p_hat_ru_sampled) else: log_p_hat_ru = np.max((log_c_hat, np.log(p_hat_ru_sampled))) # Find log of P(u,r|D_k) \approxequal \hat{B}_{ur} log_beta_vbis = np.log(mixand_weight) + log_p_hat_ru # Symmetrize var_vbis var_vbis = 0.5 * (var_vbis.T + var_vbis) # Update estimate values log_beta_hat[h] = log_beta_vbis mu_hat[h,:] = mu_vbis var_hat[h,:] = var_vbis h += 1 # Renormalize and truncate (based on weight threshold) log_beta_hat = log_beta_hat - np.max(log_beta_hat) unnormalized_beta_hats = np.exp(log_beta_hat) beta_hat = np.exp(log_beta_hat) / np.sum(np.exp(log_beta_hat)) # Reattach untouched prior values if update_intersections_only: beta_hat = unnormalized_beta_hats * max_intersecion_weight beta_hat = np.hstack((other_priors.weights, beta_hat)) mu_hat = np.vstack((other_priors.means, mu_hat)) var_hat = np.concatenate((other_priors.covariances, var_hat)) # Shrink mu, var and beta if necessary h += other_priors.weights.size beta_hat = beta_hat[:h] mu_hat = mu_hat[:h] var_hat = var_hat[:h] beta_hat /= beta_hat.sum() else: # Shrink mu, var and beta if necessary beta_hat = beta_hat[:h] mu_hat = mu_hat[:h] var_hat = var_hat[:h] # Threshold based on weights mu_hat = mu_hat[beta_hat > self.weight_threshold, :] var_hat = var_hat[beta_hat > self.weight_threshold, :] beta_hat = beta_hat[beta_hat > self.weight_threshold] # Check if covariances are positive semidefinite for i, var in enumerate(var_hat): try: assert np.all(np.linalg.det(var) > 0) except AssertionError, e: logging.warn('Following variance is not positive ' 'semidefinite: \n{}'.format(var)) var_hat[i] = np.eye(var.shape[0]) * 10 ** -3