Esempio n. 1
0
def get_mmd_postprop(exp_desc, seed):
    """
    Calculates the MMD for a given Post Prop experiment.
    """

    assert isinstance(exp_desc.inf, ed.PostProp_Descriptor)
    res_file = os.path.join(root, 'results/seed_'+str(seed), exp_desc.get_dir(), 'mmd')

    if os.path.exists(res_file + '.pkl'):
        all_prop_mmds, post_err = util.io.load(res_file)

    else:
        exp_dir = os.path.join(root, 'experiments/seed_'+str(seed), exp_desc.get_dir(), '0')

        print(exp_dir)
        if not os.path.exists(exp_dir):
            raise misc.NonExistentExperiment(exp_desc)

        true_samples = get_true_samples(seed)
        scale = util.math.median_distance(true_samples)

        all_proposals, posterior, _, _ = util.io.load(os.path.join(exp_dir, 'results'))
        all_prop_mmds = []

        for i, proposal in enumerate(all_proposals[1:]):
            samples = proposal.gen(n_mcmc_samples, rng=rng)
            prop_err = two_sample.sq_maximum_mean_discrepancy(samples, true_samples, scale=scale)
            all_prop_mmds.append(prop_err)

        samples = posterior.gen(n_mcmc_samples, rng=rng)
        post_err = two_sample.sq_maximum_mean_discrepancy(samples, true_samples, scale=scale)

        util.io.save((all_prop_mmds, post_err), res_file)

    return all_prop_mmds, post_err
Esempio n. 2
0
def get_mmd_snl(exp_desc, seed):
    """
    Calculates the MMD for a given SNL experiment.
    """

    assert isinstance(exp_desc.inf, ed.SNL_Descriptor)
    res_file = os.path.join(root, 'results/seed_'+str(seed), exp_desc.get_dir(), 'mmd')

    if os.path.exists(res_file + '.pkl'):
        all_mmds = util.io.load(res_file)

    else:
        true_samples = get_true_samples(seed)
        scale = util.math.median_distance(true_samples)
        all_samples = get_samples_snl(exp_desc, seed)

        all_mmds = []

        for samples in all_samples:
            err = two_sample.sq_maximum_mean_discrepancy(samples, true_samples, scale=scale)
            all_mmds.append(err)

        util.io.save(all_mmds, res_file)

    return all_mmds
Esempio n. 3
0
def get_mmd_snpe(exp_desc, seed):
    """
    Calculates the MMD for a given SNPE experiment.
    """

    assert isinstance(exp_desc.inf, ed.SNPE_MDN_Descriptor)
    res_file = os.path.join(root, 'results/seed_'+str(seed), exp_desc.get_dir(), 'mmd')

    if os.path.exists(res_file + '.pkl'):
        all_mmds = util.io.load(res_file)

    else:
        exp_dir = os.path.join(root, 'experiments/seed_'+str(seed), exp_desc.get_dir(), '0')

        if not os.path.exists(exp_dir):
            raise misc.NonExistentExperiment(exp_desc)

        true_samples = get_true_samples(seed)
        scale = util.math.median_distance(true_samples)

        all_posteriors, _, _, _ = util.io.load(os.path.join(exp_dir, 'results'))
        all_mmds = []

        for posterior in all_posteriors[1:]:
            samples = posterior.gen(n_mcmc_samples, rng=rng)
            err = two_sample.sq_maximum_mean_discrepancy(samples, true_samples, scale=scale)
            all_mmds.append(err)

        util.io.save(all_mmds, res_file)

    return all_mmds
Esempio n. 4
0
def get_mmd_sl(exp_desc, seed):
    """
    Calculates the MMD for a given synth likelihood experiment.
    """

    assert isinstance(exp_desc.inf, ed.SynthLik_Descriptor)
    res_file = os.path.join(root, 'results/seed_'+str(seed), exp_desc.get_dir(), 'mmd')

    if os.path.exists(res_file + '.pkl'):
        err, n_sims = util.io.load(res_file)

    else:
        exp_dir = os.path.join(root, 'experiments/seed_'+str(seed), exp_desc.get_dir(), '0')

        if not os.path.exists(exp_dir):
            raise misc.NonExistentExperiment(exp_desc)

        samples, n_sims = util.io.load(os.path.join(exp_dir, 'results'))
        true_samples = get_true_samples(seed)
        scale = util.math.median_distance(true_samples)

        err = two_sample.sq_maximum_mean_discrepancy(samples, true_samples, scale=scale)

        util.io.save((err, n_sims), res_file)

    return err, n_sims
Esempio n. 5
0
def get_mmd_smc(exp_desc, seed):
    """
    Calculates the MMD for a given SMC ABC experiment.
    """

    assert isinstance(exp_desc.inf, ed.SMC_ABC_Descriptor)
    res_file = os.path.join(root, 'results/seed_'+str(seed), exp_desc.get_dir(), 'mmd')

    if os.path.exists(res_file + '.pkl'):
        all_mmds, all_n_sims = util.io.load(res_file)

    else:
        exp_dir = os.path.join(root, 'experiments/seed_'+str(seed), exp_desc.get_dir(), '0')

        if not os.path.exists(exp_dir):
            raise misc.NonExistentExperiment(exp_desc)

        true_samples = get_true_samples(seed)
        scale = util.math.median_distance(true_samples)

        all_samples, all_log_weights, _, _, all_n_sims = util.io.load(os.path.join(exp_dir, 'results'))
        all_mmds = []

        for samples, log_weights in zip(all_samples, all_log_weights):

            weights = np.exp(log_weights)
            err = two_sample.sq_maximum_mean_discrepancy(xs=samples, ys=true_samples, wxs=weights, scale=scale)
            all_mmds.append(err)

        util.io.save((all_mmds, all_n_sims), res_file)

    return all_mmds, all_n_sims
Esempio n. 6
0
def calc_mmd(model):
    """
    Calculates MMD between true samples and a given likelihood model.
    """

    _, true_samples, scale = get_truth()
    samples = model.gen(true_samples.shape[0], rng=np.random.RandomState(42))

    return two_sample.sq_maximum_mean_discrepancy(samples,
                                                  true_samples,
                                                  scale=scale)
Esempio n. 7
0
def calc_mmd_cond(net):
    """
    Calculates MMD between true samples and a given conditional likelihood model.
    """

    true_ps, true_samples, scale = get_truth()
    samples = net.gen(true_ps,
                      true_samples.shape[0],
                      rng=np.random.RandomState(42))

    return two_sample.sq_maximum_mean_discrepancy(samples,
                                                  true_samples,
                                                  scale=scale)
Esempio n. 8
0
def get_mmd_nde(exp_desc, seed):
    """
    Calculates the MMD for a given NDE experiment.
    """

    assert isinstance(exp_desc.inf, ed.NDE_Descriptor)
    res_file = os.path.join(root, 'results/seed_'+str(seed), exp_desc.get_dir(), 'mmd')

    if os.path.exists(res_file + '.pkl'):
        err = util.io.load(res_file)

    else:
        samples = get_samples_nde(exp_desc, seed)
        true_samples = get_true_samples(seed)
        scale = util.math.median_distance(true_samples)
        err = two_sample.sq_maximum_mean_discrepancy(samples, true_samples, scale=scale)

        util.io.save(err, res_file)

    return err