def fit(self, max_epochs=None): self.event('training_start') if ('max_epochs' not in self.training) or (max_epochs is not None): self.training.max_epochs = max_epochs if self.training.max_epochs is None: bd.warn('max_epochs not set, training will continue indefinitely') while True: self.training.epoch_step = 0 for _ in self.iterate_data('training'): self.training.global_step += 1 self.training.epoch_step += 1 self.train() self.training_iteration() if self.training.should_stop_training: break # Get max epochs again in case it changed: me = self.training.max_epochs if self.training.should_stop_training or ( (me is not None) and (self.training.epoch == me) ): break self.training.epoch += 1 self.event('training_end')
def register(self, *component): if len(component) < 1: raise RuntimeError( 'Engine.register did not receive and components') if len(component) == 1: component = component[0] events = [] # Events from adding as a tuple or list in the form: # eng.register('foo', 'bar', func) # Or # eng.register(('foo', 'bar', func)) if isinstance(component, (list, tuple)): events += list(component[:-1]) component = component[-1] for x in events: _check_valid_eventname(x) # Add any events from decorator if hasattr(component, '_bd_engine_events'): events += component._bd_engine_events # Bound or staticmethods are registered on __func__ if hasattr(component, '__func__') and hasattr(component.__func__, '_bd_engine_events'): events += component.__func__._bd_engine_events events = list(set(events)) # If object or class is given, recurse and register members for key, val in inspect.getmembers(component): # This is to avoid doubly processing bound methods if key == '__func__': continue if hasattr(val, '_bd_engine_events'): self.register(val) # For objects with attach method, run attach if hasattr(component, 'attach'): component.attach(self) # Register all events for event in events: if not callable(component): raise RuntimeError( f'Attempted to register non callable {component} in Engine' ) action_dict = self._actions.get(event, {}) action = Action(event, component) already_exists = action.id in action_dict action_dict[action.id] = action self._actions[event] = action_dict if already_exists: bd.warn( f'Compoenent {component} already defined for event {event}.' ) # else: # bd.log(f'Registered {component} for event "{event}" on engine {self}.') return self
def setup_criterion_checkpoints(self): if not self.criteria: bd.warn( 'Attempted to setup criterion checkpoints before attaching optimizers.' ) return None else: return [dict(state_key=f'criteria.{k}') for k in self.criteria]
def setup_model_checkpoints(self): if not self.models: bd.warn( 'Attempted to setup model checkpoints before attaching models.' ) return None else: return [dict(state_key=f'models.{k}') for k in self.models]
def setup_optimizer_checkpoints(self): if not self.optimizers: bd.warn( 'Attempted to setup optimizer checkpoints before attaching optimizers.' ) return None else: return [dict(state_key=f'optimizers.{k}') for k in self.optimizers]
def _warn_override(self, active_groups, arg_name, basename, count, line): line = line.split('#')[0].strip() if active_groups.is_default: group_warn = 'in default group' else: group_warn = 'in groups: (' + ', '.join(active_groups) + ')' line_info = f'File \'{basename}\', line {count}: \'{line}\'.' over = 'Overwriting...' bd.warn(f'{arg_name} defined twice {group_warn}. {line_info} {over}')
def write(self, frame): if (self.writer is None) or (self.context_count <= 0): raise RuntimeError( 'Can only write Video inside a managed context: E.g. use like:' '\nwith writer:\n\twriter.write(frame)\n') # Reset the writer to start from the beginning if self.writer.isOpened(): self.writer.write(frame) else: bd.warn('Writer not opened.')
def attach_data(self, mode, force=False): if force or (f'data.{mode}' not in self): bd.log(f'Attaching {mode} dataset.') fn_name = f'setup_{mode}_data' if not hasattr(self, fn_name): bd.warn( f'Could not find setup function for {mode} dataset. Will not attach to engine.' ) return data_fn = getattr(self, fn_name) self.data[mode] = data_fn()
def attach_criteria(self, force=False): if force or (not self.criteria): if not hasattr(self, 'setup_criteria'): bd.warn( 'Could not find setup function for criteria. Will not attach to engine.' ) return crits = self.setup_criteria() if isinstance(crits, Mapping): self.criteria = crits else: self.criteria.main = crits
def attach_models(self, force=False): if force or (not self.models): if not hasattr(self, 'setup_models'): bd.warn( 'Could not find setup function for models. Will not attach to engine.' ) return models = self.setup_models() if isinstance(models, Mapping): self.models = models else: self.models.main = models
def __next__(self): if (self.capture is not None) and self.capture.isOpened(): ret, frame = self.capture.read() if not ret: raise StopIteration else: return frame else: bd.warn( 'Something went wrong with video iteration. Perhaps context manager is inactive.' ) StopIteration
def attach_optimizers(self, force=False): if force or (not self.optimizers): if not hasattr(self, 'setup_optimizers'): bd.warn( 'Could not find setup function for optimizers. Will not attach to engine.' ) return optims = self.setup_optimizers() if isinstance(optims, Mapping): self.optimizers = optims else: self.optimizers.main = optims if not isinstance(optims, optim.Optimizer): bd.warn('Optimizer is not of type torch.optim.Optimizer')
def attach_loggers(self, force=False): if force or (not self.loggers): if not hasattr(self, 'setup_loggers'): bd.warn( 'Could not find setup function for loggers. Will not attach to engine.' ) return loggers = self.setup_loggers() if isinstance(loggers, Mapping): self.loggers = { k: _make_logger(logger) for k, logger in loggers.items() } elif isinstance(loggers, Sequence): self.loggers = { f'logger_{i}': _make_logger(logger) for i, logger in enumerate(loggers) if logger } else: raise RuntimeError('setup_loggers() must return a dictionary')
def _use_argument_categories(self, *arg_cats, empty=True): if self._prv['done_setup']: raise RuntimeError('Attempted to change configuration after setup.') if empty: self._prv['argparse_arguments'] = [] self._prv['argument_prebaked_categories'] = set() if 'core' in arg_cats: self._prv['has_core_config'] = True existing_flags = [x['flag'][2:] for x in self._prv['argparse_arguments']] for arg in arg_cats: if arg in DEFAULT_CFG_DICT: new_args = deepcopy(DEFAULT_CFG_DICT[arg]) new_arg_flags = set(x['flag'][2:] for x in new_args) for ef in existing_flags: if ef in new_arg_flags: bd.warn(f'Overriding previously defined argument: {ef}') self._prv['argparse_arguments'] += new_args self._prv['argument_prebaked_categories'].add(arg) else: raise ValueError(f'Unknown configuration: {arg}')
def register(self, loss, name, weight=1.0): if name.startswith('_'): raise ValueError('Loss name can not start with "_".') if name == 'total': raise ValueError('Loss name can not be "total".') setattr(self, name, loss) # This is to get the built loss if it's provided in a "magic" context loss = getattr(self, name) weight = float(weight) if name in self._idx: bd.warn(f'Replacing {name} in MultiLoss') idx = self._idx[name] self._losses[idx] = loss self._names[idx] = name self._weights[idx] = weight else: self._losses.append(loss) self._names.append(name) self._weights.append(weight) self._idx[name] = len(self._names) - 1 return self
def add_arguments(self, *args, override=False): if self._prv['done_setup']: raise RuntimeError('Attempted to add argument after setup.') current_flags = { x['flag'][2:]: i for i, x in enumerate(self._prv['argparse_arguments']) } for arg in args: arg_name = arg['flag'] if arg_name.startswith('--'): arg_name = arg_name[2:] else: # Add the dashes if missing arg['flag'] = f'--{arg_name}' if arg_name.startswith('-') or (not _is_valid_argname(arg_name)): raise RuntimeError(f'Argument {arg_name} is invalid.') if arg_name in CORE_ARGNAMES + AUTOMATIC_ARGS: msg = f'Argument \'{arg_name}\' is in the core arguments.' if override: msg = msg[:-1] + ' and can not be overriden.' raise RuntimeError(msg) # Don't allow setting things in dir(self) if arg_name in dir(self): raise RuntimeError('Can not set {} as an argument.') if arg_name in current_flags: if override: bd.warn(f'Overriding {arg_name} for argparse') del self._prv['argparse_arguments'][current_flags[arg_name]] else: raise RuntimeError( f'Argument \'{arg_name}\' already defined. ' f'Existing flags can be overridden by ' f'passing override=True to add_args' ) self._prv['argparse_arguments'].append(arg)