def train_model(
    trial_dir,
    save_all=False,
    include_collections=None,
    reduction_config=None,
    save_config=None,
    use_tf_keras=True,
    hook=None,
    eager=False,
    use_keras_optimizer=True,
    create_relu_collection=False,
    steps=None,
    add_callbacks=None,
):
    if use_tf_keras:
        from tensorflow import keras
    else:
        import keras

    # if reset:
    #     tf.reset_default_graph()

    mnist = keras.datasets.mnist

    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0

    relu_layer = keras.layers.Dense(128, activation="relu")

    model = keras.models.Sequential([
        keras.layers.Flatten(input_shape=(28, 28)),
        relu_layer,
        keras.layers.Dropout(0.2),
        keras.layers.Dense(10, activation="softmax"),
    ])

    if hook is None:
        if save_config is None:
            save_config = SaveConfig(save_interval=3)

        hook = KerasHook(
            trial_dir,
            save_config=save_config,
            save_all=save_all,
            include_collections=include_collections,
            reduction_config=reduction_config,
        )

        if not save_all and include_collections is not None:
            for cname in hook.include_collections:
                if cname not in include_collections:
                    hook.get_collection(cname).save_config = SaveConfig(
                        end_step=0)

    if create_relu_collection:
        hook.get_collection("relu").add_keras_layer(relu_layer,
                                                    inputs=True,
                                                    outputs=True)

    if use_keras_optimizer:
        opt = keras.optimizers.RMSprop()
    else:
        opt = tf.train.RMSPropOptimizer(0.1)

    opt = hook.wrap_optimizer(opt)

    if use_tf_keras:
        model.compile(
            optimizer=opt,
            loss="sparse_categorical_crossentropy",
            run_eagerly=eager,
            metrics=["accuracy"],
        )
    else:
        model.compile(optimizer=opt,
                      loss="sparse_categorical_crossentropy",
                      metrics=["accuracy"])

    hooks = []
    if add_callbacks:
        if "tensorboard" in add_callbacks:
            hooks.append(
                tf.keras.callbacks.TensorBoard(log_dir="/tmp/logs",
                                               histogram_freq=1,
                                               write_grads=True,
                                               write_images=True))
        if "fetch_tensor" in add_callbacks:
            hooks.append(FetchTensorCallback(model.outputs + model.weights))
    hooks.append(hook)

    if steps is None:
        steps = ["train"]
    for step in steps:
        if step == "train":
            model.fit(x_train,
                      y_train,
                      epochs=1,
                      steps_per_epoch=10,
                      callbacks=hooks,
                      verbose=0)
        elif step == "eval":
            model.evaluate(x_test,
                           y_test,
                           steps=10,
                           callbacks=hooks,
                           verbose=0)
        elif step == "predict":
            model.predict(x_test[:100], callbacks=hooks, verbose=0)

    hook._cleanup()
def train_model(
    trial_dir,
    save_all=False,
    hook=None,
    include_collections=None,
    reduction_config=None,
    save_config=None,
    eager=True,
    strategy=None,
    steps=None,
    add_callbacks=None,
    include_workers="all",
):
    tf.keras.backend.clear_session()
    if not eager:
        tf.compat.v1.disable_eager_execution()

    datasets, info = tfds.load(name="mnist",
                               with_info=True,
                               as_supervised=True)

    mnist_train, mnist_test = datasets["train"], datasets["test"]

    if strategy is None:
        strategy = tf.distribute.MirroredStrategy()

    # You can also do info.splits.total_num_examples to get the total
    # number of examples in the dataset.

    BUFFER_SIZE = 10000

    BATCH_SIZE_PER_REPLICA = 64
    BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

    def scale(image, label):
        image = tf.cast(image, tf.float32)
        image /= 255

        return image, label

    train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(
        BATCH_SIZE)
    eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

    if hook is None:
        if save_config is None:
            save_config = SaveConfig(save_interval=3)

        hook = KerasHook(
            out_dir=trial_dir,
            save_config=save_config,
            reduction_config=reduction_config,
            include_collections=include_collections,
            save_all=save_all,
            include_workers=include_workers,
        )

        if not save_all and include_collections is not None:
            for cname in hook.include_collections:
                if cname not in include_collections:
                    hook.get_collection(cname).save_config = SaveConfig(
                        end_step=0)

    opt = tf.keras.optimizers.Adam()

    opt = hook.wrap_optimizer(opt)

    with strategy.scope():
        relu_layer = tf.keras.layers.Dense(64, activation="relu")
        model = tf.keras.Sequential([
            tf.keras.layers.Conv2D(32,
                                   3,
                                   activation="relu",
                                   input_shape=(28, 28, 1)),
            tf.keras.layers.MaxPooling2D(),
            tf.keras.layers.Flatten(),
            relu_layer,
            tf.keras.layers.Dense(10, activation="softmax"),
        ])
        model.compile(loss="sparse_categorical_crossentropy",
                      optimizer=opt,
                      metrics=["accuracy"])

    hooks = []
    if add_callbacks:
        if "tensorboard" in add_callbacks:
            hooks.append(
                # write_grads = True causes crash saying handle must be created in scope
                # erorr like this https://stackoverflow.com/questions/56836895/custom-training-loop-using-tensorflow-gpu-1-14-and-tf-distribute-mirroredstrateg
                # this crash is even if callback is off
                tf.keras.callbacks.TensorBoard(log_dir="/tmp/logs",
                                               histogram_freq=4,
                                               write_images=True))

    hooks.append(hook)
    scalars_to_be_saved = dict()
    ts = time.time()
    scalars_to_be_saved["scalar/foobar"] = (ts, steps)
    hook.save_scalar("foobar", 1, sm_metric=True, timestamp=ts)

    if steps is None:
        steps = ["train"]
    for step in steps:
        if step == "train":
            model.fit(train_dataset,
                      epochs=1,
                      steps_per_epoch=10,
                      callbacks=hooks,
                      verbose=0)
        elif step == "eval":
            model.evaluate(eval_dataset, steps=10, callbacks=hooks, verbose=0)
        elif step == "predict":
            model.predict(train_dataset, steps=4, callbacks=hooks, verbose=0)

    smd.get_hook().close()
    return strategy, scalars_to_be_saved