Esempio n. 1
0
def evaluate(config, train_dir, weights, customize, nevents):
    """Evaluate the trained model in train_dir"""
    if config is None:
        config = Path(train_dir) / "config.yaml"
        assert config.exists(
        ), "Could not find config file in train_dir, please provide one with -c <path/to/config>"
    config, _ = parse_config(config, weights=weights)

    if customize:
        config = customization_functions[customize](config)

    if config["setup"]["dtype"] == "float16":
        model_dtype = tf.dtypes.float16
        policy = mixed_precision.Policy("mixed_float16")
        mixed_precision.set_global_policy(policy)
    else:
        model_dtype = tf.dtypes.float32

    strategy, num_gpus = get_strategy()
    # physical_devices = tf.config.list_physical_devices('GPU')
    # for dev in physical_devices:
    #    tf.config.experimental.set_memory_growth(dev, True)

    model = make_model(config, model_dtype)
    model.build((1, config["dataset"]["padded_num_elem_size"],
                 config["dataset"]["num_input_features"]))

    # need to load the weights in the same trainable configuration as the model was set up
    configure_model_weights(model, config["setup"].get("weights_config",
                                                       "all"))
    if weights:
        model.load_weights(weights, by_name=True)
    else:
        weights = get_best_checkpoint(train_dir)
        print(
            "Loading best weights that could be found from {}".format(weights))
        model.load_weights(weights, by_name=True)

    iepoch = int(weights.split("/")[-1].split("-")[1])

    for dsname in config["validation_datasets"]:
        ds_test, _ = get_heptfds_dataset(dsname,
                                         config,
                                         num_gpus,
                                         "test",
                                         supervised=False)
        if nevents:
            ds_test = ds_test.take(nevents)
        ds_test = ds_test.batch(5)
        eval_dir = str(
            Path(train_dir) / "evaluation" / "epoch_{}".format(iepoch) /
            dsname)
        Path(eval_dir).mkdir(parents=True, exist_ok=True)
        eval_model(model, ds_test, config, eval_dir)

    freeze_model(model, config, train_dir)
Esempio n. 2
0
def evaluate(config, train_dir, weights, evaluation_dir):
    """Evaluate the trained model in train_dir"""
    if config is None:
        config = Path(train_dir) / "config.yaml"
        assert config.exists(
        ), "Could not find config file in train_dir, please provide one with -c <path/to/config>"
    config, _ = parse_config(config, weights=weights)

    if evaluation_dir is None:
        eval_dir = str(Path(train_dir) / "evaluation")
    else:
        eval_dir = evaluation_dir

    Path(eval_dir).mkdir(parents=True, exist_ok=True)

    if config["setup"]["dtype"] == "float16":
        model_dtype = tf.dtypes.float16
        policy = mixed_precision.Policy("mixed_float16")
        mixed_precision.set_global_policy(policy)
        opt = mixed_precision.LossScaleOptimizer(opt)
    else:
        model_dtype = tf.dtypes.float32

    strategy, num_gpus = get_strategy()
    ds_test, _ = get_heptfds_dataset(config["validation_dataset"], config,
                                     num_gpus, "test")
    ds_test = ds_test.batch(5)

    model = make_model(config, model_dtype)
    model.build((1, config["dataset"]["padded_num_elem_size"],
                 config["dataset"]["num_input_features"]))

    # need to load the weights in the same trainable configuration as the model was set up
    configure_model_weights(model, config["setup"].get("weights_config",
                                                       "all"))
    if weights:
        model.load_weights(weights, by_name=True)
    else:
        weights = get_best_checkpoint(train_dir)
        print(
            "Loading best weights that could be found from {}".format(weights))
        model.load_weights(weights, by_name=True)

    eval_model(model, ds_test, config, eval_dir)
    freeze_model(model, config, ds_test.take(1), train_dir)
