Esempio n. 1
0
def evaluate_all(log_dir):
    if isinstance(log_dir, str):
        log_dir = get_log_dir(log_dir, creation=False)

    good_models = find_good_models(log_dir)
    for i, path in enumerate(good_models):
        step = parse_model_path(path)["step"]
        evaluate(path, step, log_dir)
Esempio n. 2
0
    args = parser.parse_args()

    if args.train_data == "dijet":
        train_data = "../data/FastSim/dijet/training_dijet_466554_prep.root"
        val_dijet_data = "../data/FastSim/dijet/test_dijet_after_dijet_186880_prep.root"
        val_zjet_data = "../data/FastSim/dijet/test_zjet_after_dijet_176773_prep.root"
    elif args.train_data == "zjet":
        train_data = "../data/FastSim/zjet/training_zjet_440168_prep.root"
        val_dijet_data = "../data/FastSim/zjet/test_dijet_after_zjet_186880_prep.root"
        val_zjet_data = "../data/FastSim/zjet/test_zjet_after_zjet_176773_prep.root"
    else:
        raise ValueError("")

    log_dir = get_log_dir(
        path=args.log_dir.format(name=args.channel),
        creation=True)

    logger = Logger(dpath=log_dir.path, "WRITE")
    logger.get_args(args)
    logger["train_data"] = train_data
    logger["val_zjet_data"] = val_dijet_data
    logger["val_dijet_data"] = val_dijet_data

    # data loader
    train_loader = DataLodaer(
        path=train_data,
        batch_size=args.train_batch_size,
        cyclic=False)

    steps_per_epoch = np.ceil( len(train_loader) / train_loader.batch_size ).astype(int)
