コード例 #1
0
    def callback(epoch, iteration, model):

        if iteration % n_batches == 0:
            t0 = time.time()
            model.eval()

            offset = 0
            train_loss = []
            valid_loss = []
            yy, yy_pred = [], []
            for i in range(len(X_valid) // args.batch_size):
                idx = slice(offset, offset + args.batch_size)
                Xt, yt = X_train[idx], y_train[idx]
                X_var = wrap_X(Xt)
                y_var = wrap(yt)
                y_pred_1 = model(X_var)
                tl = unwrap(loss(y_pred_1, y_var))
                train_loss.append(tl)
                X = unwrap_X(X_var)
                y = unwrap(y_var)

                Xv, yv = X_valid[idx], y_valid[idx]
                X_var = wrap_X(Xv)
                y_var = wrap(yv)
                y_pred = model(X_var)
                vl = unwrap(loss(y_pred, y_var))
                valid_loss.append(vl)
                Xv = unwrap_X(X_var)
                yv = unwrap(y_var)
                y_pred = unwrap(y_pred)
                yy.append(yv)
                yy_pred.append(y_pred)

                offset += args.batch_size

            train_loss = np.mean(np.array(train_loss))
            valid_loss = np.mean(np.array(valid_loss))
            yy = np.concatenate(yy, 0)
            yy_pred = np.concatenate(yy_pred, 0)

            t1 = time.time()
            logging.info("Modeling validation data took {}s".format(t1 - t0))
            logging.info(len(yy_pred))
            logging.info(len(yy))
            logging.info(len(w_valid))
            logdict = dict(
                epoch=epoch,
                iteration=iteration,
                yy=yy,
                yy_pred=yy_pred,
                w_valid=w_valid[:len(yy_pred)],
                #w_valid=w_valid,
                train_loss=train_loss,
                valid_loss=valid_loss,
                settings=settings,
                model=model)
            eh.log(**logdict)

            scheduler.step(valid_loss)
            model.train()
コード例 #2
0
    def evaluate_models(X, yy, w, model_filenames, batch_size=64):
        rocs = []
        fprs = []
        tprs = []
        inv_fprs = []

        for i, filename in enumerate(model_filenames):
            if 'DS_Store' not in filename:
                logging.info("\t[{}] Loading {}".format(i, filename)),
                model = load_model(filename)
                if torch.cuda.is_available():
                    model.cuda()
                model_test_file = os.path.join(filename, 'test-rocs.pickle')
                work = not os.path.exists(model_test_file)
                if work:
                    model.eval()

                    offset = 0
                    yy_pred = []
                    n_batches, remainder = np.divmod(len(X), batch_size)
                    for i in range(n_batches):
                        X_batch = X[offset:offset + batch_size]
                        X_var = wrap_X(X_batch)
                        yy_pred.append(unwrap(model(X_var)))
                        unwrap_X(X_var)
                        offset += batch_size
                    if remainder > 0:
                        X_batch = X[-remainder:]
                        X_var = wrap_X(X_batch)
                        yy_pred.append(unwrap(model(X_var)))
                        unwrap_X(X_var)
                    yy_pred = np.squeeze(np.concatenate(yy_pred, 0), 1)

                    #Store Y_pred and Y_test to disc
                    np.save(args.data_dir + 'Y_pred_60.csv', yy_pred)
                    np.save(args.data_dir + 'Y_test_60.csv', yy)
                    logging.info('Files Saved')

                    logdict = dict(
                        model=filename.split('/')[-1],
                        yy=yy,
                        yy_pred=yy_pred,
                        w_valid=w[:len(yy_pred)],
                    )
                    eh.log(**logdict)
                    roc = eh.monitors['roc_auc'].value
                    fpr = eh.monitors['roc_curve'].value[0]
                    tpr = eh.monitors['roc_curve'].value[1]
                    inv_fpr = eh.monitors['inv_fpr'].value

                    with open(model_test_file, "wb") as fd:
                        pickle.dump((roc, fpr, tpr, inv_fpr), fd)
                else:
                    with open(model_test_file, "rb") as fd:
                        roc, fpr, tpr, inv_fpr = pickle.load(fd)
                    stats_dict = {'roc_auc': roc, 'inv_fpr': inv_fpr}
                    eh.stats_logger.log(stats_dict)
                rocs.append(roc)
                fprs.append(fpr)
                tprs.append(tpr)
                inv_fprs.append(inv_fpr)

        logging.info("\tMean ROC AUC = {:.4f} Mean 1/FPR = {:.4f}".format(
            np.mean(rocs), np.mean(inv_fprs)))

        return rocs, fprs, tprs, inv_fprs
コード例 #3
0
def train(args):
    _, Transform, model_type = TRANSFORMS[args.model_type]
    args.root_exp_dir = os.path.join(MODELS_DIR, model_type, str(args.iters))

    eh = ExperimentHandler(args)
    ''' DATA '''
    '''----------------------------------------------------------------------- '''
    logging.warning("Loading data...")

    tf = load_tf(args.data_dir, "{}-train.pickle".format(args.filename))
    X, y = load_data(args.data_dir, "{}-train.pickle".format(args.filename))
    for ij, jet in enumerate(X):
        jet["content"] = tf.transform(jet["content"])

    if args.n_train > 0:
        indices = torch.randperm(len(X)).numpy()[:args.n_train]
        X = [X[i] for i in indices]
        y = y[indices]

    logging.warning("Splitting into train and validation...")

    X_train, X_valid_uncropped, y_train, y_valid_uncropped = train_test_split(
        X, y, test_size=args.n_valid, random_state=0)
    logging.warning("\traw train size = %d" % len(X_train))
    logging.warning("\traw valid size = %d" % len(X_valid_uncropped))

    X_valid, y_valid, cropped_indices, w_valid = crop(
        X_valid_uncropped,
        y_valid_uncropped,
        0,
        return_cropped_indices=True,
        pileup=args.pileup)
    # add cropped indices to training data
    if not args.dont_add_cropped:
        X_train.extend([
            x for i, x in enumerate(X_valid_uncropped) if i in cropped_indices
        ])
        y_train = [y for y in y_train]
        y_train.extend([
            y for i, y in enumerate(y_valid_uncropped) if i in cropped_indices
        ])
        y_train = np.array(y_train)
    logging.warning("\tfinal train size = %d" % len(X_train))
    logging.warning("\tfinal valid size = %d" % len(X_valid))
    ''' MODEL '''
    '''----------------------------------------------------------------------- '''
    # Initialization
    logging.info("Initializing model...")
    Predict = PredictFromParticleEmbedding
    if args.load is None:
        model_kwargs = {
            'features': args.features,
            'hidden': args.hidden,
            'iters': args.iters,
            'leaves': args.leaves,
        }
        model = Predict(Transform, **model_kwargs)
        settings = {
            "transform": Transform,
            "predict": Predict,
            "model_kwargs": model_kwargs,
            "step_size": args.step_size,
            "args": args,
        }
    else:
        with open(os.path.join(args.load, 'settings.pickle'), "rb") as f:
            settings = pickle.load(f, encoding='latin-1')
            Transform = settings["transform"]
            Predict = settings["predict"]
            model_kwargs = settings["model_kwargs"]

        model = PredictFromParticleEmbedding(Transform, **model_kwargs)

        try:
            with open(os.path.join(args.load, 'cpu_model_state_dict.pt'),
                      'rb') as f:
                state_dict = torch.load(f)
        except FileNotFoundError as e:
            with open(os.path.join(args.load, 'model_state_dict.pt'),
                      'rb') as f:
                state_dict = torch.load(f)

        model.load_state_dict(state_dict)

        if args.restart:
            args.step_size = settings["step_size"]

    logging.warning(model)
    out_str = 'Number of parameters: {}'.format(
        sum(np.prod(p.data.numpy().shape) for p in model.parameters()))
    logging.warning(out_str)

    if torch.cuda.is_available():
        logging.warning("Moving model to GPU")
        model.cuda()
        logging.warning("Moved model to GPU")

    eh.signal_handler.set_model(model)
    ''' OPTIMIZER AND LOSS '''
    '''----------------------------------------------------------------------- '''
    logging.info("Building optimizer...")
    optimizer = Adam(model.parameters(), lr=args.step_size)
    scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=args.decay)
    #scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5)

    n_batches = int(len(X_train) // args.batch_size)
    best_score = [-np.inf]  # yuck, but works
    best_model_state_dict = copy.deepcopy(model.state_dict())

    def loss(y_pred, y):
        l = log_loss(y, y_pred.squeeze(1)).mean()
        return l
        ''' VALIDATION '''

    '''----------------------------------------------------------------------- '''

    def callback(epoch, iteration, model):

        if iteration % n_batches == 0:
            t0 = time.time()
            model.eval()

            offset = 0
            train_loss = []
            valid_loss = []
            yy, yy_pred = [], []
            for i in range(len(X_valid) // args.batch_size):
                idx = slice(offset, offset + args.batch_size)
                Xt, yt = X_train[idx], y_train[idx]
                X_var = wrap_X(Xt)
                y_var = wrap(yt)
                tl = unwrap(loss(model(X_var), y_var))
                train_loss.append(tl)
                X = unwrap_X(X_var)
                y = unwrap(y_var)

                Xv, yv = X_valid[idx], y_valid[idx]
                X_var = wrap_X(Xv)
                y_var = wrap(yv)
                y_pred = model(X_var)
                vl = unwrap(loss(y_pred, y_var))
                valid_loss.append(vl)
                Xv = unwrap_X(X_var)
                yv = unwrap(y_var)
                y_pred = unwrap(y_pred)
                yy.append(yv)
                yy_pred.append(y_pred)

                offset += args.batch_size

            train_loss = np.mean(np.array(train_loss))
            valid_loss = np.mean(np.array(valid_loss))
            yy = np.concatenate(yy, 0)
            yy_pred = np.concatenate(yy_pred, 0)

            t1 = time.time()
            logging.info("Modeling validation data took {}s".format(t1 - t0))
            logdict = dict(
                epoch=epoch,
                iteration=iteration,
                yy=yy,
                yy_pred=yy_pred,
                w_valid=w_valid[:len(yy_pred)],
                #w_valid=w_valid,
                train_loss=train_loss,
                valid_loss=valid_loss,
                settings=settings,
                model=model)
            eh.log(**logdict)

            scheduler.step(valid_loss)
            model.train()

    ''' TRAINING '''
    '''----------------------------------------------------------------------- '''
    eh.save(model, settings)
    logging.warning("Training...")
    iteration = 1
    for i in range(args.epochs):
        logging.info("epoch = %d" % i)
        logging.info("step_size = %.8f" % settings['step_size'])
        t0 = time.time()
        for _ in range(n_batches):
            iteration += 1
            model.train()
            optimizer.zero_grad()
            start = torch.round(
                torch.rand(1) *
                (len(X_train) - args.batch_size)).numpy()[0].astype(np.int32)
            idx = slice(start, start + args.batch_size)
            X, y = X_train[idx], y_train[idx]
            X_var = wrap_X(X)
            y_var = wrap(y)
            l = loss(model(X_var), y_var)
            l.backward()
            optimizer.step()
            X = unwrap_X(X_var)
            y = unwrap(y_var)
            callback(i, iteration, model)
        t1 = time.time()
        logging.info("Epoch took {} seconds".format(t1 - t0))

        scheduler.step()
        settings['step_size'] = args.step_size * (args.decay)**(i + 1)

    eh.finished()
コード例 #4
0
def train(args):
    _, Transform, model_type = TRANSFORMS[args.model_type]
    args.root_exp_dir = os.path.join(MODELS_DIR, model_type, str(args.iters))

    eh = ExperimentHandler(args)
    ''' DATA '''
    '''----------------------------------------------------------------------- '''
    logging.warning("Loading pileup antikt-kt-pileup40 data...")

    tf_pileup_40 = load_tf(args.data_dir,
                           "{}-train.pickle".format(args.filename))
    X_pileup_40, y_pileup_40 = load_data(
        args.data_dir, "{}-train.pickle".format(args.filename))
    for ij, jet in enumerate(X_pileup_40):
        jet["content"] = tf_pileup_40.transform(jet["content"])

    if args.n_train > 0:
        indices = torch.randperm(len(X_pileup_40)).numpy()[:args.n_train]
        X_pileup_40 = [X_pileup_40[i] for i in indices]
        y_pileup_40 = y_pileup_40[indices]

    logging.warning("Splitting into train and validation...")

    X_train_pileup_40, X_valid_uncropped_pileup_40, y_train_pileup_40, y_valid_uncropped_pileup_40 = train_test_split(
        X_pileup_40, y_pileup_40, test_size=args.n_valid, random_state=0)
    logging.warning("\traw train size = %d" % len(X_train_pileup_40))
    logging.warning("\traw valid size = %d" % len(X_valid_uncropped_pileup_40))

    X_valid_pileup_40, y_valid_pileup_40, cropped_indices_40, w_valid_40 = crop(
        X_valid_uncropped_pileup_40,
        y_valid_uncropped_pileup_40,
        pileup_lvl=40,
        return_cropped_indices=True,
        pileup=args.pileup)
    # add cropped indices to training data
    if not args.dont_add_cropped:
        X_train_pileup_40.extend([
            x for i, x in enumerate(X_valid_uncropped_pileup_40)
            if i in cropped_indices_40
        ])
        y_train_pileup_40 = [y for y in y_train_pileup_40]
        y_train_pileup_40.extend([
            y for i, y in enumerate(y_valid_uncropped_pileup_40)
            if i in cropped_indices_40
        ])
        y_train_pileup_40 = np.array(y_train_pileup_40)

    Z_train_pileup_40 = [0] * len(y_train_pileup_40)
    Z_valid_pileup_40 = [0] * len(y_valid_pileup_40)
    ''' DATA '''
    '''----------------------------------------------------------------------- '''
    logging.warning("Loading pileup antikt-kt-pileup50 data...")
    args.filename = 'antikt-kt-pileup50'
    args.pileup = False

    tf_pileup_50 = load_tf(args.data_dir,
                           "{}-train.pickle".format(args.filename))
    X_pileup_50, y_pileup_50 = load_data(
        args.data_dir, "{}-train.pickle".format(args.filename))
    for ij, jet in enumerate(X_pileup_50):
        jet["content"] = tf_pileup_50.transform(jet["content"])

    if args.n_train > 0:
        indices = torch.randperm(len(X_pileup_50)).numpy()[:args.n_train]
        X_pileup_50 = [X_pileup_50[i] for i in indices]
        y_pileup_50 = y_pileup_50[indices]

    logging.warning("Splitting into train and validation...")

    X_train_pileup_50, X_valid_uncropped_pileup_50, y_train_pileup_50, y_valid_uncropped_pileup_50 = train_test_split(
        X_pileup_50, y_pileup_50, test_size=args.n_valid, random_state=0)
    logging.warning("\traw train size = %d" % len(X_train_pileup_50))
    logging.warning("\traw valid size = %d" % len(X_valid_uncropped_pileup_50))

    X_valid_pileup_50, y_valid_pileup_50, cropped_indices_50, w_valid_50 = crop(
        X_valid_uncropped_pileup_50,
        y_valid_uncropped_pileup_50,
        pileup_lvl=50,
        return_cropped_indices=True,
        pileup=args.pileup)
    # add cropped indices to training data
    if not args.dont_add_cropped:
        X_train_pileup_50.extend([
            x for i, x in enumerate(X_valid_uncropped_pileup_50)
            if i in cropped_indices_50
        ])
        y_train_pileup_50 = [y for y in y_train_pileup_50]
        y_train_pileup_50.extend([
            y for i, y in enumerate(y_valid_uncropped_pileup_50)
            if i in cropped_indices_50
        ])
        y_train_pileup_50 = np.array(y_train_pileup_50)

    Z_train_pileup_50 = [1] * len(y_train_pileup_50)
    Z_valid_pileup_50 = [1] * len(y_valid_pileup_50)
    ''' DATA '''
    '''----------------------------------------------------------------------- '''
    logging.warning("Loading pileup antikt-kt-pileup60 data...")
    args.filename = 'antikt-kt-pileup60'
    args.pileup = False

    tf_pileup_60 = load_tf(args.data_dir,
                           "{}-train.pickle".format(args.filename))
    X_pileup_60, y_pileup_60 = load_data(
        args.data_dir, "{}-train.pickle".format(args.filename))
    for ij, jet in enumerate(X_pileup_60):
        jet["content"] = tf_pileup_60.transform(jet["content"])

    if args.n_train > 0:
        indices = torch.randperm(len(X_pileup_60)).numpy()[:args.n_train]
        X_pileup_60 = [X_pileup_60[i] for i in indices]
        y_pileup_60 = y_pileup_60[indices]

    logging.warning("Splitting into train and validation...")

    X_train_pileup_60, X_valid_uncropped_pileup_60, y_train_pileup_60, y_valid_uncropped_pileup_60 = train_test_split(
        X_pileup_60, y_pileup_60, test_size=args.n_valid, random_state=0)
    logging.warning("\traw train size = %d" % len(X_train_pileup_60))
    logging.warning("\traw valid size = %d" % len(X_valid_uncropped_pileup_60))

    X_valid_pileup_60, y_valid_pileup_60, cropped_indices_60, w_valid_60 = crop(
        X_valid_uncropped_pileup_60,
        y_valid_uncropped_pileup_60,
        pileup_lvl=60,
        return_cropped_indices=True,
        pileup=args.pileup)
    # add cropped indices to training data
    if not args.dont_add_cropped:
        X_train_pileup_60.extend([
            x for i, x in enumerate(X_valid_uncropped_pileup_60)
            if i in cropped_indices_60
        ])
        y_train_pileup_60 = [y for y in y_train_pileup_60]
        y_train_pileup_60.extend([
            y for i, y in enumerate(y_valid_uncropped_pileup_60)
            if i in cropped_indices_60
        ])
        y_train_pileup_60 = np.array(y_train_pileup_60)

    Z_train_pileup_60 = [2] * len(y_train_pileup_60)
    Z_valid_pileup_60 = [2] * len(y_valid_pileup_60)

    X_train = np.concatenate(
        (X_train_pileup_40, X_train_pileup_50, X_train_pileup_60), axis=0)
    X_valid = np.concatenate(
        (X_valid_pileup_40, X_valid_pileup_50, X_valid_pileup_60), axis=0)
    y_train = np.concatenate(
        (y_train_pileup_40, y_train_pileup_50, y_train_pileup_60), axis=0)
    y_valid = np.concatenate(
        (y_valid_pileup_40, y_valid_pileup_50, y_valid_pileup_60), axis=0)
    Z_train = np.concatenate(
        (Z_train_pileup_40, Z_train_pileup_50, Z_train_pileup_60), axis=0)
    Z_valid = np.concatenate(
        (Z_valid_pileup_40, Z_valid_pileup_50, Z_valid_pileup_60), axis=0)
    w_valid = np.concatenate((w_valid_40, w_valid_50, w_valid_60), axis=0)

    X_train, y_train, Z_train = shuffle(X_train,
                                        y_train,
                                        Z_train,
                                        random_state=0)
    X_valid, y_valid, Z_valid = shuffle(X_valid,
                                        y_valid,
                                        Z_valid,
                                        random_state=0)

    logging.warning("\tfinal X train size = %d" % len(X_train))
    logging.warning("\tfinal X valid size = %d" % len(X_valid))
    logging.warning("\tfinal Y train size = %d" % len(y_train))
    logging.warning("\tfinal Y valid size = %d" % len(y_valid))
    logging.warning("\tfinal Z train size = %d" % len(Z_train))
    logging.warning("\tfinal Z valid size = %d" % len(Z_valid))
    logging.warning("\tfinal w valid size = %d" % len(w_valid))
    ''' MODEL '''
    '''----------------------------------------------------------------------- '''
    # Initialization
    logging.info("Initializing model...")
    Predict = PredictFromParticleEmbedding
    if args.load is None:
        model_kwargs = {
            'features': args.features,
            'hidden': args.hidden,
            'iters': args.iters,
            'leaves': args.leaves,
            'batch': args.batch_size,
        }
        logging.info('No previous models')
        model = Predict(Transform, **model_kwargs)
        adversarial_model = AdversarialParticleEmbedding(**model_kwargs)
        settings = {
            "transform": Transform,
            "predict": Predict,
            "model_kwargs": model_kwargs,
            "step_size": args.step_size,
            "args": args,
        }
    else:
        with open(os.path.join(args.load, 'settings.pickle'), "rb") as f:
            settings = pickle.load(f, encoding='latin-1')
            Transform = settings["transform"]
            Predict = settings["predict"]
            model_kwargs = settings["model_kwargs"]

        model = PredictFromParticleEmbedding(Transform, **model_kwargs)

        try:
            with open(os.path.join(args.load, 'cpu_model_state_dict.pt'),
                      'rb') as f:
                state_dict = torch.load(f)
        except FileNotFoundError as e:
            with open(os.path.join(args.load, 'model_state_dict.pt'),
                      'rb') as f:
                state_dict = torch.load(f)

        model.load_state_dict(state_dict)

        if args.restart:
            args.step_size = settings["step_size"]

    logging.warning(model)
    logging.warning(adversarial_model)

    out_str = 'Number of parameters: {}'.format(
        sum(np.prod(p.data.numpy().shape) for p in model.parameters()))
    out_str_adversarial = 'Number of parameters: {}'.format(
        sum(
            np.prod(p.data.numpy().shape)
            for p in adversarial_model.parameters()))
    logging.warning(out_str)
    logging.warning(out_str_adversarial)

    if torch.cuda.is_available():
        logging.warning("Moving model to GPU")
        model.cuda()
        logging.warning("Moved model to GPU")
    else:
        logging.warning("No cuda")

    eh.signal_handler.set_model(model)
    ''' OPTIMIZER AND LOSS '''
    '''----------------------------------------------------------------------- '''

    logging.info("Building optimizer...")
    optimizer = Adam(model.parameters(), lr=args.step_size)
    optimizer_adv = Adam(adversarial_model.parameters(), lr=args.step_size)
    scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=args.decay)
    scheduler_adv = lr_scheduler.ExponentialLR(optimizer_adv, gamma=args.decay)
    #scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5)

    n_batches = int(len(X_train) // args.batch_size)
    best_score = [-np.inf]  # yuck, but works
    best_model_state_dict = copy.deepcopy(model.state_dict())

    def loss_adversarial(y_pred, y):
        return -(y * torch.log(y_pred) + (1. - y) * torch.log(1. - y_pred))

    def loss(y_pred, y):
        l = log_loss(y, y_pred.squeeze(1)).mean()
        return l

    ''' VALIDATION '''
    '''----------------------------------------------------------------------- '''

    def callback(epoch, iteration, model):

        if iteration % n_batches == 0:
            t0 = time.time()
            model.eval()

            offset = 0
            train_loss = []
            valid_loss = []
            yy, yy_pred = [], []
            for i in range(len(X_valid) // args.batch_size):
                idx = slice(offset, offset + args.batch_size)
                Xt, yt = X_train[idx], y_train[idx]
                X_var = wrap_X(Xt)
                y_var = wrap(yt)
                y_pred_1 = model(X_var)
                tl = unwrap(loss(y_pred_1, y_var))
                train_loss.append(tl)
                X = unwrap_X(X_var)
                y = unwrap(y_var)

                Xv, yv = X_valid[idx], y_valid[idx]
                X_var = wrap_X(Xv)
                y_var = wrap(yv)
                y_pred = model(X_var)
                vl = unwrap(loss(y_pred, y_var))
                valid_loss.append(vl)
                Xv = unwrap_X(X_var)
                yv = unwrap(y_var)
                y_pred = unwrap(y_pred)
                yy.append(yv)
                yy_pred.append(y_pred)

                offset += args.batch_size

            train_loss = np.mean(np.array(train_loss))
            valid_loss = np.mean(np.array(valid_loss))
            yy = np.concatenate(yy, 0)
            yy_pred = np.concatenate(yy_pred, 0)

            t1 = time.time()
            logging.info("Modeling validation data took {}s".format(t1 - t0))
            logging.info(len(yy_pred))
            logging.info(len(yy))
            logging.info(len(w_valid))
            logdict = dict(
                epoch=epoch,
                iteration=iteration,
                yy=yy,
                yy_pred=yy_pred,
                w_valid=w_valid[:len(yy_pred)],
                #w_valid=w_valid,
                train_loss=train_loss,
                valid_loss=valid_loss,
                settings=settings,
                model=model)
            eh.log(**logdict)

            scheduler.step(valid_loss)
            model.train()

    ''' TRAINING '''
    '''----------------------------------------------------------------------- '''
    eh.save(model, settings)
    logging.warning("Training...")
    iteration = 1
    loss_rnn = []
    loss_adv = []
    logging.info("Lambda selected = %.8f" % args.lmbda)
    for i in range(args.epochs):
        logging.info("epoch = %d" % i)
        logging.info("step_size = %.8f" % settings['step_size'])
        t0 = time.time()
        for _ in range(n_batches):
            iteration += 1
            model.train()
            adversarial_model.train()
            optimizer.zero_grad()
            optimizer_adv.zero_grad()
            start = torch.round(
                torch.rand(1) *
                (len(X_train) - args.batch_size)).numpy()[0].astype(np.int32)
            idx = slice(start, start + args.batch_size)
            X, y, Z = X_train[idx], y_train[idx], Z_train[idx]
            X_var = wrap_X(X)
            y_var = wrap(y)
            #Z_var = wrap(Z, 'long')
            y_pred = model(X_var)
            #l = loss(y_pred, y_var) - loss(adversarial_model(y_pred), Z_var)
            #print(adversarial_model(y_pred))
            Z_var = Variable(torch.squeeze(torch.from_numpy(Z)))
            #print(Z_var)
            l_rnn = loss(y_pred, y_var)
            loss_rnn.append(l_rnn.data.cpu().numpy()[0])
            l_adv = F.nll_loss(adversarial_model(y_pred), Z_var)
            loss_adv.append(l_adv.data.cpu().numpy()[0])
            l = l_rnn - (args.lmbda * l_adv)
            #Taking step on classifier
            optimizer.step()
            l.backward(retain_graph=True)

            #Taking step on advesarial
            optimizer_adv.step()
            l_adv.backward()

            X = unwrap_X(X_var)
            y = unwrap(y_var)
            callback(i, iteration, model)
        t1 = time.time()
        logging.info("Epoch took {} seconds".format(t1 - t0))

        scheduler.step()
        scheduler_adv.step()
        settings['step_size'] = args.step_size * (args.decay)**(i + 1)
    #logging.info(loss_rnn)
    #logging.info('==================================================')
    #logging.info(loss_adv)
    logging.info('PID : %d' % os.getpid())
    pathset = os.path.join(args.data_dir, str(os.getpid()))
    os.mkdir(pathset)
    np.save(os.path.join(pathset, 'rnn_loss.csv'), np.array(loss_rnn))
    np.save(os.path.join(pathset, 'adv_loss.csv'), np.array(loss_adv))
    eh.finished()