コード例 #1
0
ファイル: mng.py プロジェクト: garmin-data/garmdown
class Manager(object):
    """Manages downloading and database work.  This includes downloading data from
    the Garmin connect website and persisting status data in an SQLite
    database.

    """
    def __init__(self, config):
        """Initialize

        :param config: the application configuration
        """
        self.config = config
        self.download = config.download
        self._fetcher = PersistedWork('_fetcher', self, True)

    @property
    @persisted('_fetcher')
    def fetcher(self):
        return Fetcher(self.config)

    @property
    @persisted('_persister')
    def persister(self):
        return Persister(self.config)

    @property
    @persisted('_backuper')
    def backuper(self):
        return Backuper(self.config, self.persister)

    def environment(self, writer=sys.stdout):
        writer.write(f'database={self.config.db_file}\n')
        writer.write(f'activities={self.config.activities_dir}\n')
        writer.write(f'backup={self.config.db_backup_dir}\n')

    def sync_activities(self, limit=None, start_index=0):
        """Download and add activities to the SQLite database.  Note that this does not
        download the TCX files.

        :param limit: the number of activities to download
        :param start_index: the 0 based activity index (not contiguous page
            based)

        """
        # acts will be an iterable
        acts = self.fetcher.get_activities(limit, start_index)
        self.persister.insert_activities(acts)

    @staticmethod
    def _tcx_filename(activity):
        """Format a (non-directory) file name for ``activity``."""
        return f'{activity.start_date_str}_{activity.id}.tcx'

    def sync_tcx(self, limit=None):
        """Download TCX files and record each succesful download as such in the
        database.

        :param limit: the maximum number of TCX files to download, which
            defaults to all

        """
        dl_dir = self.config.activities_dir
        persister = self.persister
        if not dl_dir.exists():
            logger.info(f'creating download directory {dl_dir}')
            dl_dir.mkdir(parents=True)
        acts = persister.get_missing_downloaded(limit)
        logger.info(f'downloading {len(acts)} tcx files')
        for act in acts:
            dl_path = Path(dl_dir, self._tcx_filename(act))
            if dl_path.exists():
                logger.warning(f'activity {act.id} is downloaded ' +
                               f'but not marked--marking now')
            else:
                logger.debug(f'downloading {dl_path}')
                with open(dl_path, 'wb') as f:
                    self.fetcher.download_tcx(act, f)
                sr = dl_path.stat()
                logger.debug(f'{dl_path} has size {sr.st_size}')
                if sr.st_size < self.download.min_size:
                    m = f'downloaded file {dl_path} has size ' + \
                        f'{sr.st_size} < {self.download.min_size}'
                    raise ValueError(m)
            persister.mark_downloaded(act)

    def import_tcx(self, limit=None):
        """Download TCX files and record each succesful download as such in the
        database.

        :param limit: the maximum number of TCX files to download, which
            defaults to all

        """
        persister = self.persister
        dl_dir = self.config.activities_dir
        import_dir = self.config.import_dir
        if not import_dir.exists():
            logger.info(f'creating imported directory {import_dir}')
            import_dir.mkdir(parents=True)
        acts = persister.get_missing_imported(limit)
        logger.info(f'importing {len(acts)} activities')
        for act in acts:
            fname = self._tcx_filename(act)
            dl_path = Path(dl_dir, fname)
            import_path = Path(import_dir, fname)
            if import_path.exists():
                logger.warning(f'activity {act.id} is imported ' +
                               'but not marked--marking now')
            else:
                logger.info(f'copying {dl_path} -> {import_path}')
                shutil.copy(dl_path, import_path)
            persister.mark_imported(act)

    def sync(self, limit=None):
        """Sync activitives and TCX files.

        :param limit: the number of activities to download and import, which
            defaults to the configuration values

        """
        self.sync_activities(limit)
        self.sync_tcx(limit)
        self.import_tcx()

    def clean_imported(self, limit=None):
        """Delete all TCX files from the import directory.  This is useful so that
        programs like GoldenCheetah that imports them don't have to re-import
        them each time.

        """
        import_dir = self.config.import_dir
        logger.info(f'removing import files from {import_dir}')
        if import_dir.exists():
            for path in import_dir.iterdir():
                logger.info(f'removing {path}')
                path.unlink()

    def write_not_downloaded(self,
                             detail=False,
                             limit=None,
                             writer=sys.stdout):
        """Write human readable formatted data of all activities not yet downloaded.

        :param detail: whether or to give full information about the activity
        :param limit: the number of activities to report on
        :param writer: the stream to output, which defaults to stdout

        """
        for act in self.persister.get_missing_downloaded(limit):
            act.write(writer, detail=detail)

    def write_not_imported(self, detail=False, limit=None, writer=sys.stdout):
        """Write human readable formatted data of all activities not yet imported.

        :param detail: whether or to give full information about the activity
        :param limit: the number of activities to report on
        :param writer: the stream to output, which defaults to stdout

        """
        for act in self.persister.get_missing_imported(limit):
            act.write(writer, detail=detail)

    def close(self):
        """Close all allocated resources byt by the manager."""
        self.fetcher.close()
        self._fetcher.clear()
コード例 #2
0
class DataframeStash(ReadOnlyStash,
                     Deallocatable,
                     Writable,
                     PrimeableStash,
                     metaclass=ABCMeta):
    """A factory stash that uses a Pandas data frame from which to load.  It uses
    the data frame index as the keys and :class:`pandas.Series` as values.  The
    dataframe is usually constructed by reading a file (i.e.CSV) and doing some
    transformation before using it in an implementation of this stash.

    The dataframe created by :meth:`_get_dataframe` must have a string or
    integer index since keys for all stashes are of type :class:`str`.  The
    index will be mapped to a string if it is an int automatically.

    """
    dataframe_path: Path = field()
    """The path to store the pickeled version of the generated dataframe
    created with :meth:`_get_dataframe`.

    """
    def __post_init__(self):
        super().__post_init__()
        Deallocatable.__init__(self)
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f'split stash post init: {self.dataframe_path}')
        self._dataframe = PersistedWork(self.dataframe_path, self, mkdir=True)

    def deallocate(self):
        super().deallocate()
        self._dataframe.deallocate()

    @abstractmethod
    def _get_dataframe(self) -> pd.DataFrame:
        """Get or create the dataframe

        """
        pass

    def _prepare_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
        dt = df.index.dtype
        if dt != object:
            if dt != int:
                s = f'Data frame index must be a string or int, but got: {dt}'
                raise DataframeError(s)
            else:
                df.index = df.index.map(str)
        return df

    @property
    @persisted('_dataframe')
    def dataframe(self):
        df = self._get_dataframe()
        df = self._prepare_dataframe(df)
        return df

    def prime(self):
        super().prime()
        self.dataframe

    def clear(self):
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug('clearing dataframe stash')
        self._dataframe.clear()

    def load(self, name: str) -> pd.Series:
        return self.dataframe.loc[name]

    def exists(self, name: str) -> bool:
        return name in self.dataframe.index

    def keys(self) -> Iterable[str]:
        return map(str, self.dataframe.index)

    def write(self, depth: int = 0, writer: TextIOBase = sys.stdout):
        df = self.dataframe
        self._write_line(f'rows: {df.shape[0]}', depth, writer)
        self._write_line(f'cols: {", ".join(df.columns)}', depth, writer)
