Пример #1
0
def evaluate(config, train_dir, weights, customize, nevents):
    """Evaluate the trained model in train_dir"""
    if config is None:
        config = Path(train_dir) / "config.yaml"
        assert config.exists(
        ), "Could not find config file in train_dir, please provide one with -c <path/to/config>"
    config, _ = parse_config(config, weights=weights)

    if customize:
        config = customization_functions[customize](config)

    if config["setup"]["dtype"] == "float16":
        model_dtype = tf.dtypes.float16
        policy = mixed_precision.Policy("mixed_float16")
        mixed_precision.set_global_policy(policy)
    else:
        model_dtype = tf.dtypes.float32

    strategy, num_gpus = get_strategy()
    # physical_devices = tf.config.list_physical_devices('GPU')
    # for dev in physical_devices:
    #    tf.config.experimental.set_memory_growth(dev, True)

    model = make_model(config, model_dtype)
    model.build((1, config["dataset"]["padded_num_elem_size"],
                 config["dataset"]["num_input_features"]))

    # need to load the weights in the same trainable configuration as the model was set up
    configure_model_weights(model, config["setup"].get("weights_config",
                                                       "all"))
    if weights:
        model.load_weights(weights, by_name=True)
    else:
        weights = get_best_checkpoint(train_dir)
        print(
            "Loading best weights that could be found from {}".format(weights))
        model.load_weights(weights, by_name=True)

    iepoch = int(weights.split("/")[-1].split("-")[1])

    for dsname in config["validation_datasets"]:
        ds_test, _ = get_heptfds_dataset(dsname,
                                         config,
                                         num_gpus,
                                         "test",
                                         supervised=False)
        if nevents:
            ds_test = ds_test.take(nevents)
        ds_test = ds_test.batch(5)
        eval_dir = str(
            Path(train_dir) / "evaluation" / "epoch_{}".format(iepoch) /
            dsname)
        Path(eval_dir).mkdir(parents=True, exist_ok=True)
        eval_model(model, ds_test, config, eval_dir)

    freeze_model(model, config, train_dir)
Пример #2
0
def evaluate(config, train_dir, weights, evaluation_dir):
    """Evaluate the trained model in train_dir"""
    if config is None:
        config = Path(train_dir) / "config.yaml"
        assert config.exists(
        ), "Could not find config file in train_dir, please provide one with -c <path/to/config>"
    config, _ = parse_config(config, weights=weights)

    if evaluation_dir is None:
        eval_dir = str(Path(train_dir) / "evaluation")
    else:
        eval_dir = evaluation_dir

    Path(eval_dir).mkdir(parents=True, exist_ok=True)

    if config["setup"]["dtype"] == "float16":
        model_dtype = tf.dtypes.float16
        policy = mixed_precision.Policy("mixed_float16")
        mixed_precision.set_global_policy(policy)
        opt = mixed_precision.LossScaleOptimizer(opt)
    else:
        model_dtype = tf.dtypes.float32

    strategy, num_gpus = get_strategy()
    ds_test, _ = get_heptfds_dataset(config["validation_dataset"], config,
                                     num_gpus, "test")
    ds_test = ds_test.batch(5)

    model = make_model(config, model_dtype)
    model.build((1, config["dataset"]["padded_num_elem_size"],
                 config["dataset"]["num_input_features"]))

    # need to load the weights in the same trainable configuration as the model was set up
    configure_model_weights(model, config["setup"].get("weights_config",
                                                       "all"))
    if weights:
        model.load_weights(weights, by_name=True)
    else:
        weights = get_best_checkpoint(train_dir)
        print(
            "Loading best weights that could be found from {}".format(weights))
        model.load_weights(weights, by_name=True)

    eval_model(model, ds_test, config, eval_dir)
    freeze_model(model, config, ds_test.take(1), train_dir)
