def test_plot_quad_logl(analysis, pars, samples, curve_par, test_name): """Create plots of test logl curves, comparing direct fits with 'quad' versions""" print("pars:", pars) print("samples:", samples) # We want to fit all samples for all parameters, so we need to make sure the batch shapes # can be broadcast against each other. Easiest way is to insert some extra dimensions into # both (since they are both 1D batches to start with) pars_batch = c.deep_expand_dims(pars, axis=0) samples_batch = c.deep_expand_dims(samples, axis=1) joint = JointDistribution([analysis], pars_batch) log_prob, joint_fitted_nuisance, fitted_pars_nuisance = joint.fit_nuisance( samples_batch) print("log_prob:", log_prob) print("pars_batch:", pars_batch) print("fitted_pars_nuisance['fitted']:", fitted_pars_nuisance['fitted']) print("fitted_pars_nuisance['fixed']:", fitted_pars_nuisance['fixed']) print("pars_batch:", pars_batch) # The 'quad' log_prob, expanding about every nuisance BF point for every hypothesis (i.e. NOT what is done in lee-correction) # So the following happens: # 1. (2,20) shape nuisance parameter fits obtained (2 samples * 20 hypotheses, broadcast against each other) # These become the expansion points in the log_prob_quad evaluation # 2. (2,1) shaped samples are provided to create the log_prob_quad_f function # 3. (1,20) parameters are provided for log-likelihood evaluation # These are the same parameters used as input to the fits, so should cause evaluation to occur exactly at the expansion points # 4. Result has shape (2,20) f = joint_fitted_nuisance.log_prob_quad_f(samples_batch) log_prob_quad = f( pars_batch ) #fitted_pars_nuisance['fixed']) # the 'fixed' parameters include the 'signal' ones (EDIT: can just use pars_batch, same thing) # The 'quad' log_prob, expanding just once about the global BF point amongst input hypotheses, per sample (i.e. what IS done in lee-correction, more or less. Actually we use a null-hypothesis point rather than the BF, but the point is there is just one expansion point per sample) # So the following happens: # 1. (2,1) shape nuisance parameter fits obtained (2 samples * 1 hypothesis) # These become the expansion points in the log_prob_quad evaluation # 2. (2,1) shaped samples are provided to create the log_prob_quad_f function # 3. (1,20) parameters are provided for log-likelihood evaluation # These are DIFFERENT parameters to those used input to the fits, so should cause more non-trivial evaluation of the log_prob_quad # function to occur. Results will be less accurate of course, but this is the "real-world" use case. Will make plots to check accuracy. log_prob_g, joint_fitted_all, fitted_pars_all = joint.fit_all( samples_batch) print("log_prob_g:", log_prob_g) print("fitted_pars_all['fitted']:", fitted_pars_all['fitted']) print("fitted_pars_all['fixed']:", fitted_pars_all['fixed']) f2 = joint_fitted_all.log_prob_quad_f(samples_batch) log_prob_quad_2 = f2(pars_batch) print("log_prob:", log_prob) print("log_prob_quad (expanded from exact signal points):", log_prob_quad) print("log_prob_quad_2 (global BF expansion):", log_prob_quad_2) # Ok let's make some plots! fig = plt.figure() ax = fig.add_subplot(111) # Plot curve for each sample (0th axis of batch) if isinstance(curve_par, str): cpar, index = (curve_par, None) else: try: cpar, index = curve_par except ValueError as e: msg = "Failed to interpret curve 'parameter' specification! Needs to be either a string, or a (string,index) tuple indicating which parameter (and which index if multivariate) is the one that varies for this test!" raise ValueError(msg) from e if index is None: x = pars[analysis.name][cpar] else: x = pars[analysis.name][cpar][:, index] first = True for y, y_quad_1, y_quad_2 in zip(log_prob, log_prob_quad, log_prob_quad_2): if first: ax.plot(x, y, c='k', label="Full numerical profiling") ax.plot( x, y_quad_1, c='g', ls='--', label= "\"quad\" expansion at profiled points (i.e. no real expansion done)" ) ax.plot( x, y_quad_2, c='r', ls='--', label= "\"quad\" expansion around single global best fit per sample") first = False else: # No labels this time ax.plot(x, y, c='k') ax.plot(x, y_quad_1, c='g', ls='--') ax.plot(x, y_quad_2, c='r', ls='--') ax.set_ylabel("log_prob") ax.set_xlabel(curve_par) ax.set_title( "log_prob_quad curve test for analysis {0}, parameter {1}".format( analysis.name, curve_par)) ax.legend(loc=0, frameon=False, framealpha=0, prop={'size': 10}, ncol=1) plt.tight_layout() fig.savefig( "unit_test_output/log_prob_quad_comparison_{0}.png".format(test_name))
# fit my_sample = { 'Test normal::x': 4.3, 'Test normal::x_theta': 0, 'Test binned::n': [9, 53], 'Test binned::x': [0, 0] } # Convert standard numeric types to TensorFlow objects (must be float32) #my_sample = {k1: {k2: tf.constant(x,dtype="float32") for k2,x in inner.items()} for k1,inner in my_sample.items()} fixed_pars = { "Test normal": { "sigma_t": sig_t } } # Extra fixed parameter for NormalDist analysis print("First 'fit_all'") q, joint_fitted, par_dicts = joint.fit_all(my_sample) #,fixed_pars) print("q:", q) print(par_dicts) # The output is not so pretty because the parameters are TensorFlow objects # We can convert them to numpy for better viewing: def to_numpy(d): out = {} for k, v in d.items(): if isinstance(v, dict): out[k] = to_numpy(v) else: out[k] = v.numpy() return out print(to_numpy(par_dicts["all"]))