def test_save_all(out_dir, tf_eager_mode, workers): save_config = SaveConfig(save_steps=[5]) strategy, saved_scalars = train_model( out_dir, include_collections=None, save_all=True, save_config=save_config, steps=["train"], eager=tf_eager_mode, include_workers=workers, ) tr = create_trial_fast_refresh(out_dir) print(tr.tensor_names()) if tf_eager_mode: if is_tf_2_2(): assert len( tr.tensor_names()) == (6 + 2 + 1 + 5 + 1 + 1 + 2 + 8 + 8 if is_tf_2_2() else 6 + 3 + 1 + 5 + 1) # weights, metrics, losses, optimizer variables, scalar, inputs, outputs, gradients, layers else: assert len( tr.tensor_names()) == (6 + 2 + 1 + 5 + 1 if is_tf_2_3() else 6 + 3 + 1 + 5 + 1) else: assert (len(tr.tensor_names()) == 6 + 6 + 5 + 3 + 1 + 3 * strategy.num_replicas_in_sync + 2 * strategy.num_replicas_in_sync) # weights, grads, optimizer_variables, metrics, losses, outputs assert len(tr.steps()) == 3 for tname in tr.tensor_names(): assert len( tr.tensor(tname).workers(0)) == (1 if workers == "one" else strategy.num_replicas_in_sync) verify_files(out_dir, save_config, saved_scalars)
def helper_tensorflow_tests(use_keras, collection, save_config, with_timestamp): coll_name, coll_regex = collection run_id = "trial_" + coll_name + "-" + datetime.now().strftime( "%Y%m%d-%H%M%S%f") trial_dir = os.path.join(SMDEBUG_TF_HOOK_TESTS_DIR, run_id) if use_keras: hook = TF_KerasHook( out_dir=trial_dir, include_collections=[coll_name], save_config=save_config, export_tensorboard=True, ) saved_scalars = simple_tf_model(hook, with_timestamp=with_timestamp) else: hook = TF_SessionHook( out_dir=trial_dir, include_collections=[coll_name], save_config=save_config, export_tensorboard=True, ) saved_scalars = tf_session_model(hook, with_timestamp=with_timestamp) tf.reset_default_graph() hook.close() verify_files(trial_dir, save_config, saved_scalars) if with_timestamp: check_tf_events(trial_dir, saved_scalars)
def helper_xgboost_tests(collection, save_config, with_timestamp): coll_name, coll_regex = collection run_id = "trial_" + coll_name + "-" + datetime.now().strftime( "%Y%m%d-%H%M%S%f") trial_dir = os.path.join(SMDEBUG_XG_HOOK_TESTS_DIR, run_id) hook = XG_Hook( out_dir=trial_dir, include_collections=[coll_name], save_config=save_config, export_tensorboard=True, ) saved_scalars = simple_xg_model(hook, with_timestamp=with_timestamp) hook.close() verify_files(trial_dir, save_config, saved_scalars) if with_timestamp: check_tf_events(trial_dir, saved_scalars)