예제 #1
0
    def __init__(self):
        log.info("%s Initializing MLMD context... %s", "-" * 10, "-" * 10)
        log.info("Connecting to MLMD...")
        self.store = self._connect()
        log.info("Successfully connected to MLMD")
        log.info("Getting step details...")
        log.info("Getting pod name...")
        self.pod_name = podutils.get_pod_name()
        log.info("Successfully retrieved pod name: %s", self.pod_name)
        log.info("Getting pod namespace...")
        self.pod_namespace = podutils.get_namespace()
        log.info("Successfully retrieved pod namespace: %s",
                 self.pod_namespace)
        log.info("Getting pod...")
        self.pod = podutils.get_pod(self.pod_name, self.pod_namespace)
        log.info("Successfully retrieved pod")
        log.info("Getting workflow name from pod...")
        self.workflow_name = self.pod.metadata.labels.get(
            workflowutils.ARGO_WORKFLOW_LABEL_KEY)
        log.info("Successfully retrieved workflow name: %s",
                 self.workflow_name)
        log.info("Getting workflow...")
        self.workflow = workflowutils.get_workflow(self.workflow_name,
                                                   self.pod_namespace)
        log.info("Successfully retrieved workflow")

        workflow_labels = self.workflow["metadata"].get("labels", {})
        self.run_uuid = workflow_labels.get(podutils.KFP_RUN_ID_LABEL_KEY,
                                            self.workflow_name)
        log.info("Successfully retrieved KFP run ID: %s", self.run_uuid)

        workflow_annotations = self.workflow["metadata"].get("annotations", {})
        pipeline_spec = json.loads(
            workflow_annotations.get("pipelines.kubeflow.org/pipeline_spec",
                                     "{}"))
        self.pipeline_name = pipeline_spec.get("name", self.workflow_name)
        if self.pipeline_name:
            log.info("Successfully retrieved KFP pipeline_name: %s",
                     self.pipeline_name)
        else:
            log.info("Could not retrieve KFP pipeline name")

        self.component_id = podutils.compute_component_id(self.pod)
        self.execution_hash = self.pod.metadata.annotations.get(
            MLMD_EXECUTION_HASH_PROPERTY_KEY)
        if self.execution_hash:
            log.info("Successfully retrieved execution hash: %s",
                     self.execution_hash)
        else:
            self.execution_hash = utils.random_string(10)
            log.info(
                "Failed to retrieve execution hash."
                " Generating random string...: %s", self.execution_hash)

        self.run_context = self._get_or_create_run_context()
        self.execution = self._create_execution_in_run_context()
        self._label_with_context_and_execution()
        log.info("%s Successfully initialized MLMD context %s", "-" * 10,
                 "-" * 10)
예제 #2
0
    def _label_with_context_and_execution(self):
        self.pod = podutils.get_pod(self.pod_name, self.pod_namespace)
        labels = self.pod.metadata.labels

        labels.setdefault(METADATA_EXECUTION_ID_LABEL_KEY,
                          str(self.execution.id))
        labels.setdefault(METADATA_CONTEXT_ID_LABEL_KEY,
                          str(self.run_context.id))
        podutils.patch_pod(self.pod_name, self.pod_namespace,
                           {"metadata": {
                               "labels": labels
                           }})
예제 #3
0
파일: nb.py 프로젝트: ydataai/kale
def find_poddefault_labels_on_server(request):
    """Find server's labels that correspond to poddefaults applied."""
    request.log.info("Retrieving PodDefaults applied to server...")
    applied_poddefaults = kfutils.find_applied_poddefaults(
        podutils.get_pod(podutils.get_pod_name(), podutils.get_namespace()),
        kfutils.list_poddefaults())
    pd_names = [pd["metadata"]["name"] for pd in applied_poddefaults]
    request.log.info("Retrieved applied PodDefaults: %s", pd_names)

    labels = kfutils.get_poddefault_labels(applied_poddefaults)
    request.log.info("PodDefault labels applied on server: %s",
                     ", ".join(["%s: %s" % (k, v) for k, v in labels.items()]))
    return labels
예제 #4
0
    def _annotate_artifacts(self, annotation_key, ids):
        self.pod = podutils.get_pod(self.pod_name, self.pod_namespace)
        annotations = self.pod.metadata.annotations

        all_ids_str = annotations.get(annotation_key, "[]")
        all_ids = json.loads(all_ids_str)
        all_ids.extend(ids)
        all_ids.sort()
        all_ids_str = json.dumps(all_ids)
        annotations[annotation_key] = all_ids_str

        podutils.patch_pod(self.pod_name, self.pod_namespace,
                           {"metadata": {
                               "annotations": annotations
                           }})