Exemplo n.º 1
0
class TestCreate:
    @pytest.mark.parametrize(
        ("entity_name", "visibility"),
        [
            ("dataset", OrgCustom(write=True)),
            ("project", OrgCustom(write=True, deploy=True)),
            ("registered_model", OrgCustom(write=True, deploy=True)),
        ]
    )
    def test_mdb_entity(self, client, organization, entity_name, visibility):
        create_entity = getattr(client, "create_{}".format(entity_name))

        entity = create_entity(workspace=organization.name, visibility=visibility)
        try:
            assert_visibility(entity, visibility, entity_name)
        finally:
            entity.delete()
            client._ctx.proj = None  # otherwise client teardown tries to delete

    def test_endpoint(self, client, organization, created_entities):
        visibility = OrgCustom(write=True)

        endpoint = client.create_endpoint(
            path=_utils.generate_default_name(),
            workspace=organization.name, visibility=visibility,
        )
        created_entities.append(endpoint)
        assert_endpoint_visibility(endpoint, visibility)
Exemplo n.º 2
0
    def test_endpoint_update_run(self, client_2, client_3, organization, created_entities):
        """Update endpoint from someone else's run."""
        LogisticRegression = pytest.importorskip("sklearn.linear_model").LogisticRegression

        organization.add_member(client_2._conn.email)
        organization.add_member(client_3._conn.email)
        client_2.set_workspace(organization.name)
        client_3.set_workspace(organization.name)

        endpoint = client_2.create_endpoint(_utils.generate_default_name())
        created_entities.append(endpoint)

        # private run
        created_entities.append(client_3.create_project(visibility=Private()))
        run = client_3.create_experiment_run()
        run.log_model(LogisticRegression(), custom_modules=[])
        run.log_environment(Python(["scikit-learn"]))
        with pytest.raises(requests.HTTPError, match="Access Denied|Forbidden"):
            endpoint.update(run)

        # org run, deploy=False
        created_entities.append(client_3.create_project(visibility=OrgCustom(deploy=False)))
        run = client_3.create_experiment_run()
        run.log_model(LogisticRegression(), custom_modules=[])
        run.log_environment(Python(["scikit-learn"]))
        with pytest.raises(requests.HTTPError, match="Access Denied|Forbidden"):
            endpoint.update(run)

        # org run, deploy=True
        created_entities.append(client_3.create_project(visibility=OrgCustom(deploy=True)))
        run = client_3.create_experiment_run()
        run.log_model(LogisticRegression(), custom_modules=[])
        run.log_environment(Python(["scikit-learn"]))
        assert endpoint.update(run)
Exemplo n.º 3
0
    def test_registered_model(self, client, organization, created_entities):
        visibility = OrgCustom(write=True)
        entity = client.set_registered_model(workspace=organization.name, visibility=visibility)
        created_entities.append(entity)

        if visibility._to_public_within_org():
            assert entity._msg.visibility == _CommonCommonService.VisibilityEnum.ORG_SCOPED_PUBLIC
        else:
            assert entity._msg.visibility == _CommonCommonService.VisibilityEnum.PRIVATE
Exemplo n.º 4
0
    def test_dataset(self, client, organization, created_entities):
        visibility = OrgCustom(write=True)
        dataset = client.set_dataset(workspace=organization.name, visibility=visibility)
        created_entities.append(dataset)

        if visibility._to_public_within_org():
            assert dataset._msg.dataset_visibility == _DatasetService.DatasetVisibilityEnum.ORG_SCOPED_PUBLIC
        else:
            assert dataset._msg.dataset_visibility == _DatasetService.DatasetVisibilityEnum.PRIVATE
Exemplo n.º 5
0
 def test_dataset(self, client, organization):
     visibility = OrgCustom(write=True)
     entity = client.set_dataset(workspace=organization.name,
                                 visibility=visibility)
     try:
         if visibility._to_public_within_org():
             assert entity._msg.dataset_visibility == _DatasetService.DatasetVisibilityEnum.ORG_SCOPED_PUBLIC
         else:
             assert entity._msg.dataset_visibility == _DatasetService.DatasetVisibilityEnum.PRIVATE
     finally:
         entity.delete()
Exemplo n.º 6
0
 def test_project(self, client, organization):
     visibility = OrgCustom(write=True)
     entity = client.set_project(workspace=organization.name, visibility=visibility)
     try:
         if visibility._to_public_within_org():
             assert entity._msg.project_visibility == _ProjectService.ORG_SCOPED_PUBLIC
         else:
             assert entity._msg.project_visibility == _ProjectService.PRIVATE
     finally:
         entity.delete()
         client._ctx.proj = None  # otherwise client teardown tries to delete
Exemplo n.º 7
0
 def test_registered_model(self, client, organization):
     visibility = OrgCustom(write=True)
     entity = client.set_registered_model(workspace=organization.name,
                                          visibility=visibility)
     try:
         if visibility._to_public_within_org():
             assert entity._msg.visibility == _CommonCommonService.VisibilityEnum.ORG_SCOPED_PUBLIC
         else:
             assert entity._msg.visibility == _CommonCommonService.VisibilityEnum.PRIVATE
     finally:
         entity.delete()
