class NotebookConfig(PipelineConfig): """Config store for a notebook. This config extends the base pipeline config to take into account some small differences in the handling of a notebook. """ notebook_path = Field(type=str, required=True) # FIXME: Probably this can be removed. The labextension passes both # 'experiment_name' and 'experiment', but the latter is not used in the # backend. experiment = Field(type=dict) # Used in the UI to keep per-notebook state of the volumes snapshot toggle snapshot_volumes = Field(type=bool, default=False) # override from PipelineConfig: set the default value to False autosnapshot = Field(type=bool, default=False) @property def source_path(self): """Get the path to the source notebook.""" return self.notebook_path def _preprocess(self, kwargs): kwargs["steps_defaults"] = self._parse_steps_defaults( kwargs.get("steps_defaults")) def _parse_steps_defaults(self, steps_defaults): """Parse common step configuration defined in the metadata.""" result = dict() if not isinstance(steps_defaults, list): return steps_defaults for c in steps_defaults: if any(re.match(_c, c) for _c in _STEPS_DEFAULTS_LANGUAGE) is False: raise ValueError("Unrecognized common step configuration:" " {}".format(c)) parts = c.split(":") conf_type = parts.pop(0) if conf_type in ["annotation", "label"]: result_key = "{}s".format(conf_type) if result_key not in result: result[result_key] = dict() key, value = get_annotation_or_label_from_tag(parts) result[result_key][key] = value if conf_type == "limit": if "limits" not in result: result["limits"] = dict() key, value = get_limit_from_tag(parts) result["limits"][key] = value return result
class VolumeConfig(Config): """Used for validating the `volumes` field of NotebookConfig.""" name = Field(type=str, required=True, validators=[validators.K8sNameValidator]) mount_point = Field(type=str, required=True) snapshot = Field(type=bool, default=False) snapshot_name = Field(type=str) size = Field(type=int) # fixme: validation for this field? size_type = Field(type=str) # fixme: validation for this field? type = Field(type=str, required=True, validators=[validators.VolumeTypeValidator]) annotations = Field(type=list, default=list()) def _postprocess(self): # Convert annotations to a {k: v} dictionary try: # TODO: Make JupyterLab annotate with {k: v} instead of # {'key': k, 'value': v} self.annotations = { a['key']: a['value'] for a in self.annotations if a['key'] != '' and a['value'] != '' } except KeyError as e: if str(e) in ["'key'", "'value'"]: raise ValueError("Volume spec: volume annotations must be a" " list of {'key': k, 'value': v} dicts") else: raise e
class StepConfig(Config): """Config class used for the Step object.""" name = Field(type=str, required=True, validators=[validators.StepNameValidator]) labels = Field(type=dict, default=dict(), validators=[validators.K8sLabelsValidator]) annotations = Field(type=dict, default=dict(), validators=[validators.K8sAnnotationsValidator]) limits = Field(type=dict, default=dict(), validators=[validators.K8sLimitsValidator])
class StepConfig(Config): """Config class used for the Step object.""" name = Field(type=str, required=True, validators=[validators.StepNameValidator]) labels = Field(type=dict, default=dict(), validators=[validators.K8sLabelsValidator]) annotations = Field(type=dict, default=dict(), validators=[validators.K8sAnnotationsValidator]) limits = Field(type=dict, default=dict(), validators=[validators.K8sLimitsValidator]) retry_count = Field(type=int, default=0) retry_interval = Field(type=str) retry_factor = Field(type=int) retry_max_interval = Field(type=str) timeout = Field(type=int, validators=[validators.PositiveIntegerValidator])
class KatibConfig(Config): """Used to validate the `katib_metadata` field of NotebookConfig.""" # fixme: improve validation of single fields parameters = Field(type=list, default=[]) objective = Field(type=dict, default={}) algorithm = Field(type=dict, default={}) # fixme: Change these names to be Pythonic (need to change how the # labextension passes them) maxTrialCount = Field(type=int, default=12) maxFailedTrialCount = Field(type=int, default=3) parallelTrialCount = Field(type=int, default=3)
class PipelineConfig(Config): """Main config class to validate the pipeline metadata.""" pipeline_name = Field(type=str, required=True, validators=[validators.PipelineNameValidator]) experiment_name = Field(type=str, required=True) pipeline_description = Field(type=str, default="") docker_image = Field(type=str, default="") volumes = Field(type=list, items_config_type=VolumeConfig, default=[]) katib_run = Field(type=bool, default=False) katib_metadata = Field(type=KatibConfig) abs_working_dir = Field(type=str, default="") marshal_volume = Field(type=bool, default=True) marshal_path = Field(type=str, default="/marshal") autosnapshot = Field(type=bool, default=True) steps_defaults = Field(type=dict, default=dict()) kfp_host = Field(type=str) storage_class_name = Field(type=str, validators=[validators.K8sNameValidator]) volume_access_mode = Field( type=str, validators=[validators.IsLowerValidator, validators.VolumeAccessModeValidator]) @property def source_path(self): """Get the path to the main entry point script.""" return utils.get_main_source_path() def _postprocess(self): self._randomize_pipeline_name() self._set_docker_image() self._set_volume_storage_class() self._set_volume_access_mode() self._sort_volumes() self._set_abs_working_dir() self._set_marshal_path() def _randomize_pipeline_name(self): self.pipeline_name = "%s-%s" % (self.pipeline_name, utils.random_string()) def _set_docker_image(self): if not self.docker_image: try: self.docker_image = podutils.get_docker_base_image() except (ConfigException, FileNotFoundError, ApiException): # no K8s config found; use kfp default image self.docker_image = "" def _set_volume_storage_class(self): if not self.storage_class_name: return for v in self.volumes: if not v.storage_class_name: v.storage_class_name = self.storage_class_name def _set_volume_access_mode(self): if not self.volume_access_mode: self.volume_access_mode = DEFAULT_VOLUME_ACCESS_MODE else: self.volume_access_mode = VOLUME_ACCESS_MODE_MAP[ self.volume_access_mode] for v in self.volumes: if not v.volume_access_mode: v.volume_access_mode = self.volume_access_mode def _sort_volumes(self): # The Jupyter Web App assumes the first volume of the notebook is the # working directory, so we make sure to make it appear first in the # spec. self.volumes = sorted(self.volumes, reverse=True, key=lambda _v: podutils.is_workspace_dir( _v.mount_point)) def _set_abs_working_dir(self): if not self.abs_working_dir: self.abs_working_dir = utils.abs_working_dir(self.source_path) def _set_marshal_path(self): # Check if the workspace directory is under a mounted volume. # If so, marshal data into a folder in that volume, # otherwise create a new volume and mount it at /marshal wd = os.path.realpath(self.abs_working_dir) # get the volumes for which the working directory is a sub-path of # the mount point vols = list( filter(lambda x: wd.startswith(x.mount_point), self.volumes)) # if we found any, then set marshal directory inside working directory if len(vols) > 0: basename = os.path.basename(self.source_path) marshal_dir = ".{}.kale.marshal.dir".format(basename) self.marshal_volume = False self.marshal_path = os.path.join(wd, marshal_dir)