Esempio n. 3
0
def train():
    parser = argparse.ArgumentParser()

    parser.add_argument("--train_sample", default="dijet", type=str)
    parser.add_argument("--datasets_dir",
                        default="/data/slowmoyang/QGJets/root_100_200/3-JetImage/",
                        type=str)
    parser.add_argument("--model", default="ak4_without_residual", type=str)


    parser.add_argument("--log_dir", default="./logs/{name}", type=str)
    parser.add_argument("--num_gpus", default=len(get_available_gpus()), type=int)
    parser.add_argument("--multi-gpu", default=False, action='store_true', dest='multi_gpu')

    # Hyperparameters
    parser.add_argument("--num_epochs", default=30, type=int)
    parser.add_argument("--batch_size", default=512, type=int)
    parser.add_argument("--val_batch_size", default=1024, type=int)
    parser.add_argument("--lr", default=0.001, type=float)

    # Frequencies
    parser.add_argument("--val_freq", type=int, default=32)
    parser.add_argument("--save_freq", type=int, default=32)

    # Project parameters
    parser.add_argument("--kernel_size", type=int, default=3)

    args = parser.parse_args()

    #########################################################
    # Log directory
    #######################################################
    if '{name}' in args.log_dir:
        args.log_dir = args.log_dir.format(
            name="Untitled_{}".format(
                datetime.today().strftime("%Y-%m-%d_%H-%M-%S")))
    log_dir = get_log_dir(path=args.log_dir, creation=True)
 
    # Config
    config = Config(dpath=log_dir.path, mode="WRITE")
    config.update(args)

    dataset_paths = get_dataset_paths(config.datasets_dir, config.train_sample)
    config.update(dataset_paths)

    ########################################
    # Load training and validation datasets
    ########################################


    train_loader = AK4Loader(
        path=config.training_set,
        batch_size=config.batch_size,
        cyclic=False)
    config["maxlen"] = train_loader.maxlen

    steps_per_epoch = int(len(train_loader) / train_loader.batch_size)
    total_step = config.num_epochs * steps_per_epoch

    val_dijet_loader = AK4Loader(
        path=config.dijet_validation_set,
        maxlen=config.maxlen,
        batch_size=config.val_batch_size,
        cyclic=True)

    val_zjet_loader = AK4Loader(
        path=config.zjet_validation_set,
        maxlen=config.maxlen,
        batch_size=config.val_batch_size, 
        cyclic=True)


    #################################
    # Build & Compile a model.
    #################################
    config["model_type"] = "sequential"


    input0_shape, input1_shape = train_loader.get_shape()
    _model = build_a_model(
        model_type=config.model_type,
        model_name=config.model,
        input0_shape=input0_shape,
        input1_shape=input1_shape)

    if config.multi_gpu:
        model = multi_gpu_model(_model, gpus=config.num_gpus)
    else:
        model = _model

    # TODO config should have these information.
    loss = 'categorical_crossentropy'
    optimizer = optimizers.Adam(lr=config.lr)
    metric_list = ['accuracy', roc_auc_score]

    model.compile(
        loss=loss,
        optimizer=optimizer,
        metrics=metric_list)

    lr_scheduler = train_utils.ReduceLROnPlateau(model)

    #######################################
    # 
    ###########################################

    meter = Meter(
        name_list=["step", "lr",
                   "train_loss", "dijet_loss", "zjet_loss",
                   "train_acc", "dijet_acc", "zjet_acc",
                   "train_auc", "dijet_auc", "zjet_auc"],
        dpath=log_dir.validation.path)

    #######################################
    # Training with validation
    #######################################
    step = 0
    for epoch in range(config.num_epochs):
        print("Epoch [{epoch}/{num_epochs}]".format(epoch=(epoch+1), num_epochs=config.num_epochs))

        for train_batch in train_loader:
            # Validation
            if step % config.val_freq == 0 or step % config.save_freq == 0:
                val_dj_batch = val_dijet_loader.next()
                val_zj_batch = val_zjet_loader.next()

                train_loss, train_acc, train_auc = model.test_on_batch(
                    x=[train_batch["x_daus"], train_batch["x_glob"]],
                    y=train_batch["y"])
                dijet_loss, dijet_acc, dijet_auc = model.test_on_batch(
                    x=[val_dj_batch["x_daus"], val_dj_batch["x_glob"]],
                    y=val_dj_batch["y"])
                zjet_loss, zjet_acc, zjet_auc = model.test_on_batch(
                    x=[val_zj_batch["x_daus"], val_zj_batch["x_glob"]],
                    y=val_zj_batch["y"])

                lr_scheduler.monitor(metrics=dijet_loss)

                print("Step [{step}/{total_step}]".format(step=step, total_step=total_step))
                print("  Training:\n\tLoss {:.3f} | Acc. {:.3f} | AUC {:.3f}".format(train_loss, train_acc, train_auc))
                print("  Validation on Dijet\n\tLoss {:.3f} | Acc. {:.3f} | AUC {:.3f}".format(dijet_loss, dijet_acc, dijet_auc))
                print("  Validation on Z+jet\n\tLoss {:.3f} | Acc. {:.3f} | AUC {:.3f}".format(zjet_loss,zjet_acc, zjet_auc))

                meter.append({
                    "step": step, "lr": K.get_value(model.optimizer.lr),
                    "train_loss": train_loss, "dijet_loss": dijet_loss, "zjet_loss": zjet_loss,
                    "train_acc": train_acc, "dijet_acc": dijet_acc, "zjet_acc": zjet_acc,
                    "train_auc": train_auc, "dijet_auc": dijet_auc, "zjet_auc": zjet_auc})

            # Save model
            if (step != 0) and (step % config.save_freq == 0):
                filepath = os.path.join(
                    log_dir.saved_models.path,
                    "model_step-{step:06d}_loss-{loss:.3f}_acc-{acc:.3f}_auc-{auc:.3f}.h5".format(
                        step=step, loss=dijet_loss, acc=dijet_acc, auc=dijet_auc))
                _model.save(filepath)

            # Train on batch
            model.train_on_batch(
                x=[train_batch["x_daus"], train_batch["x_glob"]],
                 y=train_batch["y"])
            step += 1

        ###############################
        # On Epoch End
        ###########################
        lr_scheduler.step(epoch=epoch)

    #############################
    #
    #############################3
    filepath = os.path.join(log_dir.saved_models.path, "model_final.h5")
    _model.save(filepath)

    print("Training is over! :D")

    meter.add_plot(
        x="step",
        ys=[("train_loss", "Train/Dijet"),
            ("dijet_loss", "Validation/Dijet"),
            ("zjet_loss", "Validation/Z+jet")],
        title="Loss(CrossEntropy)", xlabel="Step", ylabel="Loss")

    meter.add_plot(
        x="step",
        ys=[("train_acc", "Train/Dijet"),
            ("dijet_acc", "Validation/Dijet"),
            ("zjet_acc", "Validation/Z+jet")],
        title="Accuracy", xlabel="Step", ylabel="Acc.")

    meter.add_plot(
        x="step",
        ys=[("train_auc", "Train/Dijet"),
            ("dijet_auc", "Validation/Dijet"),
            ("zjet_auc", "Validation/Z+jet")],
        title="AUC", xlabel="Step", ylabel="AUC")



    meter.finish()
    config.finish()
    
    return log_dir
