def __init__(self, period, model, data_loader, num_iter): """ Args: period (int): the period this hook is run, or 0 to not run during training. The hook will always run in the end of training. model (nn.Module): a module whose all BN layers in training mode will be updated by precise BN. Note that user is responsible for ensuring the BN layers to be updated are in training mode when this hook is triggered. data_loader (iterable): it will produce data to be run by `model(data)`. num_iter (int): number of iterations used to compute the precise statistics. """ self._logger = logging.getLogger(__name__) if len(get_bn_modules(model)) == 0: self._logger.info( "PreciseBN is disabled because model does not contain BN layers in training mode." ) self._disabled = True return self._model = model self._data_loader = data_loader self._num_iter = num_iter self._period = period self._disabled = False self._data_iter = iter(self._data_loader)
def build_hooks(self): """ Build a list of default hooks, including timing, evaluation, checkpointing, lr scheduling, precise BN, writing events. Returns: list[HookBase]: """ cfg = self.cfg # cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN ret = [ hooks.OptimizationHook( accumulate_grad_steps=cfg.SOLVER.BATCH_SUBDIVISIONS, grad_clipper=None, mixed_precision=cfg.TRAINER.FP16.ENABLED), hooks.LRScheduler(self.optimizer, self.scheduler), hooks.IterationTimer(), hooks.PreciseBN( # Run at the same freq as (but before) evaluation. cfg.TEST.EVAL_PERIOD, self.model, # Build a new data loader to not affect training self.build_train_loader(cfg), cfg.TEST.PRECISE_BN.NUM_ITER, ) if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model) else None, ] # Do PreciseBN before checkpointer, because it updates the model and need to # be saved by checkpointer. # This is not always the best: if checkpointing has a different frequency, # some checkpoints may have more precise statistics than others. if comm.is_main_process(): ret.append( hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=self.max_iter, max_epoch=self.max_epoch)) def test_and_save_results(): self._last_eval_results = self.test(self.cfg, self.model) return self._last_eval_results # Do evaluation after checkpointer, because then if it fails, # we can use the saved checkpoint to debug. ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results)) if comm.is_main_process(): # Here the default print/log frequency of each writer is used. # run writers in the end, so that evaluation metrics are written ret.append( hooks.PeriodicWriter(self.build_writers(), period=self.cfg.GLOBAL.LOG_INTERVAL)) # Put `PeriodicDumpLog` after writers so that can dump all the files, # including the files generated by writers return ret
def build_hooks(self): """ Build a list of default hooks, including timing, evaluation, checkpointing, lr scheduling, precise BN, writing events. Returns: list[HookBase]: """ cfg = self.cfg cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN ret = [ hooks.IterationTimer(), hooks.LRScheduler(self.optimizer, self.scheduler), hooks.PreciseBN( # Run at the same freq as (but before) evaluation. cfg.TEST.EVAL_PERIOD, self.model, # Build a new data loader to not affect training self.build_train_loader(cfg), cfg.TEST.PRECISE_BN.NUM_ITER, ) if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model) else None, ] # Do PreciseBN before checkpointer, because it updates the model and need to # be saved by checkpointer. # This is not always the best: if checkpointing has a different frequency, # some checkpoints may have more precise statistics than others. if comm.is_main_process(): ret.append( hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD)) def test_and_save_results(): self._last_eval_results = self.test(self.cfg, self.model) return self._last_eval_results # Do evaluation after checkpointer, because then if it fails, # we can use the saved checkpoint to debug. ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results)) if comm.is_main_process(): # run writers in the end, so that evaluation metrics are written ret.append( hooks.PeriodicWriter(self.build_writers(), period=self.cfg.GLOBAL.LOG_INTERVAL)) return ret
def build_hooks(self): """ Build a list of default hooks, including timing, evaluation, checkpointing, lr scheduling, precise BN, writing events. Returns: list[HookBase]: """ cfg = self.cfg # cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN ret = [ hooks.OptimizationHook( accumulate_grad_steps=cfg.SOLVER.BATCH_SUBDIVISIONS, grad_clipper=None, mixed_precision=cfg.TRAINER.FP16.ENABLED), hooks.LRScheduler(self.optimizer, self.scheduler), hooks.IterationTimer(), hooks.PreciseBN( # Run at the same freq as (but before) evaluation. cfg.TEST.EVAL_PERIOD, self.model, # Build a new data loader to not affect training self.build_train_loader(cfg), cfg.TEST.PRECISE_BN.NUM_ITER, ) if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model) else None, ] # Do PreciseBN before checkpointer, because it updates the model and need to # be saved by checkpointer. # This is not always the best: if checkpointing has a different frequency, # some checkpoints may have more precise statistics than others. if comm.is_main_process(): ret.append( hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=self.max_iter, max_epoch=self.max_epoch)) def test_and_save_results(): self._last_eval_results = self.test(self.cfg, self.model) return self._last_eval_results def save_best_model(): key = cfg.TEST.SORT_BY assert hasattr( self, '_last_eval_results'), "Must run after test_and_save_results()" max_value = 0.0 if self._max_eval_results is None else flatten_results_dict( self._max_eval_results)[key] cur_value = flatten_results_dict(self._last_eval_results)[key] if cur_value >= max_value: self._max_eval_results = self._last_eval_results """ start save checkpoint """ self.checkpointer.save("model_best") return None # Do evaluation after checkpointer, because then if it fails, # we can use the saved checkpoint to debug. ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results)) if cfg.TEST.get('SORT_BY'): # save max metric checkpoint, which means early stopping self._max_eval_results = None ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, save_best_model)) if comm.is_main_process(): # Here the default print/log frequency of each writer is used. # run writers in the end, so that evaluation metrics are written ret.append( hooks.PeriodicWriter(self.build_writers(), period=self.window_size)) ret.append( hooks.PeriodicWriter(self.build_everystep_writers(), period=1)) return ret