Esempio n. 3
0
def build_model_and_train(config,
                          checkpoint_dir=None,
                          full_config=None,
                          ntrain=None,
                          ntest=None,
                          name=None,
                          seeds=False):
    from ray import tune
    from ray.tune.integration.keras import TuneReportCheckpointCallback
    from raytune.search_space import set_raytune_search_parameters

    if seeds:
        # Set seeds for reproducibility
        random.seed(1234)
        np.random.seed(1234)
        tf.random.set_seed(1234)

    full_config, config_file_stem = parse_config(full_config)

    if config is not None:
        full_config = set_raytune_search_parameters(search_space=config,
                                                    config=full_config)

    strategy, num_gpus = get_strategy()

    ds_train, num_train_steps = get_datasets(
        full_config["train_test_datasets"], full_config, num_gpus, "train")
    ds_test, num_test_steps = get_datasets(full_config["train_test_datasets"],
                                           full_config, num_gpus, "test")
    ds_val, ds_info = get_heptfds_dataset(
        full_config["validation_datasets"][0],
        full_config,
        num_gpus,
        "test",
        full_config["setup"]["num_events_validation"],
        supervised=False,
    )
    ds_val = ds_val.batch(5)

    if ntrain:
        ds_train = ds_train.take(ntrain)
        num_train_steps = ntrain
    if ntest:
        ds_test = ds_test.take(ntest)
        num_test_steps = ntest

    print("num_train_steps", num_train_steps)
    print("num_test_steps", num_test_steps)
    total_steps = num_train_steps * full_config["setup"]["num_epochs"]
    print("total_steps", total_steps)

    callbacks = prepare_callbacks(
        full_config,
        tune.get_trial_dir(),
        ds_val,
    )

    callbacks = callbacks[:
                          -1]  # remove the CustomCallback at the end of the list

    with strategy.scope():
        lr_schedule, optim_callbacks = get_lr_schedule(full_config,
                                                       steps=total_steps)
        callbacks.append(optim_callbacks)
        opt = get_optimizer(full_config, lr_schedule)

        model = make_model(full_config, dtype=tf.dtypes.float32)

        # Run model once to build the layers
        model.build((1, full_config["dataset"]["padded_num_elem_size"],
                     full_config["dataset"]["num_input_features"]))

        full_config = set_config_loss(full_config,
                                      full_config["setup"]["trainable"])
        configure_model_weights(model, full_config["setup"]["trainable"])
        model.build((1, full_config["dataset"]["padded_num_elem_size"],
                     full_config["dataset"]["num_input_features"]))

        loss_dict, loss_weights = get_loss_dict(full_config)
        model.compile(
            loss=loss_dict,
            optimizer=opt,
            sample_weight_mode="temporal",
            loss_weights=loss_weights,
            metrics={
                "cls": [
                    FlattenedCategoricalAccuracy(name="acc_unweighted",
                                                 dtype=tf.float64),
                    FlattenedCategoricalAccuracy(use_weights=True,
                                                 name="acc_weighted",
                                                 dtype=tf.float64),
                ]
            },
        )
        model.summary()

        callbacks.append(
            TuneReportCheckpointCallback(metrics=[
                "adam_beta_1",
                "charge_loss",
                "cls_acc_unweighted",
                "cls_loss",
                "cos_phi_loss",
                "energy_loss",
                "eta_loss",
                "learning_rate",
                "loss",
                "pt_loss",
                "sin_phi_loss",
                "val_charge_loss",
                "val_cls_acc_unweighted",
                "val_cls_acc_weighted",
                "val_cls_loss",
                "val_cos_phi_loss",
                "val_energy_loss",
                "val_eta_loss",
                "val_loss",
                "val_pt_loss",
                "val_sin_phi_loss",
            ], ), )

        try:
            model.fit(
                ds_train.repeat(),
                validation_data=ds_test.repeat(),
                epochs=full_config["setup"]["num_epochs"],
                callbacks=callbacks,
                steps_per_epoch=num_train_steps,
                validation_steps=num_test_steps,
            )
        except tf.errors.ResourceExhaustedError:
            logging.warning(
                "Resource exhausted, skipping this hyperparameter configuration."
            )
            skiplog_file_path = Path(full_config["raytune"]["local_dir"]
                                     ) / name / "skipped_configurations.txt"
            lines = [
                "{}: {}\n".format(item[0], item[1]) for item in config.items()
            ]

            with open(skiplog_file_path, "a") as f:
                f.write("#" * 80 + "\n")
                for line in lines:
                    f.write(line)
                    logging.warning(line[:-1])
                f.write("#" * 80 + "\n\n")
