예제 #1
0
def test_plot_quad_logl(analysis, pars, samples, curve_par, test_name):
    """Create plots of test logl curves, comparing direct fits with 'quad' versions"""
    print("pars:", pars)
    print("samples:", samples)

    # We want to fit all samples for all parameters, so we need to make sure the batch shapes
    # can be broadcast against each other. Easiest way is to insert some extra dimensions into
    # both (since they are both 1D batches to start with)
    pars_batch = c.deep_expand_dims(pars, axis=0)
    samples_batch = c.deep_expand_dims(samples, axis=1)

    joint = JointDistribution([analysis], pars_batch)
    log_prob, joint_fitted_nuisance, fitted_pars_nuisance = joint.fit_nuisance(
        samples_batch)

    print("log_prob:", log_prob)
    print("pars_batch:", pars_batch)
    print("fitted_pars_nuisance['fitted']:", fitted_pars_nuisance['fitted'])
    print("fitted_pars_nuisance['fixed']:", fitted_pars_nuisance['fixed'])
    print("pars_batch:", pars_batch)

    # The 'quad' log_prob, expanding about every nuisance BF point for every hypothesis (i.e. NOT what is done in lee-correction)
    # So the following happens:
    # 1. (2,20) shape nuisance parameter fits obtained (2 samples * 20 hypotheses, broadcast against each other)
    #    These become the expansion points in the log_prob_quad evaluation
    # 2. (2,1) shaped samples are provided to create the log_prob_quad_f function
    # 3. (1,20) parameters are provided for log-likelihood evaluation
    #    These are the same parameters used as input to the fits, so should cause evaluation to occur exactly at the expansion points
    # 4. Result has shape (2,20)
    f = joint_fitted_nuisance.log_prob_quad_f(samples_batch)
    log_prob_quad = f(
        pars_batch
    )  #fitted_pars_nuisance['fixed']) # the 'fixed' parameters include the 'signal' ones (EDIT: can just use pars_batch, same thing)

    # The 'quad' log_prob, expanding just once about the global BF point amongst input hypotheses, per sample (i.e. what IS done in lee-correction, more or less. Actually we use a null-hypothesis point rather than the BF, but the point is there is just one expansion point per sample)
    # So the following happens:
    # 1. (2,1) shape nuisance parameter fits obtained (2 samples * 1 hypothesis)
    #    These become the expansion points in the log_prob_quad evaluation
    # 2. (2,1) shaped samples are provided to create the log_prob_quad_f function
    # 3. (1,20) parameters are provided for log-likelihood evaluation
    #    These are DIFFERENT parameters to those used input to the fits, so should cause more non-trivial evaluation of the log_prob_quad
    #    function to occur. Results will be less accurate of course, but this is the "real-world" use case. Will make plots to check accuracy.
    log_prob_g, joint_fitted_all, fitted_pars_all = joint.fit_all(
        samples_batch)
    print("log_prob_g:", log_prob_g)
    print("fitted_pars_all['fitted']:", fitted_pars_all['fitted'])
    print("fitted_pars_all['fixed']:", fitted_pars_all['fixed'])

    f2 = joint_fitted_all.log_prob_quad_f(samples_batch)
    log_prob_quad_2 = f2(pars_batch)

    print("log_prob:", log_prob)
    print("log_prob_quad   (expanded from exact signal points):",
          log_prob_quad)
    print("log_prob_quad_2 (global BF expansion):", log_prob_quad_2)

    # Ok let's make some plots!

    fig = plt.figure()
    ax = fig.add_subplot(111)

    # Plot curve for each sample (0th axis of batch)
    if isinstance(curve_par, str):
        cpar, index = (curve_par, None)
    else:
        try:
            cpar, index = curve_par
        except ValueError as e:
            msg = "Failed to interpret curve 'parameter' specification! Needs to be either a string, or a (string,index) tuple indicating which parameter (and which index if multivariate) is the one that varies for this test!"
            raise ValueError(msg) from e

    if index is None:
        x = pars[analysis.name][cpar]
    else:
        x = pars[analysis.name][cpar][:, index]

    first = True
    for y, y_quad_1, y_quad_2 in zip(log_prob, log_prob_quad, log_prob_quad_2):
        if first:
            ax.plot(x, y, c='k', label="Full numerical profiling")
            ax.plot(
                x,
                y_quad_1,
                c='g',
                ls='--',
                label=
                "\"quad\" expansion at profiled points (i.e. no real expansion done)"
            )
            ax.plot(
                x,
                y_quad_2,
                c='r',
                ls='--',
                label=
                "\"quad\" expansion around single global best fit per sample")
            first = False
        else:
            # No labels this time
            ax.plot(x, y, c='k')
            ax.plot(x, y_quad_1, c='g', ls='--')
            ax.plot(x, y_quad_2, c='r', ls='--')

    ax.set_ylabel("log_prob")
    ax.set_xlabel(curve_par)
    ax.set_title(
        "log_prob_quad curve test for analysis {0}, parameter {1}".format(
            analysis.name, curve_par))
    ax.legend(loc=0, frameon=False, framealpha=0, prop={'size': 10}, ncol=1)
    plt.tight_layout()
    fig.savefig(
        "unit_test_output/log_prob_quad_comparison_{0}.png".format(test_name))
예제 #2
0
# fit
my_sample = {
    'Test normal::x': 4.3,
    'Test normal::x_theta': 0,
    'Test binned::n': [9, 53],
    'Test binned::x': [0, 0]
}
# Convert standard numeric types to TensorFlow objects (must be float32)
#my_sample = {k1: {k2: tf.constant(x,dtype="float32") for k2,x in inner.items()} for k1,inner in my_sample.items()}
fixed_pars = {
    "Test normal": {
        "sigma_t": sig_t
    }
}  # Extra fixed parameter for NormalDist analysis
print("First 'fit_all'")
q, joint_fitted, par_dicts = joint.fit_all(my_sample)  #,fixed_pars)
print("q:", q)
print(par_dicts)


# The output is not so pretty because the parameters are TensorFlow objects
# We can convert them to numpy for better viewing:
def to_numpy(d):
    out = {}
    for k, v in d.items():
        if isinstance(v, dict): out[k] = to_numpy(v)
        else: out[k] = v.numpy()
    return out


print(to_numpy(par_dicts["all"]))