示例#1
0
def main(args):
    ''' ADMIN '''
    '''----------------------------------------------------------------------- '''
    img_path = os.path.join(REPORTS_DIR, 'matrices', args.load)
    if not os.path.exists(img_path):
        os.makedirs(img_path)
    ''' DATA '''
    '''----------------------------------------------------------------------- '''
    tf = load_tf(args.data_dir, "{}-train.pickle".format(args.filename))
    X, y = load_data(args.data_dir, "{}-train.pickle".format(args.filename))
    for ij, jet in enumerate(X):
        jet["content"] = tf.transform(jet["content"])

    X_valid_uncropped, y_valid_uncropped = X[:1000], y[:1000]
    X_valid, y_valid, _, _ = crop(X_valid_uncropped,
                                  y_valid_uncropped,
                                  return_cropped_indices=True)
    X_pos, X_neg = find_balanced_samples(X_valid, y_valid, args.n_viz)
    ''' MODEL '''
    '''----------------------------------------------------------------------- '''
    # Initialization
    with open(os.path.join(MODELS_DIR, args.load, 'settings.pickle'),
              "rb") as f:
        settings = pickle.load(f, encoding='latin-1')
        Transform = settings["transform"]
        Predict = settings["predict"]
        model_kwargs = settings["model_kwargs"]

    with open(os.path.join(MODELS_DIR, args.load, 'model_state_dict.pt'),
              'rb') as f:
        state_dict = torch.load(f)
        model = Predict(Transform, **model_kwargs)
        model.load_state_dict(state_dict)
    if torch.cuda.is_available():
        model.cuda()
    ''' GET MATRICES '''
    '''----------------------------------------------------------------------- '''
    AA_pos = get_matrices(model, X_pos)
    AA_neg = get_matrices(model, X_neg)
    ''' PLOT MATRICES '''
    '''----------------------------------------------------------------------- '''
    viz(AA_pos, os.path.join(img_path, 'positive'))
    viz(AA_neg, os.path.join(img_path, 'negative'))
