Ejemplo n.º 1
0
class BDTHelpers(object):
    """
    Functions to train a binary class BDT for signal/background separation

    :param data_obj: instance of ROOTHelpers class. containing Dataframes for simulated signal, simulated background, and possibly data
    :type data_obj: ROOTHelpers
    :param train_vars: list of variables to train the BDT with
    :type train_vars: list
    :param train_frac: fraction of events to train the network. Test on 1-train_frac
    :type train_frac: float
    :param eq_train: whether to train with the sum of signal weights scaled equal to the sum of background weights
    :type eq_train: bool
    """
    def __init__(self, data_obj, train_vars, train_frac, eq_train=True):

        self.data_obj = data_obj
        self.train_vars = train_vars
        self.train_frac = train_frac
        self.eq_train = eq_train

        self.X_train = None
        self.y_train = None
        self.train_weights = None
        self.train_weights_eq = None
        self.y_pred_train = None
        self.proc_arr_train = None

        self.X_test = None
        self.y_test = None
        self.test_weights = None
        self.y_pred_test = None
        self.proc_arr_test = None

        #attributes for the hp optmisation and cross validation
        self.clf = xgb.XGBClassifier(objective='binary:logistic',
                                     n_estimators=100,
                                     eta=0.05,
                                     maxDepth=4,
                                     min_child_weight=0.01,
                                     subsample=0.6,
                                     colsample_bytree=0.6,
                                     gamma=1)

        self.hp_grid_rnge = {
            'learning_rate': [0.01, 0.05, 0.1, 0.3],
            'max_depth': [x for x in range(3, 10)],
            #'min_child_weight':[x for x in range(0,3)], #FIXME: probs not appropriate range for a small sumw!
            'gamma': np.linspace(0, 5, 6).tolist(),
            'subsample': [0.5, 0.8, 1.0],
            'n_estimators': [200, 300, 400, 500]
        }

        self.X_folds_train = []
        self.y_folds_train = []
        self.X_folds_validate = []
        self.y_folds_validate = []
        self.w_folds_train = []
        self.w_folds_train_eq = []
        self.w_folds_validate = []
        self.validation_rocs = []

        #attributes for plotting.
        self.plotter = Plotter(data_obj, train_vars)
        self.sig_procs = np.unique(data_obj.mc_df_sig['proc']).tolist()
        self.bkg_procs = np.unique(data_obj.mc_df_bkg['proc']).tolist()
        del data_obj

    def create_X_and_y(self, mass_res_reweight=True):
        """
        Create X train/test and y train/test

        Arguments
        ---------
        mass_res_reweight: bool
           re-weight signal events by 1/sigma(m_ee), in training only.
        """

        mc_df_sig = self.data_obj.mc_df_sig
        mc_df_bkg = self.data_obj.mc_df_bkg

        #add y_target label
        mc_df_sig['y'] = np.ones(self.data_obj.mc_df_sig.shape[0]).tolist()
        mc_df_bkg['y'] = np.zeros(self.data_obj.mc_df_bkg.shape[0]).tolist()

        if self.eq_train:
            if mass_res_reweight:
                #be careful not to change the real weight variable, or test scores will be invalid
                mc_df_sig['MoM_weight'] = (mc_df_sig['weight']) * (
                    1. / mc_df_sig['dielectronSigmaMoM'])
                b_to_s_ratio = np.sum(mc_df_bkg['weight'].values) / np.sum(
                    mc_df_sig['MoM_weight'].values)
                mc_df_sig['eq_weight'] = (mc_df_sig['MoM_weight']) * (
                    b_to_s_ratio)
            else:
                b_to_s_ratio = np.sum(mc_df_bkg['weight'].values) / np.sum(
                    mc_df_sig['weight'].values)
                mc_df_sig['eq_weight'] = (mc_df_sig['weight']) * (b_to_s_ratio)
            mc_df_bkg['eq_weight'] = mc_df_bkg['weight']

            Z_tot = pd.concat([mc_df_sig, mc_df_bkg], ignore_index=True)
            X_train, X_test, train_w, test_w, train_w_eqw, test_w_eqw, y_train, y_test, proc_arr_train, proc_arr_test = train_test_split(
                Z_tot[self.train_vars],
                Z_tot['weight'],
                Z_tot['eq_weight'],
                Z_tot['y'],
                Z_tot['proc'],
                train_size=self.train_frac,
                #test_size=1-self.train_frac,
                shuffle=True,
                random_state=1357)
            #NB: will never test/evaluate with equalised weights. This is explicitly why we set another train weight attribute,
            #    because for overtraining we need to evaluate on the train set (and hence need nominal MC train weights)
            self.train_weights_eq = train_w_eqw.values
        elif mass_res_reweight:
            mc_df_sig['MoM_weight'] = (mc_df_sig['weight']) * (
                1. / mc_df_sig['dielectronSigmaMoM'])
            Z_tot = pd.concat([mc_df_sig, mc_df_bkg], ignore_index=True)
            X_train, X_test, train_w, test_w, train_w_eqw, test_w_eqw, y_train, y_test, proc_arr_train, proc_arr_test = train_test_split(
                Z_tot[self.train_vars],
                Z_tot['weight'],
                Z_tot['MoM_weight'],
                Z_tot['y'],
                Z_tot['proc'],
                train_size=self.train_frac,
                #test_size=1-self.train_frac,
                shuffle=True,
                random_state=1357)
            self.train_weights_eq = train_w_eqw.values
            self.eq_train = True  #use alternate weight in training. could probs rename this to something better
        else:
            print('not applying any reweighting...')
            Z_tot = pd.concat([mc_df_sig, mc_df_bkg], ignore_index=True)
            X_train, X_test, train_w, test_w, y_train, y_test, proc_arr_train, proc_arr_test = train_test_split(
                Z_tot[self.train_vars],
                Z_tot['weight'],
                Z_tot['y'],
                Z_tot['proc'],
                train_size=self.train_frac,
                #test_size=1-self.train_frac,
                shuffle=True,
                random_state=1357)

        self.X_train = X_train.values
        self.y_train = y_train.values
        self.train_weights = train_w.values
        self.proc_arr_train = proc_arr_train

        self.X_test = X_test.values
        self.y_test = y_test.values
        self.test_weights = test_w.values
        self.proc_arr_test = proc_arr_test

        self.X_data_train, self.X_data_test = train_test_split(
            self.data_obj.data_df[self.train_vars],
            train_size=self.train_frac,
            test_size=1 - self.train_frac,
            shuffle=True,
            random_state=1357)

    def create_X_and_y_three_class(self, third_class, mass_res_reweight=True):
        """
        Create X train/test and y train/test for three class BDR

        Arguments
        ---------
        mass_res_reweight: bool
           re-weight signal events by 1/sigma(m_ee), in training only.
        third_class: str
           name third bkg class for the classifier. Remaining classes: 1) all other bkgs, 2) signal
        """

        mc_df_sig = self.data_obj.mc_df_sig
        mc_df_bkg = self.data_obj.mc_df_bkg

        #add y_target label

        bkg_procs_key = [0, 1]
        bkg_procs_mask = []
        bkg_procs_mask.append(mc_df_bkg['proc'].ne(third_class))
        bkg_procs_mask.append(mc_df_bkg['proc'].eq(third_class))
        mc_df_bkg['y'] = np.select(bkg_procs_mask, bkg_procs_key)
        mc_df_sig['y'] = np.full(self.data_obj.mc_df_sig.shape[0], 2).tolist()

        if self.eq_train:
            if mass_res_reweight:
                #be careful not to change the real weight variable, or test scores will be invalid
                mc_df_sig['MoM_weight'] = (mc_df_sig['weight']) * (
                    1. / mc_df_sig['dielectronSigmaMoM'])
                bkg_sumw = np.sum(mc_df_bkg[mc_df_bkg.y == 0]['weight'].values)
                third_class_sumw = np.sum(
                    mc_df_bkg[mc_df_bkg.y == 1]['weight'].values)
                sig_sumw = np.sum(mc_df_sig['MoM_weight'].values)
                mc_df_sig['eq_weight'] = (mc_df_sig['MoM_weight']) * (
                    bkg_sumw / sig_sumw)
                mc_df_bkg.loc[mc_df_bkg.y == 1, 'weight'] = mc_df_bkg.loc[
                    mc_df_bkg.y == 1, 'weight'] * (bkg_sumw / third_class_sumw)
            else:
                #b_to_s_ratio = np.sum(mc_df_bkg['weight'].values)/np.sum(mc_df_sig['weight'].values)
                #mc_df_sig['eq_weight'] = (mc_df_sig['weight']) * (b_to_s_ratio)
                bkg_sumw = np.sum(mc_df_bkg[mc_df_bkg.y == 0]['weight'].values)
                third_class_sumw = np.sum(
                    mc_df_bkg[mc_df_bkg.y == 1]['weight'].values)
                sig_sumw = np.sum(mc_df_sig['weight'].values)
                mc_df_sig['eq_weight'] = (mc_df_sig['weight']) * (bkg_sumw /
                                                                  sig_sumw)
                mc_df_bkg.loc[mc_df_bkg.y == 1, 'weight'] = mc_df_bkg.loc[
                    mc_df_bkg.y == 1, 'weight'] * (bkg_sumw / third_class_sumw)
            mc_df_bkg['eq_weight'] = mc_df_bkg['weight']

            Z_tot = pd.concat([mc_df_sig, mc_df_bkg], ignore_index=True)
            X_train, X_test, train_w, test_w, train_w_eqw, test_w_eqw, y_train, y_test, proc_arr_train, proc_arr_test = train_test_split(
                Z_tot[self.train_vars],
                Z_tot['weight'],
                Z_tot['eq_weight'],
                Z_tot['y'],
                Z_tot['proc'],
                train_size=self.train_frac,
                #test_size=1-self.train_frac,
                shuffle=True,
                random_state=1357)
            #NB: will never test/evaluate with equalised weights. This is explicitly why we set another train weight attribute,
            #    because for overtraining we need to evaluate on the train set (and hence need nominal MC train weights)
            self.train_weights_eq = train_w_eqw.values
        elif mass_res_reweight:
            mc_df_sig['MoM_weight'] = (mc_df_sig['weight']) * (
                1. / mc_df_sig['dielectronSigmaMoM'])
            Z_tot = pd.concat([mc_df_sig, mc_df_bkg], ignore_index=True)
            X_train, X_test, train_w, test_w, train_w_eqw, test_w_eqw, y_train, y_test, proc_arr_train, proc_arr_test = train_test_split(
                Z_tot[self.train_vars],
                Z_tot['weight'],
                Z_tot['MoM_weight'],
                Z_tot['y'],
                Z_tot['proc'],
                train_size=self.train_frac,
                #test_size=1-self.train_frac,
                shuffle=True,
                random_state=1357)
            self.train_weights_eq = train_w_eqw.values
            self.eq_train = True  #use alternate weight in training. could probs rename this to something better
        else:
            print('not applying any reweighting...')
            Z_tot = pd.concat([mc_df_sig, mc_df_bkg], ignore_index=True)
            X_train, X_test, train_w, test_w, y_train, y_test, proc_arr_train, proc_arr_test = train_test_split(
                Z_tot[self.train_vars],
                Z_tot['weight'],
                Z_tot['y'],
                Z_tot['proc'],
                train_size=self.train_frac,
                #test_size=1-self.train_frac,
                shuffle=True,
                random_state=1357)
        self.X_train = X_train.values
        self.y_train = y_train.values
        self.train_weights = train_w.values
        self.proc_arr_train = proc_arr_train

        self.X_test = X_test.values
        self.y_test = y_test.values
        self.test_weights = test_w.values
        self.proc_arr_test = proc_arr_test

        self.X_data_train, self.X_data_test = train_test_split(
            self.data_obj.data_df[self.train_vars],
            train_size=self.train_frac,
            test_size=1 - self.train_frac,
            shuffle=True,
            random_state=1357)

    def train_classifier(self, file_path, save=False, model_name='my_model'):
        """
        Train the BDT and save the resulting classifier

        Arguments
        ---------
        file_path: string
            base file path used when saving model
        save: bool
            option to save the classifier
        model_name: string
            name of the model to be saved
        """

        if self.eq_train: train_weights = self.train_weights_eq
        else: train_weights = self.train_weights

        print('Training classifier... ')
        clf = self.clf.fit(self.X_train,
                           self.y_train,
                           sample_weight=train_weights)
        print('Finished Training classifier!')
        self.clf = clf

        Utils.check_dir(os.getcwd() + '/models')
        if save:
            pickle.dump(
                clf,
                open("{}/models/{}.pickle.dat".format(os.getcwd(), model_name),
                     "wb"))
            print("Saved classifier as: {}/models/{}.pickle.dat".format(
                os.getcwd(), model_name))

    def batch_gs_cv(self, k_folds=3, pt_rew=False):
        """
        Submit a sets of hyperparameters permutations (based on attribute hp_grid_rnge) to the IC batch.
        Perform k-fold cross validation; take care to separate training weights, which
        may be modified w.r.t nominal weights, and the weights used when evaluating on the
        validation set which should be the nominal weights

        Arguments
        ---------
        k_folds: int
            number of folds that the training+validation set are partitioned into
        """

        #get all possible HP sets from permutations of the above dict
        hp_perms = self.get_hp_perms()

        #submit job to the batch for the given HP range:
        for hp_string in hp_perms:
            Utils.sub_hp_script(self.eq_train, hp_string, k_folds, pt_rew)

    def get_hp_perms(self):
        """
        Return a list of all possible hyper parameter combinations (permutation template set in constructor) in format 'hp1:val1,hp2:val2, ...'

        Returns
        -------
        final_hps: all possible combinaions of hyper parameters given in self.hp_grid_rnge
        """

        from itertools import product

        hp_perms = [
            perm for perm in apply(product, self.hp_grid_rnge.values())
        ]
        final_hps = []
        counter = 0
        for hp_perm in hp_perms:
            l_entry = ''
            for hp_name, hp_value in zip(self.hp_grid_rnge.keys(), hp_perm):
                l_entry += '{}:{},'.format(hp_name, hp_value)
                counter += 1
                if (counter % len(self.hp_grid_rnge.keys())) == 0:
                    final_hps.append(l_entry[:-1])
        return final_hps

    def set_hyper_parameters(self, hp_string):
        """
        Set a given set hyper-parameters for the classifier. Append the resulting classifier as a class attribute

        Arguments
        --------
        hp_string: string
            string of hyper-parameter values, with format 'hp1:val1,hp2:val2, ...'
        """

        hp_dict = {}
        for params in hp_string.split(','):
            hp_name = params.split(':')[0]
            hp_value = params.split(':')[1]
            try:
                hp_value = int(hp_value)
            except ValueError:
                hp_value = float(hp_value)
            hp_dict[hp_name] = hp_value
        self.clf = xgb.XGBClassifier(**hp_dict)

    def set_k_folds(self, k_folds):
        """
        Partition the X and Y matrix into folds = k_folds, and append to list (X and y separate) attribute for the class, from the training samples (i.e. X_train -> X_train + X_validate, and same for y and w)
        Used in conjunction with the get_i_fold function to pull one fold out for training+validating
        Note that validation weights should always be the nominal MC weights

        Arguments
        --------
        k_folds: int
            number of folds that the training+validation set are partitioned into
        """

        kf = KFold(n_splits=k_folds)
        for train_index, valid_index in kf.split(self.X_train):
            self.X_folds_train.append(self.X_train[train_index])
            self.X_folds_validate.append(self.X_train[valid_index])

            self.y_folds_train.append(self.y_train[train_index])
            self.y_folds_validate.append(self.y_train[valid_index])

            #deal with two possible train weight scenarios
            self.w_folds_train.append(self.train_weights[train_index])
            if self.eq_train:
                self.w_folds_train_eq.append(
                    self.train_weights_eq[train_index])

            self.w_folds_validate.append(self.train_weights[valid_index])

    def set_i_fold(self, i_fold):
        """
        Gets the training and validation fold for a given CV iteration from class attribute,
        and overwrites the self.X_train, self.y_train and self.X_train, self.y_train respectively, and the weights, to train
        Note that for these purposes, our "test" sets are really the "validation" sets

        Arguments
        --------
        i_folds: int
            the index describing the train+validate datasets
        """

        self.X_train = self.X_folds_train[i_fold]
        self.train_weights = self.w_folds_train[
            i_fold]  #nominal MC weights needed for computing roc on train set (overtraining test)
        if self.eq_train:
            self.train_weights_eq = self.w_folds_train_eq[i_fold]
        self.y_train = self.y_folds_train[i_fold]

        self.X_test = self.X_folds_validate[i_fold]
        self.y_test = self.y_folds_validate[i_fold]
        self.test_weights = self.w_folds_validate[i_fold]

    def compare_rocs(self, roc_file, hp_string):
        """
        Compare the AUC for the current model, to the current best AUC saved in a .txt file 
        Arguments
        ---------
        roc_file: string
            path for the file holding the current best AUC (as the final line)
        hp_string: string
            string contraining each hyper_paramter for the network, with the following syntax: 'hp_1_name:hp_1_value, hp_2_name:hp_2_value, ...'
        """

        hp_roc = roc_file.readlines()
        avg_val_auc = np.average(self.validation_rocs)
        print('avg. validation roc is: {}'.format(avg_val_auc))
        if len(hp_roc) == 0:
            roc_file.write('{};{:.4f}'.format(hp_string, avg_val_auc))
        elif float(hp_roc[-1].split(';')[-1]) < avg_val_auc:
            roc_file.write('\n')
            roc_file.write('{};{:.4f}'.format(hp_string, avg_val_auc))

    def compute_roc(self):
        """
        Compute the area under the associated ROC curve, with mc weights. Also compute with blinded data as bkg

        Returns
        -------
        roc_auc_score: float
            area under the roc curve evluated on test set
        """

        self.y_pred_train = self.clf.predict_proba(self.X_train)[:, 1:]
        print('Area under ROC curve for train set is: {:.4f}'.format(
            roc_auc_score(self.y_train,
                          self.y_pred_train,
                          sample_weight=self.train_weights)))

        self.y_pred_test = self.clf.predict_proba(self.X_test)[:, 1:]
        print('Area under ROC curve for test set is: {:.4f}'.format(
            roc_auc_score(self.y_test,
                          self.y_pred_test,
                          sample_weight=self.test_weights)))

        #get auc for bkg->data
        sig_y_pred_test = self.y_pred_test[self.y_test == 1]
        sig_weights_test = self.test_weights[self.y_test == 1]
        sig_y_true_test = self.y_test[self.y_test == 1]
        data_weights_test = np.ones(self.X_data_test.values.shape[0])
        data_y_true_test = np.zeros(self.X_data_test.values.shape[0])
        data_y_pred_test = self.clf.predict_proba(self.X_data_test.values)[:,
                                                                           1:]
        print('Area under ROC curve with data as bkg is: {:.4f}'.format(
            roc_auc_score(np.concatenate((sig_y_true_test, data_y_true_test),
                                         axis=None),
                          np.concatenate((sig_y_pred_test, data_y_pred_test),
                                         axis=None),
                          sample_weight=np.concatenate(
                              (sig_weights_test, data_weights_test),
                              axis=None))))

        return roc_auc_score(self.y_test,
                             self.y_pred_test,
                             sample_weight=self.test_weights)

    def compute_roc_three_class(self, third_class):
        """
        Compute the area under the associated ROC curves for three class problem, with mc weights. Also compute with blinded data as bkg

        """

        self.y_pred_train = self.clf.predict_proba(self.X_train)
        self.y_pred_test = self.clf.predict_proba(self.X_test)

        sig_y_train = np.where(self.y_train == 2, 1, 0)
        sig_y_test = np.where(self.y_test == 2, 1, 0)

        bkg_y_train = np.where(self.y_train > 0, 0, 1)
        bkg_y_test = np.where(self.y_test > 0, 0, 1)

        third_class_y_train = np.where(self.y_train == 1, 1, 0)
        third_class_y_test = np.where(self.y_test == 1, 1, 0)

        print('area under roc curve for training set (sig vs rest) = %1.3f' %
              (roc_auc_score(sig_y_train,
                             self.y_pred_train[:, 2],
                             sample_weight=self.train_weights)))
        print('area under roc curve for test set = %1.3f \n' %
              (roc_auc_score(sig_y_test,
                             self.y_pred_test[:, 2],
                             sample_weight=self.test_weights)))
        print('area under roc curve for training set (bkg vs rest) = %1.3f' %
              (roc_auc_score(bkg_y_train,
                             self.y_pred_train[:, 0],
                             sample_weight=self.train_weights)))
        print('area under roc curve for test set = %1.3f \n' %
              (roc_auc_score(bkg_y_test,
                             self.y_pred_test[:, 0],
                             sample_weight=self.test_weights)))
        print(
            'area under roc curve for training set (third class vs rest) = %1.3f'
            % (roc_auc_score(third_class_y_train,
                             self.y_pred_train[:, 1],
                             sample_weight=self.train_weights)))
        print('area under roc curve for test set = %1.3f' %
              (roc_auc_score(third_class_y_test,
                             self.y_pred_test[:, 1],
                             sample_weight=self.test_weights)))

        #get auc for bkg->data
        #sig_y_pred_test  = self.y_pred_test[self.y_test==1]
        #sig_weights_test = self.test_weights[self.y_test==1]
        #sig_y_true_test  = self.y_test[self.y_test==1]
        #data_weights_test = np.ones(self.X_data_test.values.shape[0])
        #data_y_true_test  = np.zeros(self.X_data_test.values.shape[0])
        #data_y_pred_test  = self.clf.predict_proba(self.X_data_test.values)[:,1:]
        #print 'Area under ROC curve with data as bkg is: {:.4f}'.format(roc_auc_score( np.concatenate((sig_y_true_test, data_y_true_test), axis=None),
        #                                                                               np.concatenate((sig_y_pred_test, data_y_pred_test), axis=None),
        #                                                                               sample_weight=np.concatenate((sig_weights_test, data_weights_test), axis=None)
        #                                                                             )
        #                                                               )

    def plot_roc(self, out_tag):
        """
        Method to plot the roc curve, using method from Plotter() class
        """

        roc_fig = self.plotter.plot_roc(self.y_train,
                                        self.y_pred_train,
                                        self.train_weights,
                                        self.y_test,
                                        self.y_pred_test,
                                        self.test_weights,
                                        out_tag=out_tag)

        Utils.check_dir('{}/plotting/plots/{}'.format(os.getcwd(), out_tag))
        roc_fig.savefig('{0}/plotting/plots/{1}/{1}_ROC_curve.pdf'.format(
            os.getcwd(), out_tag))
        print('saving: {0}/plotting/plots/{1}/{1}_ROC_curve.pdf'.format(
            os.getcwd(), out_tag))
        plt.close()

    def plot_output_score(self,
                          out_tag,
                          ratio_plot=False,
                          norm_to_data=False,
                          log=False):
        """
        Plot the output score for the classifier, for signal, background, and data

        Arguments
        ---------
        out_tag: string
            output tag used as part of the image name, when saving
        ratio_plot: bool
            whether to plot the ratio between simulated background and data
        norm_to_data: bool
            whether to normalise the integral of the simulated background, to the integral in data
        """

        output_score_fig = self.plotter.plot_output_score(
            self.y_test,
            self.y_pred_test,
            self.test_weights,
            self.proc_arr_test,
            self.clf.predict_proba(self.X_data_test.values)[:, 1:],
            ratio_plot=ratio_plot,
            norm_to_data=norm_to_data,
            log=log)

        Utils.check_dir('{}/plotting/plots/{}'.format(os.getcwd(), out_tag))
        output_score_fig.savefig(
            '{0}/plotting/plots/{1}/{1}_output_score.pdf'.format(
                os.getcwd(), out_tag))
        print('saving: {0}/plotting/plots/{1}/{1}_output_score.pdf'.format(
            os.getcwd(), out_tag))
        plt.close()

    def plot_output_score_three_class(self,
                                      out_tag,
                                      ratio_plot=False,
                                      norm_to_data=False,
                                      log=False,
                                      third_class=''):
        """
        Plot the output score for the classifier, for signal, background, and data

        Arguments
        ---------
        out_tag: string
            output tag used as part of the image name, when saving
        ratio_plot: bool
            whether to plot the ratio between simulated background and data
        norm_to_data: bool
            whether to normalise the integral of the simulated background, to the integral in data
        """

        #class_id = {'Background':0, 'Third_Class':1 ,'Signal':2}
        class_id = {'Other_backgrounds': 0, 'VBF_Z': 1, 'VBF_Signal': 2}
        for clf_class, _id in class_id.iteritems():
            #plot all processes for each predicted class
            y_pred_test = self.y_pred_test[:, _id]
            output_score_fig = self.plotter.plot_output_score_three_class(
                self.y_test,
                y_pred_test,
                self.test_weights,
                norm_to_data=norm_to_data,
                log=log,
                clf_class=clf_class)

            Utils.check_dir('{}/plotting/plots/{}'.format(
                os.getcwd(), out_tag))
            output_score_fig.savefig(
                '{0}/plotting/plots/{1}/{1}_output_score_clf_class_{2}.pdf'.
                format(os.getcwd(), out_tag, clf_class))
            print(
                'saving: {0}/plotting/plots/{1}/{1}_output_score_clf_class_{2}.pdf'
                .format(os.getcwd(), out_tag, clf_class))
            plt.close()
