Example #1
0
 def relaxed_for_fixed_z(z):
     # (again, uses common random numbers here)
     g = coupling_util.counterfactual_gumbels(log_ps_given_z[z],
                                              p_observed, rng)
     # Gumbel-softmax instead of argmax here
     soft_q = jax.nn.softmax((g + log_qs_given_z[z]) / temperature)
     return soft_q
Example #2
0
    def counterfactual_sample(self, p_logits, q_logits, p_observed, rng):
        """Sample a single sample from q conditioned on observing p_observed.

    Args:
      p_logits: Logits describing the original distribution of p_observed.
      q_logits: Logits describing a new counterfactual intervention.
      p_observed: Sample index we observed.
      rng: PRNGKey. Sharing this across multiple calls produces an implicit
        coupling.

    Returns:
      Sampled integer index from q_logits, conditioned on observing p_observed
      under p_logits.
    """
        k1, k2 = jax.random.split(rng)
        log_z = self.get_prior()
        log_ps_given_z = self.get_forward(p_logits)
        log_qs_given_z = self.get_forward(q_logits)
        # Infer z from p_observed
        log_z_given_ps = (log_z[:, None] + log_ps_given_z)[:, p_observed]
        z = jax.random.categorical(k1, log_z_given_ps)
        # Infer Gumbels from p_observed and z
        gumbels = coupling_util.counterfactual_gumbels(log_ps_given_z[z],
                                                       p_observed, k2)
        # Choose accordingly
        qs = jnp.argmax(gumbels + log_qs_given_z[z])
        return qs
Example #3
0
    def counterfactual_sample(self, p_logits, q_logits, p_observed, rng):
        """Sample a single sample from q conditioned on observing p_observed.

    Automatically transposes the noise as needed in order to compute the
    counterfactual sample.

    Args:
      p_logits: Logits describing the original distribution of p_observed.
      q_logits: Logits describing a new counterfactual intervention.
      p_observed: Sample index we observed.
      rng: PRNGKey. Sharing this across multiple calls produces an implicit
        coupling.

    Returns:
      Sampled integer index from q_logits, conditioned on observing p_observed
      under p_logits.
    """
        k1, k2 = jax.random.split(rng)
        log_joint_from_p = self.get_joint(p_logits)
        log_joint_from_q = self.get_joint(q_logits).T  # transpose!

        # Sample what p "thought" q was.
        q_hat_from_p = jax.random.categorical(k1, log_joint_from_p[:,
                                                                   p_observed])

        # Sample the argmax under q's estimate, given that this was the argmax under
        # p's estimate
        flat_observed = q_hat_from_p * self.S_dim + p_observed
        log_joint_from_p_flat = jnp.reshape(log_joint_from_p, [-1])
        # log_joint_from_q_flat = jnp.reshape(log_joint_from_q, [-1])
        gumbels_flat = coupling_util.counterfactual_gumbels(
            log_joint_from_p_flat, flat_observed, k2)
        gumbels = gumbels_flat.reshape([self.S_dim, self.S_dim])

        # Take the argmax for q.
        shifted_gumbels_for_q = gumbels + log_joint_from_q
        max_shifted_gumbels_over_p = jnp.max(shifted_gumbels_for_q, axis=1)
        q_sample = jnp.argmax(max_shifted_gumbels_over_p)
        return q_sample