def test_list_jumpstart_models_deprecated_models(
        self,
        patched_get_model_specs: Mock,
        patched_get_manifest: Mock,
    ):

        patched_get_manifest.side_effect = get_prototype_manifest

        def deprecated_model_spec(*args, **kwargs):
            spec = get_prototype_model_spec(*args, **kwargs)
            spec.deprecated = True
            return spec

        patched_get_model_specs.side_effect = deprecated_model_spec

        num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT)
        assert [] == list_jumpstart_models("deprecated equals false")

        assert patched_get_model_specs.call_count == num_specs
        patched_get_manifest.assert_called_once()

        patched_get_manifest.reset_mock()
        patched_get_model_specs.reset_mock()

        assert [] != list_jumpstart_models()

        assert patched_get_model_specs.call_count == 0
    def test_list_jumpstart_models_unsupported_models(
        self,
        patched_get_sagemaker_version: Mock,
        patched_get_model_specs: Mock,
        patched_get_manifest: Mock,
    ):
        patched_get_model_specs.side_effect = get_prototype_model_spec
        patched_get_manifest.side_effect = get_prototype_manifest

        patched_get_sagemaker_version.return_value = "0.0.0"

        assert [] == list_jumpstart_models("supported_model == True")
        patched_get_model_specs.assert_not_called()
        assert [] == list_jumpstart_models(
            And("supported_model == True", "training_supported in [False, True]")
        )
        patched_get_model_specs.assert_not_called()

        assert [] != list_jumpstart_models("supported_model == False")

        patched_get_sagemaker_version.return_value = "999999.0.0"

        assert [] != list_jumpstart_models("supported_model == True")

        patched_get_model_specs.reset_mock()

        assert [] != list_jumpstart_models("training_supported in [False, True]")
        patched_get_model_specs.assert_called()
    def test_list_jumpstart_models_multiple_level_index(
        self,
        patched_get_model_specs: Mock,
        patched_get_manifest: Mock,
    ):
        patched_get_model_specs.side_effect = get_prototype_model_spec
        patched_get_manifest.side_effect = get_prototype_manifest

        with pytest.raises(NotImplementedError):
            list_jumpstart_models("hosting_ecr_specs.py_version == py3")
    def test_list_jumpstart_models_region(
        self, patched_get_model_specs: Mock, patched_get_manifest: Mock
    ):
        patched_get_model_specs.side_effect = get_prototype_model_spec

        patched_get_manifest.side_effect = lambda region: get_prototype_manifest(region="us-west-2")

        list_jumpstart_models(region="some-region")

        patched_get_manifest.assert_called_once_with(region="some-region")
    def test_list_jumpstart_models_vulnerable_models(
        self,
        patched_get_model_specs: Mock,
        patched_get_manifest: Mock,
    ):

        patched_get_manifest.side_effect = get_prototype_manifest

        def vulnerable_inference_model_spec(*args, **kwargs):
            spec = get_prototype_model_spec(*args, **kwargs)
            spec.inference_vulnerable = True
            return spec

        def vulnerable_training_model_spec(*args, **kwargs):
            spec = get_prototype_model_spec(*args, **kwargs)
            spec.training_vulnerable = True
            return spec

        patched_get_model_specs.side_effect = vulnerable_inference_model_spec

        num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT)
        assert [] == list_jumpstart_models(
            And("inference_vulnerable is false", "training_vulnerable is false")
        )

        assert patched_get_model_specs.call_count == num_specs
        patched_get_manifest.assert_called_once()

        patched_get_manifest.reset_mock()
        patched_get_model_specs.reset_mock()

        patched_get_model_specs.side_effect = vulnerable_training_model_spec

        assert [] == list_jumpstart_models(
            And("inference_vulnerable is false", "training_vulnerable is false")
        )

        assert patched_get_model_specs.call_count == num_specs
        patched_get_manifest.assert_called_once()

        patched_get_manifest.reset_mock()
        patched_get_model_specs.reset_mock()

        assert [] != list_jumpstart_models()

        assert patched_get_model_specs.call_count == 0
    def test_list_jumpstart_models_complex_queries(
        self,
        patched_get_model_specs: Mock,
        patched_get_manifest: Mock,
    ):
        patched_get_model_specs.side_effect = get_prototype_model_spec
        patched_get_manifest.side_effect = get_prototype_manifest

        assert list_jumpstart_models(
            Or(
                And(
                    "task is ic",
                    "framework is not huggingface",
                    And("training_supported is true", Not("false")),
                    "true",
                ),
                "false",
                "unknown",
            )
        ) == ["tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1"]

        assert list_jumpstart_models(
            Or(
                And(
                    "task is ic",
                    "framework==tensorflow",
                    Identity(
                        And(
                            And("incremental_training_supported==falSE"),
                            "true",
                            Or("unknown", "version equals 1.0.0"),
                        )
                    ),
                    And("training_supported is true", Not("false")),
                    "true",
                ),
                "false",
                "unknown",
            )
        ) == ["tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1"]
    def test_list_jumpstart_models_no_versions(
        self,
        patched_get_model_specs: Mock,
        patched_get_manifest: Mock,
    ):
        patched_get_model_specs.side_effect = get_prototype_model_spec
        patched_get_manifest.side_effect = get_prototype_manifest

        all_model_ids = [
            "catboost-classification-model",
            "huggingface-spc-bert-base-cased",
            "lightgbm-classification-model",
            "mxnet-semseg-fcn-resnet50-ade",
            "pytorch-eqa-bert-base-cased",
            "sklearn-classification-linear",
            "tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1",
            "xgboost-classification-model",
        ]

        assert list_jumpstart_models() == all_model_ids

        assert list_jumpstart_models(list_versions=False) == all_model_ids
    def test_list_jumpstart_models_task_filter(
        self, patched_get_model_specs: Mock, patched_get_manifest: Mock
    ):
        patched_get_model_specs.side_effect = get_prototype_model_spec
        patched_get_manifest.side_effect = get_prototype_manifest

        vals = [
            "classification",
            "eqa",
            "ic",
            "semseg",
            "spc",
        ]
        for val in vals:
            kwargs = {"filter": f"task == {val}"}
            list_jumpstart_models(**kwargs)
            patched_get_model_specs.assert_not_called()
            patched_get_manifest.assert_called_once()

            patched_get_manifest.reset_mock()
            patched_get_model_specs.reset_mock()

            kwargs = {"filter": f"task != {val}"}
            list_jumpstart_models(**kwargs)
            patched_get_model_specs.assert_not_called()
            patched_get_manifest.assert_called_once()

            patched_get_manifest.reset_mock()
            patched_get_model_specs.reset_mock()

        kwargs = {"filter": f"task in {vals}", "list_versions": True}
        assert list_jumpstart_models(**kwargs) == [
            ("catboost-classification-model", "1.0.0"),
            ("huggingface-spc-bert-base-cased", "1.0.0"),
            ("lightgbm-classification-model", "1.0.0"),
            ("mxnet-semseg-fcn-resnet50-ade", "1.0.0"),
            ("pytorch-eqa-bert-base-cased", "1.0.0"),
            ("sklearn-classification-linear", "1.0.0"),
            ("tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.0.0"),
            ("xgboost-classification-model", "1.0.0"),
        ]
        patched_get_model_specs.assert_not_called()
        patched_get_manifest.assert_called_once()

        patched_get_manifest.reset_mock()
        patched_get_model_specs.reset_mock()

        kwargs = {"filter": f"task not in {vals}"}
        models = list_jumpstart_models(**kwargs)
        assert [] == models
        patched_get_model_specs.assert_not_called()
        patched_get_manifest.assert_called_once()
    def test_list_jumpstart_models_simple_case(
        self, patched_get_model_specs: Mock, patched_get_manifest: Mock
    ):
        patched_get_model_specs.side_effect = get_prototype_model_spec
        patched_get_manifest.side_effect = get_prototype_manifest
        assert list_jumpstart_models(list_versions=True) == [
            ("catboost-classification-model", "1.0.0"),
            ("huggingface-spc-bert-base-cased", "1.0.0"),
            ("lightgbm-classification-model", "1.0.0"),
            ("mxnet-semseg-fcn-resnet50-ade", "1.0.0"),
            ("pytorch-eqa-bert-base-cased", "1.0.0"),
            ("sklearn-classification-linear", "1.0.0"),
            ("tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.0.0"),
            ("xgboost-classification-model", "1.0.0"),
        ]

        patched_get_manifest.assert_called()
        patched_get_model_specs.assert_not_called()
    def test_list_jumpstart_models_script_filter(
        self, patched_get_model_specs: Mock, patched_get_manifest: Mock
    ):
        patched_get_model_specs.side_effect = get_prototype_model_spec
        patched_get_manifest.side_effect = get_prototype_manifest

        manifest_length = len(get_prototype_manifest())
        vals = [True, False]
        for val in vals:
            kwargs = {"filter": f"training_supported == {val}"}
            list_jumpstart_models(**kwargs)
            assert patched_get_model_specs.call_count == manifest_length
            patched_get_manifest.assert_called_once()

            patched_get_manifest.reset_mock()
            patched_get_model_specs.reset_mock()

            kwargs = {"filter": f"training_supported != {val}"}
            list_jumpstart_models(**kwargs)
            assert patched_get_model_specs.call_count == manifest_length
            patched_get_manifest.assert_called_once()

            patched_get_manifest.reset_mock()
            patched_get_model_specs.reset_mock()

        kwargs = {"filter": f"training_supported in {vals}", "list_versions": True}
        assert list_jumpstart_models(**kwargs) == [
            ("catboost-classification-model", "1.0.0"),
            ("huggingface-spc-bert-base-cased", "1.0.0"),
            ("lightgbm-classification-model", "1.0.0"),
            ("mxnet-semseg-fcn-resnet50-ade", "1.0.0"),
            ("pytorch-eqa-bert-base-cased", "1.0.0"),
            ("sklearn-classification-linear", "1.0.0"),
            ("tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.0.0"),
            ("xgboost-classification-model", "1.0.0"),
        ]
        assert patched_get_model_specs.call_count == manifest_length
        patched_get_manifest.assert_called_once()

        patched_get_manifest.reset_mock()
        patched_get_model_specs.reset_mock()

        kwargs = {"filter": f"training_supported not in {vals}"}
        models = list_jumpstart_models(**kwargs)
        assert [] == models
        assert patched_get_model_specs.call_count == manifest_length
        patched_get_manifest.assert_called_once()
    def test_list_jumpstart_models_old_models(
        self,
        patched_get_model_specs: Mock,
        patched_get_manifest: Mock,
    ):
        def get_manifest_more_versions(region: str = JUMPSTART_DEFAULT_REGION_NAME):
            return [
                get_header_from_base_header(region=region, model_id=model_id, version=version)
                for model_id in PROTOTYPICAL_MODEL_SPECS_DICT.keys()
                for version in ["2.400.0", "1.4.0", "2.5.1", "1.300.0"]
            ]

        patched_get_manifest.side_effect = get_manifest_more_versions

        assert [
            ("catboost-classification-model", "2.400.0"),
            ("catboost-classification-model", "2.5.1"),
            ("catboost-classification-model", "1.300.0"),
            ("catboost-classification-model", "1.4.0"),
            ("huggingface-spc-bert-base-cased", "2.400.0"),
            ("huggingface-spc-bert-base-cased", "2.5.1"),
            ("huggingface-spc-bert-base-cased", "1.300.0"),
            ("huggingface-spc-bert-base-cased", "1.4.0"),
            ("lightgbm-classification-model", "2.400.0"),
            ("lightgbm-classification-model", "2.5.1"),
            ("lightgbm-classification-model", "1.300.0"),
            ("lightgbm-classification-model", "1.4.0"),
            ("mxnet-semseg-fcn-resnet50-ade", "2.400.0"),
            ("mxnet-semseg-fcn-resnet50-ade", "2.5.1"),
            ("mxnet-semseg-fcn-resnet50-ade", "1.300.0"),
            ("mxnet-semseg-fcn-resnet50-ade", "1.4.0"),
            ("pytorch-eqa-bert-base-cased", "2.400.0"),
            ("pytorch-eqa-bert-base-cased", "2.5.1"),
            ("pytorch-eqa-bert-base-cased", "1.300.0"),
            ("pytorch-eqa-bert-base-cased", "1.4.0"),
            ("sklearn-classification-linear", "2.400.0"),
            ("sklearn-classification-linear", "2.5.1"),
            ("sklearn-classification-linear", "1.300.0"),
            ("sklearn-classification-linear", "1.4.0"),
            ("tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "2.400.0"),
            ("tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "2.5.1"),
            ("tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.300.0"),
            ("tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.4.0"),
            ("xgboost-classification-model", "2.400.0"),
            ("xgboost-classification-model", "2.5.1"),
            ("xgboost-classification-model", "1.300.0"),
            ("xgboost-classification-model", "1.4.0"),
        ] == list_jumpstart_models(list_old_models=True, list_versions=True)

        patched_get_model_specs.assert_not_called()
        patched_get_manifest.assert_called_once()

        patched_get_manifest.reset_mock()
        patched_get_model_specs.reset_mock()

        assert [
            ("catboost-classification-model", "2.400.0"),
            ("huggingface-spc-bert-base-cased", "2.400.0"),
            ("lightgbm-classification-model", "2.400.0"),
            ("mxnet-semseg-fcn-resnet50-ade", "2.400.0"),
            ("pytorch-eqa-bert-base-cased", "2.400.0"),
            ("sklearn-classification-linear", "2.400.0"),
            ("tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "2.400.0"),
            ("xgboost-classification-model", "2.400.0"),
        ] == list_jumpstart_models(list_old_models=False, list_versions=True)
        assert list_jumpstart_models(
            list_old_models=False, list_versions=True
        ) == list_jumpstart_models(list_versions=True)