Пример #3
0
def main(config):
    tf.config.run_functions_eagerly(False)
    config['setup']['multi_output'] = True
    model_pf = make_model(config, tf.float32)

    tb = tf.keras.callbacks.TensorBoard(
        log_dir="logs", histogram_freq=1, write_graph=False, write_images=False,
        update_freq='epoch',
        profile_batch=0,
    )
    tb.set_model(model_pf)

    cp_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath="logs/weights-{epoch:02d}.hdf5",
        save_weights_only=True,
        verbose=0
    )
    cp_callback.set_model(model_pf)

    x = np.random.randn(1, config["dataset"]["padded_num_elem_size"], config["dataset"]["num_input_features"])
    ypred = concat_pf([model_pf(x), x])
    model_pf.load_weights("experiments/cms_20210909_132136_111774.gpu0.local/weights/weights-100-1.280379.hdf5", by_name=True)
    #model_pf.load_weights("./logs/weights-02.hdf5", by_name=True)

    model_disc = make_disc_model(config, ypred.shape[-1])

    num_gpus = 1
    ds = "cms_pf_ttbar"
    batch_size = 4
    config["datasets"][ds]["batch_per_gpu"] = batch_size
    ds_train, ds_info = get_heptfds_dataset(ds, config, num_gpus, "train", 128)
    ds_test, _ = get_heptfds_dataset(ds, config, num_gpus, "test", 128)
    ds_val, _ = get_heptfds_dataset(ds, config, num_gpus, "test", 128)

    cb = CustomCallback(
        "logs",
        ds_val,
        ds_info, plot_freq=10)

    cb.set_model(model_pf)

    input_elems = tf.keras.layers.Input(
        shape=(config["dataset"]["padded_num_elem_size"], config["dataset"]["num_input_features"]),
        batch_size=2*batch_size,
        name="input_detector_elements"
    )
    input_reco = tf.keras.layers.Input(
        shape=(config["dataset"]["padded_num_elem_size"], ypred.shape[-1]), name="input_reco_particles")
    pf_out = tf.keras.layers.Lambda(concat_pf)([model_pf(input_elems), input_elems])
    disc_out1 = model_disc([input_elems, pf_out])
    disc_out2 = model_disc([input_elems, input_reco])
    m1 = tf.keras.models.Model(inputs=[input_elems], outputs=[disc_out1], name="model_mlpf_disc")
    m2 = tf.keras.models.Model(inputs=[input_elems, input_reco], outputs=[disc_out2], name="model_reco_disc")

    def loss(x,y):
        return tf.keras.losses.binary_crossentropy(x,y, from_logits=True)

    #The MLPF reconstruction model (generator) is optimized to confuse the discriminator
    optimizer1 = tf.keras.optimizers.Adam(lr=1e-5)
    model_disc.trainable = False
    m1.compile(loss=loss, optimizer=optimizer1)
    m1.summary()

    #The discriminator model (adversarial) is optimized to distinguish between the true target and MLPF-reconstructed events
    optimizer2 = tf.keras.optimizers.Adam(lr=1e-5)
    model_pf.trainable = False
    model_disc.trainable = True
    m2.compile(loss=loss, optimizer=optimizer2)
    m2.summary()

    epochs = 1000

    ibatch = 0
    for epoch in range(epochs):
        loss_tot1 = 0.0
        loss_tot2 = 0.0
        loss_tot1_test = 0.0
        loss_tot2_test = 0.0


        for step, (xb, yb, wb) in tqdm(enumerate(ds_train), desc="Training"):

            msk_x = tf.cast(xb[:, :, 0:1]!=0, tf.float32)

            yp = concat_pf([model_pf(xb, training=True), xb])
            yb = concat_pf([yb, xb])

            yb = yb*msk_x

            #Train the discriminative (adversarial) model
            #true target particles have a classification target of 1, MLPF reconstructed a target of 0
            mlpf_train_inputs = tf.concat([xb, xb], axis=0)
            # mlpf_train_inputs = mlpf_train_inputs + tf.random.normal(mlpf_train_inputs.shape, stddev=0.0001)

            mlpf_train_outputs = tf.concat([yb, yp], axis=0)
            mlpf_train_disc_targets = tf.concat([batch_size*[0.99], batch_size*[0.01]], axis=0)
            loss2 = m2.train_on_batch([mlpf_train_inputs, mlpf_train_outputs], mlpf_train_disc_targets)

            #Train the MLPF reconstruction (generative) model with an inverted target
            disc_train_disc_targets = tf.concat([batch_size*[1.0]], axis=0)
            loss1 = m1.train_on_batch(xb, disc_train_disc_targets)

            loss_tot1 += loss1
            loss_tot2 += loss2
            ibatch += 1

        import boost_histogram as bh
        import mplhep
        import matplotlib.pyplot as plt

        preds_0 = []
        preds_1 = []

        for step, (xb, yb, wb) in tqdm(enumerate(ds_test), desc="Testing"):
            msk_x = tf.cast(xb[:, :, 0:1]!=0, tf.float32)

            yp = concat_pf([model_pf(xb, training=False), xb])
            yb = concat_pf([yb, xb])

            yb = yb*msk_x

            #Train the discriminative (adversarial) model
            #true target particles have a classification target of 1, MLPF reconstructed a target of 0
            mlpf_train_inputs = tf.concat([xb, xb], axis=0)
            mlpf_train_outputs = tf.concat([yb, yp], axis=0)
            mlpf_train_disc_targets = tf.concat([batch_size*[0.99], batch_size*[0.01]], axis=0)
            loss2 = m2.test_on_batch([mlpf_train_inputs, mlpf_train_outputs], mlpf_train_disc_targets)

            #Train the MLPF reconstruction (generative) model with an inverted target
            disc_train_disc_targets = tf.concat([batch_size*[1.0]], axis=0)
            loss1 = m1.test_on_batch(xb, disc_train_disc_targets)

            p = m2.predict_on_batch([mlpf_train_inputs, mlpf_train_outputs])
            preds_0 += list(p[mlpf_train_disc_targets<0.5, 0])
            preds_1 += list(p[mlpf_train_disc_targets>=0.5, 0])

            loss_tot1_test += loss1
            loss_tot2_test += loss2

        print("Epoch {}, l1={:.5E}/{:.5E}, l2={:.5E}/{:.5E}".format(epoch, loss_tot1, loss_tot1_test, loss_tot2, loss_tot2_test))

        #Draw histograms of the discriminator outputs for monitoring
        minval = np.min(preds_0 + preds_1)
        maxval = np.max(preds_0 + preds_1)
        h0 = bh.Histogram(bh.axis.Regular(50, minval, maxval))
        h1 = bh.Histogram(bh.axis.Regular(50, minval, maxval))
        h0.fill(preds_0)
        h1.fill(preds_1)

        fig = plt.figure(figsize=(4,4))
        mplhep.histplot(h0, label="MLPF")
        mplhep.histplot(h1, label="Target")
        plt.xlabel("Adversarial classification output")
        plt.legend(loc="best", frameon=False)
        plt.savefig("logs/disc_{}.pdf".format(epoch), bbox_inches="tight")
        plt.close("all")

        tb.on_epoch_end(epoch, {
            "loss1": loss_tot1,
            "loss2": loss_tot2,
            "val_loss1": loss_tot1_test,
            "val_loss2": loss_tot2_test,
            "val_mean_p0": np.mean(preds_0),
            "val_std_p0": np.std(preds_0),
            "val_mean_p1": np.mean(preds_1),
            "val_std_p1": np.std(preds_1),
        })

        cp_callback.on_epoch_end(epoch)
        cb.on_epoch_end(epoch)
