Esempio n. 1
0
def _test():
    from keras4hep.projects.qgjets.utils import get_dataset_paths
    path = get_dataset_paths(min_pt=100)["training"]

    dset = JetSeqSet(path,
                     extra=["pt", "eta"],
                     seq_maxlen={
                         "x_kin": 50,
                         "x_pid": 50
                     })
    data_iter = DataIterator(dset, batch_size=128)

    batch = data_iter.next()
    for key, value in batch.iteritems():
        print(key, value.shape)

    print(data_iter.get_shape("x_pid", False))
    print(data_iter.get_shape("x_kin", True))

    data_iter.fit_generator_input = {"x": ["x_kin", "x_pid"], "y": ["y"]}
    data_iter.fit_generator_mode = True
    for idx, (x, y) in enumerate(data_iter):
        if idx == 3:
            break

        print(len(x))
        for each in x:
            print("x: {}".format(each.shape))
Esempio n. 2
0
def _test():
    from keras4hep.projects.qgjets.utils import get_dataset_paths
    path = get_dataset_paths(min_pt=100)["training"]

    seq_maxlen = {
        "x": (40, "float32"),
    }

    data_iter = get_data_iter(
        path=path,
        seq_maxlen=seq_maxlen,
        batch_size=2)

    batch = data_iter.next()
    for key, value in batch.iteritems():
        print(key, value.shape)

    print(data_iter.get_shape("x", True))

    data_iter.fit_generator_input = {
        "x": ["x"],
        "y": ["y"]
    }
    data_iter.fit_generator_mode = True
    for idx, (x, y) in enumerate(data_iter):
        if idx == 1:
            break

        print(len(x))
        for each in x:
            print("x: {}".format(each))
Esempio n. 3
0
def _test():
    from dataset import get_data_iter
    from keras4hep.projects.qgjets.utils import get_dataset_paths
    path = get_dataset_paths(min_pt=100)["training"]
    data_iter = get_data_iter(path)

    batch = data_iter.next()
    x_shape = data_iter.get_shape("x", batch_shape=False)

    model = build_model(x_shape)

    logits = model.predict_on_batch(batch.x)
    print("logits: {}".format(logits.shape))
Esempio n. 4
0
def _test():
    from dataset import get_data_iter
    from keras4hep.projects.qgjets.utils import get_dataset_paths
    path = get_dataset_paths(min_pt=100)["training"]

    seq_maxlen = {
        "x": (40, "float32"),
    }
    data_iter = get_data_iter(path, seq_maxlen=seq_maxlen)

    batch = data_iter.next()
    x_shape = data_iter.get_shape("x", batch_shape=False)

    model = build_classifier(x_shape, name="RNN")
    model.summary()
Esempio n. 5
0
def _test():
    from dataset import get_data_iter
    from keras4hep.projects.qgjets.utils import get_dataset_paths
    path = get_dataset_paths(min_pt=100)["training"]

    data_iter = get_data_iter(path, seq_maxlen={"x_kin": 32, "x_pid": 32})

    batch = data_iter.next()
    x_kin_shape = data_iter.get_shape("x_kin", batch_shape=False)
    x_pid_shape = data_iter.get_shape("x_pid", batch_shape=False)

    model = build_model(x_kin_shape, x_pid_shape, name="RNN")
    model.summary()

    y_score = model.predict_on_batch([batch.x_kin, batch.x_pid, batch.x_len])
    print("y_score: {}".format(y_score.shape))
Esempio n. 6
0
def _test():
    from keras4hep.projects.qgjets.utils import get_dataset_paths
    from dataset import get_data_iter

    paths = get_dataset_paths(min_pt=100)
    data_iter = get_data_iter(paths["training"])


    batch = data_iter.next()
    for key, value in batch.iteritems():
        print(key, value.shape)

    print(data_iter._dataset[:1]["x_img"].shape[1:])

    print(data_iter.get_shape("x_img", False))
    print(data_iter.get_shape("x_img", True))
