def modular_layer(inputs, modules: ModulePool, parallel_count: int, context: ModularContext): with tf.variable_scope(None, 'modular_layer'): inputs = context.begin_modular(inputs) flat_inputs = tf.layers.flatten(inputs) logits = tf.layers.dense(flat_inputs, modules.module_count * parallel_count) logits = tf.reshape(logits, [-1, parallel_count, modules.module_count]) ctrl = tfd.Categorical(logits) initializer = tf.random_uniform_initializer(maxval=modules.module_count, dtype=tf.int32) shape = [context.dataset_size, parallel_count] best_selection_persistent = tf.get_variable('best_selection', shape, tf.int32, initializer) if context.mode == ModularMode.E_STEP: # 1 x batch_size x 1 best_selection = tf.gather(best_selection_persistent, context.data_indices)[tf.newaxis] # sample_size x batch_size x 1 sampled_selection = tf.reshape(ctrl.sample(), [context.sample_size, -1, parallel_count]) selection = tf.concat([best_selection, sampled_selection[1:]], axis=0) selection = tf.reshape(selection, [-1, parallel_count]) elif context.mode == ModularMode.M_STEP: selection = tf.gather(best_selection_persistent, context.data_indices) elif context.mode == ModularMode.EVALUATION: selection = ctrl.mode() else: raise ValueError('Invalid modular mode') attrs = ModularLayerAttributes(selection, best_selection_persistent, ctrl) context.layers.append(attrs) return run_modules(inputs, selection, modules.module_fnc, modules.output_shape)
def decode(self, prev_state, prev_input, timestep): with tf.variable_scope("loop"): if timestep > 0: tf.get_variable_scope().reuse_variables() # Run the cell on a combination of the previous input and state output, state = self.cell(prev_input, prev_state) masked_scores = self.attention( self.encoder_output, output) # [batch_size, time_sequence] self.mask_scores.append(masked_scores) # Multinomial distribution prob = distr.Categorical(masked_scores) # Sample from distribution position = prob.sample() position = tf.cast(position, tf.int32) self.positions.append(position) # Store log_prob for backprop self.log_softmax.append(prob.log_prob(position)) self.mask = self.mask + tf.one_hot(position, self.seq_length) # Retrieve decoder's new input new_decoder_input = tf.gather(self.h, position)[0] return state, new_decoder_input
def GMM2Dslow(log_pis, mus, log_sigmas, corrs, clip_lo=-10, clip_hi=10): # shapes # pis: [..., GMM_c] # mus: [..., GMM_c*state_dim] # sigmas: [..., GMM_c*state_dim] # corrs: [..., GMM_c] GMM_c = log_pis.shape[-1] mus_split = tf.split(mus, GMM_c, axis=-1) sigmas = tf.exp(tf.clip_by_value(log_sigmas, clip_lo, clip_hi)) # Sigma = [s1^2 p*s1*s2 L = [s1 0 # p*s1*s2 s2^2 ] p*s2 sqrt(1-p^2)*s2] sigmas_reshaped = tf.reshape( sigmas, [-1 if s.value is None else s.value for s in sigmas.shape[:-1]] + [GMM_c.value, 2]) Ls = tf.stack( [ (sigmas_reshaped * tf.stack([tf.ones_like(corrs), corrs], -1)), # [s1, p*s2] (sigmas_reshaped * tf.stack([tf.zeros_like(corrs), tf.sqrt(1 - corrs**2)], -1)) ], # [0, sqrt(1-p^2)*s2] axis=-1) Ls_split = tf.unstack(Ls, axis=-3) cat = distributions.Categorical(logits=log_pis) dists = [ distributions.MultivariateNormalTriL(mu, L) for mu, L in zip(mus_split, Ls_split) ] return distributions.Mixture(cat, dists)
def radial_gaussians(batch_size, n_mixture=8, std=0.01, radius=1.0, add_far=False): thetas = np.linspace(0, 2 * np.pi, n_mixture + 1)[:-1] xs, ys = radius * np.cos(thetas), radius * np.sin(thetas) cat = ds.Categorical(tf.zeros(n_mixture)) comps = [ds.MultivariateNormalDiag([xi, yi], [std, std]) for xi, yi in zip(xs.ravel(), ys.ravel())] data = ds.Mixture(cat, comps) return data.sample(batch_size)
def _make_priors(self, time_step, prior_conditioning): """Instantiates prior distributions for discovery. """ is_first_timestep = tf.to_float(tf.equal(time_step, 0)) if self._disc_prior_type == 'geom': num_steps_prior = tfd.Geometric(probs=1. - self._init_disc_step_success_prob) elif self._disc_prior_type == 'cat': init = [0.] * (self._n_steps + 1) step_logits = tf.Variable(init, trainable=True, dtype=tf.float32, name='step_prior_bias') # increase probability of zero steps when t>0 init = [10.] + [0] * self._n_steps timstep_bias = tf.Variable(init, trainable=True, dtype=tf.float32, name='step_prior_timestep_bias') step_logits += (1. - is_first_timestep) * timstep_bias if prior_conditioning is not None: step_logits = tf.expand_dims(step_logits, 0) + MLP(10, n_out=self._n_steps + 1)(prior_conditioning) step_logits = tf.nn.elu(step_logits) num_steps_prior = tfd.Categorical(logits=step_logits) else: raise ValueError('Invalid prior type: {}'.format(self._disc_prior_type)) return self._what_prior, self._where_prior, num_steps_prior
def output_function(self, state): params = dense_layer(state.h3, self.output_units, scope='gmm', reuse=tf.compat.v1.AUTO_REUSE) pis, mus, sigmas, rhos, es = self._parse_parameters(params) mu1, mu2 = tf.split(mus, 2, axis=1) mus = tf.stack([mu1, mu2], axis=2) sigma1, sigma2 = tf.split(sigmas, 2, axis=1) covar_matrix = [ tf.square(sigma1), rhos * sigma1 * sigma2, rhos * sigma1 * sigma2, tf.square(sigma2) ] covar_matrix = tf.stack(covar_matrix, axis=2) covar_matrix = tf.reshape( covar_matrix, (self.batch_size, self.num_output_mixture_components, 2, 2)) mvn = tfd.MultivariateNormalFullCovariance( loc=mus, covariance_matrix=covar_matrix) b = tfd.Bernoulli(probs=es) c = tfd.Categorical(probs=pis) sampled_e = b.sample() sampled_coords = mvn.sample() sampled_idx = c.sample() idx = tf.stack([tf.range(self.batch_size), sampled_idx], axis=1) coords = tf.gather_nd(sampled_coords, idx) return tf.concat([coords, tf.cast(sampled_e, tf.float32)], axis=1)
def sample(self, time, outputs, state, name=None): """Returns `sample_ids`.""" del time, state # return outputs with tf.variable_scope('dml'): reshaped = tf.reshape( outputs, [self._batch_size, self._n_features, self._n_mixtures * 3]) loc, unconstrained_scale, logits = tf.split(reshaped, num_or_size_splits=3, axis=-1) loc = tf.minimum(tf.nn.softplus(loc), 2.0**16 - 1.0) scale = tf.minimum( 14.0, tf.maximum(1e-8, tf.nn.softplus(unconstrained_scale))) discretized_logistic_dist = tfd.QuantizedDistribution( distribution=tfd.TransformedDistribution( distribution=tfd.Logistic(loc=loc, scale=scale), bijector=tfb.AffineScalar(shift=-0.5)), low=0., high=2**16 - 1.) mixture_dist = tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical(logits=logits), components_distribution=discretized_logistic_dist) sample = tf.minimum( 1.0, tf.maximum(-1.0, mixture_dist.sample() / (2**16 - 1.0))) return sample
def _reduce_argmax(self, x): """Reduces a tensor by argmax(x, axis=reduce_axis)). Args: x (Tensor): A ``Tensor`` of shape [batch, num_sums, max_sum_size] to reduce over the last axis. Returns: A ``Tensor`` reduced over the last axis. """ if conf.argmax_zero: # If true, uses TensorFlow's argmax directly, yielding a bias towards the zeroth index return tf.argmax(x, axis=self._reduce_axis) # Return random index in case multiple values equal max x_max = tf.expand_dims(self._reduce_mpe_inference(x), self._reduce_axis) x_eq_max = tf.to_float(tf.equal(x, x_max)) if self._masked: x_eq_max *= tf.expand_dims(tf.to_float(self._build_mask()), axis=self._batch_axis) x_eq_max /= tf.reduce_sum(x_eq_max, axis=self._reduce_axis, keepdims=True) return tfd.Categorical(probs=x_eq_max, name="StochasticArgMax", dtype=tf.int64).sample()
def create_gmm_1(d, K, name='gmm', reuse=False, scale_act=tf.nn.softplus, zero_mean=False, ki=None): with tf.variable_scope(name, reuse): #tf.random_uniform_initializer(0.,3.) probs = tf.nn.softmax(tf.get_variable('probs', shape=[d, K], dtype=DTYPE, initializer=None), axis=-1) #tf.random_uniform_initializer(-.5,.5) locs = tf.get_variable('locs', shape=[d, K], dtype=DTYPE, initializer=None) if zero_mean: locs = tf.zeros_like(locs) scales = tf.get_variable('scales', shape=[d, K], dtype=DTYPE, initializer=None) pis = tfd.Categorical(probs=probs) ps = tfd.Normal(loc=locs, scale=scale_act(scales)) p = tf.contrib.distributions.MixtureSameFamily(pis, ps) p = tf.contrib.distributions.Independent(p, 1) return p
def sampled_evaluations(self, reducing_function): context = ModularContext(ModularMode.SAMPLES_EVALUATION) logits = self.logits(context) reduced_probs = reducing_function(tf.nn.softmax(logits), context) outputs = tfd.Categorical(probs=reduced_probs) llh = tf.reduce_mean(self.llh(outputs, context)) acc = self.accuracy(outputs, context) return llh, acc
def _build(self, inputs, nb_samples=10, seed=0, encoder_param_type='natural'): ### vae encode emb = self._encoder(inputs) enc_eta1 = self._mu_net(emb) enc_eta2_diag = self._sigma_net(emb) if encoder_param_type == 'natural': enc_eta2_diag *= -1. / 2 # enc_eta2_diag -= 1e-8 enc_eta2 = tf.matrix_diag(enc_eta2_diag) ### GMM natural parameters gmm_pi, gmm_eta1, gmm_eta2 = self.phi_gmm() ### combined GMM and VAE latent parameters # eta1_tilde.shape = (N, K, D); eta2_tsilde.shape = (N, K, D, D) # with tf.control_dependencies([util.matrix_is_pos_def_op(-2 * enc_eta2)]): eta1_tilde = tf.expand_dims( enc_eta1, axis=1) + tf.expand_dims( gmm_eta1, axis=0) eta2_tilde = tf.expand_dims( enc_eta2, axis=1) + tf.expand_dims( gmm_eta2, axis=0) log_z_given_y_phi = compute_log_z_given_y(enc_eta1, enc_eta2, gmm_eta1, gmm_eta2, gmm_pi) # with tf.control_dependencies([util.matrix_is_pos_def_op(-2 * gmm_eta2)]): mu, cov = gaussian.natural_to_standard(eta1_tilde, eta2_tilde) posterior_mixture_distribution = tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical(tf.exp(log_z_given_y_phi)), components_distribution=tfd.MultivariateNormalFullCovariance( loc=mu, covariance_matrix=cov)) # sample x for each of the K components # latent_k_samples.shape == nb_samples, batch_size, nb_components, latent_dim latent_k_samples = posterior_mixture_distribution.components_distribution.sample( [nb_samples]) ### vae decode output_mean = snt.BatchApply(self._decoder, n_dims=3)(latent_k_samples) output_variance = tf.get_variable( 'output_variance', dtype=tf.float32, initializer=tf.zeros(output_mean.get_shape().as_list()), trainable=True) # learned parameter for output distribution output_distribution = tfd.Independent( tfd.MultivariateNormalDiagWithSoftplusScale( loc=output_mean, scale_diag=output_variance), reinterpreted_batch_ndims=2) # subsample for each datum in minibatch (go from `nb_samples` per component to `nb_samples` total) latent_samples = subsample_x( tf.transpose(latent_k_samples, [1, 0, 2, 3]), log_z_given_y_phi, seed) return output_distribution, posterior_mixture_distribution, latent_k_samples, latent_samples, log_z_given_y_phi
def e_step(self): context = ModularContext(ModularMode.E_STEP, self.data_indices, self.dataset_size, self.config.sample_size) # batch_size * sample_size llh = self.llh(tfd.Categorical(self.logits(context)), context) logprob = context.selection_logprob() + llh logprob = tf.reshape(logprob, [self.config.sample_size, -1]) best_selection_indices = tf.stop_gradient(tf.argmax(logprob, axis=0)) return context.update_best_selection(best_selection_indices)
def swiss(batch_size, size=1., std=0.01): x, _ = datasets.make_swiss_roll(1000) norm = x[:, ::2].max() xs = x[:, 0] * size / norm ys = x[:, 2] * size / norm cat = ds.Categorical(tf.zeros(len(x))) comps = [ds.MultivariateNormalDiag([xi, yi], [std, std]) for xi, yi in zip(xs.ravel(), ys.ravel())] data = ds.Mixture(cat, comps) return data.sample(batch_size)
def radial_gaussians2(batch_size, n_mixture=8, std=0.01, r1=1.0, r2=2.0): thetas = np.linspace(0, 2 * np.pi, n_mixture + 1)[:-1] x1s, y1s = r1 * np.sin(thetas), r1 * np.cos(thetas) x2s, y2s = r2 * np.sin(thetas), r2 * np.cos(thetas) xs = np.vstack([x1s, x2s]) ys = np.vstack([y1s, y2s]) cat = ds.Categorical(tf.zeros(n_mixture * 2)) comps = [ds.MultivariateNormalDiag([xi, yi], [std, std]) for xi, yi in zip(xs.ravel(), ys.ravel())] data = ds.Mixture(cat, comps) return data.sample(batch_size)
def kl_categorical(p=None, q=None, p_logits=None, q_logits=None, eps=1e-6): ''' Given p and q (as EITHER BOTH logits or softmax's) then this func returns the KL between them. Utilizes an eps in order to resolve divide by zero / log issues ''' if p_logits is not None and q_logits is not None: Q = distributions.Categorical(logits=q_logits, dtype=tf.float32) P = distributions.Categorical(logits=p_logits, dtype=tf.float32) elif p is not None and q is not None: print 'p shp = ', p.get_shape().as_list(), \ ' | q shp = ', q.get_shape().as_list() Q = distributions.Categorical(probs=q + eps, dtype=tf.float32) P = distributions.Categorical(probs=p + eps, dtype=tf.float32) else: raise Exception("please provide either logits or dists") return distributions.kl_divergence(P, Q)
def line_1d(batch_size, n_mixture=5, std=0.01, d=1.0, add_far=False): xs = np.linspace(-d, d, n_mixture, dtype=np.float32) p = [0.] * n_mixture if add_far: xs = np.concatenate([np.array([-10 * d]), xs, np.array([10 * d])], 0) p = [0.] + p + [0.] cat = ds.Categorical(tf.constant(p)) comps = [ds.MultivariateNormalDiag([xi], [std]) for xi in xs.ravel()] data = ds.Mixture(cat, comps) return data.sample(batch_size)
def rect(batch_size, std=0.01, nx=5, ny=5, h=2, w=2): x = np.linspace(- h, h, nx) y = np.linspace(- w, w, ny) p = [] for i in x: for j in y: p.append((i, j)) cat = ds.Categorical(tf.zeros(len(p))) comps = [ds.MultivariateNormalDiag([xi, yi], [std, std]) for xi, yi in p] data = ds.Mixture(cat, comps) return data.sample(batch_size)
def mode_evaluations(self): context = ModularContext(ModularMode.MODE_EVALUATION) outputs = tfd.Categorical(logits=self.logits(context)) evaluations = { "loglikelihood/mode": tf.reduce_mean(self.llh(outputs, context)), "accuracy/mode": self.accuracy(outputs, context), "entropy/selection": context.selection_entropy(), "entropy/batch": context.batch_selection_entropy(), } module_proportions = context.module_proportions() return evaluations, module_proportions
def ring_mog(batch_size, n_mixture=8, std=0.01, radius=1.0): thetas = np.linspace(0, 2 * np.pi, n_mixture, endpoint=False) xs = radius * np.sin(thetas, dtype=np.float32) ys = radius * np.cos(thetas, dtype=np.float32) cat = ds.Categorical(tf.zeros(n_mixture)) comps = [ ds.MultivariateNormalDiag([xi, yi], [std, std]) for xi, yi in zip(xs.ravel(), ys.ravel()) ] data = ds.Mixture(cat, comps) return data.sample(batch_size), np.stack([xs, ys], axis=1)
def decode_softmax(self, prev_state_0, prev_state_1, prev_input, position, mask): with tf.variable_scope("loop"): tf.get_variable_scope().reuse_variables() output = self.mlp(prev_input) self.mask = mask masked_scores = self.attention( self.encoder_output_ex, output) # [batch_size, time_sequence] prob = distr.Categorical(masked_scores) log_softmax = prob.log_prob(position) return log_softmax
def m_step(self): context = ModularContext(ModularMode.M_STEP, self.data_indices, self.dataset_size) objective = self.llh(tfd.Categorical(self.logits(context)), context) selection_logprob = context.selection_logprob() ctrl_objective = -tf.reduce_mean(selection_logprob) module_objective = -tf.reduce_mean(objective) joint_objective = ctrl_objective + module_objective optimizer = getattr(tf.train, self.config.optimizer) optimizer = optimizer(self.config.learning_rate) return optimizer.minimize(joint_objective)
def grid_mog(batch_size, n_mixture=25, std=0.05, space=2.0): grid_range = int(np.sqrt(n_mixture)) modes = np.array([ np.array([i, j]) for i, j in itertools.product(range(-grid_range + 1, grid_range, 2), range(-grid_range + 1, grid_range, 2)) ], dtype=np.float32) modes = modes * space / 2. cat = ds.Categorical(tf.zeros(n_mixture)) comps = [ds.MultivariateNormalDiag(mu, [std, std]) for mu in modes] data = ds.Mixture(cat, comps) return data.sample(batch_size), modes
def sample(self, inputs, seed=None): inputs = tf.concat(inputs, 2) logits = tf.transpose(self.weights[0]) dist = dists.Categorical(logits=logits) indices = dist.sample([inputs.shape[0]], seed=seed) indices = tf.reshape(tf.tile(indices, [1, inputs.shape[1]]), [inputs.shape[0], self.size, inputs.shape[1]]) indices = tf.transpose(indices, [0, 2, 1]) others = tf.meshgrid(tf.range(inputs.shape[1]), tf.range(inputs.shape[0]), tf.range(self.size)) indices = tf.stack([others[1], others[0], indices], axis=-1) result = tf.gather_nd(inputs, indices) return result
def input_tensor(batch_size, return_mixture=False): gaussians = [ ds.MultivariateNormalDiag(loc=(5.0, 5.0), scale_diag=(0.5, 0.5)), ds.MultivariateNormalDiag(loc=(-5.0, 5.0), scale_diag=(0.5, 0.5)), ds.MultivariateNormalDiag(loc=(-5.0, -5.0), scale_diag=(0.5, 0.5)), ds.MultivariateNormalDiag(loc=(5.0, -5.0), scale_diag=(0.5, 0.5)) ] uniform_mixture_probs = [1 / len(gaussians)] * len(gaussians) mixture = ds.Mixture(cat=ds.Categorical(uniform_mixture_probs), components=gaussians) sampled = mixture.sample(batch_size) return (sampled, mixture) if return_mixture else sampled
def __init__(self, nb_components, dimension, mu_init=None, cov_init=None, trainable=False, name='gmm'): super(GMM, self).__init__(name=name) with self._enter_variable_scope(): self.pi = tf.get_variable("pi", shape=(nb_components), dtype=tf.float32, trainable=trainable) if mu_init is not None: assert mu_init.get_shape().as_list() == [ nb_components, dimension ] self.mu = tf.get_variable("mixture_mu", initializer=mu_init, dtype=tf.float32, trainable=trainable) else: self.mu = tf.get_variable("mixture_mu", shape=(nb_components, dimension), dtype=tf.float32, trainable=trainable) if cov_init is not None: assert cov_init.get_shape().as_list() == [ nb_components, dimension, dimension ] self._L_k_raw = tf.get_variable( "mixture_lower_cov", initializer=tf.cholesky(cov_init), dtype=tf.float32, trainable=trainable) else: self._L_k_raw = tf.get_variable("mixture_lower_cov", shape=(nb_components, dimension, dimension), dtype=tf.float32, trainable=trainable) self.model = tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical(logits=self.pi), components_distribution=tfd.MultivariateNormalFullCovariance( loc=self.mu, covariance_matrix=self._L_k_raw))
def create_mixgaussian2D(num_components=8, std=0.05): cat = ds.Categorical(tf.zeros(num_components, dtype=tf.float32)) # mus = np.array([np.array([i, j]) for i, j in itertools.product(np.linspace(-1, 1, 5), # np.linspace(-1, 1, 5))],dtype=np.float32) mus = np.array( [(np.cos(theta), np.sin(theta)) for theta in np.linspace(0, 2 * np.pi, num_components + 1)], dtype=np.float32) # mus = (mus + 2) / 4. sigmas = [ np.array([std, std]).astype(np.float32) for i in range(num_components) ] components = list((ds.MultivariateNormalDiag(mu, sigma) for (mu, sigma) in zip(mus, sigmas))) data = ds.Mixture(cat, components) return data
def GMMdiag(log_pis, mus, log_sigmas, clip_lo=-10, clip_hi=10): # shapes # pis: [..., GMM_c] # mus: [..., GMM_c*state_dim] # sigmas: [..., GMM_c*state_dim] GMM_c = log_pis.shape[-1] ax = len(mus.shape) - 1 mus_split = tf.split(mus, GMM_c, axis=ax) sigmas = tf.exp(tf.clip_by_value(log_sigmas, clip_lo, clip_hi)) sigmas_split = tf.split(sigmas, GMM_c, axis=ax) cat = distributions.Categorical(logits=log_pis) dists = [ distributions.MultivariateNormalDiag(mu, sigma) for mu, sigma in zip(mus_split, sigmas_split) ] return distributions.Mixture(cat, dists)
def sample(self, time, outputs, state, name=None): """Returns `sample_ids`.""" del time, state # return outputs with tf.variable_scope('mdn'): means = tf.reshape( tf.slice( outputs, [0, 0], [self._batch_size, self._n_features * self._n_mixtures]), [self._batch_size, self._n_features, self._n_mixtures], name='means') sigmas = tf.minimum( 10000.0, tf.maximum( 1e-1, tf.nn.softplus( tf.reshape( tf.slice(outputs, [0, self._n_features * self._n_mixtures], [ self._batch_size, self._n_features * self._n_mixtures ], name='sigmas_pre_norm'), [ self._batch_size, self._n_features, self._n_mixtures ])))) weights = tf.nn.softmax(tf.reshape( tf.slice(outputs, [0, 2 * self._n_features * self._n_mixtures], [self._batch_size, self._n_mixtures], name='weights_pre_norm'), [self._batch_size, self._n_mixtures]), name='weights') components = [] for gauss_i in range(self._n_mixtures): mean_i = means[:, :, gauss_i] sigma_i = sigmas[:, :, gauss_i] components.append( tfd.MultivariateNormalDiag(loc=mean_i, scale_diag=sigma_i)) gauss = tfd.Mixture(cat=tfd.Categorical(probs=weights), components=components) sample = gauss.sample() return sample
def get_mixture_model(self): """ ds.Mixture in TensorFlow requires a Categorical dist. to determine which individual dist. is used for generating a sample, 'components' is a list of different classes defined from tf.contrib.distributions """ prob = 1. / self.num_gaussians probs = [prob for i in range(self.num_gaussians)] mus = self.get_mus() scales = self.get_scale_matrices() gaussians = [ ds.MultivariateNormalTriL(loc=mus[i], scale_tril=scales[i]) for i in range(self.num_gaussians) ] mixture = ds.Mixture(cat=ds.Categorical(probs=probs), components=gaussians) return mixture
def get_mixture(j, xoo, sx_defalut): # mi, ni = [ get_mix(0.33, xoo[i], 0.4, j, i) for i in range(len(xoo))] mms = [] nns = [] value = 1.0 / (1.0 * len(xoo)) for m, n in [ get_mix(value, xoo[i], sx_defalut, j, i) for i in range(len(xoo)) ]: mms.append(m) nns.append(n) print(mms[:-1]) mcomp = get_normalized_complement(mms[:-1]) print(mcomp) mms = mms[:-1] + [mcomp] print(mms) # m2, n2 = get_mix(0.33, 1.3, 0.4, j, 2) # m3, n3 = get_mix(0.33, 1.5, 0.4, j, 3) xDist = tfd.Mixture(cat=tfd.Categorical(probs=mms), components=nns) return xDist