예제 #1
0
def helper_mxnet_tests(collection, register_loss, save_config):
    coll_name, coll_regex = collection

    run_id = "trial_" + coll_name + "-" + datetime.now().strftime(
        "%Y%m%d-%H%M%S%f")
    trial_dir = os.path.join(SMDEBUG_MX_HOOK_TESTS_DIR, run_id)

    hook = MX_Hook(out_dir=trial_dir,
                   include_collections=[coll_name],
                   export_tensorboard=True)
    coll = hook.get_collection(coll_name)
    coll.save_config = save_config
    save_steps = save_config.get_save_config(ModeKeys.TRAIN).save_steps
    if not save_steps:
        save_interval = save_config.get_save_config(
            ModeKeys.TRAIN).save_interval
        save_steps = [i for i in range(0, 10, save_interval)]

    simple_mx_model(hook, register_loss=register_loss)
    hook.close()

    saved_scalars = [
        "scalar/mx_before_train", "scalar/mx_train_loss",
        "scalar/mx_after_train"
    ]
    check_trials(trial_dir, save_steps, coll_name, saved_scalars)
    check_metrics_file(saved_scalars)
def helper_mxnet_tests(collection, register_loss, save_config):
    coll_name, coll_regex = collection

    run_id = "trial_" + coll_name + "-" + datetime.now().strftime("%Y%m%d-%H%M%S%f")
    trial_dir = os.path.join(SMDEBUG_MX_HOOK_TESTS_DIR, run_id)

    hook = MX_Hook(
        out_dir=trial_dir,
        include_collections=[coll_name],
        save_config=save_config,
        export_tensorboard=True,
    )

    simple_mx_model(hook, register_loss=register_loss)
    hook.close()

    saved_scalars = ["scalar/mx_num_steps", "scalar/mx_before_train", "scalar/mx_after_train"]
    verify_files(trial_dir, save_config, saved_scalars)
예제 #3
0
def helper_mxnet_tests(collection, register_loss, save_config, with_timestamp):
    coll_name, coll_regex = collection

    run_id = "trial_" + coll_name + "-" + datetime.now().strftime("%Y%m%d-%H%M%S%f")
    trial_dir = os.path.join(SMDEBUG_MX_HOOK_TESTS_DIR, run_id)

    hook = MX_Hook(
        out_dir=trial_dir,
        include_collections=[coll_name],
        save_config=save_config,
        export_tensorboard=True,
    )

    saved_scalars = simple_mx_model(
        hook, register_loss=register_loss, with_timestamp=with_timestamp
    )
    hook.close()

    verify_files(trial_dir, save_config, saved_scalars)
    if with_timestamp:
        check_tf_events(trial_dir, saved_scalars)