コード例 #1
0
def test_collection_defaults_to_hook_config():
    """Test that hook save_configs propagate to collection defaults.

  For example, if we set ModeKeys.TRAIN: save_interval=10 in the hook
  and ModeKeys.EVAL: save_interval=20 in a collection, we would like the collection to
  be finalized as {ModeKeys.TRAIN: save_interval=10, ModeKeys.EVAL: save_interval=20}.
  """
    cm = CollectionManager()
    cm.create_collection("foo")
    cm.get("foo").include_regex = "*"
    cm.get("foo").save_config = {
        ModeKeys.EVAL: SaveConfigMode(save_interval=20)
    }

    hook = Hook(
        out_dir="/tmp/test_collections/" + str(datetime.datetime.now()),
        save_config={ModeKeys.TRAIN: SaveConfigMode(save_interval=10)},
        include_collections=["foo"],
        reduction_config=ReductionConfig(save_raw_tensor=True),
    )
    hook.collection_manager = cm
    assert cm.get("foo").save_config.mode_save_configs[ModeKeys.TRAIN] is None
    assert cm.get("foo").reduction_config is None
    hook._prepare_collections()
    assert cm.get("foo").save_config.mode_save_configs[
        ModeKeys.TRAIN].save_interval == 10
    assert cm.get("foo").reduction_config.save_raw_tensor is True
コード例 #2
0
def test_export_load_dict_save_config():
    c1 = Collection(
        "default",
        include_regex=["conv2d"],
        reduction_config=ReductionConfig(),
        save_config=SaveConfig({
            ModeKeys.TRAIN: SaveConfigMode(save_interval=10),
            ModeKeys.EVAL: SaveConfigMode(start_step=1),
        }),
    )
    c2 = Collection.from_json(c1.to_json())
    assert c1 == c2
    assert c1.to_json_dict() == c2.to_json_dict()
コード例 #3
0
def test_invalid_collection_config_exception():
    cm = CollectionManager()
    cm.create_collection("foo")

    hook = Hook(
        out_dir="/tmp/test_collections/" + str(datetime.datetime.now()),
        save_config={ModeKeys.TRAIN: SaveConfigMode(save_interval=10)},
        include_collections=["foo"],
        reduction_config=ReductionConfig(save_raw_tensor=True),
    )
    hook.collection_manager = cm
    try:
        hook._prepare_collections()
    except InvalidCollectionConfiguration:
        pass
    else:
        assert False, "Invalid Collection Name did not raise error"

    cm.get("foo").include_regex = "*"
    try:
        hook._prepare_collections()
    except InvalidCollectionConfiguration:
        assert False, "Valid Collection Name raised an error"
コード例 #4
0
    )

    saved_scalars = simple_pt_model(hook,
                                    register_loss=register_loss,
                                    with_timestamp=with_timestamp)
    hook.close()
    verify_files(trial_dir, save_config, saved_scalars)
    if with_timestamp:
        check_tf_events(trial_dir, saved_scalars)


@pytest.mark.parametrize("collection", [("all", ".*"), ("scalars", "^scalar")])
@pytest.mark.parametrize(
    "save_config",
    [
        SaveConfig(save_steps=[0, 2, 4, 6, 8]),
        SaveConfig({
            ModeKeys.TRAIN: SaveConfigMode(save_interval=2),
            ModeKeys.GLOBAL: SaveConfigMode(save_interval=3),
            ModeKeys.EVAL: SaveConfigMode(save_interval=1),
        }),
    ],
)
@pytest.mark.parametrize("register_loss", [True, False])
@pytest.mark.parametrize("with_timestamp", [True, False])
def test_pytorch_save_scalar(collection, save_config, register_loss,
                             with_timestamp):
    helper_pytorch_tests(collection, register_loss, save_config,
                         with_timestamp)
    delete_local_trials([SMDEBUG_PT_HOOK_TESTS_DIR])