def test_list_jumpstart_models_deprecated_models( self, patched_get_model_specs: Mock, patched_get_manifest: Mock, ): patched_get_manifest.side_effect = get_prototype_manifest def deprecated_model_spec(*args, **kwargs): spec = get_prototype_model_spec(*args, **kwargs) spec.deprecated = True return spec patched_get_model_specs.side_effect = deprecated_model_spec num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT) assert [] == list_jumpstart_models("deprecated equals false") assert patched_get_model_specs.call_count == num_specs patched_get_manifest.assert_called_once() patched_get_manifest.reset_mock() patched_get_model_specs.reset_mock() assert [] != list_jumpstart_models() assert patched_get_model_specs.call_count == 0
def test_list_jumpstart_models_unsupported_models( self, patched_get_sagemaker_version: Mock, patched_get_model_specs: Mock, patched_get_manifest: Mock, ): patched_get_model_specs.side_effect = get_prototype_model_spec patched_get_manifest.side_effect = get_prototype_manifest patched_get_sagemaker_version.return_value = "0.0.0" assert [] == list_jumpstart_models("supported_model == True") patched_get_model_specs.assert_not_called() assert [] == list_jumpstart_models( And("supported_model == True", "training_supported in [False, True]") ) patched_get_model_specs.assert_not_called() assert [] != list_jumpstart_models("supported_model == False") patched_get_sagemaker_version.return_value = "999999.0.0" assert [] != list_jumpstart_models("supported_model == True") patched_get_model_specs.reset_mock() assert [] != list_jumpstart_models("training_supported in [False, True]") patched_get_model_specs.assert_called()
def test_list_jumpstart_models_multiple_level_index( self, patched_get_model_specs: Mock, patched_get_manifest: Mock, ): patched_get_model_specs.side_effect = get_prototype_model_spec patched_get_manifest.side_effect = get_prototype_manifest with pytest.raises(NotImplementedError): list_jumpstart_models("hosting_ecr_specs.py_version == py3")
def test_list_jumpstart_models_region( self, patched_get_model_specs: Mock, patched_get_manifest: Mock ): patched_get_model_specs.side_effect = get_prototype_model_spec patched_get_manifest.side_effect = lambda region: get_prototype_manifest(region="us-west-2") list_jumpstart_models(region="some-region") patched_get_manifest.assert_called_once_with(region="some-region")
def test_list_jumpstart_models_vulnerable_models( self, patched_get_model_specs: Mock, patched_get_manifest: Mock, ): patched_get_manifest.side_effect = get_prototype_manifest def vulnerable_inference_model_spec(*args, **kwargs): spec = get_prototype_model_spec(*args, **kwargs) spec.inference_vulnerable = True return spec def vulnerable_training_model_spec(*args, **kwargs): spec = get_prototype_model_spec(*args, **kwargs) spec.training_vulnerable = True return spec patched_get_model_specs.side_effect = vulnerable_inference_model_spec num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT) assert [] == list_jumpstart_models( And("inference_vulnerable is false", "training_vulnerable is false") ) assert patched_get_model_specs.call_count == num_specs patched_get_manifest.assert_called_once() patched_get_manifest.reset_mock() patched_get_model_specs.reset_mock() patched_get_model_specs.side_effect = vulnerable_training_model_spec assert [] == list_jumpstart_models( And("inference_vulnerable is false", "training_vulnerable is false") ) assert patched_get_model_specs.call_count == num_specs patched_get_manifest.assert_called_once() patched_get_manifest.reset_mock() patched_get_model_specs.reset_mock() assert [] != list_jumpstart_models() assert patched_get_model_specs.call_count == 0
def test_list_jumpstart_models_complex_queries( self, patched_get_model_specs: Mock, patched_get_manifest: Mock, ): patched_get_model_specs.side_effect = get_prototype_model_spec patched_get_manifest.side_effect = get_prototype_manifest assert list_jumpstart_models( Or( And( "task is ic", "framework is not huggingface", And("training_supported is true", Not("false")), "true", ), "false", "unknown", ) ) == ["tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1"] assert list_jumpstart_models( Or( And( "task is ic", "framework==tensorflow", Identity( And( And("incremental_training_supported==falSE"), "true", Or("unknown", "version equals 1.0.0"), ) ), And("training_supported is true", Not("false")), "true", ), "false", "unknown", ) ) == ["tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1"]
def test_list_jumpstart_models_no_versions( self, patched_get_model_specs: Mock, patched_get_manifest: Mock, ): patched_get_model_specs.side_effect = get_prototype_model_spec patched_get_manifest.side_effect = get_prototype_manifest all_model_ids = [ "catboost-classification-model", "huggingface-spc-bert-base-cased", "lightgbm-classification-model", "mxnet-semseg-fcn-resnet50-ade", "pytorch-eqa-bert-base-cased", "sklearn-classification-linear", "tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "xgboost-classification-model", ] assert list_jumpstart_models() == all_model_ids assert list_jumpstart_models(list_versions=False) == all_model_ids
def test_list_jumpstart_models_task_filter( self, patched_get_model_specs: Mock, patched_get_manifest: Mock ): patched_get_model_specs.side_effect = get_prototype_model_spec patched_get_manifest.side_effect = get_prototype_manifest vals = [ "classification", "eqa", "ic", "semseg", "spc", ] for val in vals: kwargs = {"filter": f"task == {val}"} list_jumpstart_models(**kwargs) patched_get_model_specs.assert_not_called() patched_get_manifest.assert_called_once() patched_get_manifest.reset_mock() patched_get_model_specs.reset_mock() kwargs = {"filter": f"task != {val}"} list_jumpstart_models(**kwargs) patched_get_model_specs.assert_not_called() patched_get_manifest.assert_called_once() patched_get_manifest.reset_mock() patched_get_model_specs.reset_mock() kwargs = {"filter": f"task in {vals}", "list_versions": True} assert list_jumpstart_models(**kwargs) == [ ("catboost-classification-model", "1.0.0"), ("huggingface-spc-bert-base-cased", "1.0.0"), ("lightgbm-classification-model", "1.0.0"), ("mxnet-semseg-fcn-resnet50-ade", "1.0.0"), ("pytorch-eqa-bert-base-cased", "1.0.0"), ("sklearn-classification-linear", "1.0.0"), ("tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.0.0"), ("xgboost-classification-model", "1.0.0"), ] patched_get_model_specs.assert_not_called() patched_get_manifest.assert_called_once() patched_get_manifest.reset_mock() patched_get_model_specs.reset_mock() kwargs = {"filter": f"task not in {vals}"} models = list_jumpstart_models(**kwargs) assert [] == models patched_get_model_specs.assert_not_called() patched_get_manifest.assert_called_once()
def test_list_jumpstart_models_simple_case( self, patched_get_model_specs: Mock, patched_get_manifest: Mock ): patched_get_model_specs.side_effect = get_prototype_model_spec patched_get_manifest.side_effect = get_prototype_manifest assert list_jumpstart_models(list_versions=True) == [ ("catboost-classification-model", "1.0.0"), ("huggingface-spc-bert-base-cased", "1.0.0"), ("lightgbm-classification-model", "1.0.0"), ("mxnet-semseg-fcn-resnet50-ade", "1.0.0"), ("pytorch-eqa-bert-base-cased", "1.0.0"), ("sklearn-classification-linear", "1.0.0"), ("tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.0.0"), ("xgboost-classification-model", "1.0.0"), ] patched_get_manifest.assert_called() patched_get_model_specs.assert_not_called()
def test_list_jumpstart_models_script_filter( self, patched_get_model_specs: Mock, patched_get_manifest: Mock ): patched_get_model_specs.side_effect = get_prototype_model_spec patched_get_manifest.side_effect = get_prototype_manifest manifest_length = len(get_prototype_manifest()) vals = [True, False] for val in vals: kwargs = {"filter": f"training_supported == {val}"} list_jumpstart_models(**kwargs) assert patched_get_model_specs.call_count == manifest_length patched_get_manifest.assert_called_once() patched_get_manifest.reset_mock() patched_get_model_specs.reset_mock() kwargs = {"filter": f"training_supported != {val}"} list_jumpstart_models(**kwargs) assert patched_get_model_specs.call_count == manifest_length patched_get_manifest.assert_called_once() patched_get_manifest.reset_mock() patched_get_model_specs.reset_mock() kwargs = {"filter": f"training_supported in {vals}", "list_versions": True} assert list_jumpstart_models(**kwargs) == [ ("catboost-classification-model", "1.0.0"), ("huggingface-spc-bert-base-cased", "1.0.0"), ("lightgbm-classification-model", "1.0.0"), ("mxnet-semseg-fcn-resnet50-ade", "1.0.0"), ("pytorch-eqa-bert-base-cased", "1.0.0"), ("sklearn-classification-linear", "1.0.0"), ("tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.0.0"), ("xgboost-classification-model", "1.0.0"), ] assert patched_get_model_specs.call_count == manifest_length patched_get_manifest.assert_called_once() patched_get_manifest.reset_mock() patched_get_model_specs.reset_mock() kwargs = {"filter": f"training_supported not in {vals}"} models = list_jumpstart_models(**kwargs) assert [] == models assert patched_get_model_specs.call_count == manifest_length patched_get_manifest.assert_called_once()
def test_list_jumpstart_models_old_models( self, patched_get_model_specs: Mock, patched_get_manifest: Mock, ): def get_manifest_more_versions(region: str = JUMPSTART_DEFAULT_REGION_NAME): return [ get_header_from_base_header(region=region, model_id=model_id, version=version) for model_id in PROTOTYPICAL_MODEL_SPECS_DICT.keys() for version in ["2.400.0", "1.4.0", "2.5.1", "1.300.0"] ] patched_get_manifest.side_effect = get_manifest_more_versions assert [ ("catboost-classification-model", "2.400.0"), ("catboost-classification-model", "2.5.1"), ("catboost-classification-model", "1.300.0"), ("catboost-classification-model", "1.4.0"), ("huggingface-spc-bert-base-cased", "2.400.0"), ("huggingface-spc-bert-base-cased", "2.5.1"), ("huggingface-spc-bert-base-cased", "1.300.0"), ("huggingface-spc-bert-base-cased", "1.4.0"), ("lightgbm-classification-model", "2.400.0"), ("lightgbm-classification-model", "2.5.1"), ("lightgbm-classification-model", "1.300.0"), ("lightgbm-classification-model", "1.4.0"), ("mxnet-semseg-fcn-resnet50-ade", "2.400.0"), ("mxnet-semseg-fcn-resnet50-ade", "2.5.1"), ("mxnet-semseg-fcn-resnet50-ade", "1.300.0"), ("mxnet-semseg-fcn-resnet50-ade", "1.4.0"), ("pytorch-eqa-bert-base-cased", "2.400.0"), ("pytorch-eqa-bert-base-cased", "2.5.1"), ("pytorch-eqa-bert-base-cased", "1.300.0"), ("pytorch-eqa-bert-base-cased", "1.4.0"), ("sklearn-classification-linear", "2.400.0"), ("sklearn-classification-linear", "2.5.1"), ("sklearn-classification-linear", "1.300.0"), ("sklearn-classification-linear", "1.4.0"), ("tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "2.400.0"), ("tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "2.5.1"), ("tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.300.0"), ("tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.4.0"), ("xgboost-classification-model", "2.400.0"), ("xgboost-classification-model", "2.5.1"), ("xgboost-classification-model", "1.300.0"), ("xgboost-classification-model", "1.4.0"), ] == list_jumpstart_models(list_old_models=True, list_versions=True) patched_get_model_specs.assert_not_called() patched_get_manifest.assert_called_once() patched_get_manifest.reset_mock() patched_get_model_specs.reset_mock() assert [ ("catboost-classification-model", "2.400.0"), ("huggingface-spc-bert-base-cased", "2.400.0"), ("lightgbm-classification-model", "2.400.0"), ("mxnet-semseg-fcn-resnet50-ade", "2.400.0"), ("pytorch-eqa-bert-base-cased", "2.400.0"), ("sklearn-classification-linear", "2.400.0"), ("tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "2.400.0"), ("xgboost-classification-model", "2.400.0"), ] == list_jumpstart_models(list_old_models=False, list_versions=True) assert list_jumpstart_models( list_old_models=False, list_versions=True ) == list_jumpstart_models(list_versions=True)