Пример #1
0
class DatasetCache(object):
    """A wrapper class around a dataset cache folder.

    Args:
        prefix (str): A cache prefix to cluster all related entry together and avoid eventual collisions.

    """
    def __init__(self, prefix):
        # Store parameters
        self._path = Path(user_cache_dir(appname='plums')) / prefix
        self._resolver = PathResolver('{key}.json')

        # Create prefix if it does not exist
        self._path.mkdir(parents=True, exist_ok=True)

    @staticmethod
    def hash(*keys):
        """Compute a SHA256 hash string from a tuple of strings.

        Args:
            *keys (str): String keys to hash.

        Returns:
            str: A SHA256 digest of the provided string keys.

        """
        return sha256((''.join(keys)).encode('utf8')).hexdigest()

    def retrieve(self, *keys):
        """Retrieve a JSON-stored dataset from the cache prefixed folder.

        Args:
            *keys (str): The requested dataset string keys.

        Returns:
            Any: The deserialized JSON-stored dataset object corresponding to the provided keys.

        Raises:
            NotInCacheError: If the provided keys does not match any entry in the cache prefixed folder.

        """
        key = self.hash(*keys)
        index = {
            path.match['key']: path
            for path in self._resolver.find(self._path)
        }

        if key not in index:
            raise NotInCacheError(self._path[-1], key)

        return load(index[key])

    def cache(self, data, *keys):
        """Store a JSON-stored dataset in the cache prefixed folder.

        Args:
            data (Any): A JSON-serializable object to store in the cache.
            *keys (str): The requested dataset string keys.

        """
        # Dump cache
        dump(data, self._path / '{}.json'.format(self.hash(*keys)))
