def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) args = self.args args_sacred = get_maybe_missing_args(args.loggers, 'sacred') if args_sacred is None: self.use_sacred = False else: self.use_sacred = args_sacred.use_sacred if self.use_sacred: self.sacred_exp = Experiment(args.exp_name) self.sacred_exp.captured_out_filter = apply_backspaces_and_linefeeds self.sacred_exp.add_config(vars(args)) for source in self.get_sources(): self.sacred_exp.add_source_file(source) if not args_sacred.sacred.mongodb_disable: url = "{0.mongodb_url}:{0.mongodb_port}".format(args) if (args_sacred.mongodb_name is not None and args_sacred.mongodb_name != ''): db_name = args_sacred.mongodb_name else: db_name = args_sacred.mongodb_prefix + ''.join( filter(str.isalnum, args.dataset_name.lower())) self.console_log.info('Connect to MongoDB@{}:{}'.format( url, db_name)) self.sacred_exp.observers.append( MongoObserver.create(url=url, db_name=db_name))
def _configure_scheduler_optimizer(self) -> dict: args = self.args args_scheduler = get_maybe_missing_args(args.optimizer, 'scheduler') if args_scheduler is None or not args_scheduler: return {} scheduler_name = args_scheduler.name scheduler_params = args_scheduler.params # -- Get optimizer from args scheduler_class = get_scheduler_optimizer(scheduler_name) # -- Instantiate optimizer with specific parameters scheduler = scheduler_class(optimizer=self.optimizer, **scheduler_params) return scheduler
def early_stopping(self, current_stats: dict) -> (bool, bool): """This function implements a classic early stopping procedure with patience. An example of the arguments that can be used is provided. ```yml early_stopping: dataset: 'validation' metric: 'validation/y_acc' patience: 10 mode: 'max' # or 'min' train_until_end: False # it keeps going until the end of training warmup: -1 # number of epochs to skip early stopping # (it disable patience count) ``` Args: current_stats (dict): a possibly nested dictionary of the results from validation at current epoch. Keys should follow an hierarchy as dataset->stats->metric. For custom stats, `metric` can be retrieved overriding the method `self._get_metric_early_stopping`. Returns: A tuple (is_best, is_stop) describing the status of early stopping. `is_stop` is also assigned to the self object. """ args = get_maybe_missing_args(self.args, 'early_stopping') if args is None: # -- Do not save in best, and do not stop return False, False dataset = args.dataset metric = args.metric patience = args.patience mode = args.mode train_until_end = get_maybe_missing_args(args, 'train_until_end', False) warmup = get_maybe_missing_args(args, 'warmup', -1) compare_op = max if mode == "max" else min is_best = False current = self._get_metric_early_stopping(current_stats) if current is None: raise ValueError( "Metric {} does not exist in current_stats['{}'] \n" "It contains only these keys: {}".format( metric, dataset, str(recursive_keys(current_stats)))) # -- first epoch, initialize and do not stop if len(self._best_stats) == 0: self._update_best_stats(current) return True, False best = self._best_stats[-1][1] if compare_op(current, best) != best: self._update_best_stats(current) is_best = True else: if self.epoch > warmup: self._beaten_epochs += 1 if (self._beaten_epochs >= patience and not train_until_end and self.epoch > warmup): self._early_stop = True return is_best, self._early_stop
def get_maybe_missing_args(self, key, default=None): return get_maybe_missing_args(self.args, key, default)
def configure_loggers(self, external_logdir=None): """ YAPT supports logging experiments with multiple loggers at the same time. By default, an experiment is logged by TensorBoardLogger. external_logdir: if you want to override the logdir. It could be useful to use the same directory used by Ray Tune. """ if external_logdir is not None: self.args.loggers.logdir = self._logdir = external_logdir self.console_log.warning( "external logdir {}".format(external_logdir)) if self.args.loggers.debug: self.console_log.warning( "Debug flag is disable when external logdir is used") # -- No loggers in test mode if self.mode == 'test': return None args_logger = self.args.loggers loggers = dict() safe_mkdirs(self._logdir, exist_ok=True) # -- Tensorboard is not defualt anymore if args_logger.tensorboard: loggers['tb'] = TensorBoardLogger(self._logdir) # -- Neptune if (get_maybe_missing_args(args_logger, 'neptune') is not None and len(args_logger.neptune.keys()) > 0): # TODO: because of api key and sesitive data, # neptune project should be per_project in a separate file # TODO: THIS THIS SHOULD BE DONE FOR EACH LEAF args_neptune = dict() for key, val in args_logger.neptune.items(): if isinstance(val, ListConfig): val = list(val) elif isinstance(val, DictConfig): val = dict(val) args_neptune[key] = val # -- Recursively search for files or extensions if 'upload_source_files' in args_neptune.keys(): source_files = [ str(path) for ext in args_neptune['upload_source_files'] for path in Path('./').rglob(ext) ] del args_neptune['upload_source_files'] else: source_files = None loggers['neptune'] = NeptuneLogger( api_key=os.environ['NEPTUNE_API_TOKEN'], experiment_name=self.args.exp_name, params=flatten_dict(self.args), logger=self.console_log, upload_source_files=source_files, **(args_neptune)) # Wrap loggers loggers = LoggerDict(loggers) return loggers