Esempio n. 4
0
    parser.add_argument("--lr", type=float, default=0.001)
    parser.add_argument("--maxlen", type=int, default=50)

    # Freq
    parser.add_argument("--val_freq", type=int, default=100)
    parser.add_argument("--save_freq", type=int, default=500)

    args = parser.parse_args()

    train_data = args.directory + "/dijet_train.root"
    val_dijet_data = args.directory + "/dijet_test.root"
    val_zjet_data = args.directory + "/z_jet_test.root"

    if '{name}' in args.log_dir:
        args.log_dir = args.log_dir.format(name=args.model)
    log_dir = get_log_dir(path=args.log_dir, creation=True)

    logger = Logger(log_dir.path, "WRITE")
    logger.get_args(args)
    logger["train_data"] = train_data
    logger["val_dijet_data"] = val_dijet_data
    logger["val_zjet_data"] = val_zjet_data

    loss = 'binary_crossentropy'
    optimizer = optimizers.Adam(lr=args.lr)
    metric_list = ['accuracy']

    # data loader
    train_loader = DataLoader(path=train_data,
                              maxlen=args.maxlen,
                              batch_size=args.train_batch_size,
Esempio n. 5
0
def train():
    parser = argparse.ArgumentParser()

    parser.add_argument("--datasets_dir",
                        default="/store/slowmoyang/QGJets/data/root_100_200/2-Refined/",
                        type=str)

    parser.add_argument("--log_dir", default="./logs/{name}", type=str)
    parser.add_argument("--num_gpus", default=len(get_available_gpus()), type=int)
    parser.add_argument("--multi-gpu", default=False, action='store_true', dest='multi_gpu')

    # Hyperparameters
    parser.add_argument("--num_epochs", default=50, type=int)
    parser.add_argument("--batch_size", default=128, type=int)
    parser.add_argument("--valid_batch_size", default=1024, type=int)
    parser.add_argument("--lr", default=0.001, type=float)

    # Frequencies
    parser.add_argument("--valid_freq", type=int, default=32)
    parser.add_argument("--save_freq", type=int, default=32)
    parser.add_argument("-v", "--verbose", action="store_true")

    # Project parameters

    args = parser.parse_args()

    #########################################################
    # Log directory
    #######################################################
    if '{name}' in args.log_dir:
        args.log_dir = args.log_dir.format(
            name="Untitled_{}".format(
                datetime.today().strftime("%Y-%m-%d_%H-%M-%S")))
    log_dir = get_log_dir(path=args.log_dir, creation=True)
 
    # Config
    config = Config(dpath=log_dir.path, mode="WRITE")
    config.update(args)

    dataset_paths = get_dataset_paths(config.datasets_dir)
    for key, value in dataset_paths.iteritems():
        print("{}: {}".format(key, value))
    config.update(dataset_paths)

    ########################################
    # Load training and validation datasets
    ########################################
    config["seq_maxlen"] = {"x_kinematics": 40, "x_pid": 40}

    train_iter = get_data_iter(
        path=config.dijet_training_set,
        seq_maxlen=config.seq_maxlen,
        batch_size=config.batch_size)

    valid_dijet_iter = get_data_iter(
        path=config.dijet_validation_set,
        seq_maxlen=config.seq_maxlen,
        batch_size=config.valid_batch_size,
        cyclic=True)

    valid_zjet_iter = get_data_iter(
        path=config.zjet_validation_set,
        seq_maxlen=config.seq_maxlen,
        batch_size=config.valid_batch_size,
        cyclic=True)

    steps_per_epoch = len(train_iter)
    total_step = config.num_epochs * steps_per_epoch
    if config.verbose:
        print("# of steps per one epoch: {}".format(steps_per_epoch))
        print("Total step: {}".format(total_step))


    #################################
    # Build & Compile a model.
    #################################
    x_kinematics_shape = train_iter.get_shape("x_kinematics", batch_shape=False)
    x_pid_shape = train_iter.get_shape("x_pid", batch_shape=False)

    _model = build_a_model(x_kinematics_shape, x_pid_shape)

    if config.multi_gpu:
        model = multi_gpu_model(_model, gpus=config.num_gpus)
    else:
        model = _model

    # TODO config should have these information.
    loss = 'binary_crossentropy'
    optimizer = optimizers.Adam(lr=config.lr)
    metric_list = ['accuracy']

    model.compile(
        loss=loss,
        optimizer=optimizer,
        metrics=metric_list)

    if config.verbose:
        model.summary()

    #######################################
    # 
    ###########################################

    meter = Meter(
        name_list=["step", "lr",
                   "train_loss", "dijet_loss", "zjet_loss",
                   "train_acc", "dijet_acc", "zjet_acc"],
        dpath=log_dir.validation.path)

    #######################################
    # Training with validation
    #######################################
    start_message = "TRAINING START"
    print("$" * (len(start_message) + 4))
    print("$ {} $".format(start_message))
    print("$" * (len(start_message) + 4))

    step = 0
    for epoch in range(config.num_epochs):
        print("Epoch [{epoch}/{num_epochs}]".format(epoch=(epoch+1), num_epochs=config.num_epochs))

        for train_batch in train_iter:
            #########################################
            # Validation
            ################################################
            if step % config.valid_freq == 0 or step % config.save_freq == 0:
                valid_dj_batch = valid_dijet_iter.next()
                valid_zj_batch = valid_zjet_iter.next()

                train_loss, train_acc = model.test_on_batch(
                    x=[train_batch.x_kinematics, train_batch.x_pid],
                    y=train_batch.y)

                dijet_loss, dijet_acc = model.test_on_batch(
                    x=[valid_dj_batch.x_kinematics, valid_dj_batch.x_pid],
                    y=valid_dj_batch.y)

                zjet_loss, zjet_acc = model.test_on_batch(
                    x=[valid_zj_batch.x_kinematics, valid_zj_batch.x_pid],
                    y=valid_zj_batch.y)

                print("Step [{step}/{total_step}]".format(step=step, total_step=total_step))
                print("  Training:\n\tLoss {:.3f} | Acc. {:.3f}".format(train_loss, train_acc))
                print("  Validation on Dijet\n\tLoss {:.3f} | Acc. {:.3f}".format(dijet_loss, dijet_acc))
                print("  Validation on Z+jet\n\tLoss {:.3f} | Acc. {:.3f}".format(zjet_loss,zjet_acc))
                # print("  LR: {:.5f}".format(K.get_value(model.optimizer.lr)))

                meter.append({
                    "step": step, "lr": K.get_value(model.optimizer.lr),
                    "train_loss": train_loss, "dijet_loss": dijet_loss, "zjet_loss": zjet_loss,
                    "train_acc": train_acc, "dijet_acc": dijet_acc, "zjet_acc": zjet_acc})

            # Save model
            if (step != 0) and (step % config.save_freq == 0):
                filepath = os.path.join(
                    log_dir.saved_models.path,
                    "model_step-{step:06d}_loss-{loss:.3f}_acc-{acc:.3f}.h5".format(
                        step=step, loss=dijet_loss, acc=dijet_acc))
                _model.save(filepath)

            # Train on batch
            step += 1
            model.train_on_batch(
                x=[train_batch.x_kinematics, train_batch.x_pid],
                y=train_batch.y)
            # new_lr = np.power(step, -0.5)
            #K.set_value(_model.optimizer.lr, new_lr) 


        ###############################
        # On Epoch End
        ###########################

    #############################
    #
    #############################3
    filepath = os.path.join(log_dir.saved_models.path, "model_final.h5")
    _model.save(filepath)

    print("Training is over! :D")

    meter.add_plot(
        x="step",
        ys=[("train_loss", "Train/Dijet"),
            ("dijet_loss", "Validation/Dijet"),
            ("zjet_loss", "Validation/Z+jet")],
        title="Loss(CrossEntropy)", xlabel="Step", ylabel="Loss")

    meter.add_plot(
        x="step",
        ys=[("train_acc", "Train/Dijet"),
            ("dijet_acc", "Validation/Dijet"),
            ("zjet_acc", "Validation/Z+jet")],
        title="Accuracy", xlabel="Step", ylabel="Acc.")


    meter.finish()
    config.finish()
    
    return log_dir
Esempio n. 6
0
    for batch in test_zjet_loader:
        y_pred = model.predict_on_batch([batch["x_daus"], batch["x_glob"]])
        roc_zjet.append(y_true=batch["y"], y_pred=y_pred)
        out_hist.fill("test_zjet", y_true=batch["y"], y_pred=y_pred)

    roc_zjet.finish()

    out_hist.finish()


def evaluate_all(log_dir):
    if isinstance(log_dir, str):
        log_dir = get_log_dir(log_dir, creation=False)

    good_models = find_good_models(log_dir)
    for i, path in enumerate(good_models):
        step = parse_model_path(path)["step"]
        evaluate(path, step, log_dir)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--log_dir',
                        type=str,
                        required=True,
                        help='the directory path of dataset')
    args = parser.parse_args()
    log_dir = get_log_dir(path=args.log_dir, creation=False)

    evaluate_all(log_dir)
