def model(self, value): mixing_coeffs, means, covariances = value components_distribution = distrax.as_distribution( tfp.substrates.jax.distributions.MultivariateNormalFullCovariance(loc=means, covariance_matrix=covariances, validate_args=True)) self._model = MixtureSameFamily(mixture_distribution=distrax.Categorical(probs=mixing_coeffs), components_distribution=components_distribution)
def test_posterior_mode(self): mix_dist_probs = jnp.array([[0.5, 0.5], [0.01, 0.99]]) locs = jnp.array([[-1., 1.], [-1., 1.]]) scale = jnp.array([1.]) gm = MixtureSameFamily( mixture_distribution=distrax.Categorical(probs=mix_dist_probs), components_distribution=distrax.Normal(loc=locs, scale=scale)) mode = gm.posterior_mode(jnp.array([[1.], [-1.], [-6.]])) self.assertEqual((3, 2), mode.shape) self.assertAllClose(jnp.array([[1, 1], [0, 1], [0, 0]]), mode)
def model(self, value): mixing_coeffs_logits, probs_logits = value self._model = MixtureSameFamily( mixture_distribution=distrax.Categorical( logits=mixing_coeffs_logits), components_distribution=distrax.Independent( distrax.Bernoulli(logits=probs_logits), reinterpreted_batch_ndims=1))
def test_posterior_marginal(self): mix_dist_probs = jnp.array([0.1, 0.9]) component_dist_probs = jnp.array([[.2, .3, .5], [.7, .2, .1]]) bm = MixtureSameFamily( mixture_distribution=distrax.Categorical(probs=mix_dist_probs), components_distribution=distrax.Categorical( probs=component_dist_probs)) marginal_dist = bm.posterior_marginal(jnp.array([0., 1., 2.])) marginals = marginal_dist.probs self.assertEqual((3, 2), marginals.shape) expected_marginals = jnp.array([[(.1 * .2) / (.1 * .2 + .9 * .7), (.9 * .7) / (.1 * .2 + .9 * .7)], [(.1 * .3) / (.1 * .3 + .9 * .2), (.9 * .2) / (.1 * .3 + .9 * .2)], [(.1 * .5) / (.1 * .5 + .9 * .1), (.9 * .1) / (.1 * .5 + .9 * .1)]]) self.assertAllClose(marginals, expected_marginals)
def model(self, value): mixing_coeffs, probs = value self._model = MixtureSameFamily(mixture_distribution=distrax.Categorical(probs=mixing_coeffs), components_distribution=distrax.Independent(distrax.Bernoulli(probs=probs)))
class BMM(jittable.Jittable): def __init__(self, K, n_vars, rng_key=None): ''' Initializes Bernoulli Mixture Model Parameters ---------- K : int Number of latent variables n_vars : int Dimension of binary random variable rng_key : array Random key of shape (2,) and dtype uint32 ''' rng_key = PRNGKey(0) if rng_key is None else rng_key mixing_coeffs = uniform(rng_key, (K,), minval=100, maxval=200) mixing_coeffs = mixing_coeffs / mixing_coeffs.sum() initial_probs = jnp.full((K, n_vars), 1.0 / K) self._probs = initial_probs self.model = (mixing_coeffs, initial_probs) @property def mixing_coeffs(self): return self._model.mixture_distribution.probs @property def probs(self): return self._probs @property def model(self): return self._model @model.setter def model(self, value): mixing_coeffs, probs = value self._model = MixtureSameFamily(mixture_distribution=distrax.Categorical(probs=mixing_coeffs), components_distribution=distrax.Independent(distrax.Bernoulli(probs=probs))) def responsibilities(self, observations): ''' Finds responsibilities Parameters ---------- observations : array(N, seq_len) Dataset Returns ------- * array Responsibilities ''' return jnp.nan_to_num(self._model.posterior_marginal(observations).probs) def expected_log_likelihood(self, observations): ''' Calculates expected log likelihood Parameters ---------- observations : array(N, seq_len) Dataset Returns ------- * int Log likelihood ''' return jnp.sum(jnp.nan_to_num(self._model.log_prob(observations))) def _m_step(self, observations): ''' Maximization step Parameters ---------- observations : array(N, seq_len) Dataset Returns ------- * array Mixing coefficients * array Probabilities ''' n_obs, _ = observations.shape # Computes responsibilities, or posterior probability p(z|x) def m_step_per_bernoulli(responsibility): norm_const = responsibility.sum() mu = jnp.sum(responsibility[:, None] * observations, axis=0) / norm_const return mu, norm_const mus, ns = vmap(m_step_per_bernoulli, in_axes=(1))(self.responsibilities(observations)) return ns / n_obs, mus def fit_em(self, observations, num_of_iters=10): ''' Fits the model using em algorithm. Parameters ---------- observations : array(N, seq_len) Dataset num_of_iters : int The number of iterations the training process takes place Returns ------- * array Log likelihoods found per iteration * array Responsibilities ''' iterations = jnp.arange(num_of_iters) def train_step(params, i): self.model = params log_likelihood = self.expected_log_likelihood(observations) responsibilities = self.responsibilities(observations) mixing_coeffs, probs = self._m_step(observations) return (mixing_coeffs, probs), (log_likelihood, responsibilities) initial_params = (self.mixing_coeffs, self.probs) final_params, history = scan(train_step, initial_params, iterations) self.model = final_params _, probs = final_params self._probs = probs ll_hist, responsibility_hist = history ll_hist = jnp.append(ll_hist, self.expected_log_likelihood(observations)) responsibility_hist = jnp.vstack([responsibility_hist, jnp.array([self.responsibilities(observations)])]) return ll_hist, responsibility_hist def _make_minibatches(self, observations, batch_size, rng_key): ''' Creates minibatches consists of the random permutations of the given observation sequences Parameters ---------- observations : array(N, seq_len) Dataset batch_size : int The number of observation sequences that will be included in each minibatch rng_key : array Random key of shape (2,) and dtype uint32 Returns ------- * array(num_batches, batch_size, max_len) Minibatches ''' num_train = len(observations) perm = permutation(rng_key, num_train) def create_mini_batch(batch_idx): return observations[batch_idx] num_batches = num_train // batch_size batch_indices = perm.reshape((num_batches, -1)) minibatches = vmap(create_mini_batch)(batch_indices) return minibatches @jit def loss_fn(self, params, batch): ''' Calculates expected mean negative loglikelihood. Parameters ---------- params : tuple Consists of mixing coefficients and probabilities of the Bernoulli distribution respectively. batch : array The subset of observations Returns ------- * int Negative log likelihood ''' mixing_coeffs, probs = params self.model = (softmax(mixing_coeffs), expit(probs)) return -self.expected_log_likelihood(batch) / len(batch) @jit def update(self, i, opt_state, batch): ''' Updates the optimizer state after taking derivative i : int The current iteration opt_state : jax.experimental.optimizers.OptimizerState The current state of the parameters batch : array The subset of observations Returns ------- * jax.experimental.optimizers.OptimizerState The updated state * int Loss value calculated on the current batch ''' params = get_params(opt_state) loss, grads = value_and_grad(self.loss_fn)(params, batch) return opt_update(i, grads, opt_state), loss def fit_sgd(self, observations, batch_size, rng_key=None, optimizer=None, num_epochs=1): ''' Fits the model using gradient descent algorithm with the given hyperparameters. Parameters ---------- observations : array The observation sequences which Bernoulli Mixture Model is trained on batch_size : int The size of the batch rng_key : array Random key of shape (2,) and dtype uint32 optimizer : jax.experimental.optimizers.Optimizer Optimizer to be used num_epochs : int The number of epoch the training process takes place Returns ------- * array Mean loss values found per epoch * array Mixing coefficients found per epoch * array Probabilities of Bernoulli distribution found per epoch * array Responsibilites found per epoch ''' global opt_init, opt_update, get_params if rng_key is None: rng_key = PRNGKey(0) if optimizer is not None: opt_init, opt_update, get_params = optimizer opt_state = opt_init((softmax(self.mixing_coeffs), logit(self.probs))) itercount = itertools.count() def epoch_step(opt_state, key): def train_step(opt_state, batch): opt_state, loss = self.update(next(itercount), opt_state, batch) return opt_state, loss batches = self._make_minibatches(observations, batch_size, key) opt_state, losses = scan(train_step, opt_state, batches) params = get_params(opt_state) mixing_coeffs, probs_logits = params probs = expit(probs_logits) self.model = (softmax(mixing_coeffs), probs) self._probs = probs return opt_state, (losses.mean(), *params, self.responsibilities(observations)) epochs = split(rng_key, num_epochs) opt_state, history = scan(epoch_step, opt_state, epochs) params = get_params(opt_state) mixing_coeffs, probs_logits = params probs = expit(probs_logits) self.model = (softmax(mixing_coeffs), probs) self._probs = probs return history def plot(self, n_row, n_col, file_name): ''' Plots the mean of each Bernoulli distribution as an image. Parameters ---------- n_row : int The number of rows of the figure n_col : int The number of columns of the figure file_name : str The path where the figure will be stored ''' if n_row * n_col != len(self.mixing_coeffs): raise TypeError('The number of rows and columns does not match with the number of component distribution.') fig, axes = plt.subplots(n_row, n_col) for (coeff, mean), ax in zip(zip(self.mixing_coeffs, self.probs), axes.flatten()): ax.imshow(mean.reshape(28, 28), cmap=plt.cm.gray) ax.set_title("%1.2f" % coeff) ax.axis("off") fig.tight_layout(pad=1.0) pml.savefig(f"{file_name}.pdf") plt.show()
class GMM(jittable.Jittable): def __init__(self, mixing_coeffs, means, covariances): ''' Initializes Gaussian Mixture Model Parameters ---------- mixing_coeffs : array means : array variances : array ''' self.model = (mixing_coeffs, means, covariances) @property def mixing_coeffs(self): return self._model.mixture_distribution.probs @property def means(self): return self._model.components_distribution.loc @property def covariances(self): return self._model.components_distribution.covariance() @property def model(self): return self._model @model.setter def model(self, value): mixing_coeffs, means, covariances = value components_distribution = distrax.as_distribution( tfp.substrates.jax.distributions.MultivariateNormalFullCovariance(loc=means, covariance_matrix=covariances, validate_args=True)) self._model = MixtureSameFamily(mixture_distribution=distrax.Categorical(probs=mixing_coeffs), components_distribution=components_distribution) def expected_log_likelihood(self, observations): ''' Calculates expected log likelihood Parameters ---------- observations : array(N, seq_len) Dataset Returns ------- * int Log likelihood ''' return jnp.sum(self._model.log_prob(observations)) def responsibility(self, observations, comp_dist_idx): ''' Computes responsibilities, or posterior probability p(z_{comp_dist_idx}|x) Parameters ---------- observations : array(N, seq_len) Dataset comp_dist_idx : int Index which specifies the specific mixing distribution component Returns ------- * array Responsibilities ''' return self._model.posterior_marginal(observations).prob(comp_dist_idx) def responsibilities(self, observations): ''' Computes responsibilities, or posterior probability p(z|x) Parameters ---------- observations : array(N, seq_len) Dataset Returns ------- * array Responsibilities ''' return self.model.posterior_marginal(observations).probs def _m_step(self, observations, S, eta): ''' Maximization step Parameters ---------- observations : array(N, seq_len) Dataset S : array A prior p(theta) is defined over the parameters to find MAP solutions eta : int Returns ------- * array Mixing coefficients * array Means * array Covariances ''' n_obs, n_comp = observations.shape def m_step_per_gaussian(responsibility): effective_prob = responsibility.sum() mean = (responsibility[:, None] * observations).sum(axis=0) / effective_prob centralized_observations = (observations - mean) covariance = responsibility[:, None, None] * jnp.einsum("ij, ik->ikj", centralized_observations, centralized_observations) covariance = covariance.sum(axis=0) if eta is None: covariance = covariance / effective_prob else: covariance = (S + covariance) / (eta + effective_prob + n_comp + 2) mixing_coeff = effective_prob / n_obs return (mixing_coeff, mean, covariance) mixing_coeffs, means, covariances = vmap(m_step_per_gaussian, in_axes=(1))(self.responsibilities(observations)) return mixing_coeffs, means, covariances def _add_final_values_to_history(self, history, observations): ''' Appends the final values of log likelihood, mixing coefficients, means, variances and responsibilities into the history Parameters ---------- history : tuple Consists of values of log likelihood, mixing coefficients, means, variances and responsibilities, which are found per iteration observations : array(N, seq_len) Dataset Returns ------- * array Mean loss values found per iteration * array Mixing coefficients found per iteration * array Means of Gaussian distribution found per iteration * array Covariances of Gaussian distribution found per iteration * array Responsibilites found per iteration ''' ll_hist, mix_dist_probs_hist, comp_dist_loc_hist, comp_dist_cov_hist, responsibility_hist = history ll_hist = jnp.append(ll_hist, self.expected_log_likelihood(observations)) mix_dist_probs_hist = jnp.vstack([mix_dist_probs_hist, self.mixing_coeffs]) comp_dist_loc_hist = jnp.vstack([comp_dist_loc_hist, self.means[None, :]]) comp_dist_cov_hist = jnp.vstack([comp_dist_cov_hist, self.covariances[None, :]]) responsibility_hist = jnp.vstack([responsibility_hist, jnp.array([self.responsibility(observations, 0)])]) history = (ll_hist, mix_dist_probs_hist, comp_dist_loc_hist, comp_dist_cov_hist, responsibility_hist) return history def fit_em(self, observations, num_of_iters, S=None, eta=None): ''' Fits the model using em algorithm. Parameters ---------- observations : array(N, seq_len) Dataset num_of_iters : int The number of iterations the training process takes place S : array A prior p(theta) is defined over the parameters to find MAP solutions eta : int Returns ------- * array Mean loss values found per iteration * array Mixing coefficients found per iteration * array Means of Gaussian distribution found per iteration * array Covariances of Gaussian distribution found per iteration * array Responsibilites found per iteration ''' initial_mixing_coeffs = self.mixing_coeffs initial_means = self.means initial_covariances = self.covariances iterations = jnp.arange(num_of_iters) def train_step(params, i): self.model = params log_likelihood = self.expected_log_likelihood(observations) responsibility = self.responsibility(observations, 0) mixing_coeffs, means, covariances = self._m_step(observations, S, eta) return (mixing_coeffs, means, covariances), (log_likelihood, *params, responsibility) initial_params = (initial_mixing_coeffs, initial_means, initial_covariances) final_params, history = scan(train_step, initial_params, iterations) self.model = final_params history = self._add_final_values_to_history(history, observations) return history def _make_minibatches(self, observations, batch_size, rng_key): ''' Creates minibatches consists of the random permutations of the given observation sequences Parameters ---------- observations : array(N, seq_len) Dataset batch_size : int The number of observation sequences that will be included in each minibatch rng_key : array Random key of shape (2,) and dtype uint32 Returns ------- * array(num_batches, batch_size, max_len) Minibatches ''' num_train = len(observations) perm = permutation(rng_key, num_train) def create_mini_batch(batch_idx): return observations[batch_idx] num_batches = num_train // batch_size batch_indices = perm.reshape((num_batches, -1)) minibatches = vmap(create_mini_batch)(batch_indices) return minibatches def _transform_to_covariance_matrix(self, sq_mat): ''' Takes the upper triangular matrix of the given matrix and then multiplies it by its transpose https://ericmjl.github.io/notes/stats-ml/estimating-a-multivariate-gaussians-parameters-by-gradient-descent/ Parameters ---------- sq_mat : array Square matrix Returns ------- * array ''' U = jnp.triu(sq_mat) U_T = jnp.transpose(U) return jnp.dot(U_T, U) def loss_fn(self, params, batch): """ Calculates expected mean negative loglikelihood. Parameters ---------- params : tuple Consists of mixing coefficients' logits, means and variances of the Gaussian distributions respectively. batch : array The subset of observations Returns ------- * int Negative log likelihood """ mixing_coeffs, means, untransormed_cov = params cov_matrix = vmap(self._transform_to_covariance_matrix)(untransormed_cov) self.model = (softmax(mixing_coeffs), means, cov_matrix) return -self.expected_log_likelihood(batch) / len(batch) def update(self, i, opt_state, batch): ''' Updates the optimizer state after taking derivative i : int The current iteration opt_state : jax.experimental.optimizers.OptimizerState The current state of the parameters batch : array The subset of observations Returns ------- * jax.experimental.optimizers.OptimizerState The updated state * int Loss value calculated on the current batch ''' params = get_params(opt_state) loss, grads = value_and_grad(self.loss_fn)(params, batch) return opt_update(i, grads, opt_state), loss def fit_sgd(self, observations, batch_size, rng_key=None, optimizer=None, num_epochs=3): ''' Finds the parameters of Gaussian Mixture Model using gradient descent algorithm with the given hyperparameters. Parameters ---------- observations : array The observation sequences which Bernoulli Mixture Model is trained on batch_size : int The size of the batch rng_key : array Random key of shape (2,) and dtype uint32 optimizer : jax.experimental.optimizers.Optimizer Optimizer to be used num_epochs : int The number of epoch the training process takes place Returns ------- * array Mean loss values found per epoch * array Mixing coefficients found per epoch * array Means of Gaussian distribution found per epoch * array Covariances of Gaussian distribution found per epoch * array Responsibilites found per epoch ''' global opt_init, opt_update, get_params if rng_key is None: rng_key = PRNGKey(0) if optimizer is not None: opt_init, opt_update, get_params = optimizer opt_state = opt_init((softmax(self.mixing_coeffs), self.means, self.covariances)) itercount = itertools.count() def epoch_step(opt_state, key): def train_step(opt_state, batch): opt_state, loss = self.update(next(itercount), opt_state, batch) return opt_state, loss batches = self._make_minibatches(observations, batch_size, key) opt_state, losses = scan(train_step, opt_state, batches) params = get_params(opt_state) mixing_coeffs, means, untransormed_cov = params cov_matrix = vmap(self._transform_to_covariance_matrix)(untransormed_cov) self.model = (softmax(mixing_coeffs), means, cov_matrix) responsibilities = self.responsibilities(observations) return opt_state, (losses.mean(), *params, responsibilities) epochs = split(rng_key, num_epochs) opt_state, history = scan(epoch_step, opt_state, epochs) params = get_params(opt_state) mixing_coeffs, means, untransormed_cov = params cov_matrix = vmap(self._transform_to_covariance_matrix)(untransormed_cov) self.model = (softmax(mixing_coeffs), means, cov_matrix) return history def plot(self, observations, means=None, covariances=None, responsibilities=None, step=0.01, cmap="viridis", colors=None, ax=None): ''' Plots Gaussian Mixture Model. Parameters ---------- observations : array Dataset means : array covariances : array responsibilities : array step: float Step size of the grid for the density contour. cmap : str ax : array ''' means = self.means if means is None else means covariances = self.covariances if covariances is None else covariances responsibilities = self.model.posterior_marginal(observations).probs if responsibilities is None \ else responsibilities colors = uniform(PRNGKey(100), (means.shape[0], 3)) if colors is None else colors ax = ax if ax is not None else plt.subplots()[1] min_x, min_y = observations.min(axis=0) max_x, max_y = observations.max(axis=0) xs, ys = jnp.meshgrid(jnp.arange(min_x, max_x, step), jnp.arange(min_y, max_y, step)) grid = jnp.vstack([xs.ravel(), ys.ravel()]).T def multivariate_normal(mean, cov): ''' Initializes multivariate normal distribution with the given mean and covariance. Note that the pdf has the same precision with its parameters' dtype. ''' return tfp.substrates.jax.distributions.MultivariateNormalFullCovariance(loc=mean, covariance_matrix=cov) for (means, cov), color in zip(zip(means, covariances), colors): normal_dist = multivariate_normal(means, cov) density = normal_dist.prob(grid).reshape(xs.shape) ax.contour(xs, ys, density, levels=1, colors=color, linewidths=5) ax.scatter(*observations.T, alpha=0.7, c=responsibilities, cmap=cmap, s=10) ax.set_xlim(min_x, max_x) ax.set_ylim(min_y, max_y)