Пример #2
0
    def save(self, path, force=False, **kwargs):
        """Save a |Model| to |Path|.

        Args:
            path (PathLike): The |Path| where to save.
            force (bool): Optional. Default to ``False``. If path is an existing non-PMF path or a PMF model with the
                same :attr:`id`, do not raise and carry on saving.

        Raises:
            ValueError: If ``path`` points to a file.
            OSError: If ``path`` points to:

                * A non-empty directory which does not contains a PMF model and ``force`` is ``False``.
                * A non-empty directory which contains a PMF model with the same :attr:`id` and ``force`` is ``False``.
                * A non-empty directory which contains a PMF model with a different :attr:`id`.
                * A non-empty directory which contains a PMF model with an invalid metadata file.

        """
        # TODO: Improve docstring.
        path = Path(path)
        model_dst = Mock()

        # sanity checks
        if path.exists():
            if path.is_file():
                raise ValueError('Invalid path: {} is a file.'.format(path))

            if (path / 'metadata.yaml').exists():
                with open(str(path / 'metadata.yaml'), 'r') as f:
                    metadata = yaml.safe_load(f)

                try:
                    metadata = Metadata().validate(metadata)
                except (SchemaError, PlumsValidationError):
                    # If the metadata file happens to be invalid, we might enter uncharted territories we are not
                    # prepared for. Abort !
                    raise OSError(
                        'Invalid path: {} is not a valid PMF metadata file.'.
                        format(path / 'metadata.yaml'))

                if metadata['model']['id'] != self.id:
                    # If the destination model id is different from ours, we might enter uncharted territories we are
                    # not prepared for. Abort !
                    raise OSError(
                        'Invalid path: {} has a different PMF model id '
                        '({} != {}).'.format(path / 'metadata.yaml', self.id,
                                             metadata['model']['id']))

                try:
                    model_dst = Model.load(path,
                                           checkpoints=kwargs.get(
                                               'checkpoints', True))
                except (SchemaError, PlumsValidationError):
                    if not force:
                        raise OSError(
                            'Invalid path: {} is an invalid PMF model '
                            'with the same model id ({}).'.format(
                                path / 'metadata.yaml', self.id))
                    # Use the insider fail-agnostic back door to load what we can from the model anyway
                    model_dst = Model._init_from_path(path, metadata)
                    # We remove PMF related elements as the previous written model is not valid, not that is the
                    # deletion fails, we ignore it because a valid PMF model will be written anyway and we never
                    # assume the save destination to be empty.
                    rmtree(path,
                           ignore_errors=True,
                           black_list=('metadata',
                                       model_dst.producer.configuration))
            else:
                if not force:
                    raise OSError(
                        'Invalid path: {} already exists.'.format(path))

        # Initialize destination
        path.mkdir(parents=True, exist_ok=True)

        # Prepare metadata dictionary
        __metadata__ = {
            'format': {
                'version': self.__version__,
                'producer': {
                    'name': self.producer.name,
                    'version': {
                        'format': self.producer.version.format,
                        'value': self.producer.version.version
                    }
                }
            },
            'model': {
                'name': self.name,
                'id': self.id,
                'training': {
                    'status': self.training.status,
                    'start_epoch': self.training.start_epoch,
                    'start_time': self.training.start_timestamp,
                    'latest_epoch': self.training.latest_epoch,
                    'latest_time': self.training.latest_timestamp,
                    'end_epoch': self.training.end_epoch,
                    'end_time': self.training.end_timestamp,
                    'latest': self.checkpoint_collection.latest,
                    'checkpoints': {}
                },
                'initialisation': None,
                'configuration': {}
            }
        }

        # Initialize directory
        (path / 'data' / 'checkpoints').mkdir(parents=True, exist_ok=True)

        # Save build parameters
        # It should be a rather small file, so blindingly overriding it
        # should be faster than write-in-temp and lazy-copy
        with open(str(path / 'data' / 'build_parameters.yaml'), 'w') as f:
            yaml.safe_dump(self.build_parameters, f)

        # Copy configuration
        configuration_dst = path / self.producer.configuration[-1]
        copy(str(self.producer.configuration),
             str(configuration_dst),
             lazy=model_dst is not None)
        # Add configuration to metadata
        __metadata__['model']['configuration'].update({
            'path':
            str(configuration_dst.anchor_to_path(path)),
            'hash':
            md5_checksum(self.producer.configuration)
        })

        # Copy initialisation
        if self.initialisation is None:
            (path / 'data' / 'initialisation').mkdir(parents=True,
                                                     exist_ok=True)

        if isinstance(self.initialisation, Checkpoint):
            (path / 'data' / 'initialisation').mkdir(parents=True,
                                                     exist_ok=True)
            checkpoint_dst = path / 'data' / 'initialisation' / self.initialisation.path[
                -1]
            copy(str(self.initialisation.path),
                 str(checkpoint_dst),
                 lazy=model_dst is not None,
                 src_hash=self.initialisation.hash,
                 dst_hash=getattr(model_dst.initialisation, 'name', None))
            # Add file initialisation to metadata
            __metadata__['model']['initialisation'] = {
                'file': {
                    'name': str(self.initialisation.name),
                    'path': str(checkpoint_dst.anchor_to_path(path)),
                    'hash': self.initialisation.hash
                }
            }

        if isinstance(self.initialisation, Model):
            self.initialisation.save(path / 'data' / 'initialisation',
                                     force=force,
                                     checkpoints=False)
            # Add PMF initialisation to metadata
            __metadata__['model']['initialisation'] = {
                'pmf': {
                    'name':
                    self.initialisation.name,
                    'id':
                    self.initialisation.id,
                    'path':
                    str((path / 'data' /
                         'initialisation').anchor_to_path(path)),
                    'checkpoint':
                    self.initialisation.checkpoint
                }
            }

        # Copy checkpoint_collection
        for reference, checkpoint in self.checkpoint_collection.items():
            checkpoint_dst = path / 'data' / 'checkpoints' / checkpoint.path[-1] \
                if kwargs.get('checkpoints', True) else None
            # Add checkpoint to metadata
            __metadata__['model']['training']['checkpoints'][reference] = {
                'epoch':
                checkpoint.epoch,
                'path':
                str(checkpoint_dst.anchor_to_path(path)) if kwargs.get(
                    'checkpoints', True) else '.',
                'hash':
                checkpoint.hash
            }
            # If needed (usually), copy file to destination
            if kwargs.get('checkpoints', True):
                copy(str(checkpoint.path),
                     str(checkpoint_dst),
                     lazy=model_dst is not None,
                     src_hash=checkpoint.hash,
                     dst_hash=model_dst.checkpoint_collection.get(
                         checkpoint.name))

        # Save metadata
        with open(str(path / 'metadata.yaml'), 'w') as f:
            yaml.safe_dump(__metadata__, f)