Exemplo n.º 8
0
    def test_endpoint(self, client, organization, created_entities):
        visibility = OrgCustom(write=True)
        endpoint = client.set_endpoint(
            path=_utils.generate_default_name(),
            workspace=organization.name, visibility=visibility,
        )
        created_entities.append(endpoint)

        endpoint_json = endpoint._get_json_by_id(endpoint._conn, endpoint.workspace, endpoint.id)
        if visibility._to_public_within_org():
            assert endpoint_json['creator_request']['visibility'] == "ORG_SCOPED_PUBLIC"
        else:
            assert endpoint_json['creator_request']['visibility'] == "PRIVATE"
Exemplo n.º 9
0
 def test_repository(self, client, organization):
     visibility = OrgCustom(write=True)
     repo = client.set_repository(
         name=_utils.generate_default_name(),
         workspace=organization.name,
         visibility=visibility,
     )
     try:
         retrieved_visibility = repo._get_proto_by_id(
             repo._conn, repo.id).repository_visibility
         if visibility._to_public_within_org():
             assert retrieved_visibility == _VersioningService.RepositoryVisibilityEnum.ORG_SCOPED_PUBLIC
         else:
             assert retrieved_visibility == _VersioningService.RepositoryVisibilityEnum.PRIVATE
     finally:
         repo.delete()
Exemplo n.º 10
0
    def test_repository(self, client, client_2, organization, created_entities):
        """
        The above, but for repository.

        Because there is no client.create_repository() or client.get_repository().
        """
        organization.add_member(client_2._conn.email)
        client.set_workspace(organization.name)
        client_2.set_workspace(organization.name)

        # private
        private_repo = client.set_repository(_utils.generate_default_name(), visibility=Private())
        created_entities.append(private_repo)
        with pytest.raises(Exception, match="unable to get Repository"):
            client_2.set_repository(private_repo.name)

        # read-only
        read_repo = client.set_repository(_utils.generate_default_name(), visibility=OrgCustom(write=False))
        created_entities.append(read_repo)
        retrieved_repo = client_2.set_repository(read_repo.name)
        assert retrieved_repo.id == read_repo.id
        with pytest.raises(requests.HTTPError, match="Access Denied|Forbidden"):
            retrieved_repo.delete()

        # read-write
        write_repo = client.set_repository(_utils.generate_default_name(), visibility=OrgCustom(write=True))
        try:
            retrieved_repo = client_2.set_repository(write_repo.name)
            retrieved_repo.delete()
        except:
            created_entities.append(write_repo)
Exemplo n.º 11
0
    def test_endpoint(self, client, organization, created_entities):
        visibility = OrgCustom(write=True)

        endpoint = client.create_endpoint(
            path=_utils.generate_default_name(),
            workspace=organization.name, visibility=visibility,
        )
        created_entities.append(endpoint)
        assert_endpoint_visibility(endpoint, visibility)
Exemplo n.º 12
0
    def test_endpoint_update_model_version(self, client_2, client_3,
                                           organization, created_entities):
        """Update endpoint from someone else's model version."""
        LogisticRegression = pytest.importorskip(
            "sklearn.linear_model").LogisticRegression

        organization.add_member(client_2._conn.auth['Grpc-Metadata-email'])
        organization.add_member(client_3._conn.auth['Grpc-Metadata-email'])
        client_2.set_workspace(organization.name)
        client_3.set_workspace(organization.name)

        endpoint = client_2.create_endpoint(_utils.generate_default_name())
        created_entities.append(endpoint)

        # private model version
        reg_model = client_3.create_registered_model(visibility=Private())
        created_entities.append(reg_model)
        model_ver = reg_model.create_version()
        model_ver.log_model(LogisticRegression(), custom_modules=[])
        model_ver.log_environment(Python(["scikit-learn"]))
        with pytest.raises(requests.HTTPError,
                           match="Access Denied|Forbidden"):
            endpoint.update(model_ver)

        # org model version, deploy=False
        reg_model = client_3.create_registered_model(visibility=OrgCustom(
            deploy=False))
        created_entities.append(reg_model)
        model_ver = reg_model.create_version()
        model_ver.log_model(LogisticRegression(), custom_modules=[])
        model_ver.log_environment(Python(["scikit-learn"]))
        with pytest.raises(requests.HTTPError,
                           match="Access Denied|Forbidden"):
            endpoint.update(model_ver)

        # org model version, deploy=True
        reg_model = client_3.create_registered_model(visibility=OrgCustom(
            deploy=True))
        created_entities.append(reg_model)
        model_ver = reg_model.create_version()
        model_ver.log_model(LogisticRegression(), custom_modules=[])
        model_ver.log_environment(Python(["scikit-learn"]))
        assert endpoint.update(model_ver)