Esempio n. 7
0
def _test():
    from keras4hep.projects.qgjets.utils import get_dataset_paths
    from dataset import get_data_iter

    paths = get_dataset_paths(min_pt=100)
    data_iter = get_data_iter(paths["training"])

    batch = data_iter.next()
    x_shape = data_iter.get_shape("x", batch_shape=False)

    model = build_model(x_shape, kernel_size=7, name="VanillaConvNet")

    logits = model.predict_on_batch(batch.x)
    print("logits: {}".format(logits.shape))

    model.summary()
Esempio n. 8
0
def _test():
    from dataset import get_data_iter
    from keras4hep.projects.qgjets.utils import get_dataset_paths
    path = get_dataset_paths(min_pt=100)["training"]

    seq_maxlen = {
        "x": (40, "float32"),
    }
    data_iter = get_data_iter(path, seq_maxlen=seq_maxlen)

    batch = data_iter.next()
    x_shape = data_iter.get_shape("x", batch_shape=False)

    model = build_model(x_shape, name="RNN")
    model.summary()

    y_score = model.predict_on_batch([batch.x])
    print("y_score: {}".format(y_score.shape))
Esempio n. 9
0
    def set_data_iter(self):
        dataset_paths = get_dataset_paths(self.config.pt)
        self.config.append(dataset_paths)

        self.train_iter = self.get_data_iter(
            path=dataset_paths["training"],
            batch_size=self.config.batch_size,
            fit_generator_mode=True,
            cycle=True)

        self.valid_iter = self.get_data_iter(
            path=dataset_paths["validation"],
            batch_size=self.config.test_batch_size,
            fit_generator_mode=True,
            cycle=True)

        self.test_iter = self.get_data_iter(
            path=dataset_paths["test"],
            batch_size=self.config.test_batch_size,
            fit_generator_mode=True,
            cycle=False)