Пример #4
0
def build_model_and_train(config,
                          checkpoint_dir=None,
                          full_config=None,
                          ntrain=None,
                          ntest=None,
                          name=None,
                          seeds=False):
    from ray import tune
    from ray.tune.integration.keras import TuneReportCheckpointCallback
    from raytune.search_space import set_raytune_search_parameters

    if seeds:
        # Set seeds for reproducibility
        random.seed(1234)
        np.random.seed(1234)
        tf.random.set_seed(1234)

    full_config, config_file_stem = parse_config(full_config)

    if config is not None:
        full_config = set_raytune_search_parameters(search_space=config,
                                                    config=full_config)

    strategy, num_gpus = get_strategy()

    ds_train, num_train_steps = get_datasets(
        full_config["train_test_datasets"], full_config, num_gpus, "train")
    ds_test, num_test_steps = get_datasets(full_config["train_test_datasets"],
                                           full_config, num_gpus, "test")
    ds_val, ds_info = get_heptfds_dataset(
        full_config["validation_datasets"][0],
        full_config,
        num_gpus,
        "test",
        full_config["setup"]["num_events_validation"],
        supervised=False,
    )
    ds_val = ds_val.batch(5)

    if ntrain:
        ds_train = ds_train.take(ntrain)
        num_train_steps = ntrain
    if ntest:
        ds_test = ds_test.take(ntest)
        num_test_steps = ntest

    print("num_train_steps", num_train_steps)
    print("num_test_steps", num_test_steps)
    total_steps = num_train_steps * full_config["setup"]["num_epochs"]
    print("total_steps", total_steps)

    callbacks = prepare_callbacks(
        full_config,
        tune.get_trial_dir(),
        ds_val,
    )

    callbacks = callbacks[:
                          -1]  # remove the CustomCallback at the end of the list

    with strategy.scope():
        lr_schedule, optim_callbacks = get_lr_schedule(full_config,
                                                       steps=total_steps)
        callbacks.append(optim_callbacks)
        opt = get_optimizer(full_config, lr_schedule)

        model = make_model(full_config, dtype=tf.dtypes.float32)

        # Run model once to build the layers
        model.build((1, full_config["dataset"]["padded_num_elem_size"],
                     full_config["dataset"]["num_input_features"]))

        full_config = set_config_loss(full_config,
                                      full_config["setup"]["trainable"])
        configure_model_weights(model, full_config["setup"]["trainable"])
        model.build((1, full_config["dataset"]["padded_num_elem_size"],
                     full_config["dataset"]["num_input_features"]))

        loss_dict, loss_weights = get_loss_dict(full_config)
        model.compile(
            loss=loss_dict,
            optimizer=opt,
            sample_weight_mode="temporal",
            loss_weights=loss_weights,
            metrics={
                "cls": [
                    FlattenedCategoricalAccuracy(name="acc_unweighted",
                                                 dtype=tf.float64),
                    FlattenedCategoricalAccuracy(use_weights=True,
                                                 name="acc_weighted",
                                                 dtype=tf.float64),
                ]
            },
        )
        model.summary()

        callbacks.append(
            TuneReportCheckpointCallback(metrics=[
                "adam_beta_1",
                "charge_loss",
                "cls_acc_unweighted",
                "cls_loss",
                "cos_phi_loss",
                "energy_loss",
                "eta_loss",
                "learning_rate",
                "loss",
                "pt_loss",
                "sin_phi_loss",
                "val_charge_loss",
                "val_cls_acc_unweighted",
                "val_cls_acc_weighted",
                "val_cls_loss",
                "val_cos_phi_loss",
                "val_energy_loss",
                "val_eta_loss",
                "val_loss",
                "val_pt_loss",
                "val_sin_phi_loss",
            ], ), )

        try:
            model.fit(
                ds_train.repeat(),
                validation_data=ds_test.repeat(),
                epochs=full_config["setup"]["num_epochs"],
                callbacks=callbacks,
                steps_per_epoch=num_train_steps,
                validation_steps=num_test_steps,
            )
        except tf.errors.ResourceExhaustedError:
            logging.warning(
                "Resource exhausted, skipping this hyperparameter configuration."
            )
            skiplog_file_path = Path(full_config["raytune"]["local_dir"]
                                     ) / name / "skipped_configurations.txt"
            lines = [
                "{}: {}\n".format(item[0], item[1]) for item in config.items()
            ]

            with open(skiplog_file_path, "a") as f:
                f.write("#" * 80 + "\n")
                for line in lines:
                    f.write(line)
                    logging.warning(line[:-1])
                f.write("#" * 80 + "\n\n")