コード例 #3
0
ファイル: executor.py プロジェクト: plandes/deeplearn
class ModelExecutor(PersistableContainer, Deallocatable, Writable):
    """This class creates and uses a network to train, validate and test the model.
    This class is either configured using a
    :class:`~zensols.config.factory.ConfigFactory` or is unpickled with
    :class:`.ModelManager`.  If the later, it's from a previously trained (and
    possibly tested) state.

    Typically, after creating a nascent instance, :meth:`train` is called to
    train the model.  This returns the results, but the results are also
    available via the :class:`ResultManager` using the :obj:`model_manager`
    property.  To load previous results, use
    ``executor.result_manager.load()``.

    During training, the training set is used to train the weights of the model
    provided by the executor in the :obj:`model_settings`, then validated using
    the validation set.  When the validation loss is minimized, the following
    is saved to disk:

        * Settings: :obj:`net_settings`, :obj:`model_settings`,
        * the model weights,
        * the results of the training and validation thus far,
        * the entire configuration (which is later used to restore the
          executor),
        * random seed information, which includes Python, Torch and GPU random
          state.

    After the model is trained, you can immediately test the model with
    :meth:`test`.  To be more certain of being able to reproduce the same
    results, it is recommended to load the model with
    ``model_manager.load_executor()``, which loads the last instance of the
    model that produced a minimum validation loss.

    :see: :class:`.ModelExecutor`
    :see: :class:`.NetworkSettings`
    :see: :class:`zensols.deeplearn.model.ModelSettings`

    """
    ATTR_EXP_META = ('model_settings', )

    config_factory: ConfigFactory = field()
    """The configuration factory that created this instance."""

    config: Configurable = field()
    """The configuration used in the configuration factory to create this
    instance.

    """
    name: str = field()
    """The name given in the configuration."""

    model_settings: ModelSettings = field()
    """The configuration of the model."""

    net_settings: NetworkSettings = field()
    """The settings used to configure the network."""

    dataset_stash: DatasetSplitStash = field()
    """The split data set stash that contains the ``BatchStash``, which
    contains the batches on which to train and test.

    """
    dataset_split_names: List[str] = field()
    """The list of split names in the ``dataset_stash`` in the order: train,
    validation, test (see :meth:`_get_dataset_splits`)

    """
    result_path: Path = field(default=None)
    """If not ``None``, a path to a directory where the results are to be
    dumped; the directory will be created if it doesn't exist when the results
    are generated.

    """
    update_path: Path = field(default=None)
    """The path to check for commands/updates to make while training.  If this is
    set, and the file exists, then it is parsed as a JSON file.  If the file
    cannot be parsed, or 0 size etc., then the training is (early) stopped.

    If the file can be parsed, and there is a single ``epoch`` dict entry, then
    the current epoch is set to that value.

    """
    intermediate_results_path: Path = field(default=None)
    """If this is set, then save the model and results to this path after
    validation for each training epoch.

    """
    progress_bar: bool = field(default=False)
    """Create text/ASCII based progress bar if ``True``."""

    progress_bar_cols: int = field(default=None)
    """The number of console columns to use for the text/ASCII based progress
    bar.

    """
    def __post_init__(self):
        super().__init__()
        if not isinstance(self.dataset_stash, DatasetSplitStash) and False:
            raise ModelError('Expecting type DatasetSplitStash but ' +
                             f'got {self.dataset_stash.__class__}')
        self._model = None
        self._dealloc_model = False
        self.model_result: ModelResult = None
        self.batch_stash.delegate_attr: bool = True
        self._criterion_optimizer_scheduler = PersistedWork(
            '_criterion_optimizer_scheduler', self)
        self._result_manager = PersistedWork('_result_manager', self)
        self._train_manager = PersistedWork('_train_manager', self)
        self.cached_batches = {}
        self.debug = False

    @property
    def batch_stash(self) -> DatasetSplitStash:
        """The stash used to obtain the data for training and testing.  This stash
        should have a training, validation and test splits.  The names of these
        splits are given in the ``dataset_split_names``.

        """
        return self.dataset_stash.split_container

    @property
    def feature_stash(self) -> Stash:
        """The stash used to generate the feature, which is not to be confused
        with the batch source stash``batch_stash``.

        """
        return self.batch_stash.split_stash_container

    @property
    def torch_config(self) -> TorchConfig:
        """Return the PyTorch configuration used to convert models and data (usually
        GPU) during training and test.

        """
        return self.batch_stash.model_torch_config

    @property
    @persisted('_result_manager')
    def result_manager(self) -> ModelResultManager:
        """Return the manager used for controlling the life cycle of the results
        generated by this executor.

        """
        if self.result_path is not None:
            return self._create_result_manager(self.result_path)

    def _create_result_manager(self, path: Path) -> ModelResultManager:
        return ModelResultManager(name=self.model_settings.model_name,
                                  path=path,
                                  model_path=self.model_settings.path)

    @property
    @persisted('_model_manager')
    def model_manager(self) -> ModelManager:
        """Return the manager used for controlling the train of the model.

        """
        model_path = self.model_settings.path
        return ModelManager(model_path, self.config_factory, self.name)

    @property
    @persisted('_batch_iterator')
    def batch_iterator(self) -> BatchIterator:
        """The train manager that assists with the training process.

        """
        resolver = self.config_factory.class_resolver
        batch_iter_class_name = self.model_settings.batch_iteration_class_name
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f'batch_iteration: {batch_iter_class_name}')
        batch_iter_class = resolver.find_class(batch_iter_class_name)
        batch_iter = batch_iter_class(self, logger)
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f'batch_iter={batch_iter}')
        return batch_iter

    @property
    def debug(self) -> Union[bool, int]:
        return self._debug

    @debug.setter
    def debug(self, debug: Union[bool, int]):
        self._debug = debug
        self.batch_iterator.debug = debug

    @property
    @persisted('_train_manager')
    def train_manager(self) -> TrainManager:
        """Return the train manager that assists with the training process.

        """
        return TrainManager(
            logger, progress_logger, self.update_path,
            self.model_settings.max_consecutive_increased_count)

    def _weight_reset(self, m):
        if hasattr(m, 'reset_parameters') and callable(m.reset_parameters):
            if logger.isEnabledFor(logging.DEBUG):
                logger.debug(f'resetting parameters on {m}')
            m.reset_parameters()

    def reset(self):
        """Reset the executor's to it's nascent state.

        """
        if logger.isEnabledFor(logging.INFO):
            logger.info('resetting executor')
        self._criterion_optimizer_scheduler.clear()
        self._deallocate_model()

    def load(self) -> nn.Module:
        """Clear all results and trained state and reload the last trained model from
        the file system.

        :return: the model that was loaded and registered in this instance of
                 the executor

        """
        if logger.isEnabledFor(logging.INFO):
            logger.info('reloading model weights')
        self._deallocate_model()
        self.model_manager._load_model_optim_weights(self)
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f'copied model to {self.model.device}')
        return self.model

    def deallocate(self):
        super().deallocate()
        self._deallocate_model()
        self.deallocate_batches()
        self._try_deallocate(self.dataset_stash)
        self._deallocate_settings()
        self._criterion_optimizer_scheduler.deallocate()
        self._result_manager.deallocate()
        self.model_result = None

    def _deallocate_model(self):
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug('dealloc model: model exists/dealloc: ' +
                         f'{self._model is not None}/{self._dealloc_model}')
        if self._model is not None and self._dealloc_model:
            self._try_deallocate(self._model)
        self._model = None

    def _deallocate_settings(self):
        self.model_settings.deallocate()
        self.net_settings.deallocate()

    def deallocate_batches(self):
        set_of_ds_sets = self.cached_batches.values()
        ds_sets = chain.from_iterable(set_of_ds_sets)
        batches = chain.from_iterable(ds_sets)
        for batch in batches:
            batch.deallocate()
        self.cached_batches.clear()

    @property
    def model_exists(self) -> bool:
        """Return whether the executor has a model.

        :return: ``True`` if the model has been trained or loaded

        """
        return self._model is not None

    @property
    def model(self) -> BaseNetworkModule:
        """Get the PyTorch module that is used for training and test.

        """
        if self._model is None:
            raise ModelError('No model, is populated; use \'load\'')
        return self._model

    @model.setter
    def model(self, model: BaseNetworkModule):
        """Set the PyTorch module that is used for training and test.

        """
        self._set_model(model, False, True)

    def _set_model(self, model: BaseNetworkModule, take_owner: bool,
                   deallocate: bool):
        if logger.isEnabledFor(level=logging.DEBUG):
            logger.debug(f'setting model: {type(model)}')
        if deallocate:
            self._deallocate_model()
        self._model = model
        self._dealloc_model = take_owner
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f'setting dealloc model: {self._dealloc_model}')
        self._criterion_optimizer_scheduler.clear()

    def _get_or_create_model(self) -> BaseNetworkModule:
        if self._model is None:
            self._dealloc_model = True
            model = self._create_model()
            self._model = model
        else:
            model = self._model
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f'created model as dealloc: {self._dealloc_model}')
        return model

    def _create_model(self) -> BaseNetworkModule:
        """Create the network model instance.

        """
        mng: ModelManager = self.model_manager
        model = mng._create_module(self.net_settings, self.debug)
        if logger.isEnabledFor(logging.INFO):
            logger.info(f'created model on {model.device} ' +
                        f'with {self.torch_config}')
        return model

    def _create_model_result(self) -> ModelResult:
        res = ModelResult(
            self.config,
            f'{self.model_settings.model_name}: {ModelResult.get_num_runs()}',
            self.model_settings, self.net_settings,
            self.batch_stash.decoded_attributes)
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f'creating model result ({id(res)}): ' +
                         self.model_settings.model_name)
        return res

    @property
    @persisted('_criterion_optimizer_scheduler')
    def criterion_optimizer_scheduler(self) -> \
            Tuple[nn.L1Loss, torch.optim.Optimizer, Any]:
        """Return the loss function and descent optimizer.

        """
        criterion = self._create_criterion()
        optimizer, scheduler = self._create_optimizer_scheduler()
        return criterion, optimizer, scheduler

    def _create_criterion(self) -> torch.optim.Optimizer:
        """Factory method to create the loss function and optimizer.

        """
        resolver = self.config_factory.class_resolver
        criterion_class_name = self.model_settings.criterion_class_name
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f'criterion: {criterion_class_name}')
        criterion_class = resolver.find_class(criterion_class_name)
        criterion = criterion_class()
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f'criterion={criterion}')
        return criterion

    def _create_optimizer_scheduler(self) -> Tuple[nn.L1Loss, Any]:
        """Factory method to create the optimizer and the learning rate scheduler (is
        any).

        """
        model = self.model
        resolver = self.config_factory.class_resolver
        optimizer_class_name = self.model_settings.optimizer_class_name
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f'optimizer: {optimizer_class_name}')
        optimizer_class = resolver.find_class(optimizer_class_name)
        if self.model_settings.optimizer_params is None:
            optimizer_params = {}
        else:
            optimizer_params = dict(self.model_settings.optimizer_params)
        optimizer_params['lr'] = self.model_settings.learning_rate
        if issubclass(optimizer_class, ModelResourceFactory):
            opt_call = optimizer_class()
            optimizer_params['model'] = model
            optimizer_params['executor'] = self
        else:
            opt_call = optimizer_class
        optimizer = opt_call(model.parameters(), **optimizer_params)
        scheduler_class_name = self.model_settings.scheduler_class_name
        if scheduler_class_name is not None:
            scheduler_class = resolver.find_class(scheduler_class_name)
            scheduler_params = self.model_settings.scheduler_params
            if scheduler_params is None:
                scheduler_params = {}
            else:
                scheduler_params = dict(scheduler_params)
            scheduler_params['optimizer'] = optimizer
            if issubclass(scheduler_class, ModelResourceFactory):
                # model resource factories are callable
                sch_call = scheduler_class()
                scheduler_params['executor'] = self
            else:
                sch_call = scheduler_class
            scheduler = sch_call(**scheduler_params)
        else:
            scheduler = None
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f'optimizer={optimizer}')
        return optimizer, scheduler

    def get_model_parameter(self, name: str):
        """Return a parameter of the model, found in ``model_settings``.

        """
        return getattr(self.model_settings, name)

    def set_model_parameter(self, name: str, value: Any):
        """Safely set a parameter of the model, found in ``model_settings``.  This
        makes the corresponding update in the configuration, so that when it is
        restored (i.e for test) the parameters are consistent with the trained
        model.  The value is converted to a string as the configuration
        representation stores all data values as strings.

        *Important*: ``eval`` syntaxes are not supported, and probably not the
        kind of values you want to set a parameters with this interface anyway.

        :param name: the name of the value to set, which is the key in the
                     configuration file

        :param value: the value to set on the model and the configuration

        """
        self.config.set_option(name,
                               str(value),
                               section=self.model_settings.name)
        setattr(self.model_settings, name, value)

    def get_network_parameter(self, name: str):
        """Return a parameter of the network, found in ``network_settings``.

        """
        return getattr(self.net_settings, name)

    def set_network_parameter(self, name: str, value: Any):
        """Safely set a parameter of the network, found in ``network_settings``.  This
        makes the corresponding update in the configuration, so that when it is
        restored (i.e for test) the parameters are consistent with the trained
        network.  The value is converted to a string as the configuration
        representation stores all data values as strings.

        *Important*: ``eval`` syntaxes are not supported, and probably not the
        kind of values you want to set a parameters with this interface anyway.

        :param name: the name of the value to set, which is the key in the
                     configuration file

        :param value: the value to set on the network and the configuration

        """
        self.config.set_option(name,
                               str(value),
                               section=self.net_settings.name)
        setattr(self.net_settings, name, value)

    def _to_iter(self, ds):
        ds_iter = ds
        if isinstance(ds_iter, Stash):
            ds_iter = ds_iter.values()
        return ds_iter

    def _gc(self, level: int):
        """Invoke the Python garbage collector if ``level`` is high enough.  The
        *lower* the value of ``level``, the more often it will be run during
        training, testing and validation.

        :param level: if priority of the need to collect--the lower the more
                      its needed

        """
        if level <= self.model_settings.gc_level:
            if logger.isEnabledFor(logging.DEBUG):
                logger.debug('garbage collecting')
            self._notify('gc_start')
            with time('garbage collected', logging.DEBUG):
                gc.collect()
            self._notify('gc_end')

    def _notify(self, event: str, context: Any = None):
        """Notify observers of events from this class.

        """
        self.model_settings.observer_manager.notify(event, self, context)

    def _train(self, train: List[Batch], valid: List[Batch]):
        """Train the network model and record validation and training losses.  Every
        time the validation loss shrinks, the model is saved to disk.

        """
        n_epochs = self.model_settings.epochs
        # create network model, loss and optimization functions
        model = self._get_or_create_model()
        model = self.torch_config.to(model)
        self._model = model
        if logger.isEnabledFor(logging.INFO):
            logger.info(f'training model {type(model)} on {model.device} ' +
                        f'for {n_epochs} epochs using ' +
                        f'learning rate {self.model_settings.learning_rate}')
        criterion, optimizer, scheduler = self.criterion_optimizer_scheduler
        # create a second module manager for after epoch results
        if self.intermediate_results_path is not None:
            model_path = self.intermediate_results_path
            intermediate_manager = self._create_result_manager(model_path)
            intermediate_manager.file_pattern = '{prefix}.{ext}'
        else:
            intermediate_manager = None
        train_manager = self.train_manager
        action = UpdateAction.ITERATE_EPOCH
        # set up graphical progress bar
        exec_logger = logging.getLogger(__name__)
        if self.progress_bar and \
            (exec_logger.level == 0 or
             exec_logger.level > logging.INFO) and \
            (progress_logger.level == 0 or
             progress_logger.level > logging.INFO):
            pbar = tqdm(total=n_epochs, ncols=self.progress_bar_cols)
        else:
            pbar = None

        train_manager.start(optimizer, scheduler, n_epochs, pbar)
        self.model_result.train.start()
        self.model_result.validation.start()

        # epochs loop
        while action != UpdateAction.STOP:
            epoch: int = train_manager.current_epoch
            train_epoch_result = EpochResult(epoch, DatasetSplitType.train)
            valid_epoch_result = EpochResult(epoch,
                                             DatasetSplitType.validation)

            if progress_logger.isEnabledFor(logging.INFO):
                progress_logger.debug(f'training on epoch: {epoch}')

            self.model_result.train.append(train_epoch_result)
            self.model_result.validation.append(valid_epoch_result)

            # train ----
            # prep model for training and train
            model.train()
            train_epoch_result.start()
            self._notify('train_start', {'epoch': epoch})
            for batch in self._to_iter(train):
                if logger.isEnabledFor(logging.DEBUG):
                    logger.debug(f'training on batch: {batch.id}')
                with time('trained batch', level=logging.DEBUG):
                    self.batch_iterator.iterate(model, optimizer, criterion,
                                                batch, train_epoch_result,
                                                DatasetSplitType.train)
                self._gc(3)
            self._notify('train_end', {'epoch': epoch})
            train_epoch_result.end()

            self._gc(2)

            # validate ----
            # prep model for evaluation and evaluate
            ave_valid_loss = 0
            model.eval()
            valid_epoch_result.start()
            self._notify('validation_start', {'epoch': epoch})
            for batch in self._to_iter(valid):
                # forward pass: compute predicted outputs by passing inputs
                # to the model
                with torch.no_grad():
                    loss = self.batch_iterator.iterate(
                        model, optimizer, criterion, batch, valid_epoch_result,
                        DatasetSplitType.validation)
                    ave_valid_loss += (loss.item() * batch.size())
                self._gc(3)
            self._notify('validation_end', {'epoch': epoch})
            valid_epoch_result.end()
            ave_valid_loss = ave_valid_loss / len(valid)

            self._gc(2)

            valid_loss_min, decreased = train_manager.update_loss(
                valid_epoch_result, train_epoch_result, ave_valid_loss)

            if decreased:
                self.model_manager._save_executor(self)
                if intermediate_manager is not None:
                    inter_res = self.model_result.get_intermediate()
                    intermediate_manager.save_text_result(inter_res)
                    intermediate_manager.save_plot_result(inter_res)

            # look for indication of update or early stopping
            status = train_manager.get_status()
            action = status.action

        val_losses = train_manager.validation_loss_decreases
        if logger.isEnabledFor(logging.INFO):
            logger.info('final minimum validation ' +
                        f'loss: {train_manager.valid_loss_min}, ' +
                        f'{val_losses} decreases')

        if val_losses == 0:
            logger.warn('no validation loss decreases encountered, ' +
                        'so there was no model saved; model can not be tested')

        self.model_result.train.end()
        self.model_result.validation.end()
        self.model_manager._save_final_trained_results(self)

    def _test(self, batches: List[Batch]):
        """Test the model on the test set.  If a model is not given, it is unpersisted
        from the file system.

        """
        # create the loss and optimization functions
        criterion, optimizer, scheduler = self.criterion_optimizer_scheduler
        model = self.torch_config.to(self.model)
        # track epoch progress
        test_epoch_result = EpochResult(0, DatasetSplitType.test)

        if logger.isEnabledFor(logging.INFO):
            logger.info(f'testing model {type(model)} on {model.device}')

        # in for some reason the model was trained but not tested, we'll load
        # from the model file, which will have no train results (bad idea)
        if self.model_result is None:
            self.model_result = self._create_model_result()

        self.model_result.reset(DatasetSplitType.test)
        self.model_result.test.start()
        self.model_result.test.append(test_epoch_result)

        # prep model for evaluation
        model.eval()
        # run the model on test data
        test_epoch_result.start()
        for batch in self._to_iter(batches):
            # forward pass: compute predicted outputs by passing inputs
            # to the model
            with torch.no_grad():
                self.batch_iterator.iterate(model, optimizer, criterion, batch,
                                            test_epoch_result,
                                            DatasetSplitType.test)
            self._gc(3)
        test_epoch_result.end()

        self._gc(2)

        self.model_result.test.end()

    def _preproces_training(self, ds_train: Tuple[Batch]):
        """Preprocess the training set, which for this method implementation, includes
        a shuffle if configured in the model settings.

        """
        self._notify('preprocess_training_start')
        if self.model_settings.shuffle_training:
            if logger.isEnabledFor(logging.DEBUG):
                logger.debug('shuffling training dataset')
            # data sets are ordered with training as the first
            rand.shuffle(ds_train)
        self._notify('preprocess_training_end')

    def _calc_batch_limit(self, src: Stash, batch_limit: Union[int,
                                                               float]) -> int:
        if batch_limit <= 0:
            raise ModelError(f'Batch limit must be positive: {batch_limit}')
        if isinstance(batch_limit, float):
            if batch_limit > 1.0:
                raise ModelError('Batch limit must be less than 1 ' +
                                 f'when a float: {batch_limit}')
            vlim = round(len(src) * batch_limit)
            if logger.isEnabledFor(logging.DEBUG):
                logger.debug('batch limit calculated as a percentage: ' +
                             f'{vlim} = {len(src)} * {batch_limit}')
        else:
            vlim = batch_limit
        if isinstance(src, SplitStashContainer):
            desc = f' for {src.split_name}'
        else:
            desc = ''
        if logger.isEnabledFor(logging.INFO):
            logger.info(f'using batch limit: {vlim}{desc}')
        return vlim

    def _prepare_datasets(self, batch_limit: Union[int, float],
                          to_deallocate: List[Batch],
                          ds_src: List[Stash]) -> List[List[Batch]]:
        """Return batches for each data set.  The batches are returned per dataset as
        given in :meth:`_get_dataset_splits`.

        Return:
          [(training batch 1..N), (validation batch 1..N), (test batch 1..N)]

        """
        biter = self.model_settings.batch_iteration
        cnt = 0

        if logger.isEnabledFor(logging.INFO):
            logger.info(f'preparing datasets using iteration: {biter}')

        self._notify('prepare_datasets_start', biter)

        if biter == 'gpu':
            ds_dst = []
            for src in ds_src:
                vlim = self._calc_batch_limit(src, batch_limit)
                cpu_batches = tuple(it.islice(src.values(), vlim))
                gpu_batches = list(map(lambda b: b.to(), cpu_batches))
                cnt += len(gpu_batches)
                # the `to` call returns the same instance if the tensor is
                # already on the GPU, so only deallocate batches copied over
                for cpu_batch, gpu_batch in zip(cpu_batches, gpu_batches):
                    if cpu_batch is not gpu_batch:
                        to_deallocate.append(cpu_batch)
                if not self.model_settings.cache_batches:
                    to_deallocate.extend(gpu_batches)
                ds_dst.append(gpu_batches)
        elif biter == 'cpu':
            ds_dst = []
            for src in ds_src:
                vlim = self._calc_batch_limit(src, batch_limit)
                batches = list(it.islice(src.values(), vlim))
                cnt += len(batches)
                if not self.model_settings.cache_batches:
                    to_deallocate.extend(batches)
                ds_dst.append(batches)
        elif biter == 'buffered':
            ds_dst = ds_src
            cnt = '?'
        else:
            raise ModelError(f'No such batch iteration method: {biter}')

        self._notify('prepare_datasets_end', biter)

        self._preproces_training(ds_dst[0])

        return cnt, ds_dst

    def _execute(self, sets_name: str, description: str, func: Callable,
                 ds_src: tuple) -> bool:
        """Either train or test the model based on method ``func``.

        :param sets_name: the name of the data sets, which ``train`` or
                          ``test``

        :param func: the method to call to do the training or testing

        :param ds_src: a tuple of datasets in a form such as ``(train,
                       validation, test)`` (see :meth:`_get_dataset_splits`)

        :return: ``True`` if training/testing was successful, otherwise
                 `the an exception occured or early bail

        """
        to_deallocate: List[Batch] = []
        ds_dst: List[List[Batch]] = None
        batch_limit = self.model_settings.batch_limit
        biter = self.model_settings.batch_iteration

        if self.model_settings.cache_batches and biter == 'buffered':
            raise ModelError('Can not cache batches for batch ' +
                             'iteration setting \'buffered\'')

        if logger.isEnabledFor(logging.INFO):
            logger.info(f'batch iteration: {biter}, limit: {batch_limit}' +
                        f', caching: {self.model_settings.cache_batches}'
                        f', cached: {len(self.cached_batches)}')

        self._notify('execute_start', sets_name)

        self._gc(1)

        ds_dst = self.cached_batches.get(sets_name)
        if ds_dst is None:
            cnt = 0
            with time('loaded {cnt} batches'):
                cnt, ds_dst = self._prepare_datasets(batch_limit,
                                                     to_deallocate, ds_src)
            if self.model_settings.cache_batches:
                self.cached_batches[sets_name] = ds_dst

        if logger.isEnabledFor(logging.INFO):
            logger.info('train/test sets: ' +
                        f'{" ".join(map(lambda l: str(len(l)), ds_dst))}')

        try:
            with time(f'executed {sets_name}'):
                func(*ds_dst)
            if description is not None:
                res_name = f'{self.model_result.index}: {description}'
                self.model_result.name = res_name
            return True
        except EarlyBailError as e:
            logger.warning(f'<{e}>')
            self.reset()
            return False
        finally:
            self._notify('execute_end', sets_name)
            self._train_manager.clear()
            if logger.isEnabledFor(logging.INFO):
                logger.info(f'deallocating {len(to_deallocate)} batches')
            for batch in to_deallocate:
                if logger.isEnabledFor(logging.DEBUG):
                    logger.debug(f'deallocating: {batch}')
                batch.deallocate()
            self._gc(1)
            self.torch_config.empty_cache()

    def _get_dataset_splits(self) -> List[BatchStash]:
        """Return a stash, one for each respective data set tracked by this executor.

        """
        def map_split(n: str):
            s = splits.get(n)
            if s is None:
                raise ModelError(
                    f"No split '{n}' in {self.dataset_stash.split_names}, " +
                    f'executor splits: {self.dataset_split_names}')
            return s

        splits = self.dataset_stash.splits
        return tuple(map(map_split, self.dataset_split_names))

    def train(self, description: str = None) -> ModelResult:
        """Train the model.

        """
        self.model_result = self._create_model_result()
        train, valid, _ = self._get_dataset_splits()
        self._execute('train', description, self._train, (train, valid))
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f'trained model result: {self.model_result}')
        return self.model_result

    def test(self, description: str = None) -> ModelResult:
        """Test the model.

        """
        train, valid, test = self._get_dataset_splits()
        if self.model_result is None:
            logger.warning('no results found--loading')
            self.model_result = self.result_manager.load()
        self._execute('test', description, self._test, (test, ))
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f'tested model result: {self.model_result}')
        return self.model_result

    def train_production(self, description: str = None) -> ModelResult:
        """Train and test the model on the training and test datasets.  This is used
        for a "production" model that is used for some purpose other than
        evaluation.

        """
        self.model_result = self._create_model_result()
        train, valid, test = self._get_dataset_splits()
        train = UnionStash((train, test))
        self._execute('train production', description, self._train,
                      (train, valid))
        return self.model_result

    def predict(self, batches: List[Batch]) -> ModelResult:
        """Create predictions on ad-hoc data.

        :param batches: contains the data (X) on which to predict

        :return: the results of the predictions

        """
        for batch in batches:
            self.batch_stash.populate_batch_feature_mapping(batch)
        self._test(batches)
        return self.model_result.test

    def write_model(self, depth: int = 0, writer: TextIOBase = sys.stdout):
        model = self._get_or_create_model()
        sio = StringIO()
        sp = self._sp(depth + 1)
        nl = '\n'
        print(model, file=sio)
        self._write_line('model:', depth, writer)
        writer.write(nl.join(map(lambda s: sp + s, sio.getvalue().split(nl))))

    def write_settings(self, depth: int = 0, writer: TextIOBase = sys.stdout):
        self._write_line('network settings:', depth, writer)
        self._write_dict(self.net_settings.asdict(), depth + 1, writer)
        self._write_line('model settings:', depth, writer)
        self._write_dict(self.model_settings.asdict(), depth + 1, writer)

    def write(self,
              depth: int = 0,
              writer: TextIOBase = sys.stdout,
              include_settings: bool = False,
              include_model: bool = False):
        sp = self._sp(depth)
        writer.write(f'{sp}model: {self.model_settings.model_name}\n')
        writer.write(f'{sp}feature splits:\n')
        self.feature_stash.write(depth + 1, writer)
        writer.write(f'{sp}batch splits:\n')
        self.dataset_stash.write(depth + 1, writer)
        if include_settings:
            self.write_settings(depth, writer)
        if include_model:
            self.write_model(depth, writer)