Esempio n. 7
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--train_data", type=str, default="dijet")
    parser.add_argument(
        '--log_dir',
        type=str,
        default='./logs/{name}-{date}'.format(
            name="name", date=datetime.today().strftime("%Y-%m-%d_%H-%M-%S")))

    parser.add_argument("--num_epochs", type=int, default=10)
    parser.add_argument("--num_gpus",
                        type=int,
                        default=len(get_available_gpus()))
    parser.add_argument("--train_batch_size", type=int, default=500)
    parser.add_argument("--val_batch_size", type=int, default=500)

    # Hyperparameter
    parser.add_argument("--lr", type=float, default=0.001)

    # Freq
    parser.add_argument("--val_freq", type=int, default=100)
    parser.add_argument("--save_freq", type=int, default=500)

    ##
    parser.add_argument("--channel", type=str, default="cpt")

    args = parser.parse_args()

    data_dir = "../../SJ-JetImage/images33x33/"
    if args.train_data == "dijet":
        train_data = os.path.join(data_dir, "")
    elif args.train_data == "zjet":
        train_data = os.path.join(data_dir, "")
    else:
        raise ValueError("")

    log_dir = get_log_dir(path=args.log_dir.format(name=args.channel),
                          creation=True)

    logger = Logger(dpath=log_dir.path, mode="WRITE")
    logger.get_args(args)
    logger["train_data"] = train_data
    logger["val_zjet_data"] = val_dijet_data
    logger["val_dijet_data"] = val_dijet_data

    # data loader
    # data loader for training data
    train_loader = DataLodaer(path=train_data,
                              batch_size=args.train_batch_size,
                              cyclic=False)

    steps_per_epoch = np.ceil(len(train_loader) /
                              train_loader.batch_size).astype(int)
    total_step = args.num_epochs * steps_per_epoch

    # data loaders for dijet/z+jet validation data
    val_dijet_loader = DataLodaer(path=val_dijet_data,
                                  batch_size=args.val_batch_size,
                                  cyclic=True)

    val_zjet_loader = DataLodaer(path=val_zjet_data,
                                 batch_size=args.val_batch_size,
                                 cyclic=True)

    # build a model
    _model = build_a_model(input_shape)
    model = multi_gpu_model(_model, gpus=args.num_gpus)

    # Define
    loss = 'binary_crossentropy'
    optimizer = optimizers.Adam(lr=args.lr)
    metrics = ['accuracy']

    model.compile(loss=loss, optimizer=optimizer, metrics=metrics)

    # Meter
    tr_acc_ = "train_{}_acc".format(args.train_data)
    tr_loss_ = "train_{}_loss".format(args.train_data)

    meter = Meter(data_name_list=[
        "step", tr_acc_, "val_acc_dijet", "val_acc_zjet", tr_loss_,
        "val_loss_dijet", "val_loss_zjet"
    ],
                  dpath=log_dir.validation.path)

    # Training with validation
    step = 0
    for epoch in range(args.num_epochs):

        print("Epoch [{epoch}/{num_epochs}]".format(
            epoch=(epoch + 1), num_epochs=args.num_epochs))

        for x_train, y_train in train_loader:

            # Validate model
            if step % args.val_freq == 0:
                x_dijet, y_dijet = val_dijet_loader.next()
                x_zjet, y_zjet = val_zjet_loader.next()

                loss_train, train_acc = model.test_on_batch(x=x_train,
                                                            y=y_train)
                loss_dijet, acc_dijet = model.test_on_batch(x=x_dijet,
                                                            y=y_dijet)
                loss_zjet, acc_zjet = model.test_on_batch(x=x_zjet, y=y_zjet)

                print("Step [{step}/{total_step}]".format(
                    step=step, total_step=total_step))

                print("  Training:")
                print(
                    "    Loss {loss_train:.3f} | Acc. {train_acc:.3f}".format(
                        loss_train=loss_train, train_acc=train_acc))

                print("  Validation on Dijet")
                print("    Loss {val_loss:.3f} | Acc. {val_acc:.3f}".format(
                    val_loss=loss_dijet, val_acc=acc_dijet))

                print("  Validation on Z+jet")
                print("    Loss {val_loss:.3f} | Acc. {val_acc:.3f}".format(
                    val_loss=loss_zjet, val_acc=acc_zjet))

                meter.append(
                    data_dict={
                        "step": step,
                        tr_loss_: loss_train,
                        "val_loss_dijet": loss_dijet,
                        "val_loss_zjet": loss_zjet,
                        tr_acc_: train_acc,
                        "val_acc_dijet": acc_dijet,
                        "val_acc_zjet": acc_zjet
                    })

            # Save model
            if (step != 0) and (step % args.save_freq == 0):
                filepath = os.path.join(
                    log_dir.saved_models.path,
                    "{name}_{step}.h5".format(name="model", step=step))
                _model.save(filepath)

            # Train on batch
            model.train_on_batch(x=x_train, y=y_train)
            step += 1

    print("Training is over! :D")

    meter.prepare(x="step",
                  ys=[(tr_acc_, "Training Acc on {}"),
                      ("val_acc_dijet", "Validation on Dijet"),
                      ("val_acc_zjet", "Validation on Z+jet")],
                  title="Accuracy",
                  xaxis="Step",
                  yaxis="Accuracy")

    meter.prepare(x="step",
                  ys=[(tr_loss_name, "Training Loss os {}"),
                      ("val_acc_dijet", "Validation Loss on Dijet"),
                      ("val_acc_zjet", "Validation Loss on Z+jet")],
                  title="Loss",
                  xaxis="Step",
                  yaxis="Accuracy")

    meter.finish()
    logger.finish()