Пример #5
0
def hypertune(config, outdir, ntrain, ntest, recreate):
    config_file_path = config
    config, _ = parse_config(config, ntrain=ntrain, ntest=ntest)

    # Override number of epochs with max_epochs from Hyperband config if specified
    if config["hypertune"]["algorithm"] == "hyperband":
        config["setup"]["num_epochs"] = config["hypertune"]["hyperband"][
            "max_epochs"]

    strategy, num_gpus = get_strategy()

    ds_train, ds_info = get_heptfds_dataset(
        config["training_dataset"], config, num_gpus, "train",
        config["setup"]["num_events_train"])
    ds_test, _ = get_heptfds_dataset(config["testing_dataset"], config,
                                     num_gpus, "test",
                                     config["setup"]["num_events_test"])
    ds_val, _ = get_heptfds_dataset(
        config["validation_datasets"][0],
        config,
        num_gpus,
        "test",
        config["setup"]["num_events_validation"],
        supervised=False,
    )
    ds_val = ds_val.batch(5)

    num_train_steps = 0
    for _ in ds_train:
        num_train_steps += 1
    num_test_steps = 0
    for _ in ds_test:
        num_test_steps += 1

    model_builder, optim_callbacks = hypertuning.get_model_builder(
        config, num_train_steps)

    callbacks = prepare_callbacks(
        config,
        outdir,
        ds_val,
    )

    callbacks.append(optim_callbacks)
    callbacks.append(
        tf.keras.callbacks.EarlyStopping(patience=20, monitor="val_loss"))

    tuner = get_tuner(config["hypertune"], model_builder, outdir, recreate,
                      strategy)
    tuner.search_space_summary()

    tuner.search(
        ds_train.repeat(),
        epochs=config["setup"]["num_epochs"],
        validation_data=ds_test.repeat(),
        steps_per_epoch=num_train_steps,
        validation_steps=num_test_steps,
        callbacks=callbacks,
    )
    print("Hyperparameter search complete.")
    shutil.copy(config_file_path, outdir + "/config.yaml"
                )  # Copy the config file to the train dir for later reference

    tuner.results_summary()
    for trial in tuner.oracle.get_best_trials(num_trials=10):
        print(trial.hyperparameters.values, trial.score)