コード例 #4
0
ファイル: stash.py プロジェクト: plandes/deeplearn
class BatchStash(TorchMultiProcessStash, SplitKeyContainer, Writeback,
                 Deallocatable, metaclass=ABCMeta):
    """A stash that vectorizes features in to easily consumable tensors for
    training and testing.  This stash produces instances of :class:`.Batch`,
    which is a batch in the machine learning sense, and the first dimension of
    what will become the tensor used in PyTorch.  Each of these batches has a
    logical one to many relationship to that batche's respective set of data
    points, which is encapsulated in the :class:`.DataPoint` class.

    The stash creates subprocesses to vectorize features in to tensors in
    chunks of IDs (data point IDs) from the subordinate stash using
    ``DataPointIDSet`` instances.

    To speed up experiements, all available features configured in
    ``vectorizer_manager_set`` are encoded on disk.  However, only the
    ``decoded_attributes`` (see attribute below) are avilable to the model
    regardless of what was created during encoding time.

    The lifecycle of the data follows:

    1. Feature data created by the client, which could be language features,
       row data etc.

    2. Vectorize the feature data using the vectorizers in
       ``vectorizer_manager_set``.  This creates the feature contexts
       (``FeatureContext``) specifically meant to be pickeled.

    3. Pickle the feature contexts when dumping to disk, which is invoked in
       the child processes of this class.

    4. At train time, load the feature contexts from disk.

    5. Decode the feature contexts in to PyTorch tensors.

    6. The model manager uses the ``to`` method to copy the CPU tensors to the
       GPU (where GPUs are available).

    :see _process: for details on the pickling of the batch instances

    """
    _DICTABLE_WRITE_EXCLUDES = {'batch_feature_mappings'}

    data_point_type: Type[DataPoint] = field()
    """A subclass type of :class:`.DataPoint` implemented for the specific
    feature.

    """
    batch_type: Type[Batch] = field()
    """The batch class to be instantiated when created batchs.

    """
    split_stash_container: SplitStashContainer = field()
    """The source data stash that has both the data and data set keys for each
    split (i.e. ``train`` vs ``test``).

    """
    vectorizer_manager_set: FeatureVectorizerManagerSet = field()
    """Used to vectorize features in to tensors."""

    batch_size: int = field()
    """The number of data points in each batch, except the last (unless the
    data point cardinality divides the batch size).

    """
    model_torch_config: TorchConfig = field()
    """The PyTorch configuration used to (optionally) copy CPU to GPU memory.

    """
    data_point_id_sets_path: Path = field()
    """The path of where to store key data for the splits; note that the
    container might store it's key splits in some other location.

    """
    decoded_attributes: InitVar[Set[str]] = field()
    """The attributes to decode; only these are avilable to the model
    regardless of what was created during encoding time; if None, all are
    available.

    """
    batch_feature_mappings: BatchFeatureMapping = field(default=None)
    """The meta data used to encode and decode each feature in to tensors.

    """
    batch_limit: int = field(default=sys.maxsize)
    """The max number of batches to process, which is useful for debugging."""

    def __post_init__(self, decoded_attributes):
        super().__post_init__()
        Deallocatable.__init__(self)
        # TODO: this class conflates key split and delegate stash functionality
        # in the `split_stash_container`.  An instance of this type serves the
        # purpose, but it need not be.  Instead it just needs to be both a
        # SplitKeyContainer and a Stash.  This probably should be split out in
        # to two different fields.
        cont = self.split_stash_container
        if not isinstance(cont, SplitStashContainer) \
           and (not isinstance(cont, SplitKeyContainer) or
                not isinstance(cont, Stash)):
            raise DeepLearnError('Expecting SplitStashContainer but got ' +
                                 f'{self.split_stash_container.__class__}')
        self.data_point_id_sets_path.parent.mkdir(parents=True, exist_ok=True)
        self._batch_data_point_sets = PersistedWork(
            self.data_point_id_sets_path, self)
        self.priming = False
        self.decoded_attributes = decoded_attributes
        self._update_comp_stash_attribs()

    @property
    def decoded_attributes(self) -> Set[str]:
        """The attributes to decode.  Only these are avilable to the model regardless
        of what was created during encoding time; if None, all are available

        """
        return self._decoded_attributes

    @decoded_attributes.setter
    def decoded_attributes(self, attribs: Set[str]):
        """The attributes to decode.  Only these are avilable to the model regardless
        of what was created during encoding time; if None, all are available

        """
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f'setting decoded attributes: {attribs}')
        self._decoded_attributes = attribs
        if isinstance(self.delegate, BatchDirectoryCompositeStash):
            self.delegate.load_keys = attribs

    @property
    @persisted('_batch_metadata')
    def batch_metadata(self) -> BatchMetadata:
        mapping: BatchFeatureMapping
        if self.batch_feature_mappings is not None:
            mapping = self.batch_feature_mappings
        else:
            batch: Batch = self.batch_type(None, None, None, None)
            batch.batch_stash = self
            mapping = batch._get_batch_feature_mappings()
            batch.deallocate()
        vec_mng_set: FeatureVectorizerManagerSet = self.vectorizer_manager_set
        attrib_keeps = self.decoded_attributes
        vec_mng_names = set(vec_mng_set.keys())
        by_attrib = {}
        mmng: ManagerFeatureMapping
        for mmng in mapping.manager_mappings:
            vec_mng_name: str = mmng.vectorizer_manager_name
            if vec_mng_name in vec_mng_names:
                vec_mng: FeatureVectorizerManager = vec_mng_set[vec_mng_name]
                field: FieldFeatureMapping
                for field in mmng.fields:
                    if field.attr in attrib_keeps:
                        vec = vec_mng[field.feature_id]
                        by_attrib[field.attr] = BatchFieldMetadata(field, vec)
        return BatchMetadata(self.data_point_type, self.batch_type,
                             mapping, by_attrib)

    def _update_comp_stash_attribs(self):
        """Update the composite stash grouping if we're using one and if this class is
        already configured.

        """
        if isinstance(self.delegate, BatchDirectoryCompositeStash):
            meta: BatchMetadata = self.batch_metadata
            meta_attribs: Set[str] = set(
                map(lambda f: f.attr, meta.mapping.get_attributes()))
            groups: Tuple[Set[str]] = self.delegate.groups
            gattribs = reduce(lambda x, y: x | y, groups)
            to_remove = gattribs - meta_attribs
            new_groups = []
            if len(to_remove) > 0:
                group: Set[str]
                for group in groups:
                    ng: Set[str] = meta_attribs & group
                    if len(ng) > 0:
                        new_groups.append(ng)
                self.delegate.groups = tuple(new_groups)
            if logger.isEnabledFor(logging.DEBUG):
                logger.debug(f'meta attribs: {meta_attribs}, groups: {groups}')

    @property
    @persisted('_batch_data_point_sets')
    def batch_data_point_sets(self) -> List[DataPointIDSet]:
        """Create the data point ID sets.  Each instance returned will correlate to a
        batch and each set of keys point to a feature :class:`.DataPoint`.

        """
        psets = []
        batch_id = 0
        cont = self.split_stash_container
        tc_seed = TorchConfig.get_random_seed_context()
        if logger.isEnabledFor(logging.INFO):
            logger.info(f'{self.name}: creating keys with ({type(cont)}) ' +
                        f'using batch size of {self.batch_size}')
        for split, keys in cont.keys_by_split.items():
            if logger.isEnabledFor(logging.INFO):
                logger.info(f'keys for split {split}: {len(keys)}')
            # keys are ordered and needed to be as such for consistency
            # keys = sorted(keys, key=int)
            cslice = it.islice(chunks(keys, self.batch_size), self.batch_limit)
            for chunk in cslice:
                chunk = tuple(chunk)
                if logger.isEnabledFor(logging.DEBUG):
                    logger.debug(f'chunked size: {len(chunk)}')
                dp_set = DataPointIDSet(str(batch_id), chunk, split, tc_seed)
                psets.append(dp_set)
                batch_id += 1
        logger.info(f'created {len(psets)} each set limited with ' +
                    f'{self.batch_limit} with batch_limit={self.batch_limit}')
        return psets

    def _get_keys_by_split(self) -> Dict[str, Tuple[str]]:
        by_batch = collections.defaultdict(lambda: [])
        for dps in self.batch_data_point_sets:
            by_batch[dps.split_name].append(dps.batch_id)
        return {k: tuple(by_batch[k]) for k in by_batch.keys()}

    def _create_data(self) -> List[DataPointIDSet]:
        """Data created for the sub proceesses are the first N data point ID sets.

        """
        return self.batch_data_point_sets

    def populate_batch_feature_mapping(self, batch: Batch):
        """Add batch feature mappings to a batch instance."""
        if self.batch_feature_mappings is not None:
            batch.batch_feature_mappings = self.batch_feature_mappings

    def create_batch(self, points: Tuple[DataPoint], split_name: str = None,
                     batch_id: str = None):
        """Create a new batch instance with data points, which happens when primed.

        """
        bcls: Type[Batch] = self.batch_type
        batch: Batch = bcls(self, batch_id, split_name, points)
        self.populate_batch_feature_mapping(batch)
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f'created batch: {batch}')
        return batch

    def _process(self, chunk: List[DataPointIDSet]) -> \
            Iterable[Tuple[str, Any]]:
        """Create the batches by creating the set of data points for each
        :class:`.DataPointIDSet` instance.  When the subordinate stash dumps
        the batch (specifically a subclass of :class:`.Batch`), the overrided
        pickle logic is used to *detatch* the batch by encoded all data in to
        :class:`.FeatureContext` instances.

        """
        if logger.isEnabledFor(logging.INFO):
            logger.info(f'{self.name}: processing: {len(chunk)} data points')
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f'chunk data points: {chunk}')
        tseed = chunk[0].torch_seed_context
        dpcls: Type[DataPoint] = self.data_point_type
        cont = self.split_stash_container
        if tseed is not None:
            TorchConfig.set_random_seed(
                tseed['seed'], tseed['disable_cudnn'], False)
        dset: DataPointIDSet
        for dset in chunk:
            batch_id: str = dset.batch_id
            points: Tuple[DataPoint] = tuple(
                map(lambda dpid: dpcls(dpid, self, cont[dpid]),
                    dset.data_point_ids))
            batch: Batch = self.create_batch(points, dset.split_name, batch_id)
            if logger.isEnabledFor(logging.DEBUG):
                logger.debug(f'created batch: {batch}')
            yield (batch_id, batch)

    def _get_data_points_for_batch(self, batch: Any) -> Tuple[Any]:
        """Return the data points that were used to create ``batch``.

        """
        dpcls = self.data_point_type
        cont = self.split_stash_container
        return tuple(map(lambda dpid: dpcls(dpid, self, cont[dpid]),
                         batch.data_point_ids))

    def load(self, name: str):
        with time('loaded batch {name} ({obj.split_name})'):
            obj = super().load(name)
        # add back the container of the batch to reconstitute the original
        # features and use the CUDA for tensor device transforms
        if obj is not None:
            if not hasattr(obj, 'batch_stash'):
                obj.batch_stash = self
            if (not hasattr(obj, 'batch_feature_mappings') or obj.batch_feature_mappings is None):
                self.populate_batch_feature_mapping(obj)
        return obj

    def _prime_vectorizers(self):
        vec_mng_set: FeatureVectorizerManagerSet = self.vectorizer_manager_set
        vecs = map(lambda v: v.values(), vec_mng_set.values())
        for vec in chain.from_iterable(vecs):
            if isinstance(vec, Primeable):
                if logger.isEnabledFor(logging.DEBUG):
                    logger.debug(f'priming {vec}')
                vec.prime()

    def prime(self):
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f'priming {self.__class__}, is child: ' +
                         f'{self.is_child}, currently priming: {self.priming}')
        if self.priming:
            raise DeepLearnError('Already priming')
        self.priming = True
        try:
            self.batch_data_point_sets
            self._prime_vectorizers()
            super().prime()
        finally:
            self.priming = False

    def deallocate(self):
        self._batch_data_point_sets.deallocate()
        if id(self.delegate) != id(self.split_stash_container):
            self._try_deallocate(self.delegate)
        self._try_deallocate(self.split_stash_container)
        self.vectorizer_manager_set.deallocate()
        super().deallocate()

    def _from_dictable(self, *args, **kwargs):
        # avoid long Wriable.write output
        dct = super()._from_dictable(*args, **kwargs)
        rms = tuple(filter(lambda k: k.startswith('_'), dct.keys()))
        for k in rms:
            del dct[k]
        return dct

    def clear(self):
        """Clear the batch, batch data point sets."""
        logger.debug('clearing')
        super().clear()
        self._batch_data_point_sets.clear()

    def clear_all(self):
        """Clear the batch, batch data point sets, and the source data
        (:obj:`split_stash_container`).

        """
        self.clear()
        self.split_stash_container.clear()