Ejemplo n.º 2
0
class LSTM_DNN(object):
    """ 
    Class for training a DNN that uses LSTM and fully connected layers

    :param data_obj: instance of ROOTHelpers class. containing Dataframes for simulated signal, simulated background, and possibly data
    :type data_obj: ROOTHelpers
    :param low_level_vars: 2d list of low-level objects used as inputs to LSTM network layers
    :type low_level_vars: list
    :param high_level_vars: 1d list of high-level objects used as inputs to fully connected network layers
    :type high_level_vars: list
    :param train_frac: fraction of events to train the network. Test on 1-train_frac
    :type train_frac: float
    :param eq_weights: whether to train with the sum of signal weights scaled equal to the sum of background weights
    :type eq_weights: bool
    :param batch_boost: option to increase batch size based on ROC improvement. Needed for submitting to IC computing batch in hyper-parameter optimisation
    :type batch_boost: bool

    """
    def __init__(self,
                 data_obj,
                 low_level_vars,
                 high_level_vars,
                 train_frac,
                 eq_weights=True,
                 batch_boost=False):
        self.data_obj = data_obj
        self.low_level_vars = low_level_vars
        self.low_level_vars_flat = [
            var for sublist in low_level_vars for var in sublist
        ]
        self.high_level_vars = high_level_vars
        self.train_frac = train_frac
        self.batch_boost = batch_boost  #needed for HP opt
        self.eq_train = eq_weights
        self.max_epochs = 100

        self.X_tot = None
        self.y_tot = None

        self.X_train_low_level = None
        self.X_train_high_level = None
        self.y_train = None
        self.train_weights = None
        self.train_eqw = None
        self.proc_arr_train = None
        self.y_pred_train = None

        self.X_test_low_level = None
        self.X_test_high_level = None
        self.y_test = None
        self.test_weights = None
        self.proc_arr_test = None
        self.y_pred_test = None

        self.X_train_low_level = None
        self.X_valid_low_level = None
        self.y_valid = None
        self.valid_weights = None

        self.X_data_train_low_level = None
        self.X_data_train_high_level = None

        self.X_data_test_low_level = None
        self.X_data_test_high_level = None

        # high complex
        #self.set_model(n_lstm_layers=3, n_lstm_nodes=150, n_dense_1=2, n_nodes_dense_1=300,
        #               n_dense_2=3, n_nodes_dense_2=200, dropout_rate=0.3,
        #               learning_rate=0.001, batch_norm=True, batch_momentum=0.99)

        # med complex
        self.set_model(n_lstm_layers=2,
                       n_lstm_nodes=50,
                       n_dense_1=2,
                       n_nodes_dense_1=50,
                       n_dense_2=2,
                       n_nodes_dense_2=20,
                       dropout_rate=0.25,
                       learning_rate=0.001,
                       batch_norm=True,
                       batch_momentum=0.99)

        #ggH model that learns the m_ee
        #self.set_model(n_lstm_layers=3, n_lstm_nodes=50, n_dense_1=3, n_nodes_dense_1=150,
        #               n_dense_2=2, n_nodes_dense_2=100, dropout_rate=0.1,
        #               learning_rate=0.001, batch_norm=True, batch_momentum=0.99)

        #self.set_model(n_lstm_layers=2, n_lstm_nodes=50, n_dense_1=2, n_nodes_dense_1=100,
        #               n_dense_2=1, n_nodes_dense_2=50, dropout_rate=0.25,
        #               learning_rate=0.001, batch_norm=True, batch_momentum=0.99)

        # simple
        #self.set_model(n_lstm_layers=1, n_lstm_nodes=20, n_dense_1=2, n_nodes_dense_1=20,
        #               n_dense_2=1, n_nodes_dense_2=10, dropout_rate=0.2,
        #               learning_rate=0.001, batch_norm=False, batch_momentum=0.99)

        #self.hp_grid_rnge           = {'n_lstm_layers': [1,2,3], 'n_lstm_nodes':[100,150,200],
        #                               'n_dense_1':[1,2,3], 'n_nodes_dense_1':[100,200,300],
        #                               'n_dense_2':[1,2,3,4], 'n_nodes_dense_2':[100,200,300],
        #                               'dropout_rate':[0.1,0.2,0.3]
        #                              }

        self.hp_grid_rnge = {
            'n_lstm_layers': [1, 2, 3],
            'n_lstm_nodes': [50, 100, 150],
            'n_dense_1': [1, 2, 3],
            'n_nodes_dense_1': [50, 100, 150],
            'n_dense_2': [1, 2, 3, 4],
            'n_nodes_dense_2': [50, 100, 150],
            'dropout_rate': [0.1, 0.2]
        }

        #assign plotter attribute before data_obj is deleted for mem
        self.plotter = Plotter(data_obj,
                               self.low_level_vars_flat + self.high_level_vars)
        del data_obj

    def var_transform(self, do_data=False):
        """
        Apply natural log to GeV variables, unless variable value is empty. Do this for signal, background, and potentially data
        Note that the lead and sublead jet replacements make no difference for VBF, but in evaluating scores on ggH samples, we normally
        have 0J events; hence all jet varibles need replacing
        
        Arguments
        ---------
        do_data : bool
            whether to apply the transforms to X_train in data. Used if plotting the DNN output score distribution in data
        
        """

        empty_vars = [
            'leadJetEn', 'leadJetPt', 'leadJetPhi', 'leadJetEta', 'leadJetQGL',
            'subleadJetEn', 'subleadJetPt', 'subleadJetPhi', 'subleadJetEta',
            'subleadJetQGL', 'subsubleadJetEn', 'subsubleadJetPt',
            'subsubleadJetPhi', 'subsubleadJetEta', 'subsubleadJetQGL',
            'dijetMinDRJetEle', 'dijetDieleAbsDEta', 'dijetDieleAbsDPhiTrunc',
            'dijetCentrality', 'dijetMass', 'dijetAbsDEta', 'dijetDPhi'
        ]

        replacement_value = -10

        for empty_var in empty_vars:
            self.data_obj.mc_df_sig[empty_var] = self.data_obj.mc_df_sig[
                empty_var].replace(-999., replacement_value)
            self.data_obj.mc_df_bkg[empty_var] = self.data_obj.mc_df_bkg[
                empty_var].replace(-999., replacement_value)
            if do_data:
                self.data_obj.data_df[empty_var] = self.data_obj.data_df[
                    empty_var].replace(-999., replacement_value)

        #print self.data_obj.mc_df_sig[empty_vars]
        #print np.isnan(self.data_obj.mc_df_sig[empty_vars]).any()

        for var in gev_vars:
            if var in (self.low_level_vars_flat + self.high_level_vars):
                self.data_obj.mc_df_sig[var] = self.data_obj.mc_df_sig.apply(
                    self.var_transform_helper,
                    axis=1,
                    args=[var, replacement_value])
                self.data_obj.mc_df_bkg[var] = self.data_obj.mc_df_bkg.apply(
                    self.var_transform_helper,
                    axis=1,
                    args=[var, replacement_value])
                if do_data:
                    self.data_obj.data_df[var] = self.data_obj.data_df.apply(
                        self.var_transform_helper,
                        axis=1,
                        args=[var, replacement_value])

        #print np.isnan(self.data_obj.mc_df_sig[empty_vars]).any()

    def var_transform_helper(self, row, var, replacement_value):
        """
        Helper function to decide whether to transform variable. 

        Arguments
        ---------
        row : pandas Series
            pandas series object for yielded by pandas apply(). Contains per event information
        var : string
            name of the variable we are considering for transform
        """

        if row[var] == replacement_value: return row[var]
        else: return np.log(row[var])

    def create_X_y(self, mass_res_reweight=True):
        """
        Create X and y matrices to be used later for training and testing. 

        Arguments
        ---------
        mass_res_reweight: bool 
            re-weight signal events by 1/sigma(m_ee), in training only. Currently only implemented if also equalising weights,

        Returns
        --------
        X_tot: pandas dataframe of both low-level and high-level featues. Low-level features are returned as 1D columns.
        y_tot: numpy ndarray of the target column (1 for signal, 0 for background)
        """

        if self.eq_train:
            if mass_res_reweight:
                self.data_obj.mc_df_sig['MoM_weight'] = (
                    self.data_obj.mc_df_sig['weight']) * (
                        1. / self.data_obj.mc_df_sig['dielectronSigmaMoM'])
                b_to_s_ratio = np.sum(
                    self.data_obj.mc_df_bkg['weight'].values) / np.sum(
                        self.data_obj.mc_df_sig['MoM_weight'].values)
                self.data_obj.mc_df_sig['eq_weight'] = (
                    self.data_obj.mc_df_sig['MoM_weight']) * (b_to_s_ratio)
            else:
                b_to_s_ratio = np.sum(
                    self.data_obj.mc_df_bkg['weight'].values) / np.sum(
                        self.data_obj.mc_df_sig['weight'].values)
                self.data_obj.mc_df_sig['eq_weight'] = self.data_obj.mc_df_sig[
                    'weight'] * b_to_s_ratio
            self.data_obj.mc_df_bkg['eq_weight'] = self.data_obj.mc_df_bkg[
                'weight']

        self.data_obj.mc_df_sig.reset_index(drop=True, inplace=True)
        self.data_obj.mc_df_bkg.reset_index(drop=True, inplace=True)
        X_tot = pd.concat([self.data_obj.mc_df_sig, self.data_obj.mc_df_bkg],
                          ignore_index=True)

        #add y_target label (1 for signal, 0 for background). Keep separate from X-train until after Z-scaling
        y_sig = np.ones(self.data_obj.mc_df_sig.shape[0])
        y_bkg = np.zeros(self.data_obj.mc_df_bkg.shape[0])
        y_tot = np.concatenate((y_sig, y_bkg))

        return X_tot, y_tot

    def split_X_y(self, X_tot, y_tot, do_data=False):
        """
        Split X and y matrices into a set for training the LSTM, and testing set to evaluate model performance

        Arguments
        ---------
        X_tot: pandas Dataframe
            pandas dataframe of both low-level and high-level featues. Low-level features are returned as 1D columns.
        y_tot: numpy ndarray 
            numpy ndarray of the target column (1 for signal, 0 for background)
        do_data : bool
            whether to form a test (and train) dataset in data, to use for plotting
        """

        if not self.eq_train:
            self.all_vars_X_train, self.all_vars_X_test, self.train_weights, self.test_weights, self.y_train, self.y_test, self.proc_arr_train, self.proc_arr_test = train_test_split(
                X_tot[self.low_level_vars_flat + self.high_level_vars],
                X_tot['weight'],
                y_tot,
                X_tot['proc'],
                train_size=self.train_frac,
                shuffle=True,
                random_state=1357)
        else:
            self.all_vars_X_train, self.all_vars_X_test, self.train_weights, self.test_weights, self.train_eqw, self.test_eqw, self.y_train, self.y_test, self.proc_arr_train, self.proc_arr_test = train_test_split(
                X_tot[self.low_level_vars_flat + self.high_level_vars],
                X_tot['weight'],
                X_tot['eq_weight'],
                y_tot,
                X_tot['proc'],
                train_size=self.train_frac,
                shuffle=True,
                random_state=1357)
            self.train_weights_eq = self.train_eqw.values

        if do_data:  #for plotting purposes
            self.all_X_data_train, self.all_X_data_test = train_test_split(
                self.data_obj.data_df[self.low_level_vars_flat +
                                      self.high_level_vars],
                train_size=self.train_frac,
                shuffle=True,
                random_state=1357)

    def get_X_scaler(self, X_train, out_tag='lstm_scaler', save=True):
        """
        Derive transform on X features to give to zero mean and unit std. Derive on train set. Save for use later

        Arguments
        ---------
        X_train : Dataframe/ndarray
            training matrix on which to derive the transform
        out_tag : string
           output tag from the configuration file for the wrapper script e.g. LSTM_DNN
        """

        X_scaler = StandardScaler()
        X_scaler.fit(X_train.values)
        self.X_scaler = X_scaler
        if save:
            print('saving X scaler: models/{}_X_scaler.pkl'.format(out_tag))
            dump(X_scaler, open('models/{}_X_scaler.pkl'.format(out_tag),
                                'wb'))

    def load_X_scaler(self, out_tag='lstm_scaler'):
        """
        Load X feature scaler, where the transform has been derived from training sample

        Arguments
        ---------
        out_tag : string
           output tag from the configuration file for the wrapper script e.g. LSTM_DNN
        """

        print('loading X scaler: models/{}_X_scaler.pkl'.format(out_tag))
        self.X_scaler = load(
            open('models/{}_X_scaler.pkl'.format(out_tag), 'rb'))

    def X_scale_train_test(self, do_data=False):
        """ 
        Scale train and test X matrices to give zero mean and unit std. Annoying conversions between numpy <-> pandas but necessary for keeping feature names

        Arguments
        ---------
        do_data : bool
            whether to scale test (and train) dataset in data, to use for plotting
        """

        X_scaled_all_vars_train = self.X_scaler.transform(
            self.all_vars_X_train
        )  #returns np array so need to re-cast into pandas to get colums/variables
        X_scaled_all_vars_train = pd.DataFrame(
            X_scaled_all_vars_train,
            columns=self.low_level_vars_flat + self.high_level_vars)
        self.X_train_low_level = X_scaled_all_vars_train[
            self.
            low_level_vars_flat].values  #will get changed to 2D arrays later
        self.X_train_high_level = X_scaled_all_vars_train[
            self.high_level_vars].values

        X_scaled_all_vars_test = self.X_scaler.transform(
            self.all_vars_X_test)  #important to use scaler tuned on X train
        X_scaled_all_vars_test = pd.DataFrame(
            X_scaled_all_vars_test,
            columns=self.low_level_vars_flat + self.high_level_vars)
        self.X_test_low_level = X_scaled_all_vars_test[
            self.
            low_level_vars_flat].values  #will get changed to 2D arrays later
        self.X_test_high_level = X_scaled_all_vars_test[
            self.high_level_vars].values

        if do_data:  #for plotting purposes
            X_scaled_data_all_vars_train = self.X_scaler.transform(
                self.all_X_data_train)
            X_scaled_data_all_vars_train = pd.DataFrame(
                X_scaled_data_all_vars_train,
                columns=self.low_level_vars_flat + self.high_level_vars)
            self.X_data_train_high_level = X_scaled_data_all_vars_train[
                self.high_level_vars].values
            self.X_data_train_low_level = X_scaled_data_all_vars_train[
                self.low_level_vars_flat].values

            X_scaled_data_all_vars_test = self.X_scaler.transform(
                self.all_X_data_test)
            X_scaled_data_all_vars_test = pd.DataFrame(
                X_scaled_data_all_vars_test,
                columns=self.low_level_vars_flat + self.high_level_vars)
            self.X_data_test_high_level = X_scaled_data_all_vars_test[
                self.high_level_vars].values
            self.X_data_test_low_level = X_scaled_data_all_vars_test[
                self.low_level_vars_flat].values

    def set_low_level_2D_test_train(self, do_data=False, ignore_train=False):
        """
        Transform the 1D low-level variables into 2D variables, and overwrite corresponding class atributes

        Arguments
        ---------
        do_data : bool
            whether to scale test (and train) dataset in data, to use for plotting
        ignore_train: bool
            do not join 2D train objects. Useful if we want to keep low level as a 1D array when splitting train --> train+validate,
            since we want to do a 2D transform on 1D sequence on the rseulting train and validation sets.
        """

        if not ignore_train:
            self.X_train_low_level = self.join_objects(self.X_train_low_level)
        self.X_test_low_level = self.join_objects(self.X_test_low_level)
        if do_data:
            self.X_data_train_low_level = self.join_objects(
                self.X_data_train_low_level)
            self.X_data_test_low_level = self.join_objects(
                self.X_data_test_low_level)

    def create_train_valid_set(self):
        """
        Partition the X and y training matrix into a train + validation set (i.e. X_train -> X_train + X_validate, and same for y and w)
        This also means turning ordinary arrays into 2D arrays, which we should be careful to keep as 1D arrays earlier

        Note that validation weights should always be the nominal MC weights
        """

        if not self.eq_train:
            X_train_high_level, X_valid_high_level, X_train_low_level, X_valid_low_level, train_w, valid_w, y_train, y_valid = train_test_split(
                self.X_train_high_level,
                self.X_train_low_level,
                self.train_weights,
                self.y_train,
                train_size=0.7,
                test_size=0.3)
        else:
            X_train_high_level, X_valid_high_level, X_train_low_level, X_valid_low_level, train_w, valid_w, w_train_eq, w_valid_eq, y_train, y_valid = train_test_split(
                self.X_train_high_level,
                self.X_train_low_level,
                self.train_weights,
                self.train_weights_eq,
                self.y_train,
                train_size=0.7,
                test_size=0.3)
            self.train_weights_eq = w_train_eq

        #NOTE: might need to re-equalise weights in each folds as sumW_sig != sumW_bkg anymroe!
        self.train_weights = train_w
        self.valid_weights = valid_w  #validation weights should never be equalised weights!

        print 'creating validation dataset'
        self.X_train_high_level = X_train_high_level
        self.X_train_low_level = self.join_objects(X_train_low_level)

        self.X_valid_high_level = X_valid_high_level
        self.X_valid_low_level = self.join_objects(X_valid_low_level)
        print 'finished creating validation dataset'

        self.y_train = y_train
        self.y_valid = y_valid

    def join_objects(self, X_low_level):
        """
        Function take take all low level objects for each event, and transform into a matrix:
            [ [jet1-pt, jet1-eta, ...],
              [jet2-pt, jet2-eta, ...],
              [jet3-pt, jet3-eta, ...] ]_evt1 ,

            [ [jet1-pt, jet1-eta, ...],
              [jet2-pt, jet2-eta, ...],
              [jet3-pt, jet3-eta, ...] ]_evt2 ,

             ...
           
        Note that the order of the low level inputs is important, and should be jet objects in descending pT

        Arguments
        ---------
        X_low_level: numpy ndarray
            array of X_features, with columns labelled in order: low-level vars to high-level vars

        Returns
        --------
        numpy ndarray: 2D representation of all jets in each event, for all events in X_low_level
        """

        print 'Creating 2D object vars...'
        l_to_convert = []
        for index, row in pd.DataFrame(
                X_low_level, columns=self.low_level_vars_flat).iterrows(
                ):  #very slow; need a better way to do this
            l_event = []
            for i_object_list in self.low_level_vars:
                l_object = []
                for i_var in i_object_list:
                    l_object.append(row[i_var])
                l_event.append(l_object)
            l_to_convert.append(l_event)
        print 'Finished creating train object vars'
        return np.array(l_to_convert, np.float32)

    def set_model(self,
                  n_lstm_layers=3,
                  n_lstm_nodes=150,
                  n_dense_1=1,
                  n_nodes_dense_1=300,
                  n_dense_2=4,
                  n_nodes_dense_2=200,
                  dropout_rate=0.1,
                  learning_rate=0.001,
                  batch_norm=True,
                  batch_momentum=0.99):
        """
        Set hyper parameters of the network, including the general structure, learning rate, and regularisation coefficients.
        Resulting model is set as a class attribute, overwriting existing model.

        Arguments
        ---------
        n_lstm_layers : int
            number of lstm layers/units 
        n_lstm_nodes : int
            number of nodes in each lstm layer/unit
        n_dense_1 : int
            number of dense fully connected layers
        n_dense_nodes_1 : int
            number of nodes in each dense fully connected layer
        n_dense_2 : int
            number of regular fully connected layers
        n_dense_nodes_2 : int
            number of nodes in each regular fully connected layer
        dropout_rate : float
            fraction of weights to be dropped during training, to regularise the network
        learning_rate: float
            learning rate for gradient-descent based loss minimisation
        batch_norm: bool
             option to normalise each batch before training
        batch_momentum : float
             momentum for the gradient descent, evaluated on a given batch
        """

        input_objects = keras.layers.Input(shape=(len(self.low_level_vars),
                                                  len(self.low_level_vars[0])),
                                           name='input_objects')
        input_global = keras.layers.Input(shape=(len(self.high_level_vars), ),
                                          name='input_global')
        lstm = input_objects
        decay = 0.2
        for i_layer in range(n_lstm_layers):
            #lstm = keras.layers.LSTM(n_lstm_nodes, activation='tanh', kernel_regularizer=keras.regularizers.l2(decay), recurrent_regularizer=keras.regularizers.l2(decay), bias_regularizer=keras.regularizers.l2(decay), return_sequences=(i_layer!=(n_lstm_layers-1)), name='lstm_{}'.format(i_layer))(lstm)
            lstm = keras.layers.LSTM(n_lstm_nodes,
                                     activation='tanh',
                                     return_sequences=(i_layer !=
                                                       (n_lstm_layers - 1)),
                                     name='lstm_{}'.format(i_layer))(lstm)

        #inputs to dense layers are output of lstm and global-event variables. Also batch norm the FC layers
        dense = keras.layers.concatenate([input_global, lstm])
        for i in range(n_dense_1):
            dense = keras.layers.Dense(n_nodes_dense_1,
                                       activation='relu',
                                       kernel_initializer='he_uniform',
                                       name='dense1_%d' % i)(dense)
            if batch_norm:
                dense = keras.layers.BatchNormalization(
                    name='dense_batch_norm1_%d' % i)(dense)
        dense = keras.layers.Dropout(rate=dropout_rate,
                                     name='dense_dropout1_%d' % i)(dense)

        for i in range(n_dense_2):
            dense = keras.layers.Dense(n_nodes_dense_2,
                                       activation='relu',
                                       kernel_initializer='he_uniform',
                                       name='dense2_%d' % i)(dense)
            #add droput and norm if not on last layer
            if batch_norm and i < (n_dense_2 - 1):
                dense = keras.layers.BatchNormalization(
                    name='dense_batch_norm2_%d' % i)(dense)
            if i < (n_dense_2 - 1):
                dense = keras.layers.Dropout(rate=dropout_rate,
                                             name='dense_dropout2_%d' %
                                             i)(dense)

        output = keras.layers.Dense(1, activation='sigmoid',
                                    name='output')(dense)
        #optimiser = keras.optimizers.Nadam(lr = learning_rate)
        optimiser = keras.optimizers.Adam(lr=learning_rate)

        model = keras.models.Model(inputs=[input_global, input_objects],
                                   outputs=[output])
        model.compile(optimizer=optimiser, loss='binary_crossentropy')
        self.model = model

    def train_w_batch_boost(self,
                            out_tag='my_lstm',
                            save=True,
                            auc_threshold=0.01):
        """
        Alternative method of tranining, where the batch size is increased during training, 
        if the improvement in (1-AUC) is above some threshold.
        Terminate the training early if no improvement is seen after max batch size update

        Arguments
        --------
        out_tag: string
            output tag used as part of the model name, when saving
        save: bool
            option to save the best model
        auc_threshold: float
            minimum improvement in (1-AUC) to warrant not updating the batch size. 
        """

        self.create_train_valid_set()

        #paramaters that control batch size
        best_auc = 0.5
        #current_batch_size = 1024
        current_batch_size = 64
        #max_batch_size     = 50000
        max_batch_size = 50000

        #keep track of epochs for plotting loss vs epoch, and for getting best model
        epoch_counter = 0
        best_epoch = 1

        keep_training = True

        while keep_training:
            epoch_counter += 1
            print('beginning training iteration for epoch {}'.format(
                epoch_counter))
            self.train_network(epochs=1, batch_size=current_batch_size)

            self.save_model(epoch_counter, out_tag)
            val_roc = self.compute_roc(
                batch_size=current_batch_size, valid_set=True
            )  #FIXME: what is the best BS here? final BS from batch boost... initial BS? current BS??

            #get average of validation rocs and clear list entries
            improvement = ((1 - best_auc) - (1 - val_roc)) / (1 - best_auc)

            #FIXME: if the validation roc does not improve after n bad "epochs", then update the batch size accordingly. Rest bad epochs to zero each time the batch size increases, if it does

            #do checks to see if batch size needs to change etc
            if improvement > auc_threshold:
                print(
                    'Improvement in (1-AUC) of {:.4f} percent. Keeping batch size at {}'
                    .format(improvement * 100, current_batch_size))
                best_auc = val_roc
                best_epoch = epoch_counter
            elif current_batch_size * 4 < max_batch_size:
                print(
                    'Improvement in (1-AUC) of only {:.4f} percent. Increasing batch size to {}'
                    .format(improvement * 100, current_batch_size * 4))
                current_batch_size *= 4
                if val_roc > best_auc:
                    best_auc = val_roc
                    best_epoch = epoch_counter
            elif current_batch_size < max_batch_size:
                print(
                    'Improvement in (1-AUC) of only {:.4f} percent. Increasing to max batch size of {}'
                    .format(improvement * 100, max_batch_size))
                current_batch_size = max_batch_size
                if val_roc > best_auc:
                    best_auc = val_roc
                    best_epoch = epoch_counter
            elif improvement > 0:
                print(
                    'Improvement in (1-AUC) of only {:.4f} percent. Cannot increase batch further'
                    .format(improvement * 100))
                best_auc = val_roc
                best_epoch = epoch_counter
            else:
                print(
                    'AUC did not improve and batch size cannot be increased further. Stopping training...'
                )
                keep_training = False

            if epoch_counter > self.max_epochs:
                print(
                    'At the maximum number of training epochs ({}). Stopping training...'
                    .format(self.max_epochs))
                keep_training = False
                best_epoch = self.max_epochs

        print 'best epoch was: {}'.format(best_epoch)
        print 'best validation auc was: {}'.format(best_auc)
        self.val_roc = best_auc

        #delete all models that aren't from the best training. Re-load best model for predicting on test set
        for epoch in range(1, epoch_counter + 1):
            if epoch is not best_epoch:
                os.system('rm {}/models/{}_model_epoch_{}.hdf5'.format(
                    os.getcwd(), out_tag, epoch))
                os.system(
                    'rm {}/models/{}_model_architecture_epoch_{}.json'.format(
                        os.getcwd(), out_tag, epoch))
        os.system(
            'mv {0}/models/{1}_model_epoch_{2}.hdf5 {0}/models/{1}_model.hdf5'.
            format(os.getcwd(), out_tag, best_epoch))
        os.system(
            'mv {0}/models/{1}_model_architecture_epoch_{2}.json {0}/models/{1}_model_architecture.json'
            .format(os.getcwd(), out_tag, best_epoch))

        #reset model state and load in best weights
        with open(
                '{}/models/{}_model_architecture.json'.format(
                    os.getcwd(), out_tag), 'r') as model_json:
            best_model_architecture = model_json.read()
        self.model = keras.models.model_from_json(best_model_architecture)
        self.model.load_weights('{}/models/{}_model.hdf5'.format(
            os.getcwd(), out_tag))

        if not save:
            os.system('rm {}/models/{}_model_architecture.json'.format(
                os.getcwd(), out_tag))
            os.system('rm {}/models/{}_model.hdf5'.format(
                os.getcwd(), out_tag))

    def train_network(self, batch_size, epochs):
        """
        Train the network over a given number of epochs
        Arguments
        ---------
        batch_size: int
            number of training samples to compute the gradient on during training
        epochs: int
            number of full passes oevr the training sample
        """

        if self.eq_train:
            self.model.fit([self.X_train_high_level, self.X_train_low_level],
                           self.y_train,
                           epochs=epochs,
                           batch_size=batch_size,
                           sample_weight=self.train_weights_eq)
        else:
            self.model.fit([self.X_train_high_level, self.X_train_low_level],
                           self.y_train,
                           epochs=epochs,
                           batch_size=batch_size,
                           sample_weight=self.train_weights)

    def save_model(self, epoch=None, out_tag='my_lstm'):
        """
        Save the deep learning model, training up to a given epoch
        
        Arguments:
        ---------
        epoch: int
            the epoch to which to model is trained up to    
        out_tag: string
            output tag used as part of the model name, when saving
        """

        Utils.check_dir('./models/')
        if epoch is not None:
            self.model.save_weights('{}/models/{}_model_epoch_{}.hdf5'.format(
                os.getcwd(), out_tag, epoch))
            with open(
                    "{}/models/{}_model_architecture_epoch_{}.json".format(
                        os.getcwd(), out_tag, epoch), "w") as f_out:
                f_out.write(self.model.to_json())
        else:
            self.model.save_weights('{}/models/{}_model.hdf5'.format(
                os.getcwd(), out_tag))
            with open(
                    "{}/models/{}_model_architecture.json".format(
                        os.getcwd(), out_tag), "w") as f_out:
                f_out.write(self.model.to_json())

    def compare_rocs(self, roc_file, hp_string):
        """
        Compare the AUC for the current model, to the current best AUC saved in a .txt file 
        Arguments
        ---------
        roc_file: string
            path for the file holding the current best AUC (as the final line)
        hp_string: string
            string contraining each hyper_paramter for the network, with the following syntax: 'hp_1_name:hp_1_value, hp_2_name:hp_2_value, ...'
        """

        hp_roc = roc_file.readlines()
        val_auc = self.val_roc
        print 'validation roc is: {}'.format(val_auc)
        if len(hp_roc) == 0:
            roc_file.write('{};{:.4f}'.format(hp_string, val_auc))
        elif float(hp_roc[-1].split(';')[-1]) < val_auc:
            roc_file.write('\n')
            roc_file.write('{};{:.4f}'.format(hp_string, val_auc))

    def batch_gs_cv(self, pt_rew=False):
        """
        Submit sets of hyperparameters permutations (based on attribute hp_grid_rnge) to the IC batch.
        Take care to separate training weights, which may be modified w.r.t nominal weights, 
        and the weights used when evaluating on the validation set which should be the nominal weights
        """
        #get all possible HP sets from permutations of the above dict
        hp_perms = self.get_hp_perms()
        #submit job to the batch for the given HP range:
        for hp_string in hp_perms:
            Utils.sub_lstm_hp_script(self.eq_train,
                                     self.batch_boost,
                                     hp_string,
                                     pt_rew=pt_rew)

    def get_hp_perms(self):
        """
        Get all possible combinations of the hyper-parameters specified in self.hp_grid_range
        
        Returns
        -------        
        final_hps: list of all possible hyper parameter combinations in format 'hp_1_name:hp_1_value, hp_2_name:hp_2_value, ...'
        """

        from itertools import product
        hp_perms = [
            perm for perm in apply(product, self.hp_grid_rnge.values())
        ]
        final_hps = []
        counter = 0
        for hp_perm in hp_perms:
            l_entry = ''
            for hp_name, hp_value in zip(self.hp_grid_rnge.keys(), hp_perm):
                l_entry += '{}:{},'.format(hp_name, hp_value)
                counter += 1
                if (counter % len(self.hp_grid_rnge.keys())) == 0:
                    final_hps.append(l_entry[:-1])
        return final_hps

    def set_hyper_parameters(self, hp_string):
        """
        Set the hyperparameters for the network, given some inut string of parameters
        Arguments:
        ---------
        hp_string: string
            string contraining each hyper_paramter for the network, with the following syntax: 'hp_1_name:hp_1_value, hp_2_name:hp_2_value, ...'
        """

        hp_dict = {}
        for params in hp_string.split(','):
            hp_name = params.split(':')[0]
            hp_value = params.split(':')[1]
            try:
                hp_value = int(hp_value)
            except ValueError:
                hp_value = float(hp_value)
            hp_dict[hp_name] = hp_value
            self.set_model(**hp_dict)

    def compute_roc(self, batch_size=64, valid_set=False):
        """
        Compute the area under the associated ROC curve, with usual mc weights
        Arguments
        ---------
        batch_size: int
            necessary to evaluate the network. Has no impact on the output score.
        valid_set: bool
            compute the roc score on validation set instead of than the test set
        Returns
        -------
        roc_test : float
            return the score on the test set (or validation set if performing any model selection)
        """

        self.y_pred_train = self.model.predict(
            [self.X_train_high_level, self.X_train_low_level],
            batch_size=batch_size).flatten()
        roc_train = roc_auc_score(self.y_train,
                                  self.y_pred_train,
                                  sample_weight=self.train_weights)
        print 'ROC train score: {}'.format(roc_train)

        if valid_set:
            self.y_pred_valid = self.model.predict(
                [self.X_valid_high_level, self.X_valid_low_level],
                batch_size=batch_size).flatten()
            roc_test = roc_auc_score(self.y_valid,
                                     self.y_pred_valid,
                                     sample_weight=self.valid_weights)
            print 'ROC valid score: {}'.format(roc_test)
        else:
            self.y_pred_test = self.model.predict(
                [self.X_test_high_level, self.X_test_low_level],
                batch_size=batch_size).flatten()
            roc_test = roc_auc_score(self.y_test,
                                     self.y_pred_test,
                                     sample_weight=self.test_weights)
            print 'ROC test score: {}'.format(roc_test)

        return roc_test

    def plot_roc(self, out_tag):
        """
        Plot the roc curve for the classifier, using method from Plotter() class
        Arguments
        ---------
        out_tag: string
            output tag used as part of the image name, when saving
        """
        roc_fig = self.plotter.plot_roc(self.y_train,
                                        self.y_pred_train,
                                        self.train_weights,
                                        self.y_test,
                                        self.y_pred_test,
                                        self.test_weights,
                                        out_tag=out_tag)

        Utils.check_dir('{}/plotting/plots/{}'.format(os.getcwd(), out_tag))
        roc_fig.savefig('{0}/plotting/plots/{1}/{1}_ROC_curve.pdf'.format(
            os.getcwd(), out_tag))
        print('saving: {0}/plotting/plots/{1}/{1}_ROC_curve.pdf'.format(
            os.getcwd(), out_tag))
        plt.close()

        #for MVA ROC comparisons later on
        np.savez("{}/models/{}_ROC_comp_arrays".format(os.getcwd(), out_tag),
                 self.y_pred_test, self.y_pred_test, self.test_weights)

    def plot_output_score(self,
                          out_tag,
                          batch_size=64,
                          ratio_plot=False,
                          norm_to_data=False):
        """
        Plot the output score for the classifier, for signal, background, and data
        Arguments
        ---------
        out_tag: string
            output tag used as part of the image name, when saving
        batch_size: int
            necessary to evaluate the network. Has no impact on the output score.
        ratio_plot: bool
            whether to plot the ratio between simulated background and data
        norm_to_data: bool
            whether to normalise the integral of the simulated background, to the integral in data
        """

        output_score_fig = self.plotter.plot_output_score(
            self.y_test,
            self.y_pred_test,
            self.test_weights,
            self.proc_arr_test,
            self.model.predict(
                [self.X_data_test_high_level, self.X_data_test_low_level],
                batch_size=batch_size).flatten(),
            MVA='DNN',
            ratio_plot=ratio_plot,
            norm_to_data=norm_to_data)

        Utils.check_dir('{}/plotting/plots/{}'.format(os.getcwd(), out_tag))
        output_score_fig.savefig(
            '{0}/plotting/plots/{1}/{1}_output_score.pdf'.format(
                os.getcwd(), out_tag))
        print('saving: {0}/plotting/plots/{1}/{1}_output_score.pdf'.format(
            os.getcwd(), out_tag))
        plt.close()