def get_pdf(param_vec, vehicle_type): # see https://ericmjl.github.io/blog/2019/5/29/reasoning-about-shapes-and-probability-distributions/ # for info on shapes if vehicle_type == 'other_vehicle': alpha, mus, sigmas = slice_pvector(param_vec, vehicle_type) # Unpack parameter vectors mvn = tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical(probs=alpha), components_distribution=tfd.Normal( loc=mus, scale=sigmas)) if vehicle_type == 'merge_vehicle': alphas, mus_long, sigmas_long, mus_lat, \ sigmas_lat, rhos = slice_pvector(param_vec, vehicle_type) cov = get_CovMatrix(rhos, sigmas_long, sigmas_lat) mus = tf.stack([mus_long, mus_lat], axis=3, name='mus') mvn = tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical( probs=alphas), components_distribution=tfd.MultivariateNormalTriL( loc=mus, scale_tril=tf.linalg.cholesky(cov), name='MultivariateNormalTriL')) # print('mus shape: ', mus.shape) return mvn
def conditional_distribution(self, x, slice_in, slice_out): marginal_in = ds.MixtureSameFamily( mixture_distribution=ds.Categorical(probs=self._priors), components_distribution=ds.MultivariateNormalFullCovariance( loc=self._locs[:, slice_in], covariance_matrix=self._covs[:, slice_in, slice_in])) p_k_in = ds.Categorical(logits=tf_utils.log_normalize( marginal_in.components_distribution.log_prob(x[:, None]) + marginal_in.mixture_distribution.logits[None], axis=1)) sigma_in_out = self._covs[:, slice_in, slice_out] inv_sigma_in_in = tf.linalg.inv(self._covs[:, slice_in, slice_in]) inv_sigma_out_in = tf.matmul(sigma_in_out, inv_sigma_in_in, transpose_a=True) A = inv_sigma_out_in b = self._locs[:, slice_out] - tf.matmul( inv_sigma_out_in, self._locs[:, slice_in, None])[:, :, 0] cov_est = (self._covs[:, slice_out, slice_out] - tf.matmul(inv_sigma_out_in, sigma_in_out)) ys = tf.einsum('aij,bj->abi', A, x) + b[:, None] p_out_in_k = ds.MultivariateNormalFullCovariance( tf.transpose(ys, perm=(1, 0, 2)), cov_est) return ds.MixtureSameFamily(mixture_distribution=p_k_in, components_distribution=p_out_in_k)
def prepare_t_student_mixture(self): cluster_weights_unnorm = self.get_cluster_unnormalized_weights() cluster_weights_norm = cluster_weights_unnorm / np.sum( cluster_weights_unnorm) cat_distr = tpd.Categorical( probs=tf.cast(cluster_weights_norm, dtype=tf.float32)) t_student_params = [ self.prepare_t_student_params(ind) for ind in range(self.clusters_num) ] dofs = tf.constant( [t_student_params[ind]['df'] for ind in range(self.clusters_num)], dtype=tf.float32) means = tf.stack([ t_student_params[ind]['mean'] for ind in range(self.clusters_num) ]) cov_chols = tf.stack([ t_student_params[ind]['cov_chol'] for ind in range(self.clusters_num) ]) t_student_distr = tpd.MultivariateStudentTLinearOperator( df=dofs, loc=means, scale=tf.linalg.LinearOperatorLowerTriangular(cov_chols)) t_student_mixture = tpd.MixtureSameFamily( mixture_distribution=cat_distr, components_distribution=t_student_distr) return t_student_mixture
def _build(self, inputs=None): return tfd.MixtureSameFamily( components_distribution=tfd.MultivariateNormalDiag( loc=self._loc, scale_diag=tf.nn.softplus(self._raw_scale_diag)), mixture_distribution=tfd.Categorical(logits=self._mixture_logits), name="prior")
def get_distribution(self, x, **kwargs): """Build the mixture distribution implied by the set of oracles that are trained in this module Args: x: tf.Tensor a batch of training inputs shaped like [batch_size, channels] Returns: distribution: tfpd.Distribution the mixture of gaussian distributions implied by the oracles """ # get the distribution parameters for all models params = defaultdict(list) for fm in self.forward_models: for key, val in fm.get_params(x, **kwargs).items(): params[key].append(val) # stack the parameters in a new component axis for key, val in params.items(): params[key] = tf.stack(val, axis=-1) # build the mixture distribution using the family of component one weights = tf.fill([self.bootstraps], 1 / self.bootstraps) return tfpd.MixtureSameFamily( tfpd.Categorical(probs=weights), self.forward_models[0].distribution(**params))
def get_steering(preds): alpha, mu, sigma = slice_parameter_vectors(preds.numpy(), components) # print(alpha) max_prob = np.max(alpha, axis=-1) if max_prob > 0.9995: index = np.argmax(alpha, axis=-1) angle = mu[:, index[0]] else: angle = np.multiply(alpha, mu).sum(axis=-1) # print(angle) gm = tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical(probs=alpha), components_distribution=tfd.Normal(loc=mu, scale=sigma), ) x = np.linspace(-1, 1, int(1e3)) pyx = gm.prob(x) plot = cv2.plot.Plot2d_create( np.array(x).astype(np.float64), np.array(pyx).astype(np.float64)) plot.setPlotBackgroundColor((255, 255, 255)) plot.setInvertOrientation(True) plot.setPlotLineColor(0) plot = plot.render() cv2.imshow("Distribution", plot) return angle[0]
def create_dp_sb_gmm(nobs, K, dtype=np.float64): return tfd.JointDistributionNamed( dict( # Mixture means mu=tfd.Independent(tfd.Normal(np.zeros(K, dtype), 3), reinterpreted_batch_ndims=1), # Mixture scales sigma=tfd.Independent(tfd.LogNormal(loc=np.full(K, -2, dtype), scale=0.5), reinterpreted_batch_ndims=1), # Mixture weights (stick-breaking construction) alpha=tfd.Gamma(concentration=np.float64(1.0), rate=10.0), v=lambda alpha: tfd.Independent( # NOTE: Dave Moore suggests doing this instead, to ensure # that a batch dimension in alpha doesn't conflict with # the other parameters. tfd.Beta(np.ones(K - 1, dtype), alpha[..., tf.newaxis]), reinterpreted_batch_ndims=1), # Observations (likelihood) obs=lambda mu, sigma, v: tfd.Sample( tfd.MixtureSameFamily( # This will be marginalized over. mixture_distribution=tfd.Categorical(probs=stickbreak(v)), components_distribution=tfd.Normal(mu, sigma)), sample_shape=nobs)))
def mix(n, eta, loc, scale, name): return tfd.Sample( tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical(probs=eta), components_distribution=tfd.Normal(loc=loc, scale=scale), name=name), sample_shape=n)
def __call__(self): """Get the distribution object from the backend""" if get_backend() == "pytorch": # import torch.distributions as tod raise NotImplementedError else: import tensorflow as tf from tensorflow_probability import distributions as tfd # Convert to tensorflow distributions if probflow distributions if isinstance(self.distributions, BaseDistribution): self.distributions = self.distributions() # Broadcast probs/logits shape = self.distributions.batch_shape args = {"logits": None, "probs": None} if self.logits is not None: args["logits"] = tf.broadcast_to(self["logits"], shape) else: args["probs"] = tf.broadcast_to(self["probs"], shape) # Return TFP distribution object return tfd.MixtureSameFamily( tfd.Categorical(**args), self.distributions )
def make_mixture_prior(latent_size, mixture_components): """Creates the mixture of Gaussians prior distribution. Args: latent_size: The dimensionality of the latent representation. mixture_components: Number of elements of the mixture. Returns: random_prior: A `tfd.Distribution` instance representing the distribution over encodings in the absence of any evidence. """ if mixture_components == 1: return tfd.MultivariateNormalDiag( loc=tf.zeros([latent_size]), scale_identity_multiplier=1.0) loc = tf.get_variable(name="loc", shape=[mixture_components, latent_size]) raw_scale_diag = tf.get_variable( name="raw_scale_diag", shape=[mixture_components, latent_size]) mixture_logits = tf.get_variable( name="mixture_logits", shape=[mixture_components]) return tfd.MixtureSameFamily( components_distribution=tfd.MultivariateNormalDiag( loc=loc, scale_diag=tf.nn.softplus(raw_scale_diag)), mixture_distribution=tfd.Categorical(logits=mixture_logits), name="prior")
def find_pdf(index): alpha, mu, sigma = vector_unfold(mdn.predict(x_test[index].reshape(1, -1))) gm = tf_prob.MixtureSameFamily( mixture_distribution=tf_prob.Categorical(probs=alpha), components_distribution=tf_prob.Normal(loc=mu, scale=sigma)) pyx = gm.prob(x) return pyx
def mix(gamma, eta, loc, scale, neg_inf, n): return tfd.Mixture( cat=tfd.Categorical(probs=tf.stack([gamma, 1 - gamma], axis=-1)), components=[ tfd.Sample(tfd.Normal(np.float64(neg_inf), 1e-5), sample_shape=n), tfd.Sample(tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical(probs=eta), components_distribution=tfd.Normal(loc=loc, scale=scale)), sample_shape=n) ])
def gnll_loss(y, parameter_vector): #Calculate negative log-likelihood alpha, mu, sigma = vector_unfold(parameter_vector) gm = tf_prob.MixtureSameFamily( mixture_distribution=tf_prob.Categorical(probs=alpha), components_distribution=tf_prob.Normal(loc=mu, scale=sigma)) log_likelihood = gm.log_prob(tf.transpose(y)) return -tf.reduce_mean(log_likelihood, axis=-1)
def mix(gamma, eta, loc, scale, neg_inf): _gamma = gamma[..., tf.newaxis] # FIXME: Possible to use tfd.Blockwise? return tfd.Mixture( cat=tfd.Categorical(probs=tf.concat([_gamma, 1 - _gamma], axis=-1)), components=[ tfd.Deterministic(np.float64(neg_inf)), tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical(probs=eta), components_distribution=tfd.Normal(loc=loc, scale=scale)), ])
def mixture_prior(FLAGS): """Unsure exactly how this works. I thought it was supposed to use a single Gaussian prior""" loc = tf.get_variable(name="loc", shape=[FLAGS['vae_k'], FLAGS['z_size']]) raw_scale_diag = tf.get_variable(name="raw_scale_diag", shape=[FLAGS['vae_k'], FLAGS['z_size']]) mixture_logits = tf.get_variable(name="mixture_logits", shape=[FLAGS['vae_k']]) return tfd.MixtureSameFamily( components_distribution=tfd.MultivariateNormalDiag( loc=loc, scale_diag=tf.nn.softplus(raw_scale_diag)), mixture_distribution=tfd.Categorical(logits=mixture_logits), name="prior")
def plot_rate(ax, index, color_index): alpha, mu, sigma = vector_unfold(mdn.predict(x_test[index].reshape(1, -1))) gm = tf_prob.MixtureSameFamily( mixture_distribution=tf_prob.Categorical(probs=alpha), components_distribution=tf_prob.Normal(loc=mu, scale=sigma)) pyx = gm.prob(x) ax.plot(x, pyx, alpha=1, color=sns.color_palette()[color_index], linewidth=2, label="PDF for prediction {}".format(index))
def gnll_loss(y, parameter_vector): """ Computes the mean negative log-likelihood loss of y given the mixture parameters. """ alpha, mu, sigma = slice_parameter_vectors( parameter_vector) # Unpack parameter vectors gm = tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical(probs=alpha), components_distribution=tfd.Normal(loc=mu, scale=sigma)) # Evaluate log-probability of y log_likelihood = gm.log_prob(tf.transpose(y)) return -tf.reduce_mean(log_likelihood, axis=-1)
def gen_mixture(self, out): pvs = self.slice_parameter_vectors(out) mixtures = [] for pv in pvs: logits, locs, log_scales = pv scales = tf.math.softmax(log_scales) mixtures.append( tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical(logits=logits), components_distribution=tfd.Normal(loc=locs, scale=scales))) joint = tfd.JointDistributionSequential(mixtures, name='joint_mixtures') blkws = tfd.Blockwise(joint) return blkws
def mdn_loss(self, z, alpha, mu, sigma): alpha = K.repeat_elements(alpha, self.z_dim, axis=1) alpha = K.expand_dims(alpha, axis=3) mu = K.reshape(mu, (tf.shape(mu)[0], self.z_dim, self.cat_num)) mu = K.expand_dims(mu, axis=3) sigma = K.reshape(sigma, (tf.shape(sigma)[0], self.z_dim, self.cat_num)) sigma = K.expand_dims(sigma, axis=3) gm = tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical(probs=alpha), components_distribution=tfd.Normal(loc=mu, scale=sigma)) z = tf.transpose(z, (0, 2, 1)) return tf.reduce_mean(-gm.log_prob(z))
def _init_distribution(conditions, **kwargs): p, d = conditions["p"], conditions["distributions"] # if 'd' is a sequence of pymc distributions, then use the underlying # tfp distributions for the mixture if isinstance(d, collections.abc.Sequence): if any(not isinstance(el, Distribution) for el in d): raise TypeError( "every element in 'distribution' needs to be a pymc4.Distribution object" ) distr = [el._distribution for el in d] return tfd.Mixture( tfd.Categorical(probs=p, **kwargs), distr, **kwargs, use_static_graph=True ) # else if 'd' is a pymc distribution with batch_size > 1 elif isinstance(d, Distribution): return tfd.MixtureSameFamily( tfd.Categorical(probs=p, **kwargs), d._distribution, **kwargs ) else: raise TypeError( "'distribution' needs to be a pymc4.Distribution object or a sequence of distributions" )
def compute_mdn_probs(logits, locs, scales, array, FLAGS): """compute probability of all blocks being in correct spots""" if CACHE['MDN'] is None: CACHE['logits_ph'] = tf.placeholder(tf.float32, shape=[None, *logits.shape[1:]], name='logits_ph') CACHE['locs_ph'] = tf.placeholder(tf.float32, shape=[None, *locs.shape[1:]], name='locs_ph') CACHE['scales_ph'] = tf.placeholder(tf.float32, shape=[None, *scales.shape[1:]], name='scales_ph') CACHE['array_ph'] = tf.placeholder(tf.float32, shape=[None, *array.shape[1:]], name='array_ph') cat = tfd.Categorical(logits=logits) comp = tfd.MultivariateNormalDiag(loc=locs, scale_diag=scales) mixture = tfd.MixtureSameFamily(cat, comp) CACHE['MDN'] = mixture mask0, mask0_count, tstate = mask_state({'array': CACHE['array_ph']}) # TODO: add mask state logic here if needed # TODO: probably want to get rid of masked ones and scale by number of objects to make it more balanced for more vs. less objects #CACHE['prob_op'] = mixture.prob(tstate) CACHE['log_prob_op'] = mixture.log_prob(tstate) sess = get_session() feed_dict = { CACHE['logits_ph']: logits, CACHE['locs_ph']: locs, CACHE['scales_ph']: scales, CACHE['array_ph']: array } log_prob = sess.run(CACHE['log_prob_op'], feed_dict) combined_log_prob = np.sum(log_prob, axis=0) # combine along number of shapes return -combined_log_prob
def construct_model(self): with self.graph.as_default(): self.random = np.random.RandomState(self.seed) tf.compat.v1.set_random_seed( self.random.randint(1e10, dtype=np.int64)) self.global_step = tf.Variable(0, trainable=False, name='global_step') self.is_training = tf.compat.v1.placeholder_with_default( False, [], name='is_training') x = self.x = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, self.n_in], name='x') y = self.y = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, self.n_pred], name='y') T = self.T = tf.compat.v1.placeholder(dtype=tf.float32, shape=None, name='T') C = self.C = tf.compat.v1.placeholder(dtype=tf.float32, shape=None, name='C') estimate = self.forward(x) with tf.control_dependencies( self._debug_nan([estimate, x], names=['estim', 'x'])): self.coefs = prior, mu, sigma = self.get_coefs(estimate) dist = getattr(tfd, self.distribution)(mu, sigma) prob = tfd.Categorical(probs=prior) mix = tfd.MixtureSameFamily(prob, dist) def impute(): return tf.reduce_mean([ mix.log_prob( tf.compat.v2.where(tf.math.is_nan(y), mix.sample(), y)) for _ in range(self.imputations) ], 0) likelihood = tf.compat.v2.cond(tf.reduce_any(tf.math.is_nan(y)), impute, lambda: mix.log_prob(y)) neg_log_pr = tf.reduce_mean(-likelihood) l2_loss = tf_layers.apply_regularization( tf_layers.l2_regularizer(scale=self.l2)) total_loss = neg_log_pr + l2_loss self.neg_log_pr = neg_log_pr with tf.control_dependencies( tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.UPDATE_OPS)): learn_rate = self.lr # learn_rate = tf.train.polynomial_decay(self.lr, self.global_step, decay_steps=self.n_iter, end_learning_rate=self.lr/10) train_op = tf.compat.v1.train.AdamOptimizer(learn_rate) grads, var = zip(*train_op.compute_gradients(total_loss)) with tf.control_dependencies( self._debug_nan( list(grads) + [total_loss], names=[v.name.split(':')[0] for v in var] + ['loss'])): self.train = train_op.apply_gradients( zip(grads, var), global_step=self.global_step, name='train_op') self.loss = tf.identity(total_loss, name='model_loss') tf.compat.v1.global_variables_initializer().run( session=self.session) self.saver = tf.compat.v1.train.Saver(max_to_keep=1, save_relative_paths=True)
class BatchShapeInferenceTests(test_util.TestCase): @parameterized.named_parameters( { 'testcase_name': '_trivial', 'value_fn': lambda: tfd.Normal(loc=0., scale=1.), 'expected_batch_shape_parts': { 'loc': [], 'scale': [] }, 'expected_batch_shape': [] }, { 'testcase_name': '_simple_tensor_broadcasting', 'value_fn': lambda: tfd.MultivariateNormalDiag( # pylint: disable=g-long-lambda loc=[0., 0.], scale_diag=tf.convert_to_tensor([[1., 1.], [1., 1.]])), 'expected_batch_shape_parts': { 'loc': [], 'scale_diag': [2] }, 'expected_batch_shape': [2] }, { 'testcase_name': '_rank_deficient_tensor_broadcasting', 'value_fn': lambda: tfd.MultivariateNormalDiag( # pylint: disable=g-long-lambda loc=0., scale_diag=tf.convert_to_tensor([[1., 1.], [1., 1.]])), 'expected_batch_shape_parts': { 'loc': [], 'scale_diag': [2] }, 'expected_batch_shape': [2] }, { 'testcase_name': '_dynamic_event_ndims', 'value_fn': lambda: _MVNTriLWithDynamicParamNdims( # pylint: disable=g-long-lambda loc=[[0., 0.], [1., 1.], [2., 2.]], scale_tril=[[1., 0.], [-1., 1.]]), 'expected_batch_shape_parts': { 'loc': [3], 'scale_tril': [] }, 'expected_batch_shape': [3] }, { 'testcase_name': '_mixture_same_family', 'value_fn': lambda: tfd.MixtureSameFamily( # pylint: disable=g-long-lambda mixture_distribution=tfd.Categorical(logits=[[[1., 2., 3.], [4., 5., 6.]]]), components_distribution=tfd.Normal( loc=0., scale=[[[1., 2., 3.], [4., 5., 6.]]])), 'expected_batch_shape_parts': { 'mixture_distribution': [1, 2], 'components_distribution': [1, 2] }, 'expected_batch_shape': [1, 2] }, { 'testcase_name': '_deeply_nested', 'value_fn': lambda: tfd.Independent( # pylint: disable=g-long-lambda tfd.Independent(tfd.Independent(tfd.Independent( tfd.Normal(loc=0., scale=[[[[[[[[1.]]]]]]]]), reinterpreted_batch_ndims=2), reinterpreted_batch_ndims=0), reinterpreted_batch_ndims=1), reinterpreted_batch_ndims=1), 'expected_batch_shape_parts': { 'distribution': [1, 1, 1, 1] }, 'expected_batch_shape': [1, 1, 1, 1] }, { 'testcase_name': 'noparams', 'value_fn': tfb.Exp, 'expected_batch_shape_parts': {}, 'expected_batch_shape': [] }) @test_util.numpy_disable_test_missing_functionality('b/188002189') def test_batch_shape_inference_is_correct(self, value_fn, expected_batch_shape_parts, expected_batch_shape): value = value_fn( ) # Defer construction until we're in the right graph. parts = batch_shape_lib.batch_shape_parts(value) self.assertAllEqualNested( parts, nest.map_structure_up_to(parts, tf.TensorShape, expected_batch_shape_parts)) self.assertAllEqual(expected_batch_shape, batch_shape_lib.inferred_batch_shape_tensor(value)) batch_shape = batch_shape_lib.inferred_batch_shape(value) self.assertIsInstance(batch_shape, tf.TensorShape) self.assertTrue(batch_shape.is_compatible_with(expected_batch_shape)) def test_bijector_event_ndims(self): bij = tfb.Sigmoid(low=tf.zeros([2]), high=tf.ones([3, 2])) self.assertAllEqual(batch_shape_lib.inferred_batch_shape(bij), [3, 2]) self.assertAllEqual(batch_shape_lib.inferred_batch_shape_tensor(bij), [3, 2]) self.assertAllEqual( batch_shape_lib.inferred_batch_shape(bij, bijector_x_event_ndims=1), [3]) self.assertAllEqual( batch_shape_lib.inferred_batch_shape_tensor( bij, bijector_x_event_ndims=1), [3]) # Verify that we don't pass Nones through to component # `experimental_batch_shape(x_event_ndims=None)` calls, where they'd be # incorrectly interpreted as `x_event_ndims=forward_min_event_ndims`. joint_bij = tfb.JointMap([bij, bij]) self.assertAllEqual( batch_shape_lib.inferred_batch_shape( joint_bij, bijector_x_event_ndims=[None, None]), tf.TensorShape(None))
def build_psis(n_steps, n_state, n_dim, scale=None, fixed_handles=True, dtype=tf.float32, dim_obs=None): """ :param n_steps: Length of each trajectory :type n_steps: list of int :param n_state: Number of basis function :type n_state: int :type scale: float 0. - 1. :return: psis, hs, handles else: :return: psis, hs """ from tensorflow_probability import distributions as ds # create handles for each demos if fixed_handles: handles = tf.stack([ tf.linspace(tf.cast(0., dtype=dtype), tf.cast(n, dtype=dtype), n_state) for n in n_steps ]) else: handles = tf.Variable([ tf.linspace(tf.cast(0., dtype=dtype), tf.cast(n, dtype=dtype), n_state) for n in n_steps ]) n_traj = len(n_steps) if scale is None: scale = 1. / n_state # create mixtures whose p_z will be the activations h_mixture = [ ds.MixtureSameFamily( mixture_distribution=ds.Categorical(logits=tf.ones(n_state)), components_distribution=ds.MultivariateNormalDiag( loc=handles[j][:, None], scale_diag=tf.cast(scale, dtype)[None, None] * n_steps[j])) for j in range(n_traj) ] # create evaluation points of the mixture for each demo idx = [tf.range(n, dtype=dtype) for n in n_steps] from .tf_utils import log_normalize j = 0 # create activations # print tf.transpose(h_mixture[0].components_log_prob(idx[0][:, None])) hs = [ tf.exp( log_normalize(h_mixture[j].components_distribution.log_prob( idx[j][:, None, None]), axis=1)) for j in range(n_traj) ] # hs = [tf for h in hs] if dim_obs is None: psis = [ build_psi(hs[i], n_dim, n, n_state, dtype=dtype) for i, n in enumerate(n_steps) ] else: psis = [ build_psi_partial(hs[i], n_dim, dim_obs, n, n_state, dtype=dtype) for i, n in enumerate(n_steps) ] return hs, psis, handles
def __init__(self, log_unnormalized_prob, gmm=None, k=10, loc=0., std=1., ndim=None, loc_tril=None, samples=20, temp=1., cov_type='diag', loc_scale=1., priors_scale=1e1): """ :param log_unnormalized_prob: Unnormalized log density to estimate :type log_unnormalized_prob: a tensorflow function that takes [batch_size, ndim] as input and returns [batch_size] :param gmm: :param k: number of components for GMM approximation :param loc: for initialization, mean :param std: for initialization, standard deviation :param ndim: """ self.log_prob = log_unnormalized_prob self.ndim = ndim self.temp = temp if gmm is None: assert ndim is not None, "If no gmm is defined, should give the shape of x" if cov_type == 'diag': _log_priors_var = tf.Variable(1. / priors_scale * log_normalize(tf.ones(k))) log_priors = priors_scale * _log_priors_var if isinstance(loc, tf.Tensor) and loc.shape.ndims == 2: _locs_var = tf.Variable(1. / loc_scale * loc) locs = loc_scale * _locs_var else: _locs_var = tf.Variable( 1. / loc_scale * tf.random.normal((k, ndim), loc, std)) locs = loc_scale * _locs_var log_std_diags = tf.Variable(tf.log(std/k * tf.ones((k, ndim)))) self._opt_params = [_log_priors_var, _locs_var, log_std_diags] gmm = _distributions.MixtureSameFamily( mixture_distribution=_distributions.Categorical(logits=log_priors), components_distribution=_distributions.MultivariateNormalDiag( loc=locs, scale_diag=tf.math.exp(log_std_diags) ) ) elif cov_type == 'full': _log_priors_var = tf.Variable(1./priors_scale * log_normalize(tf.ones(k))) log_priors = priors_scale * _log_priors_var if isinstance(loc, tf.Tensor) and loc.shape.ndims == 2: _locs_var = tf.Variable(1. / loc_scale * loc) locs = loc_scale * _locs_var else: _locs_var = tf.Variable(1./loc_scale * tf.random.normal((k, ndim), loc, std)) locs = loc_scale * _locs_var loc_tril = loc_tril if loc_tril is not None else std/k # tril_cov = tf.Variable(loc_tril ** 2 * tf.eye(ndim, batch_shape=(k, ))) tril_cov = tf.Variable(tf1.log(loc_tril) * tf.eye(ndim, batch_shape=(k, ))) covariance = tf.linalg.expm(tril_cov + tf1.matrix_transpose(tril_cov)) # self._opt_params = [_log_priors_var, _locs_var, tril_cov] gmm = _distributions.MixtureSameFamily( mixture_distribution=_distributions.Categorical(logits=log_priors), components_distribution=_distributions.MultivariateNormalFullCovariance( loc=locs, covariance_matrix=covariance ) ) else: raise ValueError("Unrecognized covariance type") self.k = k self.num_samples = samples self.gmm = gmm
def conditional_distribution(self, x): return ds.MixtureSameFamily( mixture_distribution=self._gate.conditional_mixture_distribution( x), components_distribution=self._experts. conditional_components_distribution(x))
def mixture_distribution(alpha, mu, sigma): # scale_tril = tfp.math.fill_triangular(sigma) gm = tfd.MixtureSameFamily(mixture_distribution=tfd.Categorical(probs=tf.squeeze(alpha, 3)), components_distribution=tfd.MultivariateNormalDiag(loc=mu, scale_diag=sigma)) return gm
def uniform_mixture(dist, dtype=None): if dist.batch_shape[-1] == 1: return tfd.BatchReshape(dist, dist.batch_shape[:-1]) dtype = dtype or prec.global_policy().compute_dtype weights = tfd.Categorical(tf.zeros(dist.batch_shape, dtype)) return tfd.MixtureSameFamily(weights, dist)
def mix(eta, loc, scale): return tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical(probs=eta), components_distribution=tfd.Normal(loc=loc, scale=scale))