Esempio n. 10
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--name", default="rnn{}".format(datetime.now().strftime("%y%m%d%H%M%S")))
    parser.add_argument("--directory", default="./logs")

    # GPU
    parser.add_argument("--num_gpus", default=len(get_available_gpus()), type=int)
    parser.add_argument("--multi-gpu", default=False, action='store_true', dest='multi_gpu')

    # Hyperparameters
    parser.add_argument("--epoch", dest="epochs", default=200, type=int)
    parser.add_argument("--batch_size", default=128, type=int)
    parser.add_argument("--valid_batch_size", default=1024, type=int)

    # Optimizer
    parser.add_argument("--optimizer", default="Adam", type=str)
    parser.add_argument("--lr", default=0.001, type=float)
    parser.add_argument("--clipnorm", default=-1, type=float,
                        help="if it is greater than 0, then graidient clipping is activated")
    parser.add_argument("--clipvalue", default=-1, type=float)
    parser.add_argument("--use-class-weight", dest="use_class_weight",
                        default=False, action="store_true")

    # Frequencies
    parser.add_argument("--valid_freq", type=int, default=32)
    parser.add_argument("--save_freq", type=int, default=32)
    parser.add_argument("-v", "--verbose", action="store_true")

    # Project parameters
    parser.add_argument("--min-pt", dest="min_pt", default=100, type=int)

    # Model Archtecture
    parser.add_argument("--act", dest="activation", default="elu", type=str)
    parser.add_argument("--rnn", default="gru", type=str)

    args = parser.parse_args()

    ###################
    #
    ###################
    log_dir = Directory(path=os.path.join(args.directory, args.name))
    log_dir.mkdir("script")
    log_dir.mkdir("checkpoint")
    log_dir.mkdir("learning_curve")
    log_dir.mkdir("roc_curve")
    log_dir.mkdir("model_response")

    config = Config(log_dir.path, "w")
    config.append(args)
    config["hostname"] = os.environ["HOSTNAME"]
    config["log_dir"] = log_dir.path
    config.save()

    scripts = [
        "./dataset.py",
        "./model.py",
        "./train.py",
    ]
    for each in scripts:
        shutil.copy2(each, log_dir.script.path)
    open(log_dir.script.concat("__init__.py"), 'w').close()
    

    ########################################
    # Load training and validation datasets
    ########################################
    dset = get_dataset_paths(config.min_pt)
    config.append(dset)

    config["seq_maxlen"] = {
        "x_kin": 30,
        "x_pid": 30
    }

    train_iter = get_data_iter(
        path=dset["training"],
        batch_size=config.batch_size,
        seq_maxlen=config.seq_maxlen,
        fit_generator_mode=True)

    valid_iter = get_data_iter(
        path=dset["validation"],
        batch_size=config.valid_batch_size,
        seq_maxlen=config.seq_maxlen,
        fit_generator_mode=True)

    test_iter = get_data_iter(
        path=dset["test"],
        batch_size=config.valid_batch_size,
        seq_maxlen=config.seq_maxlen,
        fit_generator_mode=False)

    if config.use_class_weight: 
        class_weight = get_class_weight(train_iter)
        config["class_weight"] = list(class_weight)
    else:
        class_weight = None


    #################################
    # Build & Compile a model.
    #################################
    x_kin_shape = train_iter.get_shape("x_kin", batch_shape=False)
    x_pid_shape = train_iter.get_shape("x_pid", batch_shape=False)

    model = build_model(
        x_kin_shape,
        x_pid_shape,
        rnn=config.rnn,
        activation=config.activation,
        name=config.name)

    config["model"] = model.get_config()


    if config.multi_gpu:
        model = multi_gpu_model(model, gpus=config.num_gpus)

    if config.hostname == "cms05.sscc.uos.ac.kr":
        model_plot_path = log_dir.concat("model.png")
        plot_model(model, to_file=model_plot_path, show_shapes=True)
    else:
        model.summary()

    loss = 'categorical_crossentropy'

    # TODO capsulisation
    optimizer_kwargs = {}
    if config.clipnorm > 0:
        optimizer_kwargs["clipnorm"] = config.clipnorm
    if config.clipvalue > 0:
        optimizer_kwargs["clipvalue"] = config.clipvalue
    optimizer = getattr(optimizers, config.optimizer)(lr=config.lr, **optimizer_kwargs)

    metric_list = ["accuracy" , roc_auc]

    model.compile(
        loss=loss,
        optimizer=optimizer,
        metrics=metric_list)


    config["loss"] = loss
    config["optimizer_config"] = optimizer.get_config()

    ###########################################################################
    # Callbacks
    ###########################################################################
    ckpt_format_str = "weights_epoch-{epoch:02d}_loss-{val_loss:.4f}_acc-{val_acc:.4f}_auc-{val_roc_auc:.4f}.hdf5"
    ckpt_path = log_dir.checkpoint.concat(ckpt_format_str)
    csv_log_path = log_dir.concat("log_file.csv")

    learning_curve = LearningCurve(directory=log_dir.learning_curve.path)
    learning_curve.book(x="step", y="roc_auc", best="max")
    learning_curve.book(x="step", y="acc", best="max")
    learning_curve.book(x="step", y="loss", best="min")

    callback_list = [
        callbacks.ModelCheckpoint(filepath=ckpt_path),
        callbacks.ReduceLROnPlateau(verbose=1),
        callbacks.CSVLogger(csv_log_path),
        learning_curve,
    ]

    ############################################################################
    # Training
    ############################################################################
    model.fit_generator(
        train_iter,
        steps_per_epoch=len(train_iter),
        epochs=config.epochs,
        validation_data=valid_iter,
        validation_steps=len(valid_iter),
        callbacks=callback_list,
        shuffle=True,
        class_weight=class_weight)

    del model

    print("Training is over! :D")


    ###########################################
    # Evaluation
    ############################################
    train_iter.fit_generator_mode = False
    train_iter.cycle = False

    good_ckpt = find_good_checkpoint(
        log_dir.checkpoint.path,
        which={"max": ["auc", "acc"], "min": ["loss"]})

    all_ckpt = set(log_dir.checkpoint.get_entries())
    # no local optima
    useless_ckpt = all_ckpt.difference(good_ckpt)
    for each in useless_ckpt:
        os.remove(each)
    
    for idx, each in enumerate(good_ckpt, 1):
        print("[{}/{}] {}".format(idx, len(good_ckpt), each))

        K.clear_session()
        evaluate(checkpoint_path=each, 
                 train_iter=train_iter,
                 test_iter=test_iter,
                 log_dir=log_dir)

    config.save()
