Ejemplo n.º 1
0
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        args = self.args

        args_sacred = get_maybe_missing_args(args.loggers, 'sacred')
        if args_sacred is None:
            self.use_sacred = False
        else:
            self.use_sacred = args_sacred.use_sacred

        if self.use_sacred:
            self.sacred_exp = Experiment(args.exp_name)
            self.sacred_exp.captured_out_filter = apply_backspaces_and_linefeeds
            self.sacred_exp.add_config(vars(args))
            for source in self.get_sources():
                self.sacred_exp.add_source_file(source)

            if not args_sacred.sacred.mongodb_disable:
                url = "{0.mongodb_url}:{0.mongodb_port}".format(args)
                if (args_sacred.mongodb_name is not None
                        and args_sacred.mongodb_name != ''):
                    db_name = args_sacred.mongodb_name
                else:
                    db_name = args_sacred.mongodb_prefix + ''.join(
                        filter(str.isalnum, args.dataset_name.lower()))

                self.console_log.info('Connect to MongoDB@{}:{}'.format(
                    url, db_name))
                self.sacred_exp.observers.append(
                    MongoObserver.create(url=url, db_name=db_name))
Ejemplo n.º 2
0
    def _configure_scheduler_optimizer(self) -> dict:
        args = self.args

        args_scheduler = get_maybe_missing_args(args.optimizer, 'scheduler')
        if args_scheduler is None or not args_scheduler:
            return {}

        scheduler_name = args_scheduler.name
        scheduler_params = args_scheduler.params

        # -- Get optimizer from args
        scheduler_class = get_scheduler_optimizer(scheduler_name)

        # -- Instantiate optimizer with specific parameters
        scheduler = scheduler_class(optimizer=self.optimizer,
                                    **scheduler_params)
        return scheduler
Ejemplo n.º 3
0
    def early_stopping(self, current_stats: dict) -> (bool, bool):
        """This function implements a classic early stopping
        procedure with patience. An example of the arguments that
        can be used is provided.

        ```yml
        early_stopping:
          dataset: 'validation'
          metric: 'validation/y_acc'
          patience: 10
          mode: 'max'              # or 'min'
          train_until_end: False   # it keeps going until the end of training
          warmup: -1               # number of epochs to skip early stopping
                                   # (it disable patience count)

        ```

        Args:
            current_stats (dict): a possibly nested dictionary of the results
                from validation at current epoch. Keys should follow an
                hierarchy as dataset->stats->metric. For custom stats,
                `metric` can be retrieved overriding the method
                `self._get_metric_early_stopping`.

        Returns:
            A tuple (is_best, is_stop) describing the status of early stopping.
            `is_stop` is also assigned to the self object.
        """

        args = get_maybe_missing_args(self.args, 'early_stopping')
        if args is None:
            # -- Do not save in best, and do not stop
            return False, False

        dataset = args.dataset
        metric = args.metric
        patience = args.patience
        mode = args.mode
        train_until_end = get_maybe_missing_args(args, 'train_until_end',
                                                 False)
        warmup = get_maybe_missing_args(args, 'warmup', -1)
        compare_op = max if mode == "max" else min

        is_best = False
        current = self._get_metric_early_stopping(current_stats)

        if current is None:
            raise ValueError(
                "Metric {} does not exist in current_stats['{}'] \n"
                "It contains only these keys: {}".format(
                    metric, dataset, str(recursive_keys(current_stats))))

        # -- first epoch, initialize and do not stop
        if len(self._best_stats) == 0:
            self._update_best_stats(current)
            return True, False

        best = self._best_stats[-1][1]
        if compare_op(current, best) != best:
            self._update_best_stats(current)
            is_best = True
        else:
            if self.epoch > warmup:
                self._beaten_epochs += 1

        if (self._beaten_epochs >= patience and not train_until_end
                and self.epoch > warmup):
            self._early_stop = True

        return is_best, self._early_stop
Ejemplo n.º 4
0
 def get_maybe_missing_args(self, key, default=None):
     return get_maybe_missing_args(self.args, key, default)
Ejemplo n.º 5
0
    def configure_loggers(self, external_logdir=None):
        """
        YAPT supports logging experiments with multiple loggers at the same time.
        By default, an experiment is logged by TensorBoardLogger.

        external_logdir: if you want to override the logdir.
            It could be useful to use the same directory used by Ray Tune.

        """
        if external_logdir is not None:
            self.args.loggers.logdir = self._logdir = external_logdir
            self.console_log.warning(
                "external logdir {}".format(external_logdir))

            if self.args.loggers.debug:
                self.console_log.warning(
                    "Debug flag is disable when external logdir is used")

        # -- No loggers in test mode
        if self.mode == 'test':
            return None

        args_logger = self.args.loggers
        loggers = dict()

        safe_mkdirs(self._logdir, exist_ok=True)

        # -- Tensorboard is not defualt anymore
        if args_logger.tensorboard:
            loggers['tb'] = TensorBoardLogger(self._logdir)

        # -- Neptune
        if (get_maybe_missing_args(args_logger, 'neptune') is not None
                and len(args_logger.neptune.keys()) > 0):
            # TODO: because of api key and sesitive data,
            # neptune project should be per_project in a separate file

            # TODO: THIS THIS SHOULD BE DONE FOR EACH LEAF
            args_neptune = dict()
            for key, val in args_logger.neptune.items():
                if isinstance(val, ListConfig):
                    val = list(val)
                elif isinstance(val, DictConfig):
                    val = dict(val)
                args_neptune[key] = val

            # -- Recursively search for files or extensions
            if 'upload_source_files' in args_neptune.keys():
                source_files = [
                    str(path) for ext in args_neptune['upload_source_files']
                    for path in Path('./').rglob(ext)
                ]
                del args_neptune['upload_source_files']
            else:
                source_files = None

            loggers['neptune'] = NeptuneLogger(
                api_key=os.environ['NEPTUNE_API_TOKEN'],
                experiment_name=self.args.exp_name,
                params=flatten_dict(self.args),
                logger=self.console_log,
                upload_source_files=source_files,
                **(args_neptune))

        # Wrap loggers
        loggers = LoggerDict(loggers)
        return loggers