示例#2
0
def main():

    eh = EvaluationExperimentHandler(args)
    ''' GET RELATIVE PATHS TO DATA AND MODELS '''
    '''----------------------------------------------------------------------- '''
    if args.model_list_file is None:
        assert args.root_model_dir is not None
        model_paths = [(args.root_model_dir, args.root_model_dir)]
    else:
        with open(args.model_list_file, newline='') as f:
            reader = csv.DictReader(f)
            lines = [l for l in reader]
            model_paths = [(l['model'], l['filename']) for l in lines[0:]]

    logging.info("DATASET\n{}".format("\n".join(args.filename)))
    data_path = args.filename
    logging.info("MODEL PATHS\n{}".format("\n".join(mp
                                                    for (_,
                                                         mp) in model_paths)))

    def evaluate_models(X, yy, w, model_filenames, batch_size=64):
        rocs = []
        fprs = []
        tprs = []
        inv_fprs = []

        for i, filename in enumerate(model_filenames):
            if 'DS_Store' not in filename:
                logging.info("\t[{}] Loading {}".format(i, filename)),
                model = load_model(filename)
                if torch.cuda.is_available():
                    model.cuda()
                model_test_file = os.path.join(filename, 'test-rocs.pickle')
                work = not os.path.exists(model_test_file)
                if work:
                    model.eval()

                    offset = 0
                    yy_pred = []
                    n_batches, remainder = np.divmod(len(X), batch_size)
                    for i in range(n_batches):
                        X_batch = X[offset:offset + batch_size]
                        X_var = wrap_X(X_batch)
                        yy_pred.append(unwrap(model(X_var)))
                        unwrap_X(X_var)
                        offset += batch_size
                    if remainder > 0:
                        X_batch = X[-remainder:]
                        X_var = wrap_X(X_batch)
                        yy_pred.append(unwrap(model(X_var)))
                        unwrap_X(X_var)
                    yy_pred = np.squeeze(np.concatenate(yy_pred, 0), 1)

                    #Store Y_pred and Y_test to disc
                    np.save(args.data_dir + 'Y_pred_60.csv', yy_pred)
                    np.save(args.data_dir + 'Y_test_60.csv', yy)
                    logging.info('Files Saved')

                    logdict = dict(
                        model=filename.split('/')[-1],
                        yy=yy,
                        yy_pred=yy_pred,
                        w_valid=w[:len(yy_pred)],
                    )
                    eh.log(**logdict)
                    roc = eh.monitors['roc_auc'].value
                    fpr = eh.monitors['roc_curve'].value[0]
                    tpr = eh.monitors['roc_curve'].value[1]
                    inv_fpr = eh.monitors['inv_fpr'].value

                    with open(model_test_file, "wb") as fd:
                        pickle.dump((roc, fpr, tpr, inv_fpr), fd)
                else:
                    with open(model_test_file, "rb") as fd:
                        roc, fpr, tpr, inv_fpr = pickle.load(fd)
                    stats_dict = {'roc_auc': roc, 'inv_fpr': inv_fpr}
                    eh.stats_logger.log(stats_dict)
                rocs.append(roc)
                fprs.append(fpr)
                tprs.append(tpr)
                inv_fprs.append(inv_fpr)

        logging.info("\tMean ROC AUC = {:.4f} Mean 1/FPR = {:.4f}".format(
            np.mean(rocs), np.mean(inv_fprs)))

        return rocs, fprs, tprs, inv_fprs

    def build_rocs(data, model_path, batch_size):
        X, y, w = data
        model_filenames = [
            os.path.join(model_path, fn) for fn in os.listdir(model_path)
        ]
        rocs, fprs, tprs, inv_fprs = evaluate_models(X, y, w, model_filenames,
                                                     batch_size)

        return rocs, fprs, tprs, inv_fprs

    ''' BUILD ROCS '''
    '''----------------------------------------------------------------------- '''
    if args.load_rocs is None and args.model_list_file is None:

        logging.info(
            'Building ROCs for models trained on {}'.format(data_path))
        tf = load_tf(args.data_dir, "{}-train.pickle".format(data_path))
        X, y = load_data(args.data_dir,
                         "{}-{}.pickle".format(data_path, args.set))
        for ij, jet in enumerate(X):
            jet["content"] = tf.transform(jet["content"])

        if args.n_test > 0:
            indices = torch.randperm(len(X)).numpy()[:args.n_test]
            X = [X[i] for i in indices]
            y = y[indices]

        X_test, y_test, cropped_indices, w_test = crop(
            X, y, 60, return_cropped_indices=True, pileup=args.pileup)

        data = (X_test, y_test, w_test)
        for model_path in model_paths:
            model_path = model_path[0]
            logging.info(
                '\tBuilding ROCs for instances of {}'.format(model_path))
            logging.info('\tBuilding ROCs for instances of {}'.format(
                args.finished_models_dir))
            logging.info('\tBuilding ROCs for instances of {}'.format(
                os.path.join(args.finished_models_dir, model_path)))
            r, f, t, inv_fprs = build_rocs(
                data, os.path.join(args.finished_models_dir, model_path),
                args.batch_size)
            #remove_outliers_csv(os.path.join(args.finished_models_dir, model_path))
            absolute_roc_path = os.path.join(
                eh.exp_dir,
                "rocs-{}-{}.pickle".format("-".join(model_path.split('/')),
                                           data_path))
            with open(absolute_roc_path, "wb") as fd:
                pickle.dump((r, f, t, inv_fprs), fd)
    else:
        for _, model_path in model_paths:

            previous_absolute_roc_path = os.path.join(
                args.root_exp_dir, model_path,
                "rocs-{}-{}.pickle".format("-".join(model_path.split('/')),
                                           data_path))
            with open(previous_absolute_roc_path, "rb") as fd:
                r, f, t, inv_fprs = pickle.load(fd)

            absolute_roc_path = os.path.join(
                eh.exp_dir,
                "rocs-{}-{}.pickle".format("-".join(model_path.split('/')),
                                           data_path))
            with open(absolute_roc_path, "wb") as fd:
                pickle.dump((r, f, t, inv_fprs), fd)
    ''' PLOT ROCS '''
    '''----------------------------------------------------------------------- '''

    colors = (('red', (228, 26, 28)), ('blue', (55, 126, 184)),
              ('green', (77, 175, 74)), ('purple', (162, 78, 163)),
              ('orange', (255, 127, 0)))
    colors = [(name, tuple(x / 256 for x in tup)) for name, tup in colors]

    for (label, model_path), (_, color) in zip(model_paths, colors):
        absolute_roc_path = os.path.join(
            eh.exp_dir,
            "rocs-{}-{}.pickle".format("-".join(model_path.split('/')),
                                       data_path))
        with open(absolute_roc_path, "rb") as fd:
            r, f, t, inv_fprs = pickle.load(fd)

        if args.remove_outliers:
            r, f, t, inv_fprs = remove_outliers(r, f, t, inv_fprs)

        report_score(r, inv_fprs, label=label)
        plot_rocs(r, f, t, label=label, color=color)

    figure_filename = os.path.join(eh.exp_dir, 'rocs.png')
    plot_save(figure_filename)
    if args.plot:
        plot_show()

    eh.finished()