Пример #6
0
def train(config, weights, ntrain, ntest, nepochs, recreate, prefix, plot_freq,
          customize, comet_offline):

    # tf.debugging.enable_check_numerics()
    """Train a model defined by config"""
    config_file_path = config
    config, config_file_stem = parse_config(config,
                                            nepochs=nepochs,
                                            weights=weights)

    if plot_freq:
        config["callbacks"]["plot_freq"] = plot_freq

    if customize:
        config = customization_functions[customize](config)

    # Decide tf.distribute.strategy depending on number of available GPUs
    horovod_enabled = config["setup"]["horovod_enabled"]
    if horovod_enabled:
        num_gpus = initialize_horovod()
    else:
        strategy, num_gpus = get_strategy()

    outdir = ""
    if not horovod_enabled or hvd.rank() == 0:
        outdir = create_experiment_dir(prefix=prefix + config_file_stem + "_",
                                       suffix=platform.node())
        shutil.copy(
            config_file_path, outdir + "/config.yaml"
        )  # Copy the config file to the train dir for later reference

    try:
        if comet_offline:
            print("Using comet-ml OfflineExperiment, saving logs locally.")
            from comet_ml import OfflineExperiment

            experiment = OfflineExperiment(
                project_name="particleflow-tf",
                auto_metric_logging=True,
                auto_param_logging=True,
                auto_histogram_weight_logging=True,
                auto_histogram_gradient_logging=False,
                auto_histogram_activation_logging=False,
                offline_directory=outdir + "/cometml",
            )
        else:
            print("Using comet-ml Experiment, streaming logs to www.comet.ml.")
            from comet_ml import Experiment

            experiment = Experiment(
                project_name="particleflow-tf",
                auto_metric_logging=True,
                auto_param_logging=True,
                auto_histogram_weight_logging=True,
                auto_histogram_gradient_logging=False,
                auto_histogram_activation_logging=False,
            )
    except Exception as e:
        print("Failed to initialize comet-ml dashboard: {}".format(e))
        experiment = None
    if experiment:
        experiment.set_name(outdir)
        experiment.log_code("mlpf/tfmodel/model.py")
        experiment.log_code("mlpf/tfmodel/utils.py")
        experiment.log_code(config_file_path)

    ds_train, num_train_steps = get_datasets(config["train_test_datasets"],
                                             config, num_gpus, "train")
    ds_test, num_test_steps = get_datasets(config["train_test_datasets"],
                                           config, num_gpus, "test")
    ds_val, ds_info = get_heptfds_dataset(
        config["validation_datasets"][0],
        config,
        num_gpus,
        "test",
        config["setup"]["num_events_validation"],
        supervised=False,
    )
    ds_val = ds_val.batch(5)

    if ntrain:
        ds_train = ds_train.take(ntrain)
        num_train_steps = ntrain
    if ntest:
        ds_test = ds_test.take(ntest)
        num_test_steps = ntest

    print("num_train_steps", num_train_steps)
    print("num_test_steps", num_test_steps)
    total_steps = num_train_steps * config["setup"]["num_epochs"]
    print("total_steps", total_steps)

    if horovod_enabled:
        model, optim_callbacks, initial_epoch = model_scope(
            config, total_steps, weights, horovod_enabled)
    else:
        with strategy.scope():
            model, optim_callbacks, initial_epoch = model_scope(
                config, total_steps, weights)

    callbacks = prepare_callbacks(
        config,
        outdir,
        ds_val,
        comet_experiment=experiment,
        horovod_enabled=config["setup"]["horovod_enabled"])

    verbose = 1
    if horovod_enabled:
        callbacks.append(hvd.callbacks.BroadcastGlobalVariablesCallback(0))
        callbacks.append(hvd.callbacks.MetricAverageCallback())
        verbose = 1 if hvd.rank() == 0 else 0

        num_train_steps /= hvd.size()
        num_test_steps /= hvd.size()

    callbacks.append(optim_callbacks)

    model.fit(
        ds_train.repeat(),
        validation_data=ds_test.repeat(),
        epochs=initial_epoch + config["setup"]["num_epochs"],
        callbacks=callbacks,
        steps_per_epoch=num_train_steps,
        validation_steps=num_test_steps,
        initial_epoch=initial_epoch,
        verbose=verbose,
    )
