def load(configuration, extra_args=None, pretty=False): a = config.args() if extra_args: b = a a = b a.update(b) context = {} for name, obj in configuration.items(): # handle modules with registry if hasattr(obj, 'registry'): obj = obj.registry if isinstance(obj, util.Registry): registry = obj registered = registry.registered default = registry.default if default is None and name not in a: raise ValueError(f'Missing argument for `{name}`') if a.get(name, default) not in registered: raise ValueError('no `{}` found for a `{}`'.format( a.get(name, default), name)) cls = registered[a.get(name, default)] else: cls = obj spec = inspect.getargspec(cls.__init__) args = [] for arg in spec.args: # HACK if arg == 'dataset' and name == 'valid_pred': arg = 'valid_ds' if arg == 'dataset' and name == 'test_pred': arg = 'test_ds' if arg == 'self': continue elif arg in context: args.append(context[arg]) elif arg == 'random': args.append( np.random.RandomState(int(config.args()['random_seed']))) elif arg == 'logger': logger_name = (name + ':' + a.get(name, default)).rstrip(':') args.append(log.Logger(logger_name)) elif arg == 'config': args.append(config.apply_config(name, a, cls)) else: raise ValueError(f'cannot match argument `{arg}` for `{cls}`') context[name] = cls(*args) if pretty: _log_config(configuration.keys(), context) return context
def execute(self): self.logger = log.Logger(self.__class__.__name__) self.ranker.initialize(self.random.state) self.trainer.initialize(self.random.state, self.ranker, self.train_dataset) self.valid_pred.initialize(self.predictor_path, [self.val_metric], self.random.state, self.ranker, self.val_dataset) self.train_dataset.initialize(self.ranker.vocab) self.val_dataset.initialize(self.ranker.vocab) validator = self.valid_pred.pred_ctxt() top_epoch, top_value, top_train_ctxt, top_valid_ctxt = None, None, None, None prev_train_ctxt = None file_output = {'validation_metric': self.val_metric} for train_ctxt in self.trainer.iter_train( only_cached=self.only_cached): # Report progress progress(train_ctxt['epoch'] / self.max_epoch) if prev_train_ctxt is not None and top_epoch is not None and prev_train_ctxt is not top_train_ctxt: self._purge_weights(prev_train_ctxt) if train_ctxt['epoch'] >= 0 and not self.only_cached: message = self._build_train_msg(train_ctxt) if train_ctxt['cached']: self.logger.debug(f'[train] [cached] {message}') else: self.logger.debug(f'[train] {message}') if train_ctxt['epoch'] == -1 and not self.initial_eval: continue # Compute validation metrics valid_ctxt = dict(validator(train_ctxt)) message = self._build_valid_msg(valid_ctxt) if valid_ctxt['epoch'] >= self.warmup: if self.val_metric == '': top_epoch = valid_ctxt['epoch'] top_train_ctxt = train_ctxt top_valid_ctxt = valid_ctxt elif top_value is None or valid_ctxt['metrics'][ self.val_metric] > top_value: message += ' <---' top_epoch = valid_ctxt['epoch'] top_value = valid_ctxt['metrics'][self.val_metric] if top_train_ctxt is not None: self._purge_weights(top_train_ctxt) top_train_ctxt = train_ctxt top_valid_ctxt = valid_ctxt else: if prev_train_ctxt is not None: self._purge_weights(prev_train_ctxt) if not self.only_cached: if valid_ctxt['cached']: self.logger.debug(f'[valid] [cached] {message}') else: self.logger.info(f'[valid] {message}') if top_epoch is not None: epochs_since_imp = valid_ctxt['epoch'] - top_epoch if self.early_stop > 0 and epochs_since_imp >= self.early_stop: self.logger.warn( 'stopping after epoch {epoch} ({early_stop} epochs with no ' 'improvement to {val_metric})'.format( **valid_ctxt, **self.__dict__)) break if train_ctxt['epoch'] >= self.max_epoch: self.logger.warn( 'stopping after epoch {max_epoch} (max_epoch)'.format( **self.__dict__)) break prev_train_ctxt = train_ctxt self.logger.info('top validation epoch={} {}={}'.format( top_epoch, self.val_metric, top_value)) file_output.update({ 'valid_epoch': top_epoch, 'valid_run': top_valid_ctxt['run_path'], 'valid_path': top_train_ctxt['ranker_path'], 'valid_metrics': top_valid_ctxt['metrics'], }) with open(self.valtest_path, 'wt') as f: json.dump(file_output, f) f.write('\n') self.logger.info('valid run at {}'.format(valid_ctxt['run_path'])) self.logger.info('valid ' + self._build_valid_msg(top_valid_ctxt))
import sys import bdb logger = None def handle_exception(exc_type, exc_value, exc_traceback): if logger is None: sys.__excepthook__(exc_type, exc_value, exc_traceback) return if issubclass(exc_type, KeyboardInterrupt): logger.debug("Keyboard Interrupt", exc_info=(exc_type, exc_value, exc_traceback)) elif issubclass(exc_type, bdb.BdbQuit): logger.debug("Quit", exc_info=(exc_type, exc_value, exc_traceback)) else: logger.critical("Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback)) sys.excepthook = handle_exception from onir import log logger = log.Logger('onir') from onir import util, injector, metrics, datasets, interfaces, rankers, config, trainers, predictors, vocab, pipelines
def __init__(self): self.logger = log.Logger(self.__class__.__name__)