コード例 #5
0
class SplitKeyDataframeStash(DataframeStash, SplitKeyContainer):
    """A stash and split key container that reads from a dataframe.

    """
    key_path: Path = field()
    """The path where the key splits (as a ``dict``) is pickled."""

    split_col: str = field()
    """The column name in the dataframe used to indicate the split
    (i.e. ``train`` vs ``test``).

    """
    def __post_init__(self):
        super().__post_init__()
        self._keys_by_split = PersistedWork(self.key_path, self, mkdir=True)

    def deallocate(self):
        super().deallocate()
        self._keys_by_split.deallocate()

    def _create_keys_for_split(self, split_name: str, df: pd.DataFrame) -> \
            Iterable[str]:
        """Generate an iterable of string keys.  It is expected this method to be
        potentially very expensive, so the results are cached to disk.  This
        implementation returns the dataframe index.

        :param split_name: the name of the split (i.e. ``train`` vs ``test``)
        :param df: the data frame for the grouping of keys from CSV of data

        """
        return df.index

    def _get_counts_by_key(self) -> Dict[str, int]:
        sc = self.split_col
        return dict(self.dataframe.groupby([sc])[sc].count().items())

    @persisted('_split_names')
    def _get_split_names(self) -> Set[str]:
        return set(self.dataframe[self.split_col].unique())

    @persisted('_keys_by_split')
    def _get_keys_by_split(self) -> Dict[str, Tuple[str]]:
        keys_by_split = OrderedDict()
        split_col = self.split_col
        for split, df in self.dataframe.groupby([split_col]):
            logger.info(f'parsing keys for {split}')
            keys = self._create_keys_for_split(split, df)
            keys_by_split[split] = tuple(keys)
        return keys_by_split

    def clear(self):
        super().clear()
        self.clear_keys()

    def clear_keys(self):
        """Clear only the cache of keys generated from the group by.

        """
        self._keys_by_split.clear()

    def write(self, depth: int = 0, writer: TextIOBase = sys.stdout):
        total = self.dataframe.shape[0]
        self._write_line('data frame splits:', depth, writer)
        for split, cnt in self.counts_by_key.items():
            self._write_line(f'{split}: {cnt} ({cnt/total*100:.1f}%)', depth,
                             writer)
        self._write_line(f'total: {total}', depth, writer)