Esempio n. 4
0
def hypertune(config, outdir, ntrain, ntest, recreate):
    config_file_path = config
    config, _ = parse_config(config, ntrain=ntrain, ntest=ntest)

    # Override number of epochs with max_epochs from Hyperband config if specified
    if config["hypertune"]["algorithm"] == "hyperband":
        config["setup"]["num_epochs"] = config["hypertune"]["hyperband"][
            "max_epochs"]

    strategy, num_gpus = get_strategy()

    ds_train, ds_info = get_heptfds_dataset(
        config["training_dataset"], config, num_gpus, "train",
        config["setup"]["num_events_train"])
    ds_test, _ = get_heptfds_dataset(config["testing_dataset"], config,
                                     num_gpus, "test",
                                     config["setup"]["num_events_test"])
    ds_val, _ = get_heptfds_dataset(
        config["validation_datasets"][0],
        config,
        num_gpus,
        "test",
        config["setup"]["num_events_validation"],
        supervised=False,
    )
    ds_val = ds_val.batch(5)

    num_train_steps = 0
    for _ in ds_train:
        num_train_steps += 1
    num_test_steps = 0
    for _ in ds_test:
        num_test_steps += 1

    model_builder, optim_callbacks = hypertuning.get_model_builder(
        config, num_train_steps)

    callbacks = prepare_callbacks(
        config,
        outdir,
        ds_val,
    )

    callbacks.append(optim_callbacks)
    callbacks.append(
        tf.keras.callbacks.EarlyStopping(patience=20, monitor="val_loss"))

    tuner = get_tuner(config["hypertune"], model_builder, outdir, recreate,
                      strategy)
    tuner.search_space_summary()

    tuner.search(
        ds_train.repeat(),
        epochs=config["setup"]["num_epochs"],
        validation_data=ds_test.repeat(),
        steps_per_epoch=num_train_steps,
        validation_steps=num_test_steps,
        callbacks=callbacks,
    )
    print("Hyperparameter search complete.")
    shutil.copy(config_file_path, outdir + "/config.yaml"
                )  # Copy the config file to the train dir for later reference

    tuner.results_summary()
    for trial in tuner.oracle.get_best_trials(num_trials=10):
        print(trial.hyperparameters.values, trial.score)
Esempio n. 5
0
def find_lr(config, outdir, figname, logscale):
    """Run the Learning Rate Finder to produce a batch loss vs. LR plot from
    which an appropriate LR-range can be determined"""
    config, _ = parse_config(config)

    # Decide tf.distribute.strategy depending on number of available GPUs
    strategy, num_gpus = get_strategy()

    ds_train, num_train_steps = get_datasets(config["train_test_datasets"],
                                             config, num_gpus, "train")

    with strategy.scope():
        opt = tf.keras.optimizers.Adam(
            learning_rate=1e-7
        )  # This learning rate will be changed by the lr_finder
        if config["setup"]["dtype"] == "float16":
            model_dtype = tf.dtypes.float16
            policy = mixed_precision.Policy("mixed_float16")
            mixed_precision.set_global_policy(policy)
            opt = mixed_precision.LossScaleOptimizer(opt)
        else:
            model_dtype = tf.dtypes.float32

        model = make_model(config, model_dtype)
        config = set_config_loss(config, config["setup"]["trainable"])

        # Run model once to build the layers
        model.build((1, config["dataset"]["padded_num_elem_size"],
                     config["dataset"]["num_input_features"]))

        configure_model_weights(model, config["setup"]["trainable"])

        loss_dict, loss_weights = get_loss_dict(config)
        model.compile(
            loss=loss_dict,
            optimizer=opt,
            sample_weight_mode="temporal",
            loss_weights=loss_weights,
            metrics={
                "cls": [
                    FlattenedCategoricalAccuracy(name="acc_unweighted",
                                                 dtype=tf.float64),
                    FlattenedCategoricalAccuracy(use_weights=True,
                                                 name="acc_weighted",
                                                 dtype=tf.float64),
                ]
            },
        )
        model.summary()

        max_steps = 200
        lr_finder = LRFinder(max_steps=max_steps)
        callbacks = [lr_finder]

        model.fit(
            ds_train.repeat(),
            epochs=max_steps,
            callbacks=callbacks,
            steps_per_epoch=1,
        )

        lr_finder.plot(save_dir=outdir, figname=figname, log_scale=logscale)
