def test_model_registry() -> None: exp_id = exp.run_basic_test( conf.fixtures_path("mnist_pytorch/const-pytorch11.yaml"), conf.tutorials_path("mnist_pytorch"), None, ) d = Determined(conf.make_master_url()) mnist = d.create_model("mnist", "simple computer vision model") assert mnist.metadata == {} mnist.add_metadata({"testing": "metadata"}) db_model = d.get_model("mnist") # Make sure the model metadata is correct and correctly saved to the db. assert mnist.metadata == db_model.metadata assert mnist.metadata == {"testing": "metadata"} mnist.add_metadata({"some_key": "some_value"}) db_model = d.get_model("mnist") assert mnist.metadata == db_model.metadata assert mnist.metadata == {"testing": "metadata", "some_key": "some_value"} mnist.add_metadata({"testing": "override"}) db_model = d.get_model("mnist") assert mnist.metadata == db_model.metadata assert mnist.metadata == {"testing": "override", "some_key": "some_value"} mnist.remove_metadata(["some_key"]) db_model = d.get_model("mnist") assert mnist.metadata == db_model.metadata assert mnist.metadata == {"testing": "override"} checkpoint = d.get_experiment(exp_id).top_checkpoint() model_version = mnist.register_version(checkpoint.uuid) assert model_version.model_version == 1 latest_version = mnist.get_version() assert latest_version is not None assert latest_version.uuid == checkpoint.uuid d.create_model("transformer", "all you need is attention") d.create_model("object-detection", "a bounding box model") models = d.get_models(sort_by=ModelSortBy.NAME) assert [m.name for m in models] == ["mnist", "object-detection", "transformer"]
def register(detmaster: str, experiment_id: int, model_name: str) -> bool: # Submit determined experiment via CLI from determined.experimental import Determined import os os.environ['DET_MASTER'] = detmaster def get_validation_metric(checkpoint): metrics = checkpoint.validation['metrics'] config = checkpoint.experiment_config searcher = config['searcher'] smaller_is_better = bool(searcher['smaller_is_better']) metric_name = searcher['metric'] metric = metrics['validationMetrics'][metric_name] return (metric, smaller_is_better) def is_better(c1, c2): m1, smaller_is_better = get_validation_metric(c1) m2, _ = get_validation_metric(c2) if smaller_is_better and m1 < m2: return True return False d = Determined() checkpoint = d.get_experiment(experiment_id).top_checkpoint() try: model = d.get_model(model_name) except: # Model not yet in registry print(f'Registering new Model: {model_name}') model = d.create_model(model_name) print(f'Registering new version: {model_name}') model.register_version(checkpoint.uuid) return True
def test_model_registry() -> None: exp_id = exp.run_basic_test( conf.fixtures_path("mnist_pytorch/const-pytorch11.yaml"), conf.tutorials_path("mnist_pytorch"), None, ) d = Determined(conf.make_master_url()) mnist = None objectdetect = None tform = None try: # Create a model and validate twiddling the metadata. mnist = d.create_model("mnist", "simple computer vision model", labels=["a", "b"]) assert mnist.metadata == {} mnist.add_metadata({"testing": "metadata"}) db_model = d.get_model(mnist.name) # Make sure the model metadata is correct and correctly saved to the db. assert mnist.metadata == db_model.metadata assert mnist.metadata == {"testing": "metadata"} # Confirm we can look up a model by its ID db_model = d.get_model_by_id(mnist.model_id) assert db_model.name == "mnist" db_model = d.get_model(mnist.model_id) assert db_model.name == "mnist" # Confirm DB assigned username assert db_model.username == "determined" mnist.add_metadata({"some_key": "some_value"}) db_model = d.get_model(mnist.name) assert mnist.metadata == db_model.metadata assert mnist.metadata == { "testing": "metadata", "some_key": "some_value" } mnist.add_metadata({"testing": "override"}) db_model = d.get_model(mnist.name) assert mnist.metadata == db_model.metadata assert mnist.metadata == { "testing": "override", "some_key": "some_value" } mnist.remove_metadata(["some_key"]) db_model = d.get_model(mnist.name) assert mnist.metadata == db_model.metadata assert mnist.metadata == {"testing": "override"} mnist.set_labels(["hello", "world"]) db_model = d.get_model(mnist.name) assert mnist.labels == db_model.labels assert db_model.labels == ["hello", "world"] # confirm patch does not overwrite other fields mnist.set_description("abcde") db_model = d.get_model(mnist.name) assert db_model.metadata == {"testing": "override"} assert db_model.labels == ["hello", "world"] # overwrite labels to empty list mnist.set_labels([]) db_model = d.get_model(mnist.name) assert db_model.labels == [] # archive and unarchive assert mnist.archived is False mnist.archive() db_model = d.get_model(mnist.name) assert db_model.archived is True mnist.unarchive() db_model = d.get_model(mnist.name) assert db_model.archived is False # Register a version for the model and validate the latest. checkpoint = d.get_experiment(exp_id).top_checkpoint() model_version = mnist.register_version(checkpoint.uuid) assert model_version.model_version == 1 latest_version = mnist.get_version() assert latest_version is not None assert latest_version.checkpoint.uuid == checkpoint.uuid latest_version.set_name("Test 2021") db_version = mnist.get_version() assert db_version is not None assert db_version.name == "Test 2021" latest_version.set_notes("# Hello Markdown") db_version = mnist.get_version() assert db_version is not None assert db_version.notes == "# Hello Markdown" # Run another basic test and register its checkpoint as a version as well. # Validate the latest has been updated. exp_id = exp.run_basic_test( conf.fixtures_path("mnist_pytorch/const-pytorch11.yaml"), conf.tutorials_path("mnist_pytorch"), None, ) checkpoint = d.get_experiment(exp_id).top_checkpoint() model_version = mnist.register_version(checkpoint.uuid) assert model_version.model_version == 2 latest_version = mnist.get_version() assert latest_version is not None assert latest_version.checkpoint.uuid == checkpoint.uuid # Ensure the correct number of versions are present. all_versions = mnist.get_versions() assert len(all_versions) == 2 # Test deletion of model version latest_version.delete() all_versions = mnist.get_versions() assert len(all_versions) == 1 # Create some more models and validate listing models. tform = d.create_model("transformer", "all you need is attention") objectdetect = d.create_model("ac - Dc", "a test name model") models = d.get_models(sort_by=ModelSortBy.NAME) assert [m.name for m in models] == ["ac - Dc", "mnist", "transformer"] # Test model labels combined mnist.set_labels(["hello", "world"]) tform.set_labels(["world", "test", "zebra"]) labels = d.get_model_labels() assert labels == ["world", "hello", "test", "zebra"] # Test deletion of model tform.delete() tform = None models = d.get_models(sort_by=ModelSortBy.NAME) assert [m.name for m in models] == ["ac - Dc", "mnist"] finally: # Clean model registry of test models for model in [mnist, objectdetect, tform]: if model is not None: model.delete()
def test_model_registry() -> None: exp_id = exp.run_basic_test( conf.fixtures_path("mnist_pytorch/const-pytorch11.yaml"), conf.tutorials_path("mnist_pytorch"), None, ) d = Determined(conf.make_master_url()) # Create a model and validate twiddling the metadata. mnist = d.create_model("mnist", "simple computer vision model") assert mnist.metadata == {} mnist.add_metadata({"testing": "metadata"}) db_model = d.get_model("mnist") # Make sure the model metadata is correct and correctly saved to the db. assert mnist.metadata == db_model.metadata assert mnist.metadata == {"testing": "metadata"} mnist.add_metadata({"some_key": "some_value"}) db_model = d.get_model("mnist") assert mnist.metadata == db_model.metadata assert mnist.metadata == {"testing": "metadata", "some_key": "some_value"} mnist.add_metadata({"testing": "override"}) db_model = d.get_model("mnist") assert mnist.metadata == db_model.metadata assert mnist.metadata == {"testing": "override", "some_key": "some_value"} mnist.remove_metadata(["some_key"]) db_model = d.get_model("mnist") assert mnist.metadata == db_model.metadata assert mnist.metadata == {"testing": "override"} # Register a version for the model and validate the latest. checkpoint = d.get_experiment(exp_id).top_checkpoint() model_version = mnist.register_version(checkpoint.uuid) assert model_version.model_version == 1 latest_version = mnist.get_version() assert latest_version is not None assert latest_version.uuid == checkpoint.uuid # Run another basic test and register its checkpoint as a version as well. # Validate the latest has been updated. exp_id = exp.run_basic_test( conf.fixtures_path("mnist_pytorch/const-pytorch11.yaml"), conf.tutorials_path("mnist_pytorch"), None, ) checkpoint = d.get_experiment(exp_id).top_checkpoint() model_version = mnist.register_version(checkpoint.uuid) assert model_version.model_version == 2 latest_version = mnist.get_version() assert latest_version is not None assert latest_version.uuid == checkpoint.uuid # Ensure the correct number of versions are present. all_versions = mnist.get_versions() assert len(all_versions) == 2 # Create some more models and validate listing models. d.create_model("transformer", "all you need is attention") d.create_model("object-detection", "a bounding box model") models = d.get_models(sort_by=ModelSortBy.NAME) assert [m.name for m in models] == ["mnist", "object-detection", "transformer"]