コード例 #6
0
class ResultAnalyzer(object):
    """Load results from a previous run of the :class:`ModelExecutor` and a more
    recent run.  This run is usually a currently running model to compare the
    results during training.  This might provide meaningful information such as
    whether to early stop training.

    """
    executor: ModelExecutor = field()
    """The executor (not the running executor necessary) that will load the
    results if not already loadded.

    """

    previous_results_key: str = field()
    """The key given to retreive the previous results with
    :class:`ModelResultManager`.

    """

    cache_previous_results: bool = field()
    """If ``True``, globally cache the previous results to avoid having to
    reload each time.

    """

    def __post_init__(self):
        self._previous_results = PersistedWork(
            '_previous_results', self,
            cache_global=self.cache_previous_results)

    def clear(self):
        """Clear the previous results, if cached.

        """
        self._previous_results.clear()

    @property
    @persisted('_previous_results')
    def previous_results(self) -> ModelResult:
        """Return the previous results (see class docs).

        """
        rm: ModelResultManager = self.executor.result_manager
        if rm is None:
            rm = ModelError('No result manager available')
        return rm[self.previous_results_key]

    @property
    def current_results(self) -> Tuple[ModelResult, ModelResult]:
        """Return the current results (see class docs).

        """
        if self.executor.model_result is None:
            self.executor.load()
        return self.executor.model_result

    @property
    def comparison(self) -> DataComparison:
        """Load the results data and create a comparison instance read to write or
        jsonify.

        """
        prev, cur = self.previous_results, self.current_results
        prev_losses = prev.validation.losses
        cur_losses = cur.validation.losses
        cur_len = len(cur_losses)
        df = pd.DataFrame({'epoch': range(cur_len),
                           'previous': prev_losses[:cur_len],
                           'current': cur_losses})
        df['improvement'] = df['previous'] - df['current']
        return DataComparison(self.previous_results_key, prev, cur, df)
コード例 #7
0
ファイル: split.py プロジェクト: plandes/deeplearn
class StratifiedStashSplitKeyContainer(StashSplitKeyContainer):
    """Like :class:`.StashSplitKeyContainer` but data is stratified by a label
    (:obj:`partition_attr`) across each split.

    """
    partition_attr: str = field(default=None)
    """The label used to partition the strata across each split"""

    stratified_write: bool = field(default=True)
    """Whether or not to include the stratified counts when writing with
    :meth:`write`.

    """
    split_labels_path: Path = field(default=None)
    """If provided, the path is a pickled cache of
    :obj:`stratified_count_dataframe`.

    """
    def __post_init__(self):
        super().__post_init__()
        if self.partition_attr is None:
            raise DatasetError("Missing 'partition_attr' field")
        dfpath = self.split_labels_path
        if dfpath is None:
            dfpath = '_strat_split_labels'
        self._strat_split_labels = PersistedWork(dfpath, self, mkdir=True)

    def _create_splits(self) -> Dict[str, Tuple[str]]:
        dist_keys: Sequence[str] = self.distribution.keys()
        dist_last: str = next(iter(dist_keys))
        dists: Set[str] = set(dist_keys) - {dist_last}
        rows = []
        for k, v in self.stash.items():
            rows.append((k, getattr(v, self.partition_attr)))
        df = pd.DataFrame(rows, columns=['key', self.partition_attr])
        lab_splits: Dict[str, Set[str]] = collections.defaultdict(set)
        for lab, dfg in df.groupby(self.partition_attr):
            splits = {}
            keys: List[str] = dfg['key'].to_list()
            if self.shuffle:
                random.shuffle(keys)
            count = len(keys)
            for dist in dists:
                prop = self.distribution[dist]
                n_samples = math.ceil(float(count) * prop)
                samp = set(keys[:n_samples])
                splits[dist] = samp
                lab_splits[dist].update(samp)
                keys = keys[n_samples:]
            samp = set(keys)
            splits[dist_last] = samp
            lab_splits[dist_last].update(samp)
        assert sum(map(len, lab_splits.values())) == len(df)
        assert reduce(lambda a, b: a | b, lab_splits.values()) == \
            set(df['key'].tolist())
        shuf_splits = {}
        for lab, keys in lab_splits.items():
            if self.shuffle:
                keys = list(keys)
                random.shuffle(keys)
            shuf_splits[lab] = tuple(keys)
        return shuf_splits

    def _count_proportions_by_split(self) -> Dict[str, Dict[str, str]]:
        lab_counts = {}
        kbs = self.keys_by_split
        for split_name in sorted(kbs.keys()):
            keys = kbs[split_name]
            counts = collections.defaultdict(lambda: 0)
            for k in keys:
                item = self.stash[k]
                lab = getattr(item, self.partition_attr)
                counts[lab] += 1
            lab_counts[split_name] = counts
        return lab_counts

    @property
    @persisted('_strat_split_labels')
    def stratified_split_labels(self) -> pd.DataFrame:
        """A dataframe with all keys, their respective labels and split.

        """
        kbs = self.keys_by_split
        rows = []
        for split_name in sorted(kbs.keys()):
            keys = kbs[split_name]
            for k in keys:
                item = self.stash[k]
                lab = getattr(item, self.partition_attr)
                rows.append((split_name, k, lab))
        return pd.DataFrame(rows, columns='split_name id label'.split())

    def clear(self):
        super().clear()
        self._strat_split_labels.clear()

    @property
    def stratified_count_dataframe(self) -> pd.DataFrame:
        """A count summarization of :obj:`stratified_split_labels`.

        """
        df = self.stratified_split_labels
        df = df.groupby('split_name label'.split()).size().\
            reset_index(name='count')
        df['proportion'] = df['count'] / df['count'].sum()
        df = df.sort_values('split_name label'.split()).reset_index(drop=True)
        return df

    def _fmt_prop_by_split(self) -> Dict[str, Dict[str, str]]:
        df = self.stratified_count_dataframe
        tot = df['count'].sum()
        dsets: Dict[str, Dict[str, str]] = collections.OrderedDict()
        for split_name, dfg in df.groupby('split_name'):
            dfg['fmt'] = df['count'].apply(lambda x: f'{x/tot*100:.2f}%')
            dsets[split_name] = dict(dfg[['label', 'fmt']].values)
        return dsets

    def write(self, depth: int = 0, writer: TextIOBase = sys.stdout):
        if self.stratified_write:
            lab_counts: Dict[str, Dict[str, str]] = self._fmt_prop_by_split()
            self._write_dict(lab_counts, depth, writer)
            self._write_line(f'Total: {len(self.stash)}', depth, writer)
        else:
            super().write(depth, writer)