Esempio n. 6
0
def compute_validation_loss(config, train_dir, weights):
    """Evaluate the trained model in train_dir"""
    if config is None:
        config = Path(train_dir) / "config.yaml"
        assert config.exists(
        ), "Could not find config file in train_dir, please provide one with -c <path/to/config>"
    config, _ = parse_config(config, weights=weights)

    if config["setup"]["dtype"] == "float16":
        model_dtype = tf.dtypes.float16
        policy = mixed_precision.Policy("mixed_float16")
        mixed_precision.set_global_policy(policy)
    else:
        model_dtype = tf.dtypes.float32

    strategy, num_gpus = get_strategy()
    ds_test, num_test_steps = get_datasets(config["train_test_datasets"],
                                           config, num_gpus, "test")

    with strategy.scope():
        model = make_model(config, model_dtype)
        model.build((1, config["dataset"]["padded_num_elem_size"],
                     config["dataset"]["num_input_features"]))

        # need to load the weights in the same trainable configuration as the model was set up
        configure_model_weights(model,
                                config["setup"].get("weights_config", "all"))
        if weights:
            model.load_weights(weights, by_name=True)
        else:
            weights = get_best_checkpoint(train_dir)
            print("Loading best weights that could be found from {}".format(
                weights))
            model.load_weights(weights, by_name=True)

        loss_dict, loss_weights = get_loss_dict(config)
        model.compile(
            loss=loss_dict,
            # sample_weight_mode="temporal",
            loss_weights=loss_weights,
            metrics={
                "cls": [
                    FlattenedCategoricalAccuracy(name="acc_unweighted",
                                                 dtype=tf.float64),
                    FlattenedCategoricalAccuracy(use_weights=True,
                                                 name="acc_weighted",
                                                 dtype=tf.float64),
                ] + [
                    SingleClassRecall(
                        icls, name="rec_cls{}".format(icls), dtype=tf.float64)
                    for icls in range(config["dataset"]["num_output_classes"])
                ]
            },
        )

        losses = model.evaluate(
            x=ds_test,
            steps=num_test_steps,
            return_dict=True,
        )
    with open("{}/losses.txt".format(train_dir), "w") as loss_file:
        loss_file.write(json.dumps(losses) + "\n")
Esempio n. 7
0
def train(config, weights, ntrain, ntest, nepochs, recreate, prefix, plot_freq,
          customize, comet_offline):

    # tf.debugging.enable_check_numerics()
    """Train a model defined by config"""
    config_file_path = config
    config, config_file_stem = parse_config(config,
                                            nepochs=nepochs,
                                            weights=weights)

    if plot_freq:
        config["callbacks"]["plot_freq"] = plot_freq

    if customize:
        config = customization_functions[customize](config)

    # Decide tf.distribute.strategy depending on number of available GPUs
    horovod_enabled = config["setup"]["horovod_enabled"]
    if horovod_enabled:
        num_gpus = initialize_horovod()
    else:
        strategy, num_gpus = get_strategy()

    outdir = ""
    if not horovod_enabled or hvd.rank() == 0:
        outdir = create_experiment_dir(prefix=prefix + config_file_stem + "_",
                                       suffix=platform.node())
        shutil.copy(
            config_file_path, outdir + "/config.yaml"
        )  # Copy the config file to the train dir for later reference

    try:
        if comet_offline:
            print("Using comet-ml OfflineExperiment, saving logs locally.")
            from comet_ml import OfflineExperiment

            experiment = OfflineExperiment(
                project_name="particleflow-tf",
                auto_metric_logging=True,
                auto_param_logging=True,
                auto_histogram_weight_logging=True,
                auto_histogram_gradient_logging=False,
                auto_histogram_activation_logging=False,
                offline_directory=outdir + "/cometml",
            )
        else:
            print("Using comet-ml Experiment, streaming logs to www.comet.ml.")
            from comet_ml import Experiment

            experiment = Experiment(
                project_name="particleflow-tf",
                auto_metric_logging=True,
                auto_param_logging=True,
                auto_histogram_weight_logging=True,
                auto_histogram_gradient_logging=False,
                auto_histogram_activation_logging=False,
            )
    except Exception as e:
        print("Failed to initialize comet-ml dashboard: {}".format(e))
        experiment = None
    if experiment:
        experiment.set_name(outdir)
        experiment.log_code("mlpf/tfmodel/model.py")
        experiment.log_code("mlpf/tfmodel/utils.py")
        experiment.log_code(config_file_path)

    ds_train, num_train_steps = get_datasets(config["train_test_datasets"],
                                             config, num_gpus, "train")
    ds_test, num_test_steps = get_datasets(config["train_test_datasets"],
                                           config, num_gpus, "test")
    ds_val, ds_info = get_heptfds_dataset(
        config["validation_datasets"][0],
        config,
        num_gpus,
        "test",
        config["setup"]["num_events_validation"],
        supervised=False,
    )
    ds_val = ds_val.batch(5)

    if ntrain:
        ds_train = ds_train.take(ntrain)
        num_train_steps = ntrain
    if ntest:
        ds_test = ds_test.take(ntest)
        num_test_steps = ntest

    print("num_train_steps", num_train_steps)
    print("num_test_steps", num_test_steps)
    total_steps = num_train_steps * config["setup"]["num_epochs"]
    print("total_steps", total_steps)

    if horovod_enabled:
        model, optim_callbacks, initial_epoch = model_scope(
            config, total_steps, weights, horovod_enabled)
    else:
        with strategy.scope():
            model, optim_callbacks, initial_epoch = model_scope(
                config, total_steps, weights)

    callbacks = prepare_callbacks(
        config,
        outdir,
        ds_val,
        comet_experiment=experiment,
        horovod_enabled=config["setup"]["horovod_enabled"])

    verbose = 1
    if horovod_enabled:
        callbacks.append(hvd.callbacks.BroadcastGlobalVariablesCallback(0))
        callbacks.append(hvd.callbacks.MetricAverageCallback())
        verbose = 1 if hvd.rank() == 0 else 0

        num_train_steps /= hvd.size()
        num_test_steps /= hvd.size()

    callbacks.append(optim_callbacks)

    model.fit(
        ds_train.repeat(),
        validation_data=ds_test.repeat(),
        epochs=initial_epoch + config["setup"]["num_epochs"],
        callbacks=callbacks,
        steps_per_epoch=num_train_steps,
        validation_steps=num_test_steps,
        initial_epoch=initial_epoch,
        verbose=verbose,
    )
