def test_validate_metadata_missing_required():
    """Tests that required metadata keys are checked for."""
    with pytest.raises(ValueError, match=r"Key experiment_name not found.*"):
        metadata_utils.parse_metadata({})

    with pytest.raises(ValueError, match=r"Key pipeline_name not found.*"):
        metadata_utils.parse_metadata({'experiment_name': 'test'})
Ejemplo n.º 2
0
    def __init__(self,
                 source_notebook_path: str,
                 notebook_metadata_overrides: dict = None,
                 debug: bool = False,
                 auto_snapshot: bool = False):
        self.auto_snapshot = auto_snapshot
        self.source_path = str(source_notebook_path)
        if not os.path.exists(self.source_path):
            raise ValueError("Path {} does not exist".format(self.source_path))

        # read notebook
        self.notebook = nb.read(self.source_path, as_version=nb.NO_CONVERT)

        # read Kale notebook metadata.
        # In case it is not specified get an empty dict
        notebook_metadata = self.notebook.metadata.get(
            KALE_NOTEBOOK_METADATA_KEY, dict())
        # override notebook metadata with provided arguments
        if notebook_metadata_overrides:
            notebook_metadata.update(notebook_metadata_overrides)

        # validate metadata and apply transformations when needed
        self.pipeline_metadata = parse_metadata(notebook_metadata)

        # used to set container step working dir same as current environment
        abs_working_dir = utils.get_abs_working_dir(self.source_path)
        self.pipeline_metadata['abs_working_dir'] = abs_working_dir
        self.detect_environment()

        # setup logging
        self.logger = logging.getLogger("kubeflow-kale")
        formatter = logging.Formatter(
            '%(asctime)s | %(name)s |  %(levelname)s: %(message)s',
            datefmt='%m-%d %H:%M')
        self.logger.setLevel(logging.DEBUG)

        stream_handler = logging.StreamHandler()
        if debug:
            stream_handler.setLevel(logging.DEBUG)
        else:
            stream_handler.setLevel(logging.INFO)
        stream_handler.setFormatter(formatter)

        self.log_dir_path = "."
        file_handler = logging.FileHandler(filename=self.log_dir_path +
                                           '/kale.log',
                                           mode='a')
        file_handler.setFormatter(formatter)
        file_handler.setLevel(logging.DEBUG)

        self.logger.addHandler(file_handler)
        self.logger.addHandler(stream_handler)

        # mute other loggers
        logging.getLogger('urllib3.connectionpool').setLevel(logging.CRITICAL)

        # Replace all requested cloned volumes with the snapshotted PVCs
        volumes = self.pipeline_metadata['volumes'][:] \
            if self.pipeline_metadata['volumes'] else []
        self.pipeline_metadata['volumes'] = self.create_cloned_volumes(volumes)
def test_validate_metadata(random_string, metadata, target):
    """Tests metadata is parsed correctly."""
    random_string.return_value = 'rnd'
    # these are required fields that will always have to be present in the
    # metadata dict
    target.update({'pipeline_name': metadata['pipeline_name'] + '-rnd'})
    target.update({'experiment_name': metadata['experiment_name']})
    assert target == metadata_utils.parse_metadata(metadata)