コード例 #8
0
ファイル: facade.py プロジェクト: plandes/deeplearn
class ModelFacade(PersistableContainer, Writable):
    """This class provides easy to use client entry points to the model executor,
    which trains, validates, tests, saves and loads the model.

    More common attributes, such as the learning rate and number of epochs, are
    properties that dispatch to :py:obj:`executor`.  For the others, go
    directly to the property.

    :see: :class:`zensols.deeplearn.domain.ModelSettings`

    """
    SINGLETONS = {}

    config: Configurable = field()
    """The configuraiton used to create the facade, and used to create a new
    configuration factory to load models.

    """
    config_factory: InitVar[ConfigFactory] = field(default=None)
    """The configuration factory used to create this facade, or ``None`` if no
    factory was used.

    """
    progress_bar: bool = field(default=True)
    """Create text/ASCII based progress bar if ``True``."""

    progress_bar_cols: Union[str, int] = field(default='term')
    """The number of console columns to use for the text/ASCII based progress
    bar.  If the value is ``term``, then use the terminal width.

    """
    executor_name: str = field(default='executor')
    """The configuration entry name for the executor, which defaults to
    ``executor``.

    """
    writer: TextIOBase = field(default=sys.stdout)
    """The writer to this in methods like :meth:`train`, and :meth:`test` for
    writing performance metrics results and predictions or ``None`` to not
    output them.

    """
    predictions_datafrmae_factory_class: Type[PredictionsDataFrameFactory] = \
        field(default=PredictionsDataFrameFactory)
    """The factory class used to create predictions.

    :see: :meth:`get_predictions_factory`

    """
    def __post_init__(self, config_factory: ConfigFactory):
        super().__init__()
        self._init_config_factory(config_factory)
        self._config_factory = PersistedWork('_config_factory', self)
        self._executor = PersistedWork('_executor', self)
        self.debuged = False
        if self.progress_bar_cols == 'term':
            try:
                term_width = os.get_terminal_size()[0]
                # make space for embedded validation loss messages
                self.progress_bar_cols = term_width - 5
            except OSError:
                logger.debug('unable to automatically determine ' +
                             'terminal width--skipping')
                self.progress_bar_cols = None

    @classmethod
    def get_singleton(cls, *args, **kwargs) -> Any:
        key = str(cls)
        inst = cls.SINGLETONS.get(key)
        if inst is None:
            inst = cls(*args, **kwargs)
            cls.SINGLETONS[key] = inst
        return inst

    def _init_config_factory(self, config_factory: ConfigFactory):
        if isinstance(config_factory, ImportConfigFactory):
            params = config_factory.__dict__
            keeps = set('reload shared reload_pattern'.split())
            params = {k: params[k] for k in set(params.keys()) & keeps}
            if logger.isEnabledFor(logging.DEBUG):
                logger.debug(f'import config factory params: {params}')
            self._config_factory_params = params
        else:
            self._config_factory_params = {}

    def _create_executor(self) -> ModelExecutor:
        """Create a new instance of an executor.  Used by :obj:`executor`.

        """
        logger.info('creating new executor')
        executor = self.config_factory(
            self.executor_name,
            progress_bar=self.progress_bar,
            progress_bar_cols=self.progress_bar_cols)
        return executor

    @property
    @persisted('_config_factory')
    def config_factory(self):
        """The configuration factory used to create facades.

        """
        return ImportConfigFactory(self.config, **self._config_factory_params)

    @property
    @persisted('_executor')
    def executor(self) -> ModelExecutor:
        """A cached instance of the executor tied to the instance of this class.

        """
        return self._create_executor()

    @property
    def net_settings(self) -> NetworkSettings:
        """Return the executor's network settings.

        """
        return self.executor.net_settings

    @property
    def model_settings(self) -> ModelSettings:
        """Return the executor's model settings.

        """
        return self.executor.model_settings

    @property
    def result_manager(self) -> ModelResultManager:
        """Return the executor's result manager.

        """
        rm: ModelResultManager = self.executor.result_manager
        if rm is None:
            rm = ModelError('No result manager available')
        return rm

    @property
    def feature_stash(self) -> Stash:
        """The stash used to generate the feature, which is not to be confused
        with the batch source stash ``batch_stash``.

        """
        return self.executor.feature_stash

    @property
    def batch_stash(self) -> BatchStash:
        """The stash used to encode and decode batches by the executor.

        """
        return self.executor.batch_stash

    @property
    def dataset_stash(self) -> DatasetSplitStash:
        """The stash used to encode and decode batches split by dataset.

        """
        return self.executor.dataset_stash

    @property
    def vectorizer_manager_set(self) -> FeatureVectorizerManagerSet:
        """Return the vectorizer manager set used for the facade.  This is taken from
        the executor's batch stash.

        """
        return self.batch_stash.vectorizer_manager_set

    @property
    def batch_metadata(self) -> BatchMetadata:
        """Return the batch metadata used on the executor.

        :see: :class:`zensols.deepnlp.model.module.EmbeddingNetworkSettings`

        """
        return self.batch_stash.batch_metadata

    @property
    def label_attribute_name(self):
        """Get the label attribute name.

        """
        bmeta = self.batch_metadata
        if bmeta is not None:
            return bmeta.mapping.label_attribute_name

    def _notify(self, event: str, context: Any = None):
        """Notify observers of events from this class.

        """
        self.model_settings.observer_manager.notify(event, self, context)

    def remove_metadata_mapping_field(self, attr: str) -> bool:
        """Remove a field by attribute if it exists across all metadata mappings.

        This is useful when a very expensive vectorizer slows down tasks, such
        as prediction, on a single run of a program.  For this use case,
        override :meth:`predict` to call this method before calling the super
        ``predict`` method.

        :param attr: the name of the field's attribute to remove

        :return: ``True`` if the field was removed, ``False`` otherwise

        """
        removed = False
        meta: BatchMetadata = self.batch_metadata
        mapping: BatchFeatureMapping
        for mapping in meta.mapping.manager_mappings:
            removed = removed or mapping.remove_field(attr)
        return removed

    @property
    def dropout(self) -> float:
        """The dropout for the entire network.

        """
        return self.net_settings.dropout

    @dropout.setter
    def dropout(self, dropout: float):
        """The dropout for the entire network.

        """
        self.net_settings.dropout = dropout

    @property
    def epochs(self) -> int:
        """The number of epochs for training and validation.

        """
        return self.model_settings.epochs

    @epochs.setter
    def epochs(self, n_epochs: int):
        """The number of epochs for training and validation.

        """
        self.model_settings.epochs = n_epochs

    @property
    def learning_rate(self) -> float:
        """The learning rate to set on the optimizer.

        """
        return self.model_settings.learning_rate

    @learning_rate.setter
    def learning_rate(self, learning_rate: float):
        """The learning rate to set on the optimizer.

        """
        self.executor.model_settings.learning_rate = learning_rate

    @property
    def cache_batches(self) -> bool:
        """The cache_batches for the entire network.

        """
        return self.model_settings.cache_batches

    @cache_batches.setter
    def cache_batches(self, cache_batches: bool):
        """The cache_batches for the entire network.

        """
        # if the caching strategy changed, be safe and deallocate and purge to
        # lazy recreate everything
        if self.model_settings.cache_batches != cache_batches:
            self.clear()
        self.model_settings.cache_batches = cache_batches

    def clear(self):
        """Clear out any cached executor.

        """
        if logger.isEnabledFor(logging.INFO):
            logger.info('clearing')
        executor = self.executor
        config_factory = self.config_factory
        executor.deallocate()
        config_factory.deallocate()
        self._executor.clear()
        self._config_factory.clear()

    def reload(self):
        """Clears all state and reloads the configuration.

        """
        self.clear()
        self.config.reload()

    def deallocate(self):
        super().deallocate()
        self.SINGLETONS.pop(str(self.__class__), None)

    @classmethod
    def load_from_path(cls, path: Path, *args, **kwargs) -> ModelFacade:
        """Construct a new facade from the data saved in a persisted model file.  This
        uses the :py:meth:`.ModelManager.load_from_path` to reconstruct the
        returned facade, which means some attributes are taken from default if
        not taken from ``*args`` or ``**kwargs``.

        Arguments:
           Passed through to the initializer of invoking class ``cls``.

        :return: a new instance of a :class:`.ModelFacade`

        :see: :meth:`.ModelManager.load_from_path`

        """
        if logger.isEnabledFor(logging.INFO):
            logger.info(f'loading from facade from {path}')
        mm = ModelManager.load_from_path(path)
        if 'executor_name' not in kwargs:
            kwargs['executor_name'] = mm.model_executor_name
        executor = mm.load_executor()
        executor.model_settings.path = path
        mm.config_factory.deallocate()
        facade: ModelFacade = cls(executor.config, *args, **kwargs)
        facade._config_factory.set(executor.config_factory)
        facade._executor.set(executor)
        return facade

    def debug(self, debug_value: Union[bool, int] = True):
        """Debug the model by setting the configuration to debug mode and invoking a
        single forward pass.  Logging must be configured properly to get the
        output, which is typically just invoking
        :py:meth:`logging.basicConfig`.

        :param debug_value: ``True`` turns on executor debugging; if an
                            ``int``, the higher the value, the more the logging

        """
        executor = self.executor
        self._configure_debug_logging()
        executor.debug = debug_value
        executor.progress_bar = False
        executor.model_settings.batch_limit = 1
        self.debuged = True
        executor.train()

    def persist_result(self):
        """Save the last recorded result during an :py:meth:`.Executor.train` or
        :py:meth:`.Executor.test` invocation to disk.  Optionally also save a
        plotted graphics file to disk as well when :obj:`persist_plot_result`
        is set to ``True``.

        Note that in Jupyter notebooks, this method has the side effect of
        plotting the results in the cell when ``persist_plot_result`` is
        ``True``.

        :param persist_plot_result: if ``True``, plot and save the graph as a
                                    PNG file to the results directory

        """
        executor = self.executor
        rmng: ModelResultManager = self.result_manager
        if executor.result_manager is not None:
            if logger.isEnabledFor(logging.DEBUG):
                logger.debug(f'dumping model result: {executor.model_result}')
            rmng.dump(executor.model_result)

    def train(self, description: str = None) -> ModelResult:
        """Train and test or just debug the model depending on the configuration.

        :param description: a description used in the results, which is useful
                            when making incremental hyperparameter changes to
                            the model

        """
        executor = self.executor
        executor.reset()
        logger.info('training...')
        self._notify('train_start', description)
        with time('trained'):
            res = executor.train(description)
        self._notify('train_end', description)
        return res

    def test(self, description: str = None) -> ModelResult:
        """Load the model from disk and test it.

        """
        if self.debuged:
            raise ModelError('Testing is not allowed in debug mode')
        executor = self.executor
        executor.load()
        logger.info('testing...')
        self._notify('test_start', description)
        with time('tested'):
            res = executor.test(description)
        if self.writer is not None:
            res.write(writer=self.writer)
        self._notify('test_end', description)
        return res

    def train_production(self, description: str = None) -> ModelResult:
        """Train on the training and test data sets, then test

        :param description: a description used in the results, which is useful
                            when making incremental hyperparameter changes to
                            the model

        """
        executor = self.executor
        executor.reset()
        if self.writer is not None:
            executor.write(writer=self.writer)
        logger.info('training...')
        self._notify('train_production_start', description)
        with time('trained'):
            res = executor.train_production(description)
        self._notify('train_production_end', description)
        return res

    def predict(self, datas: Iterable[Any]) -> Any:
        """Make ad-hoc predictions on batches without labels, and return the results.

        :param datas: the data predict on, each as a separate element as a data
                      point in a batch

        """
        executor: ModelExecutor = self.executor
        ms: ModelSettings = self.model_settings
        if ms.prediction_mapper_name is None:
            raise ModelError(
                f'The model settings ({ms.name}) is not configured to create '
                + "prediction batches: no set 'prediction_mapper'")
        pm: PredictionMapper = self.config_factory.new_instance(
            ms.prediction_mapper_name, datas, self.batch_stash)
        self._notify('predict_start')
        try:
            batches: List[Batch] = pm.batches
            if not executor.model_exists:
                executor.load()
            logger.info('predicting...')
            with time('predicted'):
                res: ModelResult = executor.predict(batches)
            eres: EpochResult = res.results[0]
            ret: Any = pm.map_results(eres)
        finally:
            self._notify('predict_end')
            pm.deallocate()
        return ret

    def stop_training(self):
        """Early stop training if the model is currently training.  This invokes the
        :meth:`.TrainManager.stop`, communicates to the training process to
        stop on the next check.

        :return: ``True`` if the application is configured to early stop and
                 the signal has not already been given

        """
        self._notify('stop_training')
        return self.executor.train_manager.stop()

    @property
    def last_result(self) -> ModelResult:
        """The last recorded result during an :meth:`.ModelExecutor.train` or
        :meth:`.ModelExecutor.test` invocation is used.

        """
        res = self.executor.model_result
        if res is None:
            rm: ModelResultManager = self.result_manager
            res = rm.load()
            if res is None:
                raise ModelError('No results found')
        return res

    def write_result(self,
                     depth: int = 0,
                     writer: TextIOBase = sys.stdout,
                     include_settings: bool = False,
                     include_converged: bool = False,
                     include_config: bool = False):
        """Load the last set of results from the file system and print them out.  The
        result to print is taken from :obj:`last_result`

        :param depth: the number of indentation levels

        :param writer: the data sink

        :param include_settings: whether or not to include model and network
                                 settings in the output

        :param include_config: whether or not to include the configuration in
                               the output

        """
        if logger.isEnabledFor(logging.INFO):
            logger.info('load previous results')
        res = self.last_result
        res.write(depth,
                  writer,
                  include_settings=include_settings,
                  include_converged=include_converged,
                  include_config=include_config)

    def plot_result(self,
                    result: ModelResult = None,
                    save: bool = False,
                    show: bool = False) -> ModelResult:
        """Plot results and optionally save and show them.  If this is called in a
        Jupyter notebook, the plot will be rendered in a cell.

        :param result: the result to plot, or if ``None``, use
                       :py:meth:`last_result`

        :param save: if ``True``, save the plot to the results directory with
                     the same naming as the last data results

        :param show: if ``True``, invoke ``matplotlib``'s ``show`` function to
                     visualize in a non-Jupyter environment

        :return: the result used to graph, which comes from the executor when
                 none is given to the invocation

        """
        result = self.last_result if result is None else result
        grapher = self.executor.result_manager.get_grapher()
        grapher.plot([result])
        if save:
            grapher.save()
        if show:
            grapher.show()
        return result

    def get_predictions_factory(self, column_names: List[str] = None,
                                transform: Callable[[DataPoint], tuple] = None,
                                batch_limit: int = sys.maxsize,
                                name: str = None) \
            -> PredictionsDataFrameFactory:
        """Generate a predictions factoty from the test data set.

        :param column_names: the list of string column names for each data item
                             the list returned from ``data_point_transform`` to
                             be added to the results for each label/prediction

        :param transform:

            a function that returns a tuple, each with an element respective of
            ``column_names`` to be added to the results for each
            label/prediction; if ``None`` (the default), ``str`` used (see the
            `Iris Jupyter Notebook
            <https://github.com/plandes/deeplearn/blob/master/notebook/iris.ipynb>`_
            example)

        :param batch_limit: the max number of batche of results to output

        :param name: the name/ID (name of the file sans extension in the
                     results directory) of the previously archived saved
                     results to fetch or ``None`` to get the last result

        """
        rm: ModelResultManager = self.result_manager
        res: ModelResult
        if name is None:
            res = self.last_result
            key: str = rm.get_last_key(False)
        else:
            res = rm.results_stash[name].model_result
            key: str = name
        if res is None:
            raise ModelError(f'No test results found: {name}')
        if not res.test.contains_results:
            raise ModelError('No test results found')
        path: Path = rm.key_to_path(key)
        return self.predictions_datafrmae_factory_class(
            path, res, self.batch_stash, column_names, transform, batch_limit)

    def get_predictions(self, *args, **kwargs) -> pd.DataFrame:
        """Generate a Pandas dataframe containing all predictions from the test data
        set.  This method is meant to be overridden by application specific
        facades to customize prediction output.

        :see: :meth:`get_predictions_factory`

        :param args: arguments passed to :meth:`get_predictions_factory`

        :param kwargs: arguments passed to :meth:`get_predictions_factory`

        """
        df_fac = self.get_predictions_factory(*args, **kwargs)
        return df_fac.dataframe

    def write_predictions(self, lines: int = 10):
        """Print the predictions made during the test phase of the model execution.

        :param lines: the number of lines of the predictions data frame to be
                      printed

        :param writer: the data sink

        """
        preds = self.get_predictions()
        print(preds.head(lines), file=self.writer)

    def get_result_analyzer(self, key: str = None,
                            cache_previous_results: bool = False) \
            -> ResultAnalyzer:
        """Return a results analyzer for comparing in flight training progress.

        """
        rm: ModelResultManager = self.result_manager
        if key is None:
            key = rm.get_last_key()
        return ResultAnalyzer(self.executor, key, cache_previous_results)

    @property
    def class_explorer(self) -> FacadeClassExplorer:
        return self._create_facade_explorer()

    def _create_facade_explorer(self) -> FacadeClassExplorer:
        """Return a facade explorer used to print the facade's object graph.

        """
        return FacadeClassExplorer()

    def write(self,
              depth: int = 0,
              writer: TextIOBase = None,
              include_executor: bool = True,
              include_metadata: bool = True,
              include_settings: bool = True,
              include_model: bool = True,
              include_config: bool = False,
              include_object_graph: bool = False):
        writer = self.writer if writer is None else writer
        writer = sys.stdout if writer is None else writer
        bmeta = None
        try:
            bmeta = self.batch_metadata
        except AttributeError:
            pass
        if include_executor:
            self._write_line(f'{self.executor.name}:', depth, writer)
            self.executor.write(depth + 1,
                                writer,
                                include_settings=include_settings,
                                include_model=include_model)
        if include_metadata and bmeta is not None:
            self._write_line('metadata:', depth, writer)
            bmeta.write(depth + 1, writer)
        if include_object_graph:
            self._write_line('graph:', depth, writer)
            ce = self._create_facade_explorer()
            ce.write(self, depth=depth + 1, writer=writer)
        if include_config:
            self._write_line('config:', depth, writer)
            self.config.write(depth + 1, writer)

    def _deallocate_config_instance(self, inst: Any):
        if isinstance(self.config_factory, ImportConfigFactory):
            inst = self.config_factory.clear_instance(inst)
        dealloc = isinstance(inst, Deallocatable)
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f'deallocate {inst}: {type(inst)}: {dealloc}')
        if dealloc:
            inst.deallocate()

    def _configure_debug_logging(self):
        """When debuging the model, configure the logging system for output.  The
        correct loggers need to be set to debug mode to print the model
        debugging information such as matrix shapes.

        """
        for name in ['zensols.deeplearn.model', __name__]:
            logging.getLogger(name).setLevel(logging.DEBUG)

    def _configure_cli_logging(self, info_loggers: List[str],
                               debug_loggers: List[str]):
        info_loggers.extend([
            # multi-process (i.e. batch creation)
            'zensols.multi.stash',
            'zensols.deeplearn.batch.multi',
            # validation/training loss messages
            'zensols.deeplearn.model.executor.status',
            __name__
        ])
        if not self.progress_bar:
            info_loggers.extend([
                # load messages
                'zensols.deeplearn.batch.stash',
                # save results messages
                'zensols.deeplearn.result',
                # validation/training loss messages
                'zensols.deeplearn.model.executor.progress',
                # model save/load
                'zensols.deeplearn.model.manager',
                # early stop messages
                'zensols.deeplearn.model.trainmng',
                # performance metrics formatting
                'zensols.deeplearn.model.format',
                # model save messages
                'zensols.deeplearn.result.manager',
                # observer module API messages
                'zensols.deeplearn.observer.status',
                #'zensols.deeplearn.observer.event',
                # CLI interface
                'zensols.deeplearn.cli.app'
            ])

    @staticmethod
    def configure_default_cli_logging(log_level: int = logging.WARNING):
        """Configure the logging system with the defaults.

        """
        fmt = '%(asctime)s[%(levelname)s]%(name)s: %(message)s'
        logging.basicConfig(format=fmt, level=log_level)

    def configure_cli_logging(self, log_level: int = None):
        """"Configure command line (or Python REPL) debugging.  Each facade can turn on
        name spaces that make sense as useful information output for long
        running training/testing iterations.

        This calls "meth:`_configure_cli_logging` to collect the names of
        loggers at various levels.

        """
        info = []
        debug = []
        if log_level is not None:
            self.configure_default_cli_logging(log_level)
        self._configure_cli_logging(info, debug)
        for name in info:
            logging.getLogger(name).setLevel(logging.INFO)
        for name in debug:
            logging.getLogger(name).setLevel(logging.DEBUG)

    def configure_jupyter(self,
                          log_level: int = logging.WARNING,
                          progress_bar_cols: int = 120):
        """Configures logging and other configuration related to a Jupyter notebook.
        This is just like :py:meth:`configure_cli_logging`, but adjusts logging
        for what is conducive for reporting in Jupyter cells.

        ;param log_level: the default logging level for the logging system

        :param progress_bar_cols: the number of columns to use for the progress
                                  bar

        """
        self.configure_cli_logging(log_level)
        for name in [
                # turn off loading messages
                'zensols.deeplearn.batch.stash',
                # turn off model save messages
                'zensols.deeplearn.result.manager'
        ]:
            logging.getLogger(name).setLevel(logging.WARNING)
        # number of columns for the progress bar
        self.executor.progress_bar_cols = progress_bar_cols
        # turn off console output (non-logging)
        self.writer = None

    @staticmethod
    def get_encode_sparse_matrices() -> bool:
        """Return whether or not sparse matricies are encoded.

        :see: :meth:`set_sparse`

        """
        return SparseTensorFeatureContext.USE_SPARSE

    @staticmethod
    def set_encode_sparse_matrices(use_sparse: bool = False):
        """If called before batches are created, encode all tensors the would be
        encoded as dense rather than sparse when ``use_sparse`` is ``False``.
        Oherwise, tensors will be encoded as sparse where it makes sense on a
        per vectorizer basis.

        """
        SparseTensorFeatureContext.USE_SPARSE = use_sparse