Esempio n. 11
0
def main():
    ##########################
    # Argument Parsing
    ##########################
    parser = argparse.ArgumentParser()

    parser.add_argument("--logdir",
                        dest="log_dir",
                        type=str,
                        default="./logs/untitled-{}".format(
                            datetime.now().strftime("%y%m%d-%H%M%S")))

    parser.add_argument("--num_gpus",
                        default=len(get_available_gpus()),
                        type=int)
    parser.add_argument("--multi-gpu",
                        default=False,
                        action='store_true',
                        dest='multi_gpu')

    # Hyperparameters
    parser.add_argument("--optimizer", default="Adam", type=str)
    parser.add_argument("--lr", default=0.003, type=float)
    parser.add_argument(
        "--clipnorm",
        default=-1,
        type=float,
        help="if it is greater than 0, then graidient clipping is activated")
    parser.add_argument("--clipvalue", default=-1, type=float)
    parser.add_argument("--use-class-weight",
                        dest="use_class_weight",
                        default=False,
                        action="store_true")

    parser.add_argument("--batch_size", default=128, type=int)
    parser.add_argument("--valid_batch_size", default=1024, type=int)

    # Frequencies
    parser.add_argument("--valid_freq", type=int, default=32)
    parser.add_argument("--save_freq", type=int, default=32)
    parser.add_argument("-v", "--verbose", action="store_true")

    # Project parameters
    parser.add_argument("--min-pt", dest="min_pt", default=100, type=int)

    # Model Archtecture
    parser.add_argument("--act", dest="activation", default="elu", type=str)

    args = parser.parse_args()

    ###################
    #
    ###################
    log_dir = Directory(path=args.log_dir)
    log_dir.mkdir("script")
    log_dir.mkdir("checkpoint")
    log_dir.mkdir("learning_curve")
    log_dir.mkdir("roc_curve")
    log_dir.mkdir("model_response")

    backup_scripts(log_dir.script.path)

    config = Config(log_dir.path, "w")
    config.append(args)
    config["hostname"] = os.environ["HOSTNAME"]

    ###############################
    # Load
    #################################
    if os.environ["HOSTNAME"] == "cms05.sscc.uos.ac.kr":
        ckpt_dir = "/store/slowmoyang/QGJets/SJ-keras4hep/Dev-Composite"
    elif os.environ["HOSTNAME"] == "gate2":
        ckpt_dir = "/scratch/slowmoyang/QGJets/SJ-keras4hep/Dev-Composite"
    else:
        raise NotImplementedError

    cnn_path = os.path.join(
        ckpt_dir,
        "VanillaConvNet_epoch-67_loss-0.4987_acc-0.7659_auc-0.8422.hdf5")
    rnn_path = os.path.join(
        ckpt_dir,
        "RNNGatherEmbedding_weights_epoch-121_loss-0.4963_acc-0.7658_auc-0.8431.hdf5"
    )

    cnn_custom_objects = get_cnn_custom_objects()
    rnn_custom_objects = get_rnn_custom_objects()

    custom_objects = {}
    custom_objects.update(cnn_custom_objects)
    custom_objects.update(rnn_custom_objects)

    cnn = load_model(cnn_path, custom_objects=cnn_custom_objects)
    cnn.summary()
    print("\n" * 5)

    rnn = load_model(rnn_path, custom_objects=rnn_custom_objects)
    rnn.summary()
    print("\n" * 5)

    ######################################
    # Build
    ######################################
    inputs = cnn.inputs + rnn.inputs

    # cnn_logits = cnn.get_layer("cnn_gap2d_0").output
    # rnn_logits = rnn.get_layer("rnn_dense_6").output

    cnn_softmax = cnn.get_layer("cnn_softmax_0").output
    rnn_softmax = rnn.get_layer("rnn_softmax_0").output

    # logits = Add()([cnn_logits, rnn_logits])
    logits = Add()([cnn_softmax, rnn_softmax])

    y_pred = Softmax()(logits)

    model = Model(inputs=inputs, outputs=y_pred)
    model.summary()

    ################################################
    # Freeze
    ##################################################
    for each in model.layers:
        each.trainable = False

    ###################################################
    #
    ####################################################
    dset = get_dataset_paths(config.min_pt)
    config["fit_generator_input"] = {
        "x": ["x_img", "x_kin", "x_pid", "x_len"],
        "y": ["y"]
    }

    train_iter = get_data_iter(path=dset["training"],
                               batch_size=config.batch_size,
                               fit_generator_input=config.fit_generator_input,
                               fit_generator_mode=True)

    test_iter = get_data_iter(path=dset["test"],
                              batch_size=config.valid_batch_size,
                              fit_generator_input=config.fit_generator_input,
                              fit_generator_mode=False)

    ###########################################
    # Evaluation
    ############################################
    train_iter.fit_generator_mode = False
    train_iter.cycle = False

    evaluate(model=model,
             train_iter=train_iter,
             test_iter=test_iter,
             log_dir=log_dir)

    config.save()
