Exemple #1
0
    def fit(self, max_epochs=None):
        self.event('training_start')
        if ('max_epochs' not in self.training) or (max_epochs is not None):
            self.training.max_epochs = max_epochs

        if self.training.max_epochs is None:
            bd.warn('max_epochs not set, training will continue indefinitely')

        while True:
            self.training.epoch_step = 0
            for _ in self.iterate_data('training'):
                self.training.global_step += 1
                self.training.epoch_step += 1
                self.train()
                self.training_iteration()
                if self.training.should_stop_training:
                    break
            # Get max epochs again in case it changed:
            me = self.training.max_epochs
            if self.training.should_stop_training or (
                (me is not None) and (self.training.epoch == me)
            ):
                break
            self.training.epoch += 1
        self.event('training_end')
Exemple #2
0
    def register(self, *component):
        if len(component) < 1:
            raise RuntimeError(
                'Engine.register did not receive and components')
        if len(component) == 1:
            component = component[0]

        events = []
        # Events from adding as a tuple or list in the form:
        #     eng.register('foo', 'bar', func)
        # Or
        #     eng.register(('foo', 'bar', func))
        if isinstance(component, (list, tuple)):
            events += list(component[:-1])
            component = component[-1]
        for x in events:
            _check_valid_eventname(x)

        # Add any events from decorator
        if hasattr(component, '_bd_engine_events'):
            events += component._bd_engine_events

        # Bound or staticmethods are registered on __func__
        if hasattr(component, '__func__') and hasattr(component.__func__,
                                                      '_bd_engine_events'):
            events += component.__func__._bd_engine_events

        events = list(set(events))

        # If object or class is given, recurse and register members
        for key, val in inspect.getmembers(component):
            # This is to avoid doubly processing bound methods
            if key == '__func__':
                continue
            if hasattr(val, '_bd_engine_events'):
                self.register(val)

        # For objects with attach method, run attach
        if hasattr(component, 'attach'):
            component.attach(self)

        # Register all events
        for event in events:
            if not callable(component):
                raise RuntimeError(
                    f'Attempted to register non callable {component} in Engine'
                )
            action_dict = self._actions.get(event, {})
            action = Action(event, component)
            already_exists = action.id in action_dict
            action_dict[action.id] = action
            self._actions[event] = action_dict
            if already_exists:
                bd.warn(
                    f'Compoenent {component} already defined for event {event}.'
                )
            #  else:
            #      bd.log(f'Registered {component} for event "{event}" on engine {self}.')
        return self
Exemple #3
0
 def setup_criterion_checkpoints(self):
     if not self.criteria:
         bd.warn(
             'Attempted to setup criterion checkpoints before attaching optimizers.'
         )
         return None
     else:
         return [dict(state_key=f'criteria.{k}') for k in self.criteria]
Exemple #4
0
 def setup_model_checkpoints(self):
     if not self.models:
         bd.warn(
             'Attempted to setup model checkpoints before attaching models.'
         )
         return None
     else:
         return [dict(state_key=f'models.{k}') for k in self.models]
Exemple #5
0
 def setup_optimizer_checkpoints(self):
     if not self.optimizers:
         bd.warn(
             'Attempted to setup optimizer checkpoints before attaching optimizers.'
         )
         return None
     else:
         return [dict(state_key=f'optimizers.{k}') for k in self.optimizers]
 def _warn_override(self, active_groups, arg_name, basename, count, line):
     line = line.split('#')[0].strip()
     if active_groups.is_default:
         group_warn = 'in default group'
     else:
         group_warn = 'in groups: (' + ', '.join(active_groups) + ')'
     line_info = f'File \'{basename}\', line {count}: \'{line}\'.'
     over = 'Overwriting...'
     bd.warn(f'{arg_name} defined twice {group_warn}. {line_info} {over}')
Exemple #7
0
 def write(self, frame):
     if (self.writer is None) or (self.context_count <= 0):
         raise RuntimeError(
             'Can only write Video inside a managed context: E.g. use like:'
             '\nwith writer:\n\twriter.write(frame)\n')
     # Reset the writer to start from the beginning
     if self.writer.isOpened():
         self.writer.write(frame)
     else:
         bd.warn('Writer not opened.')
Exemple #8
0
 def attach_data(self, mode, force=False):
     if force or (f'data.{mode}' not in self):
         bd.log(f'Attaching {mode} dataset.')
         fn_name = f'setup_{mode}_data'
         if not hasattr(self, fn_name):
             bd.warn(
                 f'Could not find setup function for {mode} dataset. Will not attach to engine.'
             )
             return
         data_fn = getattr(self, fn_name)
         self.data[mode] = data_fn()
Exemple #9
0
 def attach_criteria(self, force=False):
     if force or (not self.criteria):
         if not hasattr(self, 'setup_criteria'):
             bd.warn(
                 'Could not find setup function for criteria. Will not attach to engine.'
             )
             return
         crits = self.setup_criteria()
         if isinstance(crits, Mapping):
             self.criteria = crits
         else:
             self.criteria.main = crits