コード例 #9
0
class FacadeApplication(Deallocatable):
    """Base class for applications that use :class:`.ModelFacade`.

    """
    CLI_META = {'mnemonic_excludes': {'get_cached_facade', 'create_facade',
                                      'deallocate', 'clear_cached_facade'},
                'option_overrides': {'model_path': {'long_name': 'model',
                                                    'short_name': None}}}
    """Tell the command line app API to igonore subclass and client specific use
    case methods.

    """
    config: Configurable = field()
    """The config used to create facade instances."""

    facade_name: str = field(default='facade')
    """The client facade."""

    # simply copy this field and documentation to the implementation class to
    # add model path location (for those subclasses that don't have the
    # ``CLASS_INSPECTOR`` class level attribute set (see
    # :obj:`~zensols.util.introspect.inspect.ClassInspector.INSPECT_META`);
    # this can also be set as a parameter such as with
    # :methd:`.FacadeModelApplication.test`
    model_path: Path = field(default=None)
    """The path to the model or use the last trained model if not provided.

    """
    config_factory_args: Dict[str, Any] = field(default_factory=dict)
    """The arguments given to the :class:`~zensols.config.ImportConfigFactory`,
    which could be useful for reloading all classes while debugingg.

    """
    config_overwrites: Configurable = field(default=None)
    """A configurable that clobbers any configuration in :obj:`config` for those
    sections/options set.

    """
    def __post_init__(self):
        self.dealloc_resources = []
        self._cached_facade = PersistedWork('_cached_facade', self, True)

    def _enable_cli_logging(self, facade: ModelFacade):
        facade.progress_bar = False
        facade.configure_cli_logging()

    def create_facade(self) -> ModelFacade:
        """Create a new instance of the facade."""
        # we must create a new (non-shared) instance of the facade since it
        # will get deallcated after complete.
        config = self.config
        model_path = self.model_path
        if self.config_overwrites is not None:
            config = cp.deepcopy(config)
            config.merge(self.config_overwrites)
        if model_path is None:
            cf = ImportConfigFactory(config, **self.config_factory_args)
            facade: ModelFacade = cf.instance(self.facade_name)
            if logger.isEnabledFor(logging.DEBUG):
                logger.debug(f'created facade: {facade}')
            self.dealloc_resources.extend((cf, facade))
        else:
            if logger.isEnabledFor(logging.INFO):
                logger.info(f'loading model from {model_path}')
            with dealloc(ImportConfigFactory(
                    config, **self.config_factory_args)) as cf:
                cls: Type[ModelFacade] = cf.get_class(self.facade_name)
            facade: ModelFacade = cls.load_from_path(model_path)
            if logger.isEnabledFor(logging.DEBUG):
                logger.debug(f'created facade: {type(facade)} ' +
                             f'from path: {model_path}')
            self.dealloc_resources.append(facade)
        return facade

    @persisted('_cached_facade')
    def get_cached_facade(self, path: Path = None) -> ModelFacade:
        """Return a created facade that is cached in this application instance.

        """
        return self.create_facade()

    def clear_cached_facade(self):
        """Clear any cached facade this application instance.

        """
        if self._cached_facade.is_set():
            self._cached_facade().deallocate()
        self._cached_facade.clear()

    def deallocate(self):
        super().deallocate()
        self._try_deallocate(self.dealloc_resources, recursive=True)
        self._cached_facade.deallocate()
コード例 #10
0
class WeightedModelExecutor(ModelExecutor):
    """A class that weighs labels non-uniformly.  This class uses invert class
    sampling counts to help the minority label.

    """
    weighted_split_name: str = field(default='train')
    """The split name used to re-weight labels."""

    weighted_split_path: InitVar[Path] = field(default=None)
    """The path to the cached weithed labels."""

    use_weighted_criterion: bool = field(default=True)
    """If ``True``, use the class weights in the initializer of the criterion.
    Setting this to ``False`` effectively disables this class.

    """

    def __post_init__(self, weighted_split_path: Path):
        super().__post_init__()
        if weighted_split_path is None:
            path = '_label_counts'
        else:
            file_name = f'weighted-labels-{self.weighted_split_name}.dat'
            path = weighted_split_path / file_name
        self._label_counts = PersistedWork(path, self)

    def clear(self):
        super().clear()
        self._label_counts.clear()

    @persisted('_label_counts')
    def get_label_counts(self) -> Dict[int, int]:
        stash = self.dataset_stash.splits[self.weighted_split_name]
        label_counts = collections.defaultdict(lambda: 0)
        batches = tuple(stash.values())
        for batch in batches:
            for label in batch.get_labels():
                label_counts[label.item()] += 1
        for batch in batches:
            batch.deallocate()
        return dict(label_counts)

    @persisted('_class_weighs')
    def get_class_weights(self) -> torch.Tensor:
        """Compute invert class sampling counts to return the weighted class.

        """
        counts = self.get_label_counts().items()
        counts = map(lambda x: x[1], sorted(counts, key=lambda x: x[0]))
        counts = self.torch_config.from_iterable(counts)
        return counts.mean() / counts

    def get_label_statistics(self) -> Dict[str, Dict[str, Any]]:
        """Return a dictionary whose keys are the labels and values are dictionaries
        containing statistics on that label.

        """
        counts = self.get_label_counts()
        weights = self.get_class_weights().cpu().numpy()
        batch = next(iter(self.dataset_stash.values()))
        vec = batch.batch_stash.get_label_feature_vectorizer(batch)
        classes = vec.get_classes(range(weights.shape[0]))
        return {c[0]: {'index': c[1],
                       'count': counts[c[1]],
                       'weight': weights[c[1]]}
                for c in zip(classes, range(weights.shape[0]))}

    def _create_criterion(self) -> torch.optim.Optimizer:
        resolver = self.config_factory.class_resolver
        criterion_class_name = self.model_settings.criterion_class_name
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f'criterion: {criterion_class_name}')
        criterion_class = resolver.find_class(criterion_class_name)
        with time('weighted classes'):
            class_weights = self.get_class_weights()
        if logger.isEnabledFor(logging.INFO):
            logger.info(f'using class weights: {class_weights}')
        if self.use_weighted_criterion:
            inst = criterion_class(weight=class_weights)
        else:
            inst = criterion_class()
        return inst