Esempio n. 12
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--logdir",
                        dest="log_dir",
                        type=str,
                        default="./logs/untitled-{}".format(
                            datetime.now().strftime("%y%m%d-%H%M%S")))

    parser.add_argument("--num_gpus",
                        default=len(get_available_gpus()),
                        type=int)
    parser.add_argument("--multi-gpu",
                        default=False,
                        action='store_true',
                        dest='multi_gpu')

    # Hyperparameters
    parser.add_argument("--epoch", dest="num_epochs", default=100, type=int)
    parser.add_argument("--batch_size", default=128, type=int)
    parser.add_argument("--valid_batch_size", default=1024, type=int)

    # Optimizer
    parser.add_argument("--lr", default=0.001, type=float)
    parser.add_argument(
        "--clipnorm",
        default=-1,
        type=float,
        help="if it is greater than 0, then graidient clipping is activated")
    parser.add_argument("--clipvalue", default=-1, type=float)

    # Frequencies
    parser.add_argument("--valid_freq", type=int, default=32)
    parser.add_argument("--save_freq", type=int, default=32)
    parser.add_argument("-v", "--verbose", action="store_true")

    # Project parameters
    parser.add_argument("--min-pt", dest="min_pt", default=100, type=int)
    parser.add_argument("--activation", default="relu", type=str)
    parser.add_argument("--top", default="dense", type=str)
    parser.add_argument("--filters_list", nargs="+", default=[64, 256, 32])

    args = parser.parse_args()

    ###################
    #
    ###################
    log_dir = Directory(path=args.log_dir)
    log_dir.mkdir("script")
    log_dir.mkdir("checkpoint")
    log_dir.mkdir("learning_curve")
    log_dir.mkdir("roc_curve")
    log_dir.mkdir("model_response")

    backup_scripts(log_dir.script.path)

    config = Config(log_dir.path, "w")
    config.append(args)
    config["hostname"] = os.environ["HOSTNAME"]

    ########################################
    # Load training and validation datasets
    ########################################
    dset = get_dataset_paths(config.min_pt)
    config.append(dset)

    train_iter = get_data_iter(path=config.training,
                               prep_path=config.preprocessing,
                               batch_size=config.batch_size,
                               fit_generator_mode=True,
                               drop_last=True)

    valid_iter = get_data_iter(path=dset["validation"],
                               prep_path=config.preprocessing,
                               batch_size=config.valid_batch_size,
                               fit_generator_mode=True)

    test_iter = get_data_iter(path=dset["test"],
                              prep_path=config.preprocessing,
                              batch_size=config.valid_batch_size,
                              fit_generator_mode=False)

    class_weight = get_class_weight(train_iter)
    config["class_weight"] = list(class_weight)

    #################################
    # Build & Compile a model.
    #################################
    x_shape = train_iter.get_shape("x", batch_shape=False)

    model = build_a_model(x_shape)
    config["model"] = model.get_config()

    if config.multi_gpu:
        model = multi_gpu_model(_model, gpus=config.num_gpus)

    model_plot_path = log_dir.concat("model.png")
    plot_model(model, to_file=model_plot_path, show_shapes=True)

    # TODO args should have these information.
    loss = 'categorical_crossentropy'

    # TODO capsulisation
    optimizer_kwargs = {}
    if config.clipnorm > 0:
        optimzer_kwargs["clipnorm"] = config.clipnorm
    if config.clipvalue > 0:
        optimzer_kwargs["clipvalue"] = config.clipvalue
    optimizer = optimizers.Adam(lr=config.lr, **optimizer_kwargs)

    metric_list = ["accuracy", roc_auc]

    config["loss"] = loss
    config["optimizer"] = "Adam"
    config["optimizer_config"] = optimizer.get_config()

    model.compile(loss=loss, optimizer=optimizer, metrics=metric_list)

    ###########################################################################
    # Callbacks
    ###########################################################################
    ckpt_format_str = "weights_epoch-{epoch:02d}_loss-{val_loss:.4f}_acc-{val_acc:.4f}_auc-{val_roc_auc:.4f}.hdf5"
    ckpt_path = log_dir.checkpoint.concat(ckpt_format_str)
    csv_log_path = log_dir.concat("log_file.csv")

    learning_curve = LearningCurve(directory=log_dir.learning_curve.path)
    learning_curve.book(x="step", y="roc_auc", best="max")
    learning_curve.book(x="step", y="acc", best="max")
    learning_curve.book(x="step", y="loss", best="min")

    callback_list = [
        callbacks.ModelCheckpoint(filepath=ckpt_path),
        callbacks.EarlyStopping(monitor="val_loss", patience=5),
        callbacks.ReduceLROnPlateau(),
        callbacks.CSVLogger(csv_log_path),
        learning_curve,
    ]

    ############################################################################
    # Training
    ############################################################################
    model.fit_generator(train_iter,
                        steps_per_epoch=len(train_iter),
                        epochs=50,
                        validation_data=valid_iter,
                        validation_steps=len(valid_iter),
                        callbacks=callback_list,
                        shuffle=True,
                        class_weight=class_weight)

    print("Training is over! :D")

    del model
    K.clear_session()

    ###########################################
    #
    ############################################
    train_iter.fit_generator_mode = False
    train_iter.cycle = False

    good_ckpt = find_good_checkpoint(log_dir.checkpoint.path,
                                     which={
                                         "max": ["auc", "acc"],
                                         "min": ["loss"]
                                     })

    for idx, each in enumerate(good_ckpt, 1):
        print("[{}/{}] {}".format(idx, len(good_ckpt), each))

        K.clear_session()
        evaluate(custom_objects={"roc_auc": roc_auc},
                 checkpoint_path=each,
                 train_iter=train_iter,
                 test_iter=test_iter,
                 log_dir=log_dir)

    config.save()