Пример #7
0
def find_lr(config, outdir, figname, logscale):
    """Run the Learning Rate Finder to produce a batch loss vs. LR plot from
    which an appropriate LR-range can be determined"""
    config, _ = parse_config(config)

    # Decide tf.distribute.strategy depending on number of available GPUs
    strategy, num_gpus = get_strategy()

    ds_train, ds_info = get_heptfds_dataset(
        config["training_dataset"], config, num_gpus, "train",
        config["setup"]["num_events_train"])
    ds_train = ds_train.take(1)

    with strategy.scope():
        opt = tf.keras.optimizers.Adam(
            learning_rate=1e-7
        )  # This learning rate will be changed by the lr_finder
        if config["setup"]["dtype"] == "float16":
            model_dtype = tf.dtypes.float16
            policy = mixed_precision.Policy("mixed_float16")
            mixed_precision.set_global_policy(policy)
            opt = mixed_precision.LossScaleOptimizer(opt)
        else:
            model_dtype = tf.dtypes.float32

        model = make_model(config, model_dtype)
        config = set_config_loss(config, config["setup"]["trainable"])

        # Run model once to build the layers
        model.build((1, config["dataset"]["padded_num_elem_size"],
                     config["dataset"]["num_input_features"]))

        configure_model_weights(model, config["setup"]["trainable"])

        loss_dict, loss_weights = get_loss_dict(config)
        model.compile(
            loss=loss_dict,
            optimizer=opt,
            sample_weight_mode="temporal",
            loss_weights=loss_weights,
            metrics={
                "cls": [
                    FlattenedCategoricalAccuracy(name="acc_unweighted",
                                                 dtype=tf.float64),
                    FlattenedCategoricalAccuracy(use_weights=True,
                                                 name="acc_weighted",
                                                 dtype=tf.float64),
                ]
            },
        )
        model.summary()

        max_steps = 200
        lr_finder = LRFinder(max_steps=max_steps)
        callbacks = [lr_finder]

        model.fit(
            ds_train.repeat(),
            epochs=max_steps,
            callbacks=callbacks,
            steps_per_epoch=1,
        )

        lr_finder.plot(save_dir=outdir, figname=figname, log_scale=logscale)
