def mixture_sampling(logit): """ Args: - logit: [B, 2 * out_dim * num_mix + num_mix] Returns: - sample: [B, out_dim] """ mean, logit_kappa, logit_pi = tf.split( logit, num_or_size_splits=[out_dim * num_mix, out_dim * num_mix, num_mix], axis=-1, name='mix_ivm_coeff_split_sampling') mean = tf.reshape(mean, [-1, num_mix, out_dim]) logit_kappa = tf.reshape(logit_kappa, [-1, num_mix, out_dim]) kappa = tf.math.softplus(logit_kappa) logit_pi = tf.reshape(logit_pi, [-1, num_mix]) means = tf.unstack(mean, axis=1) kappas = tf.unstack(kappa, axis=1) mixture = tfd.Mixture(cat=tfd.Categorical(logits=logit_pi), components=[ tfd.Independent(distribution=tfd.VonMises( loc=loc, concentration=scale), reinterpreted_batch_ndims=1) for loc, scale in zip(means, kappas) ]) sample = mixture.sample() return sample
def mixture_loss(y_true, logit, mask): """ Args: - y_true: [B, L, out_dim] - logit: [B, L, 2 * out_dim * num_mix + num_mix] - mask: [B, L] Return: - loss """ batch_size, time_step, _ = tf.shape(y_true) mean, logit_kappa, logit_pi = tf.split( logit, num_or_size_splits=[out_dim * num_mix, out_dim * num_mix, num_mix], axis=-1, name='mix_ivm_coeff_split') mask = tf.reshape(mask, [-1]) # [B*L] mean = tf.reshape(mean, [-1, num_mix, out_dim]) # [B*L, num_mix, out_dim] logit_kappa = tf.reshape( logit_kappa, [-1, num_mix, out_dim]) # [B*L, num_mix, out_dim] logit_pi = tf.reshape(logit_pi, [-1, num_mix]) # [B*L, num_mix] # rescale parameters kappa = tf.math.softplus(logit_kappa) if use_tfp: y_true = tf.reshape(y_true, [-1, out_dim]) means = tf.unstack(mean, axis=1) kappas = tf.unstack(kappa, axis=1) mixture = tfd.Mixture(cat=tfd.Categorical(logits=logit_pi), components=[ tfd.Independent( distribution=tfd.VonMises( loc=loc, concentration=scale), reinterpreted_batch_ndims=1) for loc, scale in zip(means, kappas) ]) loss = -mixture.log_prob(y_true) else: y_true = tf.reshape(y_true, [-1, 1, out_dim]) # [B*L, 1, out_dim] cos_diff = tf.cos(y_true - mean) log_probs = tf.reduce_sum( -LOGTWOPI - (tf.math.log(tf.math.bessel_i0e(kappa)) + kappa) + cos_diff * kappa, axis=-1) mixed_log_probs = log_probs + tf.nn.log_softmax(logit_pi, axis=-1) loss = -tf.reduce_logsumexp(mixed_log_probs, axis=-1) loss = tf.multiply(loss, mask, name='masking') if reduce: return tf.reduce_sum(loss) else: return tf.reshape(loss, [batch_size, time_step])
def gibbs_sampler(batch_size, mu, kappa, lambda_): """ Gibbs sampling for triviate Von Mises distribution. Note: this function is hardcoded for 3-dimensional samples """ x_0 = tf.zeros([batch_size]) x_1 = tf.zeros([batch_size]) x_2 = tf.zeros([batch_size]) samples = [] for i in tf.range(burn_in + avg_count): phi_0 = lambda_[:, 0] * tf.sin( x_1 - mu[:, 1]) + lambda_[:, 1] * tf.sin(x_2 - mu[:, 2]) k_neg_0 = tf.sqrt(kappa[:, 0] * kappa[:, 0] + phi_0 * phi_0) mu_neg_0 = mu[:, 0] + tf.atan(phi_0 / kappa[:, 0]) dist_0 = tfd.VonMises(loc=mu_neg_0, concentration=k_neg_0) x_0 = dist_0.sample() phi_1 = lambda_[:, 0] * tf.sin( x_0 - mu[:, 0]) + lambda_[:, 2] * tf.sin(x_2 - mu[:, 2]) k_neg_1 = tf.sqrt(kappa[:, 1] * kappa[:, 1] + phi_1 * phi_1) mu_neg_1 = mu[:, 1] + tf.atan(phi_1 / kappa[:, 1]) dist_1 = tfd.VonMises(loc=mu_neg_1, concentration=k_neg_1) x_1 = dist_1.sample() phi_2 = lambda_[:, 1] * tf.sin( x_0 - mu[:, 0]) + lambda_[:, 2] * tf.sin(x_1 - mu[:, 1]) k_neg_2 = tf.sqrt(kappa[:, 2] * kappa[:, 2] + phi_2 * phi_2) mu_neg_2 = mu[:, 2] + tf.atan(phi_2 / kappa[:, 2]) dist_2 = tfd.VonMises(loc=mu_neg_2, concentration=k_neg_2) x_2 = dist_2.sample() if i >= burn_in: samples.append(tf.stack([x_0, x_1, x_2], axis=1)) # [avg_count, B, out_dim] return tf.reduce_mean(tf.stack(samples), axis=0)
def _init_distribution(conditions, **kwargs): loc, concentration = conditions["loc"], conditions["concentration"] return tfd.VonMises(loc=loc, concentration=concentration, **kwargs)
def _base_dist(self, mu: TensorLike, kappa: TensorLike, *args, **kwargs): return tfd.VonMises(loc=mu, concentration=kappa, *args, **kwargs)