Ejemplo n.º 1
0
def test_export_load():
    # with none as save config
    c1 = Collection(
        "default",
        include_regex=["conv2d"],
        tensor_names=["a", "b"],
        reduction_config=ReductionConfig(),
    )
    c2 = Collection.from_json(c1.to_json())
    assert c1 == c2
    assert c1.tensor_names == c2.tensor_names
    assert isinstance(c2.tensor_names, set)
Ejemplo n.º 2
0
def test_export_load_dict_save_config():
    c1 = Collection(
        "default",
        include_regex=["conv2d"],
        reduction_config=ReductionConfig(),
        save_config=SaveConfig({
            ModeKeys.TRAIN: SaveConfigMode(save_interval=10),
            ModeKeys.EVAL: SaveConfigMode(start_step=1),
        }),
    )
    c2 = Collection.from_json(c1.to_json())
    assert c1 == c2
    assert c1.to_json_dict() == c2.to_json_dict()
Ejemplo n.º 3
0
def test_manager_export_load():
    cm = CollectionManager()
    cm.create_collection("default")
    cm.get("default").include("loss")
    cm.add(Collection("trial1"))
    cm.add("trial2")
    cm.get("trial2").include("total_loss")
    cm.export("/tmp/dummy_trial", DEFAULT_COLLECTIONS_FILE_NAME)
    cm2 = CollectionManager.load(
        os.path.join(get_path_to_collections("/tmp/dummy_trial"),
                     DEFAULT_COLLECTIONS_FILE_NAME))
    assert cm == cm2
Ejemplo n.º 4
0
def test_manager():
    cm = CollectionManager()
    cm.create_collection("default")
    cm.get("default").include("loss")
    cm.get("default").add_tensor_name("assaas")
    cm.add(Collection("trial1"))
    cm.add("trial2")
    cm.get("trial2").include("total_loss")
    assert len(cm.collections) == 3
    assert cm.get("default") == cm.collections["default"]
    assert "loss" in cm.get("default").include_regex
    assert len(cm.get("default").tensor_names) > 0
    assert "total_loss" in cm.collections["trial2"].include_regex
Ejemplo n.º 5
0
def write_dummy_collection_file(trial):
    cm = CollectionManager()
    cm.create_collection("default")
    cm.add(Collection(trial))
    cm.export(trial, DEFAULT_COLLECTIONS_FILE_NAME)
Ejemplo n.º 6
0
def test_load_empty():
    c = Collection("trial")
    assert c == Collection.from_json(c.to_json())