Exemple #10
0
 def attach_models(self, force=False):
     if force or (not self.models):
         if not hasattr(self, 'setup_models'):
             bd.warn(
                 'Could not find setup function for models. Will not attach to engine.'
             )
             return
         models = self.setup_models()
         if isinstance(models, Mapping):
             self.models = models
         else:
             self.models.main = models
Exemple #11
0
 def __next__(self):
     if (self.capture is not None) and self.capture.isOpened():
         ret, frame = self.capture.read()
         if not ret:
             raise StopIteration
         else:
             return frame
     else:
         bd.warn(
             'Something went wrong with video iteration. Perhaps context manager is inactive.'
         )
         StopIteration
Exemple #12
0
 def attach_optimizers(self, force=False):
     if force or (not self.optimizers):
         if not hasattr(self, 'setup_optimizers'):
             bd.warn(
                 'Could not find setup function for optimizers. Will not attach to engine.'
             )
             return
         optims = self.setup_optimizers()
         if isinstance(optims, Mapping):
             self.optimizers = optims
         else:
             self.optimizers.main = optims
             if not isinstance(optims, optim.Optimizer):
                 bd.warn('Optimizer is not of type torch.optim.Optimizer')
Exemple #13
0
 def attach_loggers(self, force=False):
     if force or (not self.loggers):
         if not hasattr(self, 'setup_loggers'):
             bd.warn(
                 'Could not find setup function for loggers. Will not attach to engine.'
             )
             return
         loggers = self.setup_loggers()
         if isinstance(loggers, Mapping):
             self.loggers = {
                 k: _make_logger(logger)
                 for k, logger in loggers.items()
             }
         elif isinstance(loggers, Sequence):
             self.loggers = {
                 f'logger_{i}': _make_logger(logger)
                 for i, logger in enumerate(loggers) if logger
             }
         else:
             raise RuntimeError('setup_loggers() must return a dictionary')
Exemple #14
0
 def _use_argument_categories(self, *arg_cats, empty=True):
     if self._prv['done_setup']:
         raise RuntimeError('Attempted to change configuration after setup.')
     if empty:
         self._prv['argparse_arguments'] = []
         self._prv['argument_prebaked_categories'] = set()
     if 'core' in arg_cats:
         self._prv['has_core_config'] = True
     existing_flags = [x['flag'][2:] for x in self._prv['argparse_arguments']]
     for arg in arg_cats:
         if arg in DEFAULT_CFG_DICT:
             new_args = deepcopy(DEFAULT_CFG_DICT[arg])
             new_arg_flags = set(x['flag'][2:] for x in new_args)
             for ef in existing_flags:
                 if ef in new_arg_flags:
                     bd.warn(f'Overriding previously defined argument: {ef}')
             self._prv['argparse_arguments'] += new_args
             self._prv['argument_prebaked_categories'].add(arg)
         else:
             raise ValueError(f'Unknown configuration: {arg}')
Exemple #15
0
 def register(self, loss, name, weight=1.0):
     if name.startswith('_'):
         raise ValueError('Loss name can not start with "_".')
     if name == 'total':
         raise ValueError('Loss name can not be "total".')
     setattr(self, name, loss)
     # This is to get the built loss if it's provided in a "magic" context
     loss = getattr(self, name)
     weight = float(weight)
     if name in self._idx:
         bd.warn(f'Replacing {name} in MultiLoss')
         idx = self._idx[name]
         self._losses[idx] = loss
         self._names[idx] = name
         self._weights[idx] = weight
     else:
         self._losses.append(loss)
         self._names.append(name)
         self._weights.append(weight)
         self._idx[name] = len(self._names) - 1
     return self
Exemple #16
0
    def add_arguments(self, *args, override=False):
        if self._prv['done_setup']:
            raise RuntimeError('Attempted to add argument after setup.')
        current_flags = {
            x['flag'][2:]: i for i, x in enumerate(self._prv['argparse_arguments'])
        }
        for arg in args:
            arg_name = arg['flag']
            if arg_name.startswith('--'):
                arg_name = arg_name[2:]
            else:
                # Add the dashes if missing
                arg['flag'] = f'--{arg_name}'
            if arg_name.startswith('-') or (not _is_valid_argname(arg_name)):
                raise RuntimeError(f'Argument {arg_name} is invalid.')

            if arg_name in CORE_ARGNAMES + AUTOMATIC_ARGS:
                msg = f'Argument \'{arg_name}\' is in the core arguments.'
                if override:
                    msg = msg[:-1] + ' and can not be overriden.'
                raise RuntimeError(msg)

            # Don't allow setting things in dir(self)
            if arg_name in dir(self):
                raise RuntimeError('Can not set {} as an argument.')

            if arg_name in current_flags:
                if override:
                    bd.warn(f'Overriding {arg_name} for argparse')
                    del self._prv['argparse_arguments'][current_flags[arg_name]]
                else:
                    raise RuntimeError(
                        f'Argument \'{arg_name}\' already defined. '
                        f'Existing flags can be overridden by '
                        f'passing override=True to add_args'
                    )
            self._prv['argparse_arguments'].append(arg)