コード例 #11
0
ファイル: stash.py プロジェクト: plandes/deeplearn
class DatasetSplitStash(DelegateStash, SplitStashContainer,
                        PersistableContainer, Writable):
    """A default implementation of :class:`.SplitStashContainer`.  However, it
    needs an instance of a :class:`.SplitKeyContainer`.  This implementation
    generates a separate stash instance for each data set split (i.e. ``train``
    vs ``test``).  Each split instance holds the data (keys and values) for
    each split.

    Stash instances by split are obtained with ``splits``, and will have
    a ``split`` attribute that give the name of the split.

    To maintain reproducibility, key ordering must be considered (see
    :class:`.SortedDatasetSplitStash`).

    :see: :meth:`.SplitStashContainer.splits`

    """
    split_container: SplitKeyContainer = field()
    """The instance that provides the splits in the dataset."""

    def __post_init__(self):
        super().__post_init__()
        PersistableContainer.__init__(self)
        if not isinstance(self.split_container, SplitKeyContainer):
            raise DatasetError('Expecting type SplitKeyContainer but ' +
                               f'got: {type(self.split_container)}')
        self._inst_split_name = None
        self._keys_by_split = PersistedWork('_keys_by_split', self)
        self._splits = PersistedWork('_splits', self)

    def _add_keys(self, split_name: str, to_populate: Dict[str, str],
                  keys: List[str]):
        to_populate[split_name] = tuple(keys)

    @persisted('_keys_by_split')
    def _get_keys_by_split(self) -> Dict[str, Tuple[str]]:
        """Return keys by split type (i.e. ``train`` vs ``test``) for only those keys
        available by the delegate backing stash.

        """
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug('creating in memory available keys data structure')
        with time('created key data structures', logging.DEBUG):
            delegate_keys = set(self.delegate.keys())
            avail_kbs = OrderedDict()
            for split, keys in self.split_container.keys_by_split.items():
                ks = list()
                for k in keys:
                    if k in delegate_keys:
                        ks.append(k)
                if logger.isEnabledFor(logging.DEBUG):
                    logger.debug(f'{split} has {len(ks)} keys')
                self._add_keys(split, avail_kbs, ks)
            return avail_kbs

    def _get_counts_by_key(self) -> Dict[str, int]:
        return dict(map(lambda i: (i[0], len(i[1])),
                        self.keys_by_split.items()))

    def check_key_consistent(self) -> bool:
        """Return if the :obj:`split_container` have the same key count divisiion as
        this stash's split counts.

        """
        return self.counts_by_key == self.split_container.counts_by_key

    def keys(self) -> Iterable[str]:
        self.prime()
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f'keys for {self.split_name}')
        kbs = self.keys_by_split
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f'obtained keys for {self.split_name}')
        if self.split_name is None:
            return chain.from_iterable(kbs.values())
        else:
            return kbs[self.split_name]

    def exists(self, name: str) -> bool:
        if self.split_name is None:
            return super().exists(name)
        else:
            return name in self.keys_by_split[self.split_name]

    def load(self, name: str) -> Any:
        if self.split_name is None or \
           name in self.keys_by_split[self.split_name]:
            return super().load(name)

    def get(self, name: str, default: Any = None) -> Any:
        if self.split_name is None or \
           name in self.keys_by_split[self.split_name]:
            return super().get(name)
        return default

    def prime(self):
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug('priming ds split stash')
        super().prime()
        self.keys_by_split

    def _delegate_has_data(self):
        return not isinstance(self.delegate, PreemptiveStash) or \
            self.delegate.has_data

    def deallocate(self):
        if id(self.delegate) != id(self.split_container):
            self._try_deallocate(self.delegate)
        self._try_deallocate(self.split_container)
        self._keys_by_split.deallocate()
        if self._splits.is_set():
            splits = tuple(self._splits().values())
            self._splits.clear()
            if logger.isEnabledFor(logging.DEBUG):
                logger.debug(f'deallocating: {len(splits)} stash data splits')
            for v in splits:
                self._try_deallocate(v, recursive=True)
        self._splits.deallocate()
        super().deallocate()

    def clear_keys(self):
        """Clear any cache state for keys, and keys by split.  It does this by clearing
        the key state for stash, and then the :meth:`clear` of the
        :obj:`split_container`.

        """
        self.split_container.clear()
        self._keys_by_split.clear()

    def clear(self):
        """Clear and destory key and delegate data.

        """
        del_has_data = self._delegate_has_data()
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f'clearing: {del_has_data}')
        if del_has_data:
            if logger.isEnabledFor(logging.DEBUG):
                logger.debug('clearing delegate and split container')
            super().clear()
            self.clear_keys()

    def _get_split_names(self) -> Set[str]:
        return self.split_container.split_names

    def _get_split_name(self) -> str:
        return self._inst_split_name

    @persisted('_splits')
    def _get_splits(self) -> Dict[str, Stash]:
        """Return an instance of ta stash that contains only the data for a split.

        :param split: the name of the split of the instance to get
                      (i.e. ``train``, ``test``).

        """
        self.prime()
        stashes = OrderedDict()
        for split_name in self.split_names:
            clone = self.__class__(
                delegate=self.delegate, split_container=self.split_container)
            clone._keys_by_split.deallocate()
            clone._splits.deallocate()
            clone.__dict__.update(self.__dict__)
            clone._inst_split_name = split_name
            stashes[split_name] = clone
        return stashes

    def write(self, depth: int = 0, writer: TextIOBase = sys.stdout,
              include_delegate: bool = False):
        self._write_line('split stash splits:', depth, writer)
        t = 0
        for ks in self.split_container.keys_by_split.values():
            t += len(ks)
        for k, ks in self.split_container.keys_by_split.items():
            ln = len(ks)
            self._write_line(f'{k}: {ln} ({ln/t*100:.1f}%)',
                             depth + 1, writer)
        self._write_line(f'total: {t}', depth + 1, writer)
        ckc = self.check_key_consistent()
        self._write_line(f'total this instance: {len(self)}', depth, writer)
        self._write_line(f'keys consistent: {ckc}', depth, writer)
        if include_delegate and isinstance(self.delegate, Writable):
            self._write_line('delegate:', depth, writer)
            self.delegate.write(depth + 1, writer)
コード例 #12
0
ファイル: resource.py プロジェクト: plandes/deepnlp
class TransformerResource(PersistableContainer, Dictable):
    """A utility base class that allows configuration and creates various
    huggingface models.

    """
    name: str = field()
    """The name of the model given by the configuration.  Used for debugging.

    """

    torch_config: TorchConfig = field()
    """The config device used to copy the embedding data."""

    model_id: str = field()
    """The ID of the model (i.e. ``bert-base-uncased``).  If this is not set, is
    derived from the ``model_name`` and ``case``.

    Token embeding using :class:`.TransformerEmbedding` as been tested with:

      * ``bert-base-cased``

      * ``bert-large-cased``

      * ``roberta-base``

      * ``distilbert-base-cased``

    :see: `Pretrained Models <https://huggingface.co/transformers/pretrained_models.html>`_

    """
    cased: bool = field(default=None)
    """``True`` for the case sensitive, ``False`` (default) otherwise.  The negated
    value of it is also used as the ``do_lower_case`` parameter in the
    ``*.from_pretrained`` calls to huggingface transformers.

    """
    trainable: bool = field(default=False)
    """If ``False`` the weights on the transformer model are frozen and the use of
    the model (including in subclasses) turn off autograd when executing..

    """
    args: Dict[str, Any] = field(default_factory=dict)
    """Additional arguments to pass to the `from_pretrained` method for the
    tokenizer and the model.

    """
    tokenizer_args: Dict[str, Any] = field(default_factory=dict)
    """Additional arguments to pass to the `from_pretrained` method for the
    tokenizer.

    """

    model_args: Dict[str, Any] = field(default_factory=dict)
    """Additional arguments to pass to the `from_pretrained` method for the
    model.

    """
    model_class: str = field(default='transformers.AutoModel')
    """The model fully qualified class used to create models with the
    ``from_pretrained`` static method.

    """
    tokenizer_class: str = field(default='transformers.AutoTokenizer')
    """The model fully qualified class used to create tokenizers with the
    ``from_pretrained`` static method.

    """
    cache: InitVar[bool] = field(default=False)
    """When set to ``True`` cache a global space model using the parameters from
    the first instance creation.

    """
    cache_dir: Path = field(default=None)
    """The directory that is contains the BERT model(s)."""
    def __post_init__(self, cache: bool):
        super().__init__()
        if self.cache_dir is not None and not self.cache_dir.exists():
            if logger.isEnabledFor(logging.DEBUG):
                logger.info(f'creating cache directory: {self.cache_dir}')
            self.cache_dir.mkdir(parents=True, exist_ok=True)
        if self.cased is None:
            if self.model_id.find('uncased') >= 0:
                self.cased = False
            else:
                logger.info("'cased' not given--assuming a cased model")
                self.cased = True
        self._tokenizer = PersistedWork('_tokenzier', self, cache)
        self._model = PersistedWork('_model', self, cache)
        if self.cache_dir is not None and not self.cache_dir.exists():
            if logger.isEnabledFor(logging.DEBUG):
                logger.info(f'creating cache directory: {self.cache_dir}')
            self.cache_dir.mkdir(parents=True, exist_ok=True)
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f'id: {self.model_id}, cased: {self.cased}')

    @property
    def cached(self) -> bool:
        """If the model is cached.

        :see: :obj:`cache`

        """
        return self._tokenizer.cache_global

    @cached.setter
    def cached(self, cached: bool):
        """If the model is cached.

        :see: :obj:`cache`

        """
        self._tokenizer.cache_global = cached
        self._model.cache_global = cached

    def _is_roberta(self):
        return self.model_id.find('roberta') > -1

    def _create_tokenizer_class(self) -> Type[PreTrainedTokenizer]:
        """Create the huggingface class used for tokenizer."""
        ci = ClassImporter(self.tokenizer_class)
        return ci.get_class()

    @property
    @persisted('_tokenizer')
    def tokenizer(self) -> PreTrainedTokenizer:
        params = {'do_lower_case': not self.cased}
        if self.cache_dir is not None:
            params['cache_dir'] = str(self.cache_dir.absolute())
        params.update(self.args)
        params.update(self.tokenizer_args)
        if self._is_roberta():
            if not self.cased:
                raise TransformerError('RoBERTa only supports cased models')
            params['add_prefix_space'] = True
        cls = self._create_tokenizer_class()
        return cls.from_pretrained(self.model_id, **params)

    def _create_model_class(self) -> Type[PreTrainedModel]:
        ci = ClassImporter(self.model_class)
        return ci.get_class()

    @property
    @persisted('_model')
    def model(self) -> PreTrainedModel:
        # load pre-trained model (weights)
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f'loading model: {self.model_id}')
        params = {}
        if self.cache_dir is not None:
            params['cache_dir'] = str(self.cache_dir.absolute())
        params.update(self.args)
        params.update(self.model_args)
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f'creating model using: {params}')
        with time(f'loaded model from pretrained {self.model_id}'):
            cls = self._create_model_class()
            model = cls.from_pretrained(self.model_id, **params)
            #print('position types', model.embeddings.position_ids.dtype)
            #print('token types', model.embedding.token_type_ids.dtype)
        # put the model in `evaluation` mode, meaning feed-forward operation.
        if self.trainable:
            logger.debug('model is trainable')
        else:
            logger.debug('turning off grad for non-trainable transformer')
            model.eval()
            for param in model.base_model.parameters():
                param.requires_grad = False
        model = self.torch_config.to(model)
        return model

    def _from_dictable(self, *args, **kwargs) -> Dict[str, Any]:
        dct = super()._from_dictable(*args, **kwargs)
        secs = collections.OrderedDict()
        name: str
        param: Tensor
        n_total_params = 0
        for name, param in self.model.named_parameters():
            prefix = name[:name.find('.')]
            layer: Dict[str, Tuple[int, int]] = secs.get(prefix)
            if layer is None:
                layer = collections.OrderedDict()
                secs[prefix] = layer
            shape: Tuple[int, int] = tuple(param.shape)
            n_total_params += reduce(lambda x, y: x * y, shape)
            layer[name] = shape
        dct['model'] = {'sections': secs, 'params': n_total_params}
        return dct

    def _write_dict(self, data: dict, depth: int, writer: TextIOBase):
        is_param = False
        if len(data) > 0:
            val = next(iter(data.values()))
            is_param = (isinstance(val, tuple) and len(val) == 2)
        super()._write_dict(data, depth, writer, is_param)

    def clear(self):
        self._tokenizer.clear()
        self._model.clear()

    def __str__(self) -> str:
        return f'{self.name}: id: {self.model_id}, cased: {self.cased}'