Esempio n. 8
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--train_data", type=str, default="dijet")
    parser.add_argument("--directory", type=str, default="../../SJ-JetImage/image33x33/")
    parser.add_argument("--model", type=str, default="resnet")
    parser.add_argument('--log_dir', type=str,
	                default='./logs/{name}-{date}'.format(
                            name="{name}",
                            date=datetime.today().strftime("%Y-%m-%d_%H-%M-%S")))

    parser.add_argument("--num_epochs", type=int, default=10)
    parser.add_argument("--num_gpus", type=int, default=len(get_available_gpus()))
    parser.add_argument("--train_batch_size", type=int, default=500)
    parser.add_argument("--val_batch_size", type=int, default=500)
    parser.add_argument("--multi-gpu", default=False, action='store_true', dest='multi_gpu')

    # Hyperparameter
    parser.add_argument("--lr", type=float, default=0.001)

    # Freq
    parser.add_argument("--val_freq", type=int, default=100)
    parser.add_argument("--save_freq", type=int, default=500)

    args = parser.parse_args()

    if args.train_data == "dijet":
        train_data = args.directory+"/dijet_train.root"
        val_dijet_data = args.directory+"/dijet_test_after_dijet.root"
        val_zjet_data = args.directory+"/z_jet_test_after_dijet.root"
    elif args.train_data == "zjet":
        train_data = args.directory+"/z_jet_train.root"
        val_dijet_data = args.directory+"/dijet_test_after_zjet.root"
        val_zjet_data = args.directory+"/z_jet_test_after_zjet.root"
    else:
        raise ValueError("")

    if '{name}' in args.log_dir: args.log_dir = args.log_dir.format(name=args.model)
    log_dir = get_log_dir(path=args.log_dir, creation=True)

    logger = Logger(dpath=log_dir.path, mode="WRITE")
    logger.get_args(args)
    logger["train_data"] = train_data
    logger["val_zjet_data"] = val_zjet_data
    logger["val_dijet_data"] = val_dijet_data

    # data loader
    train_loader = DataLoader(
        path=train_data,
        batch_size=args.train_batch_size,
        cyclic=False)

    steps_per_epoch = np.ceil( len(train_loader) / train_loader.batch_size ).astype(int)
    total_step = args.num_epochs * steps_per_epoch

    val_dijet_loader = DataLoader(
        path=val_dijet_data,
        batch_size=args.val_batch_size,
        cyclic=True)

    val_zjet_loader = DataLoader(
        path=val_zjet_data,
        batch_size=args.val_batch_size,
        cyclic=True)


    loss = 'binary_crossentropy'
    optimizer = optimizers.Adam(lr=args.lr)
    metric_list = ['accuracy']

    # build a model and compile it
    _model = build_a_model(model_name=args.model, input_shape=train_loader._image_shape)

    if args.multi_gpu:
        model = multi_gpu_model(_model, gpus=args.num_gpus)
    else:
        model = _model


    model.compile(
        loss=loss,
        optimizer=optimizer,
        metrics=metric_list
    )


    # Meter
    tr_acc_ = "train_{}_acc".format(args.train_data)
    tr_loss_ = "train_{}_loss".format(args.train_data)

    meter = Meter(
        data_name_list=[
            "step",
            tr_acc_, "val_dijet_acc", "val_zjet_acc",
            tr_loss_, "val_dijet_loss", "val_zjet_loss"],
        dpath=log_dir.validation.path)
    
    meter.prepare(
        data_pair_list=[("step", tr_acc_),
                        ("step", "val_dijet_acc"),
                        ("step", "val_zjet_acc")],
        title="Accuracy")

    meter.prepare(
        data_pair_list=[("step", tr_loss_),
                        ("step", "val_dijet_loss"),
                        ("step", "val_zjet_loss")],
        title="Loss(Cross-entropy)")


    # Training with validation
    step = 0
    for epoch in range(args.num_epochs):

        print("Epoch [{epoch}/{num_epochs}]".format(
            epoch=(epoch+1), num_epochs=args.num_epochs))


        for x_train, y_train in train_loader:

            # Validate model
            if step % args.val_freq == 0:
                x_dijet, y_dijet = val_dijet_loader.next()
                x_zjet, y_zjet = val_zjet_loader.next()

                train_loss, train_acc = model.test_on_batch(x=x_train, y=y_train)
                dijet_loss, dijet_acc = model.test_on_batch(x=x_dijet, y=y_dijet)
                zjet_loss, zjet_acc = model.test_on_batch(x=x_zjet, y=y_zjet)

                print("Step [{step}/{total_step}]".format(
                    step=step, total_step=total_step))

                print("  Training:")
                print("    Loss {train_loss:.3f} | Acc. {train_acc:.3f}".format(
                    train_loss=train_loss, train_acc=train_acc))

                print("  Validation on Dijet")
                print("    Loss {val_loss:.3f} | Acc. {val_acc:.3f}".format(
                    val_loss=dijet_loss, val_acc=dijet_acc))

                print("  Validation on Z+jet")
                print("    Loss {val_loss:.3f} | Acc. {val_acc:.3f}".format(
                    val_loss=zjet_loss, val_acc=zjet_acc))

                meter.append(data_dict={
                    "step": step,
                    tr_loss_: train_loss,
                    "val_dijet_loss": dijet_loss,
                    "val_zjet_loss": zjet_loss,
                    tr_acc_: train_acc,
                    "val_dijet_acc": dijet_acc,
                    "val_zjet_acc": zjet_acc})

            if (step!=0) and (step % args.save_freq == 0):
                filepath = os.path.join(
                    log_dir.saved_models.path,
                    "{name}_{step}.h5".format(name="model", step=step))
                _model.save(filepath)


            # Train on batch
            model.train_on_batch(x=x_train, y=y_train)
            step += 1

    filepath = os.path.join(log_dir.saved_models.path,
                            "model_final.h5")
    _model.save(filepath)
    print("Training is over! :D")
    meter.finish()
    logger.finish()