def validate_reco_method(basis_parameters,
                         verification_parameters,
                         base_equations=combination_utils.full_scan_terms,
                         name_suffix='',
                         title_suffix=''):

    reweight_vector = get_amplitude_function(basis_parameters,
                                             as_scalar=False,
                                             base_equations=base_equations)
    #var_edges = numpy.linspace(200, 1200, 31)
    var_edges = numpy.linspace(200, 2000, 55)
    #var_edges = numpy.arange(0, 2050, 50)

    data_files = fileio_utils.read_coupling_file()
    base_events_list = fileio_utils.get_events(basis_parameters, data_files)
    verification_events_list = fileio_utils.get_events(verification_parameters,
                                                       data_files)

    base_histograms = [
        fileio_utils.retrieve_reco_weights(var_edges, base_events)
        for base_events in base_events_list
    ]
    base_weights, base_errors = numpy.array(list(zip(*base_histograms)))

    for verification_events, coupling_parameters in zip(
            verification_events_list, verification_parameters):
        verification_weights, verification_errors = fileio_utils.retrieve_reco_weights(
            var_edges, verification_events)
        combined_weights, combined_errors = reco_reweight(
            reweight_vector, coupling_parameters, base_weights, base_errors)

        plot_histogram('reco_mHH' + name_suffix,
                       'NNT-Based Linear Combination:\n$m_{HH}$' +
                       title_suffix,
                       var_edges,
                       coupling_parameters,
                       combined_weights,
                       combined_errors,
                       verification_weights,
                       verification_errors,
                       xlabel='Reconstructed $m_{HH}$ (GeV)')
def view_reco_method(basis_parameters, view_params):
    reweight_vector = get_amplitude_function(basis_parameters, as_scalar=False)
    var_edges = numpy.linspace(200, 1200, 31)

    data_files = fileio_utils.read_coupling_file()
    base_events_list = fileio_utils.get_events(basis_parameters, data_files)
    base_histograms = [
        fileio_utils.retrieve_reco_weights(var_edges, base_events)
        for base_events in base_events_list
    ]
    base_weights, base_errors = numpy.array(list(zip(*base_histograms)))

    for coupling_parameters in view_params:
        print(coupling_parameters)
        combined_weights, combined_errors = reco_reweight(
            reweight_vector, coupling_parameters, base_weights, base_errors)

        plot_histogram('preview_reco_mHH_new',
                       'NNT-Based Linear Combination:\n$m_{HH}$',
                       var_edges,
                       coupling_parameters,
                       combined_weights,
                       combined_errors,
                       xlabel='Reconstructed $m_{HH}$ (GeV)')
def compare12_reco_method(basis_parameters,
                          k2v_basis_parameters,
                          kl_basis_parameters,
                          verification_parameters,
                          base_equations=combination_utils.full_scan_terms,
                          name_suffix='',
                          title_suffix=''):

    reweight_vector = get_amplitude_function(basis_parameters,
                                             as_scalar=False,
                                             base_equations=base_equations)
    k2v_reweight_vector = get_amplitude_function(
        k2v_basis_parameters,
        as_scalar=False,
        base_equations=combination_utils.k2v_scan_terms)
    kl_reweight_vector = get_amplitude_function(
        kl_basis_parameters,
        as_scalar=False,
        base_equations=combination_utils.kl_scan_terms)

    #var_edges = numpy.linspace(200, 1200, 31)
    var_edges = numpy.linspace(200, 2000, 55)
    #var_edges = numpy.arange(0, 2050, 50)

    data_files = fileio_utils.read_coupling_file()
    base_events_list = fileio_utils.get_events(basis_parameters, data_files)
    k2v_base_events_list = fileio_utils.get_events(k2v_basis_parameters,
                                                   data_files)
    kl_base_events_list = fileio_utils.get_events(kl_basis_parameters,
                                                  data_files)
    verification_events_list = fileio_utils.get_events(verification_parameters,
                                                       data_files)

    base_histograms = [
        fileio_utils.retrieve_reco_weights(var_edges, base_events)
        for base_events in base_events_list
    ]
    base_weights, base_errors = numpy.array(list(zip(*base_histograms)))
    k2v_base_histograms = [
        fileio_utils.retrieve_reco_weights(var_edges, base_events)
        for base_events in k2v_base_events_list
    ]
    k2v_base_weights, k2v_base_errors = numpy.array(
        list(zip(*k2v_base_histograms)))
    kl_base_histograms = [
        fileio_utils.retrieve_reco_weights(var_edges, base_events)
        for base_events in kl_base_events_list
    ]
    kl_base_weights, kl_base_errors = numpy.array(
        list(zip(*kl_base_histograms)))

    for verification_events, coupling_parameters in zip(
            verification_events_list, verification_parameters):
        k2v, kl, kv = coupling_parameters
        if coupling_parameters == (1, 1, 1): continue
        if k2v != 1 and kl != 1: continue
        if kv != 1: continue
        alt_combined_weights, alt_combined_errors = None, None
        if k2v != 1 and kl == 1:
            alt_combined_weights, alt_combined_errors = reco_reweight(
                k2v_reweight_vector, coupling_parameters, k2v_base_weights,
                k2v_base_errors)

        if k2v == 1 and kl != 1:
            alt_combined_weights, alt_combined_errors = reco_reweight(
                kl_reweight_vector, coupling_parameters, kl_base_weights,
                kl_base_errors)

        verification_weights, verification_errors = fileio_utils.retrieve_reco_weights(
            var_edges, verification_events)
        combined_weights, combined_errors = reco_reweight(
            reweight_vector, coupling_parameters, base_weights, base_errors)

        plot_histogram(
            'reco_mHH_1-2D_compare' + name_suffix,
            'NNT-Based Linear Combination:\n$m_{HH}$' + title_suffix,
            var_edges,
            coupling_parameters,
            combined_weights,
            combined_errors,
            verification_weights,
            verification_errors,
            alt_linearly_combined_weights=alt_combined_weights,
            alt_linearly_combined_errors=alt_combined_errors,
            generated_label='3D Combination',
            xlabel='Reconstructed $m_{HH}$ (GeV)',
        )