Esempio n. 13
0
def main():
    ##########################
    # Argument Parsing
    ##########################
    parser = argparse.ArgumentParser()

    parser.add_argument("--logdir",
                        dest="log_dir",
                        type=str,
                        default="./logs/untitled-{}".format(
                            datetime.now().strftime("%y%m%d-%H%M%S")))

    parser.add_argument("--num_gpus",
                        default=len(get_available_gpus()),
                        type=int)
    parser.add_argument("--multi-gpu",
                        default=False,
                        action='store_true',
                        dest='multi_gpu')

    # Hyperparameters
    parser.add_argument("--epoch", dest="epochs", default=100, type=int)
    parser.add_argument("--batch_size", default=128, type=int)
    parser.add_argument("--valid_batch_size", default=1024, type=int)

    parser.add_argument("--optimizer", default="Adam", type=str)
    parser.add_argument("--lr", default=0.003, type=float)
    parser.add_argument(
        "--clipnorm",
        default=-1,
        type=float,
        help="if it is greater than 0, then graidient clipping is activated")
    parser.add_argument("--clipvalue", default=-1, type=float)
    parser.add_argument("--use-class-weight",
                        dest="use_class_weight",
                        default=False,
                        action="store_true")

    # Frequencies
    parser.add_argument("--valid_freq", type=int, default=32)
    parser.add_argument("--save_freq", type=int, default=32)
    parser.add_argument("-v", "--verbose", action="store_true")

    # Project parameters
    parser.add_argument("--min-pt", dest="min_pt", default=100, type=int)

    # Model Archtecture
    parser.add_argument("--act", dest="activation", default="elu", type=str)

    args = parser.parse_args()

    ###################
    #
    ###################
    log_dir = Directory(path=args.log_dir)
    log_dir.mkdir("script")
    log_dir.mkdir("checkpoint")
    log_dir.mkdir("learning_curve")
    log_dir.mkdir("roc_curve")
    log_dir.mkdir("model_response")

    backup_scripts(log_dir.script.path)

    config = Config(log_dir.path, "w")
    config.append(args)
    config["hostname"] = os.environ["HOSTNAME"]

    ###############################
    # Load
    #################################
    if os.environ["HOSTNAME"] == "cms05.sscc.uos.ac.kr":
        ckpt_dir = "/store/slowmoyang/QGJets/SJ-keras4hep/Dev-Composite"
    elif os.environ["HOSTNAME"] == "gate2":
        ckpt_dir = "/scratch/slowmoyang/QGJets/SJ-keras4hep/Dev-Composite"
    else:
        raise NotImplementedError

    cnn_path = os.path.join(
        ckpt_dir,
        "VanillaConvNet_epoch-67_loss-0.4987_acc-0.7659_auc-0.8422.hdf5")
    rnn_path = os.path.join(
        ckpt_dir,
        "RNNGatherEmbedding_weights_epoch-121_loss-0.4963_acc-0.7658_auc-0.8431.hdf5"
    )

    cnn_custom_objects = get_cnn_custom_objects()
    rnn_custom_objects = get_rnn_custom_objects()

    custom_objects = {}
    custom_objects.update(cnn_custom_objects)
    custom_objects.update(rnn_custom_objects)

    cnn = load_model(cnn_path, custom_objects=cnn_custom_objects)
    cnn.summary()
    print("\n" * 5)

    rnn = load_model(rnn_path, custom_objects=rnn_custom_objects)
    rnn.summary()
    print("\n" * 5)

    ######################################
    # Build
    ######################################
    inputs = cnn.inputs + rnn.inputs

    # cnn_last_hidden = cnn.get_layer("cnn_conv2d_3").output
    # rnn_last_hidden = rnn.get_layer("rnn_dense_5").output
    # cnn_flatten = Flatten()(cnn_last_hidden)
    # joint = Concatenate(axis=-1)([cnn_flatten, rnn_last_hidden])

    cnn_last_hidden = cnn.get_layer("cnn_batch_norm_2").output
    rnn_last_hidden = rnn.get_layer("rnn_dense_5").output

    cnn_gap = GlobalAveragePooling2D()(cnn_last_hidden)
    cnn_flatten = Flatten()(cnn_gap)
    joint = Concatenate(axis=-1)([cnn_flatten, rnn_last_hidden])
    joint = BatchNormalization(axis=-1, name="joint_batch_norm")(joint)
    joint = Dense(128)(joint)
    joint = Activation("relu")(joint)
    logits = Dense(2)(joint)

    y_pred = Softmax()(logits)

    model = Model(inputs=inputs, outputs=y_pred)

    model.summary()

    ################################################
    # Freeze
    ##################################################
    for each in model.layers:
        if each.name.startswith("cnn") or each.name.startswith("rnn"):
            each.trainable = False

    ###################################################
    #
    ####################################################
    dset = get_dataset_paths(config.min_pt)
    config["fit_generator_input"] = {
        "x": ["x_img", "x_kin", "x_pid", "x_len"],
        "y": ["y"]
    }

    train_iter = get_data_iter(path=dset["training"],
                               batch_size=config.batch_size,
                               fit_generator_input=config.fit_generator_input,
                               fit_generator_mode=True)

    valid_iter = get_data_iter(path=dset["validation"],
                               batch_size=config.valid_batch_size,
                               fit_generator_input=config.fit_generator_input,
                               fit_generator_mode=True)

    test_iter = get_data_iter(path=dset["test"],
                              batch_size=config.valid_batch_size,
                              fit_generator_input=config.fit_generator_input,
                              fit_generator_mode=False)

    if config.use_class_weight:
        class_weight = get_class_weight(train_iter)
        config["class_weight"] = list(class_weight)
    else:
        class_weight = None

    ######################################
    #
    #######################################
    loss = 'categorical_crossentropy'

    # TODO capsulisation
    optimizer_kwargs = {}
    if config.clipnorm > 0:
        optimizer_kwargs["clipnorm"] = config.clipnorm
    if config.clipvalue > 0:
        optimizer_kwargs["clipvalue"] = config.clipvalue
    optimizer = getattr(optimizers, config.optimizer)(lr=config.lr,
                                                      **optimizer_kwargs)

    metric_list = ["accuracy", roc_auc]
    model.compile(loss=loss, optimizer=optimizer, metrics=metric_list)

    ###########################################################################
    # Callbacks
    ###########################################################################
    ckpt_format_str = "weights_epoch-{epoch:02d}_loss-{val_loss:.4f}_acc-{val_acc:.4f}_auc-{val_roc_auc:.4f}.hdf5"
    ckpt_path = log_dir.checkpoint.concat(ckpt_format_str)
    csv_log_path = log_dir.concat("log_file.csv")

    learning_curve = LearningCurve(directory=log_dir.learning_curve.path)
    learning_curve.book(x="step", y="roc_auc", best="max")
    learning_curve.book(x="step", y="acc", best="max")
    learning_curve.book(x="step", y="loss", best="min")

    callback_list = [
        callbacks.ModelCheckpoint(filepath=ckpt_path),
        #         callbacks.EarlyStopping(monitor="val_loss" , patience=5),
        callbacks.ReduceLROnPlateau(verbose=1),
        callbacks.CSVLogger(csv_log_path),
        learning_curve,
    ]

    ############################################################################
    # Training
    ############################################################################
    model.fit_generator(train_iter,
                        steps_per_epoch=len(train_iter),
                        epochs=config.epochs,
                        validation_data=valid_iter,
                        validation_steps=len(valid_iter),
                        callbacks=callback_list,
                        shuffle=True,
                        class_weight=class_weight)

    print("Training is over! :D")

    del model

    ###########################################
    # Evaluation
    ############################################
    train_iter.fit_generator_mode = False
    train_iter.cycle = False

    good_ckpt = find_good_checkpoint(log_dir.checkpoint.path,
                                     which={
                                         "max": ["auc", "acc"],
                                         "min": ["loss"]
                                     })

    for idx, each in enumerate(good_ckpt, 1):
        print("[{}/{}] {}".format(idx, len(good_ckpt), each))

        K.clear_session()
        evaluate(custom_objects=custom_objects,
                 checkpoint_path=each,
                 train_iter=train_iter,
                 test_iter=test_iter,
                 log_dir=log_dir)

    config.save()