Пример #8
0
def train(config, weights, ntrain, ntest, nepochs, recreate, prefix, plot_freq,
          customize):

    try:
        from comet_ml import Experiment
        experiment = Experiment(
            project_name="particleflow-tf",
            auto_metric_logging=True,
            auto_param_logging=True,
            auto_histogram_weight_logging=True,
            auto_histogram_gradient_logging=False,
            auto_histogram_activation_logging=False,
        )
    except Exception as e:
        print("Failed to initialize comet-ml dashboard")
        experiment = None
    """Train a model defined by config"""
    config_file_path = config
    config, config_file_stem = parse_config(config,
                                            nepochs=nepochs,
                                            weights=weights)

    if plot_freq:
        config["callbacks"]["plot_freq"] = plot_freq

    if customize:
        config = customization_functions[customize](config)

    if recreate or (weights is None):
        outdir = create_experiment_dir(prefix=prefix + config_file_stem + "_",
                                       suffix=platform.node())
    else:
        outdir = str(Path(weights).parent)

    # Decide tf.distribute.strategy depending on number of available GPUs
    strategy, num_gpus = get_strategy()
    #if "CPU" not in strategy.extended.worker_devices[0]:
    #    nvidia_smi_call = "nvidia-smi --query-gpu=timestamp,name,pci.bus_id,pstate,power.draw,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used --format=csv -l 1 -f {}/nvidia_smi_log.csv".format(outdir)
    #    p = subprocess.Popen(shlex.split(nvidia_smi_call))

    ds_train, num_train_steps = get_datasets(config["train_test_datasets"],
                                             config, num_gpus, "train")
    ds_test, num_test_steps = get_datasets(config["train_test_datasets"],
                                           config, num_gpus, "test")
    ds_val, ds_info = get_heptfds_dataset(
        config["validation_dataset"], config, num_gpus, "test",
        config["setup"]["num_events_validation"])
    ds_val = ds_val.batch(5)

    if ntrain:
        ds_train = ds_train.take(ntrain)
        num_train_steps = ntrain
    if ntest:
        ds_test = ds_test.take(ntest)
        num_test_steps = ntest

    print("num_train_steps", num_train_steps)
    print("num_test_steps", num_test_steps)
    total_steps = num_train_steps * config["setup"]["num_epochs"]
    print("total_steps", total_steps)

    if experiment:
        experiment.set_name(outdir)
        experiment.log_code("mlpf/tfmodel/model.py")
        experiment.log_code("mlpf/tfmodel/utils.py")
        experiment.log_code(config_file_path)

    shutil.copy(config_file_path, outdir + "/config.yaml"
                )  # Copy the config file to the train dir for later reference

    with strategy.scope():
        lr_schedule, optim_callbacks = get_lr_schedule(config,
                                                       steps=total_steps)
        opt = get_optimizer(config, lr_schedule)

        if config["setup"]["dtype"] == "float16":
            model_dtype = tf.dtypes.float16
            policy = mixed_precision.Policy("mixed_float16")
            mixed_precision.set_global_policy(policy)
            opt = mixed_precision.LossScaleOptimizer(opt)
        else:
            model_dtype = tf.dtypes.float32

        model = make_model(config, model_dtype)

        # Build the layers after the element and feature dimensions are specified
        model.build((1, config["dataset"]["padded_num_elem_size"],
                     config["dataset"]["num_input_features"]))

        initial_epoch = 0
        if weights:
            # We need to load the weights in the same trainable configuration as the model was set up
            configure_model_weights(
                model, config["setup"].get("weights_config", "all"))
            model.load_weights(weights, by_name=True)
            initial_epoch = int(weights.split("/")[-1].split("-")[1])
        model.build((1, config["dataset"]["padded_num_elem_size"],
                     config["dataset"]["num_input_features"]))

        config = set_config_loss(config, config["setup"]["trainable"])
        configure_model_weights(model, config["setup"]["trainable"])
        model.build((1, config["dataset"]["padded_num_elem_size"],
                     config["dataset"]["num_input_features"]))

        print("model weights")
        tw_names = [m.name for m in model.trainable_weights]
        for w in model.weights:
            print("layer={} trainable={} shape={} num_weights={}".format(
                w.name, w.name in tw_names, w.shape, np.prod(w.shape)))

        loss_dict, loss_weights = get_loss_dict(config)
        model.compile(
            loss=loss_dict,
            optimizer=opt,
            sample_weight_mode="temporal",
            loss_weights=loss_weights,
            metrics={
                "cls": [
                    FlattenedCategoricalAccuracy(name="acc_unweighted",
                                                 dtype=tf.float64),
                    FlattenedCategoricalAccuracy(use_weights=True,
                                                 name="acc_weighted",
                                                 dtype=tf.float64),
                ] + [
                    SingleClassRecall(
                        icls, name="rec_cls{}".format(icls), dtype=tf.float64)
                    for icls in range(config["dataset"]["num_output_classes"])
                ]
            },
        )
        model.summary()

    callbacks = prepare_callbacks(config["callbacks"],
                                  outdir,
                                  ds_val,
                                  ds_info,
                                  comet_experiment=experiment)
    callbacks.append(optim_callbacks)

    fit_result = model.fit(
        ds_train.repeat(),
        validation_data=ds_test.repeat(),
        epochs=initial_epoch + config["setup"]["num_epochs"],
        callbacks=callbacks,
        steps_per_epoch=num_train_steps,
        validation_steps=num_test_steps,
        initial_epoch=initial_epoch,
    )

    history_path = Path(outdir) / "history"
    history_path = str(history_path)
    with open("{}/history.json".format(history_path), "w") as fi:
        json.dump(fit_result.history, fi)

    weights = get_best_checkpoint(outdir)
    print("Loading best weights that could be found from {}".format(weights))
    model.load_weights(weights, by_name=True)

    model.save(outdir + "/model_full", save_format="tf")

    print("Training done.")