def test_collection_defaults_to_hook_config(): """Test that hook save_configs propagate to collection defaults. For example, if we set ModeKeys.TRAIN: save_interval=10 in the hook and ModeKeys.EVAL: save_interval=20 in a collection, we would like the collection to be finalized as {ModeKeys.TRAIN: save_interval=10, ModeKeys.EVAL: save_interval=20}. """ cm = CollectionManager() cm.create_collection("foo") cm.get("foo").include_regex = "*" cm.get("foo").save_config = { ModeKeys.EVAL: SaveConfigMode(save_interval=20) } hook = Hook( out_dir="/tmp/test_collections/" + str(datetime.datetime.now()), save_config={ModeKeys.TRAIN: SaveConfigMode(save_interval=10)}, include_collections=["foo"], reduction_config=ReductionConfig(save_raw_tensor=True), ) hook.collection_manager = cm assert cm.get("foo").save_config.mode_save_configs[ModeKeys.TRAIN] is None assert cm.get("foo").reduction_config is None hook._prepare_collections() assert cm.get("foo").save_config.mode_save_configs[ ModeKeys.TRAIN].save_interval == 10 assert cm.get("foo").reduction_config.save_raw_tensor is True
def test_export_load_dict_save_config(): c1 = Collection( "default", include_regex=["conv2d"], reduction_config=ReductionConfig(), save_config=SaveConfig({ ModeKeys.TRAIN: SaveConfigMode(save_interval=10), ModeKeys.EVAL: SaveConfigMode(start_step=1), }), ) c2 = Collection.from_json(c1.to_json()) assert c1 == c2 assert c1.to_json_dict() == c2.to_json_dict()
def test_invalid_collection_config_exception(): cm = CollectionManager() cm.create_collection("foo") hook = Hook( out_dir="/tmp/test_collections/" + str(datetime.datetime.now()), save_config={ModeKeys.TRAIN: SaveConfigMode(save_interval=10)}, include_collections=["foo"], reduction_config=ReductionConfig(save_raw_tensor=True), ) hook.collection_manager = cm try: hook._prepare_collections() except InvalidCollectionConfiguration: pass else: assert False, "Invalid Collection Name did not raise error" cm.get("foo").include_regex = "*" try: hook._prepare_collections() except InvalidCollectionConfiguration: assert False, "Valid Collection Name raised an error"
) saved_scalars = simple_pt_model(hook, register_loss=register_loss, with_timestamp=with_timestamp) hook.close() verify_files(trial_dir, save_config, saved_scalars) if with_timestamp: check_tf_events(trial_dir, saved_scalars) @pytest.mark.parametrize("collection", [("all", ".*"), ("scalars", "^scalar")]) @pytest.mark.parametrize( "save_config", [ SaveConfig(save_steps=[0, 2, 4, 6, 8]), SaveConfig({ ModeKeys.TRAIN: SaveConfigMode(save_interval=2), ModeKeys.GLOBAL: SaveConfigMode(save_interval=3), ModeKeys.EVAL: SaveConfigMode(save_interval=1), }), ], ) @pytest.mark.parametrize("register_loss", [True, False]) @pytest.mark.parametrize("with_timestamp", [True, False]) def test_pytorch_save_scalar(collection, save_config, register_loss, with_timestamp): helper_pytorch_tests(collection, register_loss, save_config, with_timestamp) delete_local_trials([SMDEBUG_PT_HOOK_TESTS_DIR])