Exemplo n.º 13
0
class TestSet:
    @pytest.mark.parametrize(("entity_name", "visibility"), [
        ("dataset", OrgCustom(write=True)),
        ("project", OrgCustom(write=True, deploy=True)),
        ("registered_model", OrgCustom(write=True, deploy=True)),
    ])
    def test_mdb_entity(self, client, organization, entity_name, visibility):
        set_entity = getattr(client, "set_{}".format(entity_name))

        entity = set_entity(workspace=organization.name, visibility=visibility)
        try:
            assert_visibility(entity, visibility, entity_name)

            # second set ignores visibility
            with pytest.warns(UserWarning, match="cannot set"):
                entity = set_entity(entity.name,
                                    workspace=organization.name,
                                    visibility=Private())
            assert_visibility(entity, visibility, entity_name)
        finally:
            entity.delete()
            client._ctx.proj = None  # otherwise client teardown tries to delete

    def test_endpoint(self, client, organization, created_entities):
        visibility = OrgCustom(write=True)

        endpoint = client.set_endpoint(
            path=_utils.generate_default_name(),
            workspace=organization.name,
            visibility=visibility,
        )
        created_entities.append(endpoint)

        assert_endpoint_visibility(endpoint, visibility)

        # second set ignores visibility
        with pytest.warns(UserWarning, match="cannot set"):
            endpoint = client.set_endpoint(path=endpoint.path,
                                           workspace=organization.name,
                                           visibility=Private())
        assert_endpoint_visibility(endpoint, visibility)
Exemplo n.º 14
0
    def test_endpoint(self, client, organization):
        visibility = OrgCustom(write=True)

        endpoint = client.create_endpoint(
            path=_utils.generate_default_name(),
            workspace=organization.name,
            visibility=visibility,
        )
        try:
            assert_endpoint_visibility(endpoint, visibility)
        finally:
            endpoint.delete()
Exemplo n.º 15
0
    def test_endpoint(self, client, organization, created_entities):
        visibility = OrgCustom(write=True)

        endpoint = client.set_endpoint(
            path=_utils.generate_default_name(),
            workspace=organization.name, visibility=visibility,
        )
        created_entities.append(endpoint)

        assert_endpoint_visibility(endpoint, visibility)

        # second set ignores visibility
        with pytest.warns(UserWarning, match="cannot set"):
            endpoint = client.set_endpoint(path=endpoint.path, workspace=organization.name, visibility=Private())
        assert_endpoint_visibility(endpoint, visibility)
Exemplo n.º 16
0
    def test_read_write(self, client, client_2, organization, created_entities, entity_name):
        """Org member can get, and delete."""
        organization.add_member(client_2._conn.email)
        client.set_workspace(organization.name)
        client_2.set_workspace(organization.name)
        name = _utils.generate_default_name()
        visibility = OrgCustom(write=True)

        entity = getattr(client, "create_{}".format(entity_name))(name, visibility=visibility)

        try:
            retrieved_entity = getattr(client_2, "get_{}".format(entity_name))(name)
            retrieved_entity.delete()
        except:
            created_entities.append(entity)
Exemplo n.º 17
0
    def test_read(self, client, client_2, organization, created_entities, entity_name):
        """Org member can get, but not delete."""
        organization.add_member(client_2._conn.email)
        client.set_workspace(organization.name)
        client_2.set_workspace(organization.name)
        name = _utils.generate_default_name()
        visibility = OrgCustom(write=False)

        entity = getattr(client, "create_{}".format(entity_name))(name, visibility=visibility)
        created_entities.append(entity)

        retrieved_entity = getattr(client_2, "get_{}".format(entity_name))(name)
        assert retrieved_entity.id == entity.id

        with pytest.raises(requests.HTTPError, match="Access Denied|Forbidden"):
            retrieved_entity.delete()
Exemplo n.º 18
0
    def test_read_registry(self, client, client_2, organization, created_entities):
        """Registry entities erroneously masked 403s in _update()."""
        organization.add_member(client_2._conn.email)
        client.set_workspace(organization.name)
        client_2.set_workspace(organization.name)
        visibility = OrgCustom(write=False)

        reg_model = client.create_registered_model(visibility=visibility)
        retrieved_reg_model = client_2.get_registered_model(reg_model.name)
        with pytest.raises(requests.HTTPError, match="Access Denied|Forbidden"):
            retrieved_reg_model.add_label("foo")

        model_ver = reg_model.create_version()
        retrieved_model_ver = retrieved_reg_model.get_version(model_ver.name)
        with pytest.raises(requests.HTTPError, match="Access Denied|Forbidden"):
            retrieved_model_ver.add_label("foo")
Exemplo n.º 19
0
    def test_repository(self, client, organization):
        visibility = OrgCustom(write=True)

        repo = client.set_repository(
            name=_utils.generate_default_name(),
            workspace=organization.name,
            visibility=visibility,
        )
        try:
            assert_repository_visibility(repo, visibility)

            # second set ignores visibility
            repo = client.set_repository(name=repo.name,
                                         workspace=organization.name,
                                         visibility=Private())
            assert_repository_visibility(repo, visibility)
        finally:
            repo.delete()