示例#1
0
def load(configuration, extra_args=None, pretty=False):
    a = config.args()
    if extra_args:
        b = a
        a = b
        a.update(b)
    context = {}
    for name, obj in configuration.items():
        # handle modules with registry
        if hasattr(obj, 'registry'):
            obj = obj.registry

        if isinstance(obj, util.Registry):
            registry = obj
            registered = registry.registered
            default = registry.default
            if default is None and name not in a:
                raise ValueError(f'Missing argument for `{name}`')
            if a.get(name, default) not in registered:
                raise ValueError('no `{}` found for a `{}`'.format(
                    a.get(name, default), name))
            cls = registered[a.get(name, default)]
        else:
            cls = obj

        spec = inspect.getargspec(cls.__init__)
        args = []
        for arg in spec.args:
            # HACK
            if arg == 'dataset' and name == 'valid_pred':
                arg = 'valid_ds'
            if arg == 'dataset' and name == 'test_pred':
                arg = 'test_ds'

            if arg == 'self':
                continue
            elif arg in context:
                args.append(context[arg])
            elif arg == 'random':
                args.append(
                    np.random.RandomState(int(config.args()['random_seed'])))
            elif arg == 'logger':
                logger_name = (name + ':' + a.get(name, default)).rstrip(':')
                args.append(log.Logger(logger_name))
            elif arg == 'config':
                args.append(config.apply_config(name, a, cls))
            else:
                raise ValueError(f'cannot match argument `{arg}` for `{cls}`')
        context[name] = cls(*args)
    if pretty:
        _log_config(configuration.keys(), context)

    return context
示例#2
0
    def execute(self):
        self.logger = log.Logger(self.__class__.__name__)
        self.ranker.initialize(self.random.state)
        self.trainer.initialize(self.random.state, self.ranker,
                                self.train_dataset)
        self.valid_pred.initialize(self.predictor_path, [self.val_metric],
                                   self.random.state, self.ranker,
                                   self.val_dataset)
        self.train_dataset.initialize(self.ranker.vocab)
        self.val_dataset.initialize(self.ranker.vocab)

        validator = self.valid_pred.pred_ctxt()

        top_epoch, top_value, top_train_ctxt, top_valid_ctxt = None, None, None, None
        prev_train_ctxt = None

        file_output = {'validation_metric': self.val_metric}

        for train_ctxt in self.trainer.iter_train(
                only_cached=self.only_cached):
            # Report progress
            progress(train_ctxt['epoch'] / self.max_epoch)

            if prev_train_ctxt is not None and top_epoch is not None and prev_train_ctxt is not top_train_ctxt:
                self._purge_weights(prev_train_ctxt)

            if train_ctxt['epoch'] >= 0 and not self.only_cached:
                message = self._build_train_msg(train_ctxt)

                if train_ctxt['cached']:
                    self.logger.debug(f'[train] [cached] {message}')
                else:
                    self.logger.debug(f'[train] {message}')

            if train_ctxt['epoch'] == -1 and not self.initial_eval:
                continue

            # Compute validation metrics
            valid_ctxt = dict(validator(train_ctxt))

            message = self._build_valid_msg(valid_ctxt)

            if valid_ctxt['epoch'] >= self.warmup:
                if self.val_metric == '':
                    top_epoch = valid_ctxt['epoch']
                    top_train_ctxt = train_ctxt
                    top_valid_ctxt = valid_ctxt
                elif top_value is None or valid_ctxt['metrics'][
                        self.val_metric] > top_value:
                    message += ' <---'
                    top_epoch = valid_ctxt['epoch']
                    top_value = valid_ctxt['metrics'][self.val_metric]
                    if top_train_ctxt is not None:
                        self._purge_weights(top_train_ctxt)
                    top_train_ctxt = train_ctxt
                    top_valid_ctxt = valid_ctxt
            else:
                if prev_train_ctxt is not None:
                    self._purge_weights(prev_train_ctxt)

            if not self.only_cached:
                if valid_ctxt['cached']:
                    self.logger.debug(f'[valid] [cached] {message}')
                else:
                    self.logger.info(f'[valid] {message}')

            if top_epoch is not None:
                epochs_since_imp = valid_ctxt['epoch'] - top_epoch
                if self.early_stop > 0 and epochs_since_imp >= self.early_stop:
                    self.logger.warn(
                        'stopping after epoch {epoch} ({early_stop} epochs with no '
                        'improvement to {val_metric})'.format(
                            **valid_ctxt, **self.__dict__))
                    break

            if train_ctxt['epoch'] >= self.max_epoch:
                self.logger.warn(
                    'stopping after epoch {max_epoch} (max_epoch)'.format(
                        **self.__dict__))
                break

            prev_train_ctxt = train_ctxt

        self.logger.info('top validation epoch={} {}={}'.format(
            top_epoch, self.val_metric, top_value))

        file_output.update({
            'valid_epoch': top_epoch,
            'valid_run': top_valid_ctxt['run_path'],
            'valid_path': top_train_ctxt['ranker_path'],
            'valid_metrics': top_valid_ctxt['metrics'],
        })

        with open(self.valtest_path, 'wt') as f:
            json.dump(file_output, f)
            f.write('\n')

        self.logger.info('valid run at {}'.format(valid_ctxt['run_path']))
        self.logger.info('valid ' + self._build_valid_msg(top_valid_ctxt))
示例#3
0
import sys
import bdb

logger = None


def handle_exception(exc_type, exc_value, exc_traceback):
    if logger is None:
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return
    if issubclass(exc_type, KeyboardInterrupt):
        logger.debug("Keyboard Interrupt",
                     exc_info=(exc_type, exc_value, exc_traceback))
    elif issubclass(exc_type, bdb.BdbQuit):
        logger.debug("Quit", exc_info=(exc_type, exc_value, exc_traceback))
    else:
        logger.critical("Uncaught exception",
                        exc_info=(exc_type, exc_value, exc_traceback))


sys.excepthook = handle_exception

from onir import log

logger = log.Logger('onir')

from onir import util, injector, metrics, datasets, interfaces, rankers, config, trainers, predictors, vocab, pipelines
示例#4
0
 def __init__(self):
     self.logger = log.Logger(self.__class__.__name__)