Example #1
0
def evaluate_baselines(experiment,
                       seed,
                       num_pairs,
                       samples_per_pair,
                       loop_size=None):
  """Helper function to evaluate the set of baselines."""
  gumbel_max_joint_fn = functools.partial(
      coupling_util.joint_from_samples,
      coupling_util.gumbel_max_sampler,
      num_samples=samples_per_pair,
      loop_size=loop_size)
  return {
      "Independent":
          evaluate_joint(
              lambda p, q, _: coupling_util.independent_coupling(p, q),
              experiment, seed, num_pairs),
      "ICDF":
          evaluate_joint(
              lambda p, q, _: coupling_util.inverse_cdf_coupling(p, q),
              experiment, seed, num_pairs),
      "ICDF (permuted)":
          evaluate_joint(
              lambda p, q, _: coupling_util.permuted_inverse_cdf_coupling(p, q),
              experiment, seed, num_pairs),
      "Gumbel-max":
          evaluate_joint(
              gumbel_max_joint_fn,
              experiment,
              seed,
              num_pairs,
              joint_correction_num_samples=samples_per_pair),
  }
Example #2
0
  logits_1 -= jax.scipy.special.logsumexp(logits_1)
  logits_2 -= jax.scipy.special.logsumexp(logits_2)

  probs_1 = jnp.exp(logits_1)
  probs_2 = jnp.exp(logits_2)

  independent_coupling = probs_1[:, None] * probs_2[None, :]
  gumbel_max_estimate = coupling_util.joint_from_samples(
      coupling_util.gumbel_max_sampler,
      logits_1,
      logits_2,
      vis_key,
      num_joint_samples,
      loop_size=500)
  icdf = coupling_util.inverse_cdf_coupling(logits_1, logits_2)
  icdf_perm = coupling_util.permuted_inverse_cdf_coupling(logits_1, logits_2)

  couplings = {
      "Independent": independent_coupling,
      "ICDF": icdf,
      "ICDF (permuted)": icdf_perm,
      "Gumbel-max": gumbel_max_estimate,
  }
  for experiment, result in zip(experiments, results):
    couplings[experiment.name] = coupling_util.joint_from_samples(
        experiment.build_sampler(result.params),
        logits_1,
        logits_2,
        vis_key,
        num_joint_samples,