Esempio n. 8
0
def debug_data(config, ntrain, ntest):
    """Train a model defined by config"""
    config, config_file_stem, global_batch_size, n_train, n_test, n_epochs, weights = parse_config(
        config,
        ntrain,
        ntest,
        weights=None,
    )

    dataset_def = get_dataset_def(config)
    ds_train, ds_test, dataset_transform = get_train_val_datasets(
        config, global_batch_size=1, n_train=n_train, n_test=n_test)

    # cand_counts = np.zeros(8)
    # for data_item in tqdm(ds_train, desc="Counting"):
    #     import pdb; pdb.set_trace()
    #     cand_vals, cand_count = np.unique(np.argmax(data_item[1]['cls'], axis=2), return_counts=True)
    #     cand_counts[cand_vals.astype("int32")] += cand_count
    # print("cand_counts: ", cand_counts)

    dsf = CMSDatasetFactory(config)
    ds_train, _ = dsf.get_dataset(split="train")
    ds_test, _ = dsf.get_dataset(split="test")
    for data_item in tqdm(ds_train, desc="Counting"):
        import pdb
        pdb.set_trace()
Esempio n. 9
0
def train(config, weights, ntrain, ntest, nepochs, recreate, prefix, plot_freq,
          customize):

    try:
        from comet_ml import Experiment
        experiment = Experiment(
            project_name="particleflow-tf",
            auto_metric_logging=True,
            auto_param_logging=True,
            auto_histogram_weight_logging=True,
            auto_histogram_gradient_logging=False,
            auto_histogram_activation_logging=False,
        )
    except Exception as e:
        print("Failed to initialize comet-ml dashboard")
        experiment = None
    """Train a model defined by config"""
    config_file_path = config
    config, config_file_stem = parse_config(config,
                                            nepochs=nepochs,
                                            weights=weights)

    if plot_freq:
        config["callbacks"]["plot_freq"] = plot_freq

    if customize:
        config = customization_functions[customize](config)

    if recreate or (weights is None):
        outdir = create_experiment_dir(prefix=prefix + config_file_stem + "_",
                                       suffix=platform.node())
    else:
        outdir = str(Path(weights).parent)

    # Decide tf.distribute.strategy depending on number of available GPUs
    strategy, num_gpus = get_strategy()
    #if "CPU" not in strategy.extended.worker_devices[0]:
    #    nvidia_smi_call = "nvidia-smi --query-gpu=timestamp,name,pci.bus_id,pstate,power.draw,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used --format=csv -l 1 -f {}/nvidia_smi_log.csv".format(outdir)
    #    p = subprocess.Popen(shlex.split(nvidia_smi_call))

    ds_train, num_train_steps = get_datasets(config["train_test_datasets"],
                                             config, num_gpus, "train")
    ds_test, num_test_steps = get_datasets(config["train_test_datasets"],
                                           config, num_gpus, "test")
    ds_val, ds_info = get_heptfds_dataset(
        config["validation_dataset"], config, num_gpus, "test",
        config["setup"]["num_events_validation"])
    ds_val = ds_val.batch(5)

    if ntrain:
        ds_train = ds_train.take(ntrain)
        num_train_steps = ntrain
    if ntest:
        ds_test = ds_test.take(ntest)
        num_test_steps = ntest

    print("num_train_steps", num_train_steps)
    print("num_test_steps", num_test_steps)
    total_steps = num_train_steps * config["setup"]["num_epochs"]
    print("total_steps", total_steps)

    if experiment:
        experiment.set_name(outdir)
        experiment.log_code("mlpf/tfmodel/model.py")
        experiment.log_code("mlpf/tfmodel/utils.py")
        experiment.log_code(config_file_path)

    shutil.copy(config_file_path, outdir + "/config.yaml"
                )  # Copy the config file to the train dir for later reference

    with strategy.scope():
        lr_schedule, optim_callbacks = get_lr_schedule(config,
                                                       steps=total_steps)
        opt = get_optimizer(config, lr_schedule)

        if config["setup"]["dtype"] == "float16":
            model_dtype = tf.dtypes.float16
            policy = mixed_precision.Policy("mixed_float16")
            mixed_precision.set_global_policy(policy)
            opt = mixed_precision.LossScaleOptimizer(opt)
        else:
            model_dtype = tf.dtypes.float32

        model = make_model(config, model_dtype)

        # Build the layers after the element and feature dimensions are specified
        model.build((1, config["dataset"]["padded_num_elem_size"],
                     config["dataset"]["num_input_features"]))

        initial_epoch = 0
        if weights:
            # We need to load the weights in the same trainable configuration as the model was set up
            configure_model_weights(
                model, config["setup"].get("weights_config", "all"))
            model.load_weights(weights, by_name=True)
            initial_epoch = int(weights.split("/")[-1].split("-")[1])
        model.build((1, config["dataset"]["padded_num_elem_size"],
                     config["dataset"]["num_input_features"]))

        config = set_config_loss(config, config["setup"]["trainable"])
        configure_model_weights(model, config["setup"]["trainable"])
        model.build((1, config["dataset"]["padded_num_elem_size"],
                     config["dataset"]["num_input_features"]))

        print("model weights")
        tw_names = [m.name for m in model.trainable_weights]
        for w in model.weights:
            print("layer={} trainable={} shape={} num_weights={}".format(
                w.name, w.name in tw_names, w.shape, np.prod(w.shape)))

        loss_dict, loss_weights = get_loss_dict(config)
        model.compile(
            loss=loss_dict,
            optimizer=opt,
            sample_weight_mode="temporal",
            loss_weights=loss_weights,
            metrics={
                "cls": [
                    FlattenedCategoricalAccuracy(name="acc_unweighted",
                                                 dtype=tf.float64),
                    FlattenedCategoricalAccuracy(use_weights=True,
                                                 name="acc_weighted",
                                                 dtype=tf.float64),
                ] + [
                    SingleClassRecall(
                        icls, name="rec_cls{}".format(icls), dtype=tf.float64)
                    for icls in range(config["dataset"]["num_output_classes"])
                ]
            },
        )
        model.summary()

    callbacks = prepare_callbacks(config["callbacks"],
                                  outdir,
                                  ds_val,
                                  ds_info,
                                  comet_experiment=experiment)
    callbacks.append(optim_callbacks)

    fit_result = model.fit(
        ds_train.repeat(),
        validation_data=ds_test.repeat(),
        epochs=initial_epoch + config["setup"]["num_epochs"],
        callbacks=callbacks,
        steps_per_epoch=num_train_steps,
        validation_steps=num_test_steps,
        initial_epoch=initial_epoch,
    )

    history_path = Path(outdir) / "history"
    history_path = str(history_path)
    with open("{}/history.json".format(history_path), "w") as fi:
        json.dump(fit_result.history, fi)

    weights = get_best_checkpoint(outdir)
    print("Loading best weights that could be found from {}".format(weights))
    model.load_weights(weights, by_name=True)

    model.save(outdir + "/model_full", save_format="tf")

    print("Training done.")