def evaluate_baselines(experiment, seed, num_pairs, samples_per_pair, loop_size=None): """Helper function to evaluate the set of baselines.""" gumbel_max_joint_fn = functools.partial( coupling_util.joint_from_samples, coupling_util.gumbel_max_sampler, num_samples=samples_per_pair, loop_size=loop_size) return { "Independent": evaluate_joint( lambda p, q, _: coupling_util.independent_coupling(p, q), experiment, seed, num_pairs), "ICDF": evaluate_joint( lambda p, q, _: coupling_util.inverse_cdf_coupling(p, q), experiment, seed, num_pairs), "ICDF (permuted)": evaluate_joint( lambda p, q, _: coupling_util.permuted_inverse_cdf_coupling(p, q), experiment, seed, num_pairs), "Gumbel-max": evaluate_joint( gumbel_max_joint_fn, experiment, seed, num_pairs, joint_correction_num_samples=samples_per_pair), }
logits_1 -= jax.scipy.special.logsumexp(logits_1) logits_2 -= jax.scipy.special.logsumexp(logits_2) probs_1 = jnp.exp(logits_1) probs_2 = jnp.exp(logits_2) independent_coupling = probs_1[:, None] * probs_2[None, :] gumbel_max_estimate = coupling_util.joint_from_samples( coupling_util.gumbel_max_sampler, logits_1, logits_2, vis_key, num_joint_samples, loop_size=500) icdf = coupling_util.inverse_cdf_coupling(logits_1, logits_2) icdf_perm = coupling_util.permuted_inverse_cdf_coupling(logits_1, logits_2) couplings = { "Independent": independent_coupling, "ICDF": icdf, "ICDF (permuted)": icdf_perm, "Gumbel-max": gumbel_max_estimate, } for experiment, result in zip(experiments, results): couplings[experiment.name] = coupling_util.joint_from_samples( experiment.build_sampler(result.params), logits_1, logits_2, vis_key, num_joint_samples,