def get_logs( self, db_session: Session, project: str, uid: str, size: int = -1, offset: int = 0, source: LogSources = LogSources.AUTO, ) -> typing.Tuple[str, bytes]: """ :return: Tuple with: 1. str of the run state (so watchers will know whether to continue polling for logs) 2. bytes of the logs themselves """ project = project or mlrun.mlconf.default_project out = b"" log_file = log_path(project, uid) data = get_db().read_run(db_session, uid, project) if not data: log_and_raise(HTTPStatus.NOT_FOUND.value, project=project, uid=uid) run_state = data.get("status", {}).get("state", "") if log_file.exists() and source in [LogSources.AUTO, LogSources.PERSISTENCY]: with log_file.open("rb") as fp: fp.seek(offset) out = fp.read(size) elif source in [LogSources.AUTO, LogSources.K8S]: if get_k8s(): pods = get_k8s().get_logger_pods(project, uid) if pods: pod, pod_phase = list(pods.items())[0] if pod_phase != PodPhases.pending: resp = get_k8s().logs(pod) if resp: out = resp.encode()[offset:] return run_state, out
def list_secret_keys( project: str, provider: schemas.SecretProviderName = schemas.SecretProviderName. vault, token: str = Header(None, alias=schemas.HeaderNames.secret_store_token), ): if provider == schemas.SecretProviderName.vault: if not token: raise mlrun.errors.MLRunInvalidArgumentError( "Vault list project secret keys request without providing token" ) vault = VaultStore(token) secret_values = vault.get_secrets(None, project=project) return schemas.SecretKeysData(provider=provider, secret_keys=list(secret_values.keys())) elif provider == schemas.SecretProviderName.kubernetes: if token: raise mlrun.errors.MLRunInvalidArgumentError( "Cannot specify token when requesting k8s secret keys") if get_k8s(): secret_keys = get_k8s().get_project_secret_keys(project) or [] return schemas.SecretKeysData(provider=provider, secret_keys=secret_keys) else: raise mlrun.errors.MLRunInternalServerError( "K8s provider cannot be initialized") else: raise mlrun.errors.MLRunInvalidArgumentError( f"Provider requested is not supported. provider = {provider}")
def _mock_create_namespaced_pod(self): def _generate_pod(namespace, pod): terminated_container_state = client.V1ContainerStateTerminated( finished_at=datetime.now(timezone.utc), exit_code=0) container_state = client.V1ContainerState( terminated=terminated_container_state) container_status = client.V1ContainerStatus( state=container_state, image=self.image_name, image_id="must-provide-image-id", name=self.name, ready=True, restart_count=0, ) status = client.V1PodStatus(phase=PodPhases.succeeded, container_statuses=[container_status]) response_pod = deepcopy(pod) response_pod.status = status response_pod.metadata.name = "test-pod" response_pod.metadata.namespace = namespace return response_pod get_k8s().v1api.create_namespaced_pod = unittest.mock.Mock( side_effect=_generate_pod) # Our purpose is not to test the client watching on logs, mock empty list (used in get_logger_pods) get_k8s().v1api.list_namespaced_pod = unittest.mock.Mock( return_value=client.V1PodList(items=[]))
def _assert_delete_namespaced_pods(expected_pod_names: List[str], expected_pod_namespace: str = None): calls = [ unittest.mock.call(expected_pod_name, expected_pod_namespace) for expected_pod_name in expected_pod_names ] if not expected_pod_names: assert get_k8s().v1api.delete_namespaced_pod.call_count == 0 else: get_k8s().v1api.delete_namespaced_pod.assert_has_calls(calls)
def _assert_runtime_handler_list_resources( runtime_kind, expected_crds=None, expected_pods=None, expected_services=None, ): runtime_handler = get_runtime_handler(runtime_kind) resources = runtime_handler.list_resources() crd_group, crd_version, crd_plural = runtime_handler._get_crd_info() get_k8s().v1api.list_namespaced_pod.assert_called_once_with( get_k8s().resolve_namespace(), label_selector=runtime_handler._get_default_label_selector(), ) if expected_crds: get_k8s( ).crdapi.list_namespaced_custom_object.assert_called_once_with( crd_group, crd_version, get_k8s().resolve_namespace(), crd_plural, label_selector=runtime_handler._get_default_label_selector(), ) if expected_services: get_k8s().v1api.list_namespaced_service.assert_called_once_with( get_k8s().resolve_namespace(), label_selector=runtime_handler._get_default_label_selector(), ) TestRuntimeHandlerBase._assert_list_resources_response( resources, expected_crds=expected_crds, expected_pods=expected_pods, expected_services=expected_services, )
def setup_method_fixture(self, db: Session, client: fastapi.testclient.TestClient): # We want this mock for every test, ideally we would have simply put it in the setup_method # but it is happening before the fixtures initialization. We need the client fixture (which needs the db one) # in order to be able to mock k8s stuff get_k8s().v1api = unittest.mock.Mock() get_k8s().crdapi = unittest.mock.Mock() get_k8s().is_running_inside_kubernetes_cluster = unittest.mock.Mock( return_value=True) # enable inheriting classes to do the same self.custom_setup_after_fixtures()
def _execute_run(self, runtime, **kwargs): # Reset the mock, so that when checking is create_pod was called, no leftovers are there (in case running # multiple runs in the same test) get_k8s().v1api.create_namespaced_pod.reset_mock() runtime.run( name=self.name, project=self.project, artifact_path=self.artifact_path, **kwargs, )
def _assert_list_namespaced_crds_calls(runtime_handler, expected_number_of_calls: int): crd_group, crd_version, crd_plural = runtime_handler._get_crd_info() assert (get_k8s().crdapi.list_namespaced_custom_object.call_count == expected_number_of_calls) get_k8s().crdapi.list_namespaced_custom_object.assert_any_call( crd_group, crd_version, get_k8s().resolve_namespace(), crd_plural, label_selector=runtime_handler._get_default_label_selector(), )
def _assert_run_logs( db: Session, project: str, uid: str, expected_log: str, logger_pod_name: str = None, ): if logger_pod_name is not None: get_k8s().v1api.read_namespaced_pod_log.assert_called_once_with( name=logger_pod_name, namespace=get_k8s().resolve_namespace(), ) _, log = crud.Logs.get_logs(db, project, uid, source=LogSources.PERSISTENCY) assert log == expected_log.encode()
def client() -> Generator: with TemporaryDirectory(suffix="mlrun-logs") as log_dir: mlconf.httpdb.logs_path = log_dir mlconf.runs_monitoring_interval = 0 mlconf.runtimes_cleanup_interval = 0 # in case some test setup already mocked them, don't override it if not hasattr(get_k8s(), "v1api"): get_k8s().v1api = unittest.mock.Mock() if not hasattr(get_k8s(), "crdapi"): get_k8s().crdapi = unittest.mock.Mock() with TestClient(app) as c: yield c
def _assert_list_namespaced_pods_calls( runtime_handler, expected_number_of_calls: int, expected_label_selector: str = None, ): assert (get_k8s().v1api.list_namespaced_pod.call_count == expected_number_of_calls) expected_label_selector = ( expected_label_selector or runtime_handler._get_default_label_selector()) get_k8s().v1api.list_namespaced_pod.assert_any_call( get_k8s().resolve_namespace(), label_selector=expected_label_selector)
def _mock_create_namespaced_pod(self): def _generate_pod(namespace, pod): terminated_container_state = client.V1ContainerStateTerminated( finished_at=datetime.now(timezone.utc), exit_code=0 ) container_state = client.V1ContainerState( terminated=terminated_container_state ) container_status = client.V1ContainerStatus( state=container_state, image=self.image_name, image_id="must-provide-image-id", name=self.name, ready=True, restart_count=0, ) status = client.V1PodStatus( phase=PodPhases.succeeded, container_statuses=[container_status] ) response_pod = deepcopy(pod) response_pod.status = status response_pod.metadata.name = "test-pod" response_pod.metadata.namespace = namespace return response_pod get_k8s().v1api.create_namespaced_pod = unittest.mock.Mock( side_effect=_generate_pod )
def _mock_list_namespaced_pods(list_pods_call_responses: List[List[client.V1Pod]]): calls = [] for list_pods_call_response in list_pods_call_responses: pods = client.V1PodList(items=list_pods_call_response) calls.append(pods) get_k8s().v1api.list_namespaced_pod = unittest.mock.Mock(side_effect=calls) return calls
def setup_method(self, method): self.namespace = mlconf.namespace = "test-namespace" get_k8s().namespace = self.namespace self._logger = logger self.project = "test-project" self.name = "test-function" self.run_uid = "test_run_uid" self.image_name = "mlrun/mlrun:latest" self.artifact_path = "/tmp" self.function_name_label = "mlrun/name" self.code_filename = str(self.assets_path / "sample_function.py") self.vault_secrets = ["secret1", "secret2", "AWS_KEY"] self.vault_secret_value = "secret123!@" self.vault_secret_name = "vault-secret" self.azure_vault_secrets = ["azure_secret1", "azure_secret2"] self.azure_secret_value = "azure-secret-123!@" self.azure_vault_secret_name = "k8s-vault-secret" self._logger.info( f"Setting up test {self.__class__.__name__}::{method.__name__}") self.custom_setup() self._logger.info( f"Finished setting up test {self.__class__.__name__}::{method.__name__}" )
def setup_method(self, method): self.namespace = mlconf.namespace = "test-namespace" get_k8s().namespace = self.namespace # set auto-mount to work as if this is an Iguazio system (otherwise it may try to mount PVC) mlconf.igz_version = "1.1.1" mlconf.storage.auto_mount_type = "auto" mlconf.storage.auto_mount_params = "" self._logger = logger self.project = "test-project" self.name = "test-function" self.run_uid = "test_run_uid" self.image_name = "mlrun/mlrun:latest" self.artifact_path = "/tmp" self.function_name_label = "mlrun/name" self.code_filename = str(self.assets_path / "sample_function.py") self.requirements_file = str(self.assets_path / "requirements.txt") self.vault_secrets = ["secret1", "secret2", "AWS_KEY"] self.vault_secret_value = "secret123!@" self.vault_secret_name = "vault-secret" self.azure_vault_secrets = ["azure_secret1", "azure_secret2"] self.azure_secret_value = "azure-secret-123!@" self.azure_vault_secret_name = "k8s-vault-secret" self._logger.info( f"Setting up test {self.__class__.__name__}::{method.__name__}") self.custom_setup() self._logger.info( f"Finished setting up test {self.__class__.__name__}::{method.__name__}" )
def _mock_list_namespaced_crds(crd_dicts_call_responses: List[List[Dict]]): calls = [] for crd_dicts_call_response in crd_dicts_call_responses: calls.append({"items": crd_dicts_call_response}) get_k8s().crdapi.list_namespaced_custom_object = unittest.mock.Mock( side_effect=calls) return calls
def _assert_runtime_handler_list_resources( self, runtime_kind, expected_crds=None, expected_pods=None, expected_services=None, group_by: Optional[ mlrun.api.schemas.ListRuntimeResourcesGroupByField] = None, ): runtime_handler = get_runtime_handler(runtime_kind) if group_by is None: project = "*" label_selector = runtime_handler._get_default_label_selector() assertion_func = TestRuntimeHandlerBase._assert_list_resources_response elif group_by == mlrun.api.schemas.ListRuntimeResourcesGroupByField.job: project = self.project label_selector = ",".join([ runtime_handler._get_default_label_selector(), f"mlrun/project={self.project}", ]) assertion_func = ( TestRuntimeHandlerBase._assert_list_resources_grouped_response) else: raise NotImplementedError("Unsupported group by value") resources = runtime_handler.list_resources(project, group_by=group_by) crd_group, crd_version, crd_plural = runtime_handler._get_crd_info() get_k8s().v1api.list_namespaced_pod.assert_called_once_with( get_k8s().resolve_namespace(), label_selector=label_selector, ) if expected_crds: get_k8s( ).crdapi.list_namespaced_custom_object.assert_called_once_with( crd_group, crd_version, get_k8s().resolve_namespace(), crd_plural, label_selector=label_selector, ) if expected_services: get_k8s().v1api.list_namespaced_service.assert_called_once_with( get_k8s().resolve_namespace(), label_selector=label_selector, ) assertion_func( resources, expected_crds=expected_crds, expected_pods=expected_pods, expected_services=expected_services, )
def delete_project_secrets( project: str, provider: str, secrets: List[str] = Query(None, alias="secret"), ): if provider == schemas.SecretProviderName.vault: raise mlrun.errors.MLRunInvalidArgumentError( f"Delete secret is not implemented for provider {provider}") elif provider == schemas.SecretProviderName.kubernetes: if get_k8s(): get_k8s().delete_project_secrets(project, secrets) else: raise mlrun.errors.MLRunInternalServerError( "K8s provider cannot be initialized") else: raise mlrun.errors.MLRunInvalidArgumentError( f"Provider requested is not supported. provider = {provider}") return Response(status_code=HTTPStatus.NO_CONTENT.value)
def test_serving_with_secrets_remote_build(self, db: Session, client: TestClient): orig_function = get_k8s()._get_project_secrets_raw_data get_k8s()._get_project_secrets_raw_data = unittest.mock.Mock(return_value={}) function = self._create_serving_function() # Simulate a remote build by issuing client's API. Code below is taken from httpdb. req = { "function": function.to_dict(), "with_mlrun": "no", "mlrun_version_specifier": "0.6.0", } response = client.post("/api/build/function", json=req) assert response.status_code == HTTPStatus.OK.value self._assert_deploy_called_basic_config(expected_class=self.class_name) get_k8s()._get_project_secrets_raw_data = orig_function
def _assert_delete_namespaced_custom_objects( runtime_handler, expected_custom_object_names: List[str], expected_custom_object_namespace: str = None, ): crd_group, crd_version, crd_plural = runtime_handler._get_crd_info() calls = [ unittest.mock.call( crd_group, crd_version, expected_custom_object_namespace, crd_plural, expected_custom_object_name, ) for expected_custom_object_name in expected_custom_object_names ] if not expected_custom_object_names: assert get_k8s().crdapi.delete_namespaced_custom_object.call_count == 0 else: get_k8s().crdapi.delete_namespaced_custom_object.assert_has_calls(calls)
def _mock_vault_functionality(self): secret_dict = {key: "secret" for key in self.vault_secrets} VaultStore.get_secrets = unittest.mock.Mock(return_value=secret_dict) object_meta = client.V1ObjectMeta(name="test-service-account", namespace=self.namespace) secret = client.V1ObjectReference(name=self.vault_secret_name, namespace=self.namespace) service_account = client.V1ServiceAccount(metadata=object_meta, secrets=[secret]) get_k8s().v1api.read_namespaced_service_account = unittest.mock.Mock( return_value=service_account)
def test_run_with_k8s_secrets(self, db: Session, client: TestClient): project_secret_name = "dummy_secret_name" secret_keys = ["secret1", "secret2", "secret3"] # Need to do some mocking, so code thinks that the secret contains these keys. Otherwise it will not add # the env. variables to the pod spec. get_k8s().get_project_secret_name = unittest.mock.Mock( return_value=project_secret_name) get_k8s().get_project_secret_keys = unittest.mock.Mock( return_value=secret_keys) runtime = self._generate_runtime() task = self._generate_task() task.metadata.project = self.project secret_source = { "kind": "kubernetes", "source": secret_keys, } task.with_secrets(secret_source["kind"], secret_keys) # What we expect in this case is that environment variables will be added to the pod which get their # value from the k8s secret, using the correct keys. expected_env_from_secrets = {} for key in secret_keys: env_variable_name = SecretsStore._k8s_env_variable_name_for_secret( key) expected_env_from_secrets[env_variable_name] = { project_secret_name: key } self._execute_run(runtime, runspec=task) self._assert_pod_creation_config( expected_secrets=secret_source, expected_env_from_secrets=expected_env_from_secrets, )
def initialize_project_secrets( project: str, secrets: schemas.SecretsData, ): if secrets.provider == schemas.SecretProviderName.vault: # Init is idempotent and will do nothing if infra is already in place init_project_vault_configuration(project) # If no secrets were passed, no need to touch the actual secrets. if secrets.secrets: add_vault_project_secrets(project, secrets.secrets) elif secrets.provider == schemas.SecretProviderName.kubernetes: # K8s secrets is the only other option right now if get_k8s(): get_k8s().store_project_secrets(project, secrets.secrets) else: raise mlrun.errors.MLRunInternalServerError( "K8s provider cannot be initialized") else: raise mlrun.errors.MLRunInvalidArgumentError( f"Provider requested is not supported. provider = {secrets.provider}" ) return Response(status_code=HTTPStatus.CREATED.value)
def _generate_mpijob_crd(project, uid, status=None): crd_dict = { "metadata": { "name": "train-eaf63df8", "namespace": get_k8s().resolve_namespace(), "labels": { "mlrun/class": "mpijob", "mlrun/function": "trainer", "mlrun/name": "train", "mlrun/project": project, "mlrun/scrape_metrics": "False", "mlrun/tag": "latest", "mlrun/uid": uid, }, }, } if status is not None: crd_dict["status"] = status return crd_dict
def _generate_pod(name, labels, phase=PodPhases.succeeded): terminated_container_state = client.V1ContainerStateTerminated( finished_at=datetime.now(timezone.utc), exit_code=0 ) container_state = client.V1ContainerState(terminated=terminated_container_state) container_status = client.V1ContainerStatus( state=container_state, image="must/provide:image", image_id="must-provide-image-id", name="must-provide-name", ready=True, restart_count=0, ) status = client.V1PodStatus(phase=phase, container_statuses=[container_status]) metadata = client.V1ObjectMeta( name=name, labels=labels, namespace=get_k8s().resolve_namespace() ) pod = client.V1Pod(metadata=metadata, status=status) return pod
def _generate_sparkjob_crd(project, uid, status=None): if status is None: status = TestSparkjobRuntimeHandler._get_completed_crd_status() crd_dict = { "metadata": { "name": "my-spark-jdbc-2ea432f1", "namespace": get_k8s().resolve_namespace(), "labels": { "mlrun/class": "spark", "mlrun/function": "my-spark-jdbc", "mlrun/name": "my-spark-jdbc", "mlrun/project": project, "mlrun/scrape_metrics": "False", "mlrun/tag": "latest", "mlrun/uid": uid, }, }, "status": status, } return crd_dict
def setup_method(self, method): self.namespace = mlconf.namespace = "test-namespace" get_k8s().namespace = self.namespace self._logger = logger self.project = "test-project" self.name = "test-function" self.run_uid = "test_run_uid" self.image_name = "mlrun/mlrun:latest" self.artifact_path = "/tmp" self.function_name_label = "mlrun/name" self.code_filename = str(self.assets_path / "sample_function.py") self._logger.info( f"Setting up test {self.__class__.__name__}::{method.__name__}") self.custom_setup() self._logger.info( f"Finished setting up test {self.__class__.__name__}::{method.__name__}" )
def _generate_mpijob_crd(project, uid, status=None): if status is None: status = TestMPIjobRuntimeHandler._get_succeeded_crd_status() crd_dict = { "metadata": { "name": "train-eaf63df8", "namespace": get_k8s().resolve_namespace(), "labels": { "mlrun/class": "mpijob", "mlrun/function": "trainer", "mlrun/name": "train", "mlrun/project": project, "mlrun/scrape_metrics": "False", "mlrun/tag": "latest", "mlrun/uid": uid, }, }, "status": status, } return crd_dict
def pod_create_mock(): create_pod_orig_function = get_k8s().create_pod _get_project_secrets_raw_data_orig_function = ( get_k8s()._get_project_secrets_raw_data) get_k8s().create_pod = unittest.mock.Mock(return_value=("pod-name", "namespace")) get_k8s()._get_project_secrets_raw_data = unittest.mock.Mock( return_value={}) update_run_state_orig_function = ( mlrun.runtimes.kubejob.KubejobRuntime._update_run_state) mlrun.runtimes.kubejob.KubejobRuntime._update_run_state = unittest.mock.Mock( ) mock_run_object = mlrun.RunObject() mock_run_object.metadata.uid = "1234567890" mock_run_object.metadata.project = "project-name" wrap_run_result_orig_function = mlrun.runtimes.base.BaseRuntime._wrap_run_result mlrun.runtimes.base.BaseRuntime._wrap_run_result = unittest.mock.Mock( return_value=mock_run_object) auth_info_mock = AuthInfo(username=username, session="some-session", data_session=access_key) authenticate_request_orig_function = ( mlrun.api.utils.auth.verifier.AuthVerifier().authenticate_request) mlrun.api.utils.auth.verifier.AuthVerifier( ).authenticate_request = unittest.mock.Mock(return_value=auth_info_mock) yield get_k8s().create_pod # Have to revert the mocks, otherwise other tests are failing get_k8s().create_pod = create_pod_orig_function get_k8s()._get_project_secrets_raw_data = ( _get_project_secrets_raw_data_orig_function) mlrun.runtimes.kubejob.KubejobRuntime._update_run_state = ( update_run_state_orig_function) mlrun.runtimes.base.BaseRuntime._wrap_run_result = wrap_run_result_orig_function mlrun.api.utils.auth.verifier.AuthVerifier().authenticate_request = ( authenticate_request_orig_function)
def _mock_list_services(services): services_list = client.V1ServiceList(items=services) get_k8s().v1api.list_namespaced_service = unittest.mock.Mock( return_value=services_list) return services