Пример #1
0
    def mixture_sampling(logit):
        """
    Args:
      - logit: [B, 2 * out_dim * num_mix + num_mix]

    Returns:
      - sample: [B, out_dim]
    """
        mean, logit_kappa, logit_pi = tf.split(
            logit,
            num_or_size_splits=[out_dim * num_mix, out_dim * num_mix, num_mix],
            axis=-1,
            name='mix_ivm_coeff_split_sampling')
        mean = tf.reshape(mean, [-1, num_mix, out_dim])
        logit_kappa = tf.reshape(logit_kappa, [-1, num_mix, out_dim])
        kappa = tf.math.softplus(logit_kappa)
        logit_pi = tf.reshape(logit_pi, [-1, num_mix])

        means = tf.unstack(mean, axis=1)
        kappas = tf.unstack(kappa, axis=1)

        mixture = tfd.Mixture(cat=tfd.Categorical(logits=logit_pi),
                              components=[
                                  tfd.Independent(distribution=tfd.VonMises(
                                      loc=loc, concentration=scale),
                                                  reinterpreted_batch_ndims=1)
                                  for loc, scale in zip(means, kappas)
                              ])

        sample = mixture.sample()
        return sample
Пример #2
0
    def mixture_loss(y_true, logit, mask):
        """
    Args:
      - y_true: [B, L, out_dim]
      - logit: [B, L, 2 * out_dim * num_mix + num_mix]
      - mask: [B, L]

    Return:
      - loss
    """
        batch_size, time_step, _ = tf.shape(y_true)
        mean, logit_kappa, logit_pi = tf.split(
            logit,
            num_or_size_splits=[out_dim * num_mix, out_dim * num_mix, num_mix],
            axis=-1,
            name='mix_ivm_coeff_split')

        mask = tf.reshape(mask, [-1])  # [B*L]
        mean = tf.reshape(mean,
                          [-1, num_mix, out_dim])  # [B*L, num_mix, out_dim]
        logit_kappa = tf.reshape(
            logit_kappa, [-1, num_mix, out_dim])  # [B*L, num_mix, out_dim]
        logit_pi = tf.reshape(logit_pi, [-1, num_mix])  # [B*L, num_mix]
        # rescale parameters
        kappa = tf.math.softplus(logit_kappa)

        if use_tfp:
            y_true = tf.reshape(y_true, [-1, out_dim])
            means = tf.unstack(mean, axis=1)
            kappas = tf.unstack(kappa, axis=1)

            mixture = tfd.Mixture(cat=tfd.Categorical(logits=logit_pi),
                                  components=[
                                      tfd.Independent(
                                          distribution=tfd.VonMises(
                                              loc=loc, concentration=scale),
                                          reinterpreted_batch_ndims=1)
                                      for loc, scale in zip(means, kappas)
                                  ])

            loss = -mixture.log_prob(y_true)
        else:
            y_true = tf.reshape(y_true, [-1, 1, out_dim])  # [B*L, 1, out_dim]
            cos_diff = tf.cos(y_true - mean)
            log_probs = tf.reduce_sum(
                -LOGTWOPI - (tf.math.log(tf.math.bessel_i0e(kappa)) + kappa) +
                cos_diff * kappa,
                axis=-1)
            mixed_log_probs = log_probs + tf.nn.log_softmax(logit_pi, axis=-1)
            loss = -tf.reduce_logsumexp(mixed_log_probs, axis=-1)

        loss = tf.multiply(loss, mask, name='masking')

        if reduce:
            return tf.reduce_sum(loss)
        else:
            return tf.reshape(loss, [batch_size, time_step])
Пример #3
0
    def gibbs_sampler(batch_size, mu, kappa, lambda_):
        """
    Gibbs sampling for triviate Von Mises distribution.
      Note: this function is hardcoded for 3-dimensional samples
    """
        x_0 = tf.zeros([batch_size])
        x_1 = tf.zeros([batch_size])
        x_2 = tf.zeros([batch_size])
        samples = []
        for i in tf.range(burn_in + avg_count):
            phi_0 = lambda_[:, 0] * tf.sin(
                x_1 - mu[:, 1]) + lambda_[:, 1] * tf.sin(x_2 - mu[:, 2])
            k_neg_0 = tf.sqrt(kappa[:, 0] * kappa[:, 0] + phi_0 * phi_0)
            mu_neg_0 = mu[:, 0] + tf.atan(phi_0 / kappa[:, 0])
            dist_0 = tfd.VonMises(loc=mu_neg_0, concentration=k_neg_0)
            x_0 = dist_0.sample()

            phi_1 = lambda_[:, 0] * tf.sin(
                x_0 - mu[:, 0]) + lambda_[:, 2] * tf.sin(x_2 - mu[:, 2])
            k_neg_1 = tf.sqrt(kappa[:, 1] * kappa[:, 1] + phi_1 * phi_1)
            mu_neg_1 = mu[:, 1] + tf.atan(phi_1 / kappa[:, 1])
            dist_1 = tfd.VonMises(loc=mu_neg_1, concentration=k_neg_1)
            x_1 = dist_1.sample()

            phi_2 = lambda_[:, 1] * tf.sin(
                x_0 - mu[:, 0]) + lambda_[:, 2] * tf.sin(x_1 - mu[:, 1])
            k_neg_2 = tf.sqrt(kappa[:, 2] * kappa[:, 2] + phi_2 * phi_2)
            mu_neg_2 = mu[:, 2] + tf.atan(phi_2 / kappa[:, 2])
            dist_2 = tfd.VonMises(loc=mu_neg_2, concentration=k_neg_2)
            x_2 = dist_2.sample()

            if i >= burn_in:
                samples.append(tf.stack([x_0, x_1, x_2],
                                        axis=1))  # [avg_count, B, out_dim]

        return tf.reduce_mean(tf.stack(samples), axis=0)
Пример #4
0
 def _init_distribution(conditions, **kwargs):
     loc, concentration = conditions["loc"], conditions["concentration"]
     return tfd.VonMises(loc=loc, concentration=concentration, **kwargs)
Пример #5
0
 def _base_dist(self, mu: TensorLike, kappa: TensorLike, *args, **kwargs):
     return tfd.VonMises(loc=mu, concentration=kappa, *args, **kwargs)