def compare_bases_reco_method(basis_parameters_list,
                              verification_parameters,
                              base_equations=combination_utils.full_scan_terms,
                              name_suffix='',
                              title_suffix='',
                              labels=('', ''),
                              is_verification=True,
                              truth_level=False,
                              truth_data_files=None):

    #var_edges = numpy.linspace(200, 1200, 31)
    #var_edges = numpy.arange(0, 2050, 50)
    var_edges = numpy.linspace(200, 2000, 55)

    basis_tuple_list = []
    for basis_parameters in basis_parameters_list:
        reweight_vector = get_amplitude_function(basis_parameters,
                                                 as_scalar=False,
                                                 base_equations=base_equations)
        if truth_level:
            data_files = fileio_utils.read_coupling_file(
                coupling_file='basis_files/truth_LHE_couplings_extended.dat')
            basis_files = [
                truth_data_files[coupling] for coupling in basis_parameters
            ]
            truth_weights, truth_errors = fileio_utils.extract_lhe_truth_data(
                basis_files, var_edges)
            basis_tuple_list.append(
                (truth_weights, truth_errors, reweight_vector))
        else:
            data_files = fileio_utils.read_coupling_file()
            base_events_list = fileio_utils.get_events(basis_parameters,
                                                       data_files)
            base_histograms = [
                fileio_utils.retrieve_reco_weights(var_edges, base_events)
                for base_events in base_events_list
            ]
            base_weights, base_errors = numpy.array(list(
                zip(*base_histograms)))
            basis_tuple_list.append(
                (base_weights, base_errors, reweight_vector))

    testpoint_list = verification_parameters
    if is_verification:
        if truth_level:
            verification_files = [
                data_files[key] for key in verification_parameters
            ]
            truth_verification_weights, truth_verification_errors = fileio_utils.extract_lhe_truth_data(
                verification_files, var_edges)
            testpoint_list = zip(verification_parameters,
                                 truth_verification_weights,
                                 truth_verification_errors)
        else:
            testpoint_list = []
            verification_events_list = fileio_utils.get_events(
                verification_parameters, data_files)
            for events, param in zip(verification_events_list,
                                     verification_parameters):
                verification_weights, verification_errors = fileio_utils.retrieve_reco_weights(
                    var_edges, events)
                testpoint_list.append(
                    (param, verification_weights, verification_errors))

    for testpoint in testpoint_list:
        verification_weights, verification_errors = None, None
        if is_verification:
            coupling_parameters, verification_weights, verification_errors = testpoint
        else:
            coupling_parameters = testpoint

        combined_tuples = []
        for base_weights, base_errors, reweight_vector in basis_tuple_list:
            combined_tuples.append(
                reco_reweight(reweight_vector, coupling_parameters,
                              base_weights, base_errors))

        if truth_level:
            name = 'truth_mHH_compare' + name_suffix
            title = 'Truth LHE-Based Linear Combination:\nTruth $m_{HH}$' + title_suffix
            xlabel = 'Truth $m_{HH}$ (GeV)'
        else:
            name = 'reco_mHH_compare' + name_suffix
            title = 'NNT-Based Linear Combination:\n$m_{HH}$' + title_suffix
            xlabel = 'Reconstructed $m_{HH}$ (GeV)'

        plot_histogram(
            name,
            title,
            var_edges,
            coupling_parameters,
            combined_tuples[0][0],
            combined_tuples[0][1],
            verification_weights,
            verification_errors,
            alt_linearly_combined_weights=combined_tuples[1][0],
            alt_linearly_combined_errors=combined_tuples[1][1],
            generated_label=labels[0],
            alt_label=labels[1],
            xlabel=xlabel,
        )
