def test_save_all(out_dir, tf_eager_mode, workers):
    save_config = SaveConfig(save_steps=[5])
    strategy, saved_scalars = train_model(
        out_dir,
        include_collections=None,
        save_all=True,
        save_config=save_config,
        steps=["train"],
        eager=tf_eager_mode,
        include_workers=workers,
    )
    tr = create_trial_fast_refresh(out_dir)
    print(tr.tensor_names())
    if tf_eager_mode:
        if is_tf_2_2():
            assert len(
                tr.tensor_names()) == (6 + 2 + 1 + 5 + 1 + 1 + 2 + 8 +
                                       8 if is_tf_2_2() else 6 + 3 + 1 + 5 + 1)
            # weights, metrics, losses, optimizer variables, scalar, inputs, outputs, gradients, layers
        else:
            assert len(
                tr.tensor_names()) == (6 + 2 + 1 + 5 +
                                       1 if is_tf_2_3() else 6 + 3 + 1 + 5 + 1)
    else:
        assert (len(tr.tensor_names()) == 6 + 6 + 5 + 3 + 1 +
                3 * strategy.num_replicas_in_sync +
                2 * strategy.num_replicas_in_sync)
        # weights, grads, optimizer_variables, metrics, losses, outputs
    assert len(tr.steps()) == 3
    for tname in tr.tensor_names():
        assert len(
            tr.tensor(tname).workers(0)) == (1 if workers == "one" else
                                             strategy.num_replicas_in_sync)
    verify_files(out_dir, save_config, saved_scalars)
def helper_tensorflow_tests(use_keras, collection, save_config,
                            with_timestamp):
    coll_name, coll_regex = collection

    run_id = "trial_" + coll_name + "-" + datetime.now().strftime(
        "%Y%m%d-%H%M%S%f")
    trial_dir = os.path.join(SMDEBUG_TF_HOOK_TESTS_DIR, run_id)

    if use_keras:
        hook = TF_KerasHook(
            out_dir=trial_dir,
            include_collections=[coll_name],
            save_config=save_config,
            export_tensorboard=True,
        )

        saved_scalars = simple_tf_model(hook, with_timestamp=with_timestamp)

    else:
        hook = TF_SessionHook(
            out_dir=trial_dir,
            include_collections=[coll_name],
            save_config=save_config,
            export_tensorboard=True,
        )

        saved_scalars = tf_session_model(hook, with_timestamp=with_timestamp)
        tf.reset_default_graph()

    hook.close()
    verify_files(trial_dir, save_config, saved_scalars)
    if with_timestamp:
        check_tf_events(trial_dir, saved_scalars)
def helper_xgboost_tests(collection, save_config, with_timestamp):
    coll_name, coll_regex = collection

    run_id = "trial_" + coll_name + "-" + datetime.now().strftime(
        "%Y%m%d-%H%M%S%f")
    trial_dir = os.path.join(SMDEBUG_XG_HOOK_TESTS_DIR, run_id)

    hook = XG_Hook(
        out_dir=trial_dir,
        include_collections=[coll_name],
        save_config=save_config,
        export_tensorboard=True,
    )

    saved_scalars = simple_xg_model(hook, with_timestamp=with_timestamp)
    hook.close()
    verify_files(trial_dir, save_config, saved_scalars)
    if with_timestamp:
        check_tf_events(trial_dir, saved_scalars)