示例#3
0
def train(args):
    model_type = MODEL_TYPES[args.model_type]
    eh = ExperimentHandler(args, os.path.join(MODELS_DIR, model_type))
    signal_handler = eh.signal_handler
    ''' DATA '''
    '''----------------------------------------------------------------------- '''
    logging.warning("Loading data...")
    tf = load_tf(DATA_DIR, "{}-train.pickle".format(args.filename))
    X, y = load_data(DATA_DIR, "{}-train.pickle".format(args.filename))

    for jet in X:
        jet["content"] = tf.transform(jet["content"])

    if args.n_train > 0:
        indices = torch.randperm(len(X)).numpy()[:args.n_train]
        X = [X[i] for i in indices]
        y = y[indices]

    logging.warning("Splitting into train and validation...")

    X_train, X_valid_uncropped, y_train, y_valid_uncropped = train_test_split(
        X, y, test_size=args.n_valid)
    logging.warning("\traw train size = %d" % len(X_train))
    logging.warning("\traw valid size = %d" % len(X_valid_uncropped))

    X_valid, y_valid, cropped_indices, w_valid = crop(
        X_valid_uncropped, y_valid_uncropped, return_cropped_indices=True)

    # add cropped indices to training data
    if args.add_cropped:
        X_train.extend([
            x for i, x in enumerate(X_valid_uncropped) if i in cropped_indices
        ])
        y_train = [y for y in y_train]
        y_train.extend([
            y for i, y in enumerate(y_valid_uncropped) if i in cropped_indices
        ])
        y_train = np.array(y_train)
    logging.warning("\tfinal train size = %d" % len(X_train))
    logging.warning("\tfinal valid size = %d" % len(X_valid))
    ''' MODEL '''
    '''----------------------------------------------------------------------- '''
    # Initialization
    Predict = PredictFromParticleEmbedding
    if args.load is None:
        Transform = TRANSFORMS[args.model_type]
        model_kwargs = {
            'n_features': args.n_features,
            'n_hidden': args.n_hidden,
        }
        if Transform in [MPNNTransform, GRNNTransformGated]:
            model_kwargs['n_iters'] = args.n_iters
            model_kwargs['leaves'] = args.leaves
        model = Predict(Transform, **model_kwargs)
        settings = {
            "transform": Transform,
            "predict": Predict,
            "model_kwargs": model_kwargs,
            "step_size": args.step_size,
            "args": args,
        }
    else:
        with open(os.path.join(args.load, 'settings.pickle'), "rb") as f:
            settings = pickle.load(f, encoding='latin-1')
            Transform = settings["transform"]
            Predict = settings["predict"]
            model_kwargs = settings["model_kwargs"]

        with open(os.path.join(args.load, 'model_state_dict.pt'), 'rb') as f:
            state_dict = torch.load(f)
            model = PredictFromParticleEmbedding(Transform, **model_kwargs)
            model.load_state_dict(state_dict)

        if args.restart:
            args.step_size = settings["step_size"]

    logging.warning(model)
    out_str = 'Number of parameters: {}'.format(
        sum(np.prod(p.data.numpy().shape) for p in model.parameters()))
    logging.warning(out_str)

    if torch.cuda.is_available():
        model.cuda()
    signal_handler.set_model(model)
    ''' OPTIMIZER AND LOSS '''
    '''----------------------------------------------------------------------- '''

    optimizer = Adam(model.parameters(), lr=args.step_size)
    scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=args.decay)
    #scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5)

    n_batches = int(np.ceil(len(X_train) / args.batch_size))
    best_score = [-np.inf]  # yuck, but works
    best_model_state_dict = copy.deepcopy(model.state_dict())

    def loss(y_pred, y):
        l = log_loss(y, y_pred.squeeze(1)).mean()
        return l
        ''' VALIDATION '''

    '''----------------------------------------------------------------------- '''

    def callback(iteration, model):
        out_str = None

        def save_everything(model):
            with open(os.path.join(eh.exp_dir, 'model_state_dict.pt'),
                      'wb') as f:
                torch.save(model.state_dict(), f)

            with open(os.path.join(eh.exp_dir, 'settings.pickle'), "wb") as f:
                pickle.dump(settings, f)

        if iteration % 25 == 0:
            model.eval()

            offset = 0
            train_loss = []
            valid_loss = []
            yy, yy_pred = [], []
            for i in range(len(X_valid) // args.batch_size):
                idx = slice(offset, offset + args.batch_size)
                Xt, yt = X_train[idx], y_train[idx]
                X_var = wrap_X(Xt)
                y_var = wrap(yt)
                tl = unwrap(loss(model(X_var), y_var))
                train_loss.append(tl)
                X = unwrap_X(X_var)
                y = unwrap(y_var)

                Xv, yv = X_valid[offset:offset +
                                 args.batch_size], y_valid[offset:offset +
                                                           args.batch_size]
                X_var = wrap_X(Xv)
                y_var = wrap(yv)
                y_pred = model(X_var)
                vl = unwrap(loss(y_pred, y_var))
                valid_loss.append(vl)
                Xv = unwrap_X(X_var)
                yv = unwrap(y_var)
                y_pred = unwrap(y_pred)
                yy.append(yv)
                yy_pred.append(y_pred)

                offset += args.batch_size

            train_loss = np.mean(np.array(train_loss))
            valid_loss = np.mean(np.array(valid_loss))
            yy = np.concatenate(yy, 0)
            yy_pred = np.concatenate(yy_pred, 0)

            roc_auc = roc_auc_score(yy, yy_pred, sample_weight=w_valid)

            # 1/fpr
            fpr, tpr, _ = roc_curve(yy, yy_pred, sample_weight=w_valid)
            inv_fpr = inv_fpr_at_tpr_equals_half(tpr, fpr)

            if np.isnan(inv_fpr):
                logging.warning("NaN in 1/FPR\n")

            if inv_fpr > best_score[0]:
                best_score[0] = inv_fpr
                save_everything(model)

            out_str = "{:5d}\t~loss(train)={:.4f}\tloss(valid)={:.4f}\troc_auc(valid)={:.4f}".format(
                iteration,
                train_loss,
                valid_loss,
                roc_auc,
            )

            out_str += "\t1/FPR @ TPR = 0.5: {:.2f}\tBest 1/FPR @ TPR = 0.5: {:.2f}".format(
                inv_fpr, best_score[0])

            scheduler.step(valid_loss)
            model.train()
        return out_str

    ''' TRAINING '''
    '''----------------------------------------------------------------------- '''

    logging.warning("Training...")
    for i in range(args.n_epochs):
        logging.info("epoch = %d" % i)
        logging.info("step_size = %.8f" % settings['step_size'])

        for j in range(n_batches):

            model.train()
            optimizer.zero_grad()
            start = torch.round(
                torch.rand(1) *
                (len(X_train) - args.batch_size)).numpy()[0].astype(np.int32)
            idx = slice(start, start + args.batch_size)
            X, y = X_train[idx], y_train[idx]
            X_var = wrap_X(X)
            y_var = wrap(y)
            l = loss(model(X_var), y_var)
            l.backward()
            optimizer.step()
            X = unwrap_X(X_var)
            y = unwrap(y_var)

            out_str = callback(j, model)

            if out_str is not None:
                signal_handler.results_strings.append(out_str)
                logging.info(out_str)

        scheduler.step()
        settings['step_size'] = args.step_size * (args.decay)**(i + 1)
    logging.info("FINISHED TRAINING")
    signal_handler.job_completed()
from monitors.losses import *
from monitors.monitors import *

from architectures import PredictFromParticleEmbedding
#from architectures import AdversarialParticleEmbedding

from loading import load_data
from loading import load_tf
from loading import crop

from sklearn.utils import shuffle

filename = 'antikt-kt'
data_dir = '/scratch/psn240/capstone/data/w-vs-qcd/pickles/'
tf = load_tf(data_dir, "{}-train.pickle".format(filename))
X, y = load_data(data_dir, "{}-train.pickle".format(filename))
for ij, jet in enumerate(X):
    jet["content"] = tf.transform(jet["content"])
Z = [0] * len(y)

print(len(X))
print(len(y))

filename = 'antikt-kt-pileup25-new'
data_dir = '/scratch/psn240/capstone/data/w-vs-qcd/pickles/'
tf_pileup = load_tf(data_dir, "{}-train.pickle".format(filename))
X_pileup, y_pileup = load_data(data_dir, "{}-train.pickle".format(filename))
for ij, jet in enumerate(X_pileup):
    jet["content"] = tf_pileup.transform(jet["content"])
Z_pileup = [1] * len(y)
示例#5
0
文件: evaluation.py 项目: wz1070/jets
def main():

    eh = ExperimentHandler(args, REPORTS_DIR)
    signal_handler = eh.signal_handler
    ''' GET RELATIVE PATHS TO DATA AND MODELS '''
    '''----------------------------------------------------------------------- '''
    with open(args.model_list_filename, "r") as f:
        model_paths = [l.strip('\n') for l in f.readlines() if l[0] != '#']

    with open(args.data_list_filename, "r") as f:
        data_paths = [l.strip('\n') for l in f.readlines() if l[0] != '#']

    logging.info("DATA PATHS\n{}".format("\n".join(data_paths)))
    logging.info("MODEL PATHS\n{}".format("\n".join(model_paths)))
    ''' BUILD ROCS '''
    '''----------------------------------------------------------------------- '''
    if args.load_rocs is None:
        for data_path in data_paths:

            logging.info(
                'Building ROCs for models trained on {}'.format(data_path))
            tf = load_tf(DATA_DIR, "{}-train.pickle".format(data_path))
            if args.set == 'test':
                data = load_test(tf, DATA_DIR,
                                 "{}-test.pickle".format(data_path),
                                 args.n_test)
            elif args.set == 'valid':
                data = load_test(tf, DATA_DIR,
                                 "{}-valid.pickle".format(data_path),
                                 args.n_test)
            elif args.set == 'train':
                data = load_test(tf, DATA_DIR,
                                 "{}-train.pickle".format(data_path),
                                 args.n_test)

            for model_path in model_paths:
                logging.info(
                    '\tBuilding ROCs for instances of {}'.format(model_path))
                r, f, t = build_rocs(data,
                                     os.path.join(MODELS_DIR,
                                                  model_path), args.batch_size)

                absolute_roc_path = os.path.join(
                    eh.exp_dir,
                    "rocs-{}-{}.pickle".format("-".join(model_path.split('/')),
                                               data_path))
                with open(absolute_roc_path, "wb") as fd:
                    pickle.dump((r, f, t), fd)
    else:
        for data_path in data_paths:
            for model_path in model_paths:

                previous_absolute_roc_path = os.path.join(
                    REPORTS_DIR, args.load_rocs,
                    "rocs-{}-{}.pickle".format("-".join(model_path.split('/')),
                                               data_path))
                with open(previous_absolute_roc_path, "rb") as fd:
                    r, f, t = pickle.load(fd)

                absolute_roc_path = os.path.join(
                    eh.exp_dir,
                    "rocs-{}-{}.pickle".format("-".join(model_path.split('/')),
                                               data_path))
                with open(absolute_roc_path, "wb") as fd:
                    pickle.dump((r, f, t), fd)
    ''' PLOT ROCS '''
    '''----------------------------------------------------------------------- '''

    labels = model_paths
    colors = ['c', 'm', 'y', 'k']

    for data_path in data_paths:
        for model_path, label, color in zip(model_paths, labels, colors):
            absolute_roc_path = os.path.join(
                eh.exp_dir,
                "rocs-{}-{}.pickle".format("-".join(model_path.split('/')),
                                           data_path))
            with open(absolute_roc_path, "rb") as fd:
                r, f, t = pickle.load(fd)

            if args.remove_outliers:
                r, f, t = remove_outliers(r, f, t)

            report_score(r, f, t, label=label)
            plot_rocs(r, f, t, label=label, color=color)

    figure_filename = os.path.join(eh.exp_dir, 'rocs.png')
    plot_save(figure_filename)
    if args.plot:
        plot_show()

    signal_handler.job_completed()
def train(args):
    _, Transform, model_type = TRANSFORMS[args.model_type]
    args.root_exp_dir = os.path.join(MODELS_DIR, model_type, str(args.iters))

    eh = ExperimentHandler(args)
    ''' DATA '''
    '''----------------------------------------------------------------------- '''
    logging.warning("Loading data...")

    tf = load_tf(args.data_dir, "{}-train.pickle".format(args.filename))
    X, y = load_data(args.data_dir, "{}-train.pickle".format(args.filename))
    for ij, jet in enumerate(X):
        jet["content"] = tf.transform(jet["content"])

    if args.n_train > 0:
        indices = torch.randperm(len(X)).numpy()[:args.n_train]
        X = [X[i] for i in indices]
        y = y[indices]

    logging.warning("Splitting into train and validation...")

    X_train, X_valid_uncropped, y_train, y_valid_uncropped = train_test_split(
        X, y, test_size=args.n_valid, random_state=0)
    logging.warning("\traw train size = %d" % len(X_train))
    logging.warning("\traw valid size = %d" % len(X_valid_uncropped))

    X_valid, y_valid, cropped_indices, w_valid = crop(
        X_valid_uncropped,
        y_valid_uncropped,
        0,
        return_cropped_indices=True,
        pileup=args.pileup)
    # add cropped indices to training data
    if not args.dont_add_cropped:
        X_train.extend([
            x for i, x in enumerate(X_valid_uncropped) if i in cropped_indices
        ])
        y_train = [y for y in y_train]
        y_train.extend([
            y for i, y in enumerate(y_valid_uncropped) if i in cropped_indices
        ])
        y_train = np.array(y_train)
    logging.warning("\tfinal train size = %d" % len(X_train))
    logging.warning("\tfinal valid size = %d" % len(X_valid))
    ''' MODEL '''
    '''----------------------------------------------------------------------- '''
    # Initialization
    logging.info("Initializing model...")
    Predict = PredictFromParticleEmbedding
    if args.load is None:
        model_kwargs = {
            'features': args.features,
            'hidden': args.hidden,
            'iters': args.iters,
            'leaves': args.leaves,
        }
        model = Predict(Transform, **model_kwargs)
        settings = {
            "transform": Transform,
            "predict": Predict,
            "model_kwargs": model_kwargs,
            "step_size": args.step_size,
            "args": args,
        }
    else:
        with open(os.path.join(args.load, 'settings.pickle'), "rb") as f:
            settings = pickle.load(f, encoding='latin-1')
            Transform = settings["transform"]
            Predict = settings["predict"]
            model_kwargs = settings["model_kwargs"]

        model = PredictFromParticleEmbedding(Transform, **model_kwargs)

        try:
            with open(os.path.join(args.load, 'cpu_model_state_dict.pt'),
                      'rb') as f:
                state_dict = torch.load(f)
        except FileNotFoundError as e:
            with open(os.path.join(args.load, 'model_state_dict.pt'),
                      'rb') as f:
                state_dict = torch.load(f)

        model.load_state_dict(state_dict)

        if args.restart:
            args.step_size = settings["step_size"]

    logging.warning(model)
    out_str = 'Number of parameters: {}'.format(
        sum(np.prod(p.data.numpy().shape) for p in model.parameters()))
    logging.warning(out_str)

    if torch.cuda.is_available():
        logging.warning("Moving model to GPU")
        model.cuda()
        logging.warning("Moved model to GPU")

    eh.signal_handler.set_model(model)
    ''' OPTIMIZER AND LOSS '''
    '''----------------------------------------------------------------------- '''
    logging.info("Building optimizer...")
    optimizer = Adam(model.parameters(), lr=args.step_size)
    scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=args.decay)
    #scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5)

    n_batches = int(len(X_train) // args.batch_size)
    best_score = [-np.inf]  # yuck, but works
    best_model_state_dict = copy.deepcopy(model.state_dict())

    def loss(y_pred, y):
        l = log_loss(y, y_pred.squeeze(1)).mean()
        return l
        ''' VALIDATION '''

    '''----------------------------------------------------------------------- '''

    def callback(epoch, iteration, model):

        if iteration % n_batches == 0:
            t0 = time.time()
            model.eval()

            offset = 0
            train_loss = []
            valid_loss = []
            yy, yy_pred = [], []
            for i in range(len(X_valid) // args.batch_size):
                idx = slice(offset, offset + args.batch_size)
                Xt, yt = X_train[idx], y_train[idx]
                X_var = wrap_X(Xt)
                y_var = wrap(yt)
                tl = unwrap(loss(model(X_var), y_var))
                train_loss.append(tl)
                X = unwrap_X(X_var)
                y = unwrap(y_var)

                Xv, yv = X_valid[idx], y_valid[idx]
                X_var = wrap_X(Xv)
                y_var = wrap(yv)
                y_pred = model(X_var)
                vl = unwrap(loss(y_pred, y_var))
                valid_loss.append(vl)
                Xv = unwrap_X(X_var)
                yv = unwrap(y_var)
                y_pred = unwrap(y_pred)
                yy.append(yv)
                yy_pred.append(y_pred)

                offset += args.batch_size

            train_loss = np.mean(np.array(train_loss))
            valid_loss = np.mean(np.array(valid_loss))
            yy = np.concatenate(yy, 0)
            yy_pred = np.concatenate(yy_pred, 0)

            t1 = time.time()
            logging.info("Modeling validation data took {}s".format(t1 - t0))
            logdict = dict(
                epoch=epoch,
                iteration=iteration,
                yy=yy,
                yy_pred=yy_pred,
                w_valid=w_valid[:len(yy_pred)],
                #w_valid=w_valid,
                train_loss=train_loss,
                valid_loss=valid_loss,
                settings=settings,
                model=model)
            eh.log(**logdict)

            scheduler.step(valid_loss)
            model.train()

    ''' TRAINING '''
    '''----------------------------------------------------------------------- '''
    eh.save(model, settings)
    logging.warning("Training...")
    iteration = 1
    for i in range(args.epochs):
        logging.info("epoch = %d" % i)
        logging.info("step_size = %.8f" % settings['step_size'])
        t0 = time.time()
        for _ in range(n_batches):
            iteration += 1
            model.train()
            optimizer.zero_grad()
            start = torch.round(
                torch.rand(1) *
                (len(X_train) - args.batch_size)).numpy()[0].astype(np.int32)
            idx = slice(start, start + args.batch_size)
            X, y = X_train[idx], y_train[idx]
            X_var = wrap_X(X)
            y_var = wrap(y)
            l = loss(model(X_var), y_var)
            l.backward()
            optimizer.step()
            X = unwrap_X(X_var)
            y = unwrap(y_var)
            callback(i, iteration, model)
        t1 = time.time()
        logging.info("Epoch took {} seconds".format(t1 - t0))

        scheduler.step()
        settings['step_size'] = args.step_size * (args.decay)**(i + 1)

    eh.finished()
def train(args):
    _, Transform, model_type = TRANSFORMS[args.model_type]
    args.root_exp_dir = os.path.join(MODELS_DIR, model_type, str(args.iters))

    eh = ExperimentHandler(args)
    ''' DATA '''
    '''----------------------------------------------------------------------- '''
    logging.warning("Loading pileup antikt-kt-pileup40 data...")

    tf_pileup_40 = load_tf(args.data_dir,
                           "{}-train.pickle".format(args.filename))
    X_pileup_40, y_pileup_40 = load_data(
        args.data_dir, "{}-train.pickle".format(args.filename))
    for ij, jet in enumerate(X_pileup_40):
        jet["content"] = tf_pileup_40.transform(jet["content"])

    if args.n_train > 0:
        indices = torch.randperm(len(X_pileup_40)).numpy()[:args.n_train]
        X_pileup_40 = [X_pileup_40[i] for i in indices]
        y_pileup_40 = y_pileup_40[indices]

    logging.warning("Splitting into train and validation...")

    X_train_pileup_40, X_valid_uncropped_pileup_40, y_train_pileup_40, y_valid_uncropped_pileup_40 = train_test_split(
        X_pileup_40, y_pileup_40, test_size=args.n_valid, random_state=0)
    logging.warning("\traw train size = %d" % len(X_train_pileup_40))
    logging.warning("\traw valid size = %d" % len(X_valid_uncropped_pileup_40))

    X_valid_pileup_40, y_valid_pileup_40, cropped_indices_40, w_valid_40 = crop(
        X_valid_uncropped_pileup_40,
        y_valid_uncropped_pileup_40,
        pileup_lvl=40,
        return_cropped_indices=True,
        pileup=args.pileup)
    # add cropped indices to training data
    if not args.dont_add_cropped:
        X_train_pileup_40.extend([
            x for i, x in enumerate(X_valid_uncropped_pileup_40)
            if i in cropped_indices_40
        ])
        y_train_pileup_40 = [y for y in y_train_pileup_40]
        y_train_pileup_40.extend([
            y for i, y in enumerate(y_valid_uncropped_pileup_40)
            if i in cropped_indices_40
        ])
        y_train_pileup_40 = np.array(y_train_pileup_40)

    Z_train_pileup_40 = [0] * len(y_train_pileup_40)
    Z_valid_pileup_40 = [0] * len(y_valid_pileup_40)
    ''' DATA '''
    '''----------------------------------------------------------------------- '''
    logging.warning("Loading pileup antikt-kt-pileup50 data...")
    args.filename = 'antikt-kt-pileup50'
    args.pileup = False

    tf_pileup_50 = load_tf(args.data_dir,
                           "{}-train.pickle".format(args.filename))
    X_pileup_50, y_pileup_50 = load_data(
        args.data_dir, "{}-train.pickle".format(args.filename))
    for ij, jet in enumerate(X_pileup_50):
        jet["content"] = tf_pileup_50.transform(jet["content"])

    if args.n_train > 0:
        indices = torch.randperm(len(X_pileup_50)).numpy()[:args.n_train]
        X_pileup_50 = [X_pileup_50[i] for i in indices]
        y_pileup_50 = y_pileup_50[indices]

    logging.warning("Splitting into train and validation...")

    X_train_pileup_50, X_valid_uncropped_pileup_50, y_train_pileup_50, y_valid_uncropped_pileup_50 = train_test_split(
        X_pileup_50, y_pileup_50, test_size=args.n_valid, random_state=0)
    logging.warning("\traw train size = %d" % len(X_train_pileup_50))
    logging.warning("\traw valid size = %d" % len(X_valid_uncropped_pileup_50))

    X_valid_pileup_50, y_valid_pileup_50, cropped_indices_50, w_valid_50 = crop(
        X_valid_uncropped_pileup_50,
        y_valid_uncropped_pileup_50,
        pileup_lvl=50,
        return_cropped_indices=True,
        pileup=args.pileup)
    # add cropped indices to training data
    if not args.dont_add_cropped:
        X_train_pileup_50.extend([
            x for i, x in enumerate(X_valid_uncropped_pileup_50)
            if i in cropped_indices_50
        ])
        y_train_pileup_50 = [y for y in y_train_pileup_50]
        y_train_pileup_50.extend([
            y for i, y in enumerate(y_valid_uncropped_pileup_50)
            if i in cropped_indices_50
        ])
        y_train_pileup_50 = np.array(y_train_pileup_50)

    Z_train_pileup_50 = [1] * len(y_train_pileup_50)
    Z_valid_pileup_50 = [1] * len(y_valid_pileup_50)
    ''' DATA '''
    '''----------------------------------------------------------------------- '''
    logging.warning("Loading pileup antikt-kt-pileup60 data...")
    args.filename = 'antikt-kt-pileup60'
    args.pileup = False

    tf_pileup_60 = load_tf(args.data_dir,
                           "{}-train.pickle".format(args.filename))
    X_pileup_60, y_pileup_60 = load_data(
        args.data_dir, "{}-train.pickle".format(args.filename))
    for ij, jet in enumerate(X_pileup_60):
        jet["content"] = tf_pileup_60.transform(jet["content"])

    if args.n_train > 0:
        indices = torch.randperm(len(X_pileup_60)).numpy()[:args.n_train]
        X_pileup_60 = [X_pileup_60[i] for i in indices]
        y_pileup_60 = y_pileup_60[indices]

    logging.warning("Splitting into train and validation...")

    X_train_pileup_60, X_valid_uncropped_pileup_60, y_train_pileup_60, y_valid_uncropped_pileup_60 = train_test_split(
        X_pileup_60, y_pileup_60, test_size=args.n_valid, random_state=0)
    logging.warning("\traw train size = %d" % len(X_train_pileup_60))
    logging.warning("\traw valid size = %d" % len(X_valid_uncropped_pileup_60))

    X_valid_pileup_60, y_valid_pileup_60, cropped_indices_60, w_valid_60 = crop(
        X_valid_uncropped_pileup_60,
        y_valid_uncropped_pileup_60,
        pileup_lvl=60,
        return_cropped_indices=True,
        pileup=args.pileup)
    # add cropped indices to training data
    if not args.dont_add_cropped:
        X_train_pileup_60.extend([
            x for i, x in enumerate(X_valid_uncropped_pileup_60)
            if i in cropped_indices_60
        ])
        y_train_pileup_60 = [y for y in y_train_pileup_60]
        y_train_pileup_60.extend([
            y for i, y in enumerate(y_valid_uncropped_pileup_60)
            if i in cropped_indices_60
        ])
        y_train_pileup_60 = np.array(y_train_pileup_60)

    Z_train_pileup_60 = [2] * len(y_train_pileup_60)
    Z_valid_pileup_60 = [2] * len(y_valid_pileup_60)

    X_train = np.concatenate(
        (X_train_pileup_40, X_train_pileup_50, X_train_pileup_60), axis=0)
    X_valid = np.concatenate(
        (X_valid_pileup_40, X_valid_pileup_50, X_valid_pileup_60), axis=0)
    y_train = np.concatenate(
        (y_train_pileup_40, y_train_pileup_50, y_train_pileup_60), axis=0)
    y_valid = np.concatenate(
        (y_valid_pileup_40, y_valid_pileup_50, y_valid_pileup_60), axis=0)
    Z_train = np.concatenate(
        (Z_train_pileup_40, Z_train_pileup_50, Z_train_pileup_60), axis=0)
    Z_valid = np.concatenate(
        (Z_valid_pileup_40, Z_valid_pileup_50, Z_valid_pileup_60), axis=0)
    w_valid = np.concatenate((w_valid_40, w_valid_50, w_valid_60), axis=0)

    X_train, y_train, Z_train = shuffle(X_train,
                                        y_train,
                                        Z_train,
                                        random_state=0)
    X_valid, y_valid, Z_valid = shuffle(X_valid,
                                        y_valid,
                                        Z_valid,
                                        random_state=0)

    logging.warning("\tfinal X train size = %d" % len(X_train))
    logging.warning("\tfinal X valid size = %d" % len(X_valid))
    logging.warning("\tfinal Y train size = %d" % len(y_train))
    logging.warning("\tfinal Y valid size = %d" % len(y_valid))
    logging.warning("\tfinal Z train size = %d" % len(Z_train))
    logging.warning("\tfinal Z valid size = %d" % len(Z_valid))
    logging.warning("\tfinal w valid size = %d" % len(w_valid))
    ''' MODEL '''
    '''----------------------------------------------------------------------- '''
    # Initialization
    logging.info("Initializing model...")
    Predict = PredictFromParticleEmbedding
    if args.load is None:
        model_kwargs = {
            'features': args.features,
            'hidden': args.hidden,
            'iters': args.iters,
            'leaves': args.leaves,
            'batch': args.batch_size,
        }
        logging.info('No previous models')
        model = Predict(Transform, **model_kwargs)
        adversarial_model = AdversarialParticleEmbedding(**model_kwargs)
        settings = {
            "transform": Transform,
            "predict": Predict,
            "model_kwargs": model_kwargs,
            "step_size": args.step_size,
            "args": args,
        }
    else:
        with open(os.path.join(args.load, 'settings.pickle'), "rb") as f:
            settings = pickle.load(f, encoding='latin-1')
            Transform = settings["transform"]
            Predict = settings["predict"]
            model_kwargs = settings["model_kwargs"]

        model = PredictFromParticleEmbedding(Transform, **model_kwargs)

        try:
            with open(os.path.join(args.load, 'cpu_model_state_dict.pt'),
                      'rb') as f:
                state_dict = torch.load(f)
        except FileNotFoundError as e:
            with open(os.path.join(args.load, 'model_state_dict.pt'),
                      'rb') as f:
                state_dict = torch.load(f)

        model.load_state_dict(state_dict)

        if args.restart:
            args.step_size = settings["step_size"]

    logging.warning(model)
    logging.warning(adversarial_model)

    out_str = 'Number of parameters: {}'.format(
        sum(np.prod(p.data.numpy().shape) for p in model.parameters()))
    out_str_adversarial = 'Number of parameters: {}'.format(
        sum(
            np.prod(p.data.numpy().shape)
            for p in adversarial_model.parameters()))
    logging.warning(out_str)
    logging.warning(out_str_adversarial)

    if torch.cuda.is_available():
        logging.warning("Moving model to GPU")
        model.cuda()
        logging.warning("Moved model to GPU")
    else:
        logging.warning("No cuda")

    eh.signal_handler.set_model(model)
    ''' OPTIMIZER AND LOSS '''
    '''----------------------------------------------------------------------- '''

    logging.info("Building optimizer...")
    optimizer = Adam(model.parameters(), lr=args.step_size)
    optimizer_adv = Adam(adversarial_model.parameters(), lr=args.step_size)
    scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=args.decay)
    scheduler_adv = lr_scheduler.ExponentialLR(optimizer_adv, gamma=args.decay)
    #scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5)

    n_batches = int(len(X_train) // args.batch_size)
    best_score = [-np.inf]  # yuck, but works
    best_model_state_dict = copy.deepcopy(model.state_dict())

    def loss_adversarial(y_pred, y):
        return -(y * torch.log(y_pred) + (1. - y) * torch.log(1. - y_pred))

    def loss(y_pred, y):
        l = log_loss(y, y_pred.squeeze(1)).mean()
        return l

    ''' VALIDATION '''
    '''----------------------------------------------------------------------- '''

    def callback(epoch, iteration, model):

        if iteration % n_batches == 0:
            t0 = time.time()
            model.eval()

            offset = 0
            train_loss = []
            valid_loss = []
            yy, yy_pred = [], []
            for i in range(len(X_valid) // args.batch_size):
                idx = slice(offset, offset + args.batch_size)
                Xt, yt = X_train[idx], y_train[idx]
                X_var = wrap_X(Xt)
                y_var = wrap(yt)
                y_pred_1 = model(X_var)
                tl = unwrap(loss(y_pred_1, y_var))
                train_loss.append(tl)
                X = unwrap_X(X_var)
                y = unwrap(y_var)

                Xv, yv = X_valid[idx], y_valid[idx]
                X_var = wrap_X(Xv)
                y_var = wrap(yv)
                y_pred = model(X_var)
                vl = unwrap(loss(y_pred, y_var))
                valid_loss.append(vl)
                Xv = unwrap_X(X_var)
                yv = unwrap(y_var)
                y_pred = unwrap(y_pred)
                yy.append(yv)
                yy_pred.append(y_pred)

                offset += args.batch_size

            train_loss = np.mean(np.array(train_loss))
            valid_loss = np.mean(np.array(valid_loss))
            yy = np.concatenate(yy, 0)
            yy_pred = np.concatenate(yy_pred, 0)

            t1 = time.time()
            logging.info("Modeling validation data took {}s".format(t1 - t0))
            logging.info(len(yy_pred))
            logging.info(len(yy))
            logging.info(len(w_valid))
            logdict = dict(
                epoch=epoch,
                iteration=iteration,
                yy=yy,
                yy_pred=yy_pred,
                w_valid=w_valid[:len(yy_pred)],
                #w_valid=w_valid,
                train_loss=train_loss,
                valid_loss=valid_loss,
                settings=settings,
                model=model)
            eh.log(**logdict)

            scheduler.step(valid_loss)
            model.train()

    ''' TRAINING '''
    '''----------------------------------------------------------------------- '''
    eh.save(model, settings)
    logging.warning("Training...")
    iteration = 1
    loss_rnn = []
    loss_adv = []
    logging.info("Lambda selected = %.8f" % args.lmbda)
    for i in range(args.epochs):
        logging.info("epoch = %d" % i)
        logging.info("step_size = %.8f" % settings['step_size'])
        t0 = time.time()
        for _ in range(n_batches):
            iteration += 1
            model.train()
            adversarial_model.train()
            optimizer.zero_grad()
            optimizer_adv.zero_grad()
            start = torch.round(
                torch.rand(1) *
                (len(X_train) - args.batch_size)).numpy()[0].astype(np.int32)
            idx = slice(start, start + args.batch_size)
            X, y, Z = X_train[idx], y_train[idx], Z_train[idx]
            X_var = wrap_X(X)
            y_var = wrap(y)
            #Z_var = wrap(Z, 'long')
            y_pred = model(X_var)
            #l = loss(y_pred, y_var) - loss(adversarial_model(y_pred), Z_var)
            #print(adversarial_model(y_pred))
            Z_var = Variable(torch.squeeze(torch.from_numpy(Z)))
            #print(Z_var)
            l_rnn = loss(y_pred, y_var)
            loss_rnn.append(l_rnn.data.cpu().numpy()[0])
            l_adv = F.nll_loss(adversarial_model(y_pred), Z_var)
            loss_adv.append(l_adv.data.cpu().numpy()[0])
            l = l_rnn - (args.lmbda * l_adv)
            #Taking step on classifier
            optimizer.step()
            l.backward(retain_graph=True)

            #Taking step on advesarial
            optimizer_adv.step()
            l_adv.backward()

            X = unwrap_X(X_var)
            y = unwrap(y_var)
            callback(i, iteration, model)
        t1 = time.time()
        logging.info("Epoch took {} seconds".format(t1 - t0))

        scheduler.step()
        scheduler_adv.step()
        settings['step_size'] = args.step_size * (args.decay)**(i + 1)
    #logging.info(loss_rnn)
    #logging.info('==================================================')
    #logging.info(loss_adv)
    logging.info('PID : %d' % os.getpid())
    pathset = os.path.join(args.data_dir, str(os.getpid()))
    os.mkdir(pathset)
    np.save(os.path.join(pathset, 'rnn_loss.csv'), np.array(loss_rnn))
    np.save(os.path.join(pathset, 'adv_loss.csv'), np.array(loss_adv))
    eh.finished()