def compare1D3S9S_reco_method(k2v_3S_basis_parameters, k2v_9S_basis_tuple):
    vmin, vmax = 1e-5, 5
    generate_1D9S_pojection_scans(k2v_9S_basis_tuple, vmin, vmax)

    #var_edges = numpy.linspace(200, 1200, 31)
    alt_var_edges = numpy.linspace(200, 1200, 31)
    var_edges = numpy.linspace(200, 2000, 55)
    #var_edges = numpy.arange(0, 2050, 50)
    num_kappa_bins = 10
    k2v_vals = numpy.linspace(-2, 4, num_kappa_bins + 1)
    k2v_vals_alt = numpy.linspace(-2, 4, 100 + 1)

    data_files = fileio_utils.read_coupling_file()
    k2v_3S_reweight_vector = get_amplitude_function(
        k2v_3S_basis_parameters,
        as_scalar=False,
        base_equations=combination_utils.k2v_scan_terms)
    k2v_3S_base_events_list = fileio_utils.get_events(k2v_3S_basis_parameters,
                                                      data_files)
    k2v_3S_base_histograms = [
        fileio_utils.retrieve_reco_weights(var_edges, base_events)
        for base_events in k2v_3S_base_events_list
    ]
    k2v_3S_base_weights, k2v_3S_base_errors = numpy.array(
        list(zip(*k2v_3S_base_histograms)))

    k2v_3S_base_histograms_alt = [
        fileio_utils.retrieve_reco_weights(alt_var_edges, base_events)
        for base_events in k2v_3S_base_events_list
    ]
    k2v_3S_base_weights_alt, k2v_3S_base_errors_alt = numpy.array(
        list(zip(*k2v_3S_base_histograms_alt)))
    draw_1D_mhh_heatmap(k2v_3S_basis_parameters,
                        k2v_3S_base_weights_alt,
                        alt_var_edges,
                        k2v_vals_alt,
                        1,
                        1,
                        base_equations=combination_utils.k2v_scan_terms,
                        which_coupling='k2v',
                        filename='projectionscan_k2v_multicompare',
                        title_suffix='Using Single Basis',
                        vrange=(vmin, vmax))

    multibasis_list = []
    for k2v_list in k2v_9S_basis_tuple[1]:
        basis_parameters = [(k2v, 1, 1) for k2v in k2v_list]
        base_events_list = fileio_utils.get_events(basis_parameters,
                                                   data_files)
        base_histograms = [
            fileio_utils.retrieve_reco_weights(var_edges, base_events)
            for base_events in base_events_list
        ]
        weights, errors = numpy.array(list(zip(*base_histograms)))
        reweight_vector_function = combination_utils.get_amplitude_function(
            basis_parameters,
            as_scalar=False,
            base_equations=combination_utils.k2v_scan_terms)
        multibasis_list.append((weights, errors, reweight_vector_function))

    index_bounds = k2v_9S_basis_tuple[0]
    for k2v in k2v_vals:
        coupling_parameters = [k2v, 1, 1]
        k2v_combined_weights, k2v_combined_errors = reco_reweight(
            k2v_3S_reweight_vector, coupling_parameters, k2v_3S_base_weights,
            k2v_3S_base_errors)

        multibasis_index = None
        if k2v <= index_bounds[0]: multibasis_index = 0
        elif k2v <= index_bounds[1]: multibasis_index = 1
        else: multibasis_index = 2
        multibasis_weights, multibasis_errors, multibasis_reweight_vector_function = multibasis_list[
            multibasis_index]
        multicombined_weights, multicombined_errors = reco_reweight(
            multibasis_reweight_vector_function, coupling_parameters,
            multibasis_weights, multibasis_errors)

        view_linear_combination.plot_histogram(
            'preview_reco_mHH_multibasis',
            'NNT-Based Linear Combination:\n$m_{HH}$',
            var_edges,
            coupling_parameters,
            k2v_combined_weights,
            k2v_combined_errors,
            alt_linearly_combined_weights=multicombined_weights,
            alt_linearly_combined_errors=multicombined_errors,
            alt_label='3-Basis Set',
            generated_label='1-Basis Equation',
            xlabel='Reconstructed $m_{HH}$ (GeV)',
        )