def _get_mmf_trainer(self, ckpt_config=None, model_config=None, seed=2, max_updates=6): config = self._get_ckpt_config(ckpt_config=ckpt_config, max_steps=max_updates) load_model_from_config = False if model_config: config.model_config = model_config config.model = list(model_config.keys())[0] load_model_from_config = True mmf_trainer = get_mmf_trainer( config=config, load_model_from_config=load_model_from_config, seed=seed) mmf_trainer.load_metrics() checkpoint_callback = CheckpointCallback(config, mmf_trainer) mmf_trainer.on_init_start = checkpoint_callback.on_init_start mmf_trainer.on_train_end = checkpoint_callback.on_train_end mmf_trainer.callbacks.append(checkpoint_callback) mmf_trainer.checkpoint_callback = checkpoint_callback mmf_trainer.lr_scheduler_callback = None early_stop_callback = EarlyStoppingCallback(config, mmf_trainer) mmf_trainer.early_stop_callback = early_stop_callback mmf_trainer.callbacks.append(early_stop_callback) return mmf_trainer
def configure_callbacks(self): self.checkpoint_callback = CheckpointCallback(self.config, self) self.early_stop_callback = EarlyStoppingCallback(self.config, self) self.logistics_callback = LogisticsCallback(self.config, self) self.lr_scheduler_callback = LRSchedulerCallback(self.config, self) # Add callbacks for execution during events self.callbacks.append(self.checkpoint_callback) self.callbacks.append(self.logistics_callback) self.callbacks.append(self.lr_scheduler_callback)
def _init_early_stopping(self, checkpoint): self.trainer.num_updates = 0 self.trainer.current_iteration = 0 self.trainer.current_epoch = 0 self.trainer.checkpoint_callback = CheckpointCallback( self.config, self.trainer) self.trainer.early_stop_callback = EarlyStoppingCallback( self.config, self.trainer) self.trainer.early_stop_callback.early_stopping.best_monitored_iteration = 1000 self.trainer.early_stop_callback.early_stopping.best_monitored_update = 1000 self.trainer.early_stop_callback.early_stopping.best_monitored_value = 0.1 self.trainer.current_epoch = 2
def configure_callbacks(self): self.checkpoint_callback = CheckpointCallback(self.config, self) self.early_stop_callback = EarlyStoppingCallback(self.config, self) self.logistics_callback = LogisticsCallback(self.config, self) self.lr_scheduler_callback = LRSchedulerCallback(self.config, self) # Add callbacks for execution during events self.callbacks.append(self.lr_scheduler_callback) # checkpoint_callback needs to be called after lr_scheduler_callback so that # lr_scheduler_callback._scheduler.step() happens before saving checkpoints # (otherwise the saved last_epoch in scheduler would be wrong) self.callbacks.append(self.checkpoint_callback) self.callbacks.append(self.logistics_callback)
def configure_callbacks(self): self.checkpoint_callback = CheckpointCallback(self.config, self) self.early_stop_callback = EarlyStoppingCallback(self.config, self) self.logistics_callback = LogisticsCallback(self.config, self) self.lr_scheduler_callback = LRSchedulerCallback(self.config, self) # Reset callbacks as they are class variables and would be shared between # multiple interactive shell calls to `run` self.callbacks = [] # Add callbacks for execution during events self.callbacks.append(self.lr_scheduler_callback) # checkpoint_callback needs to be called after lr_scheduler_callback so that # lr_scheduler_callback._scheduler.step() happens before saving checkpoints # (otherwise the saved last_epoch in scheduler would be wrong) self.callbacks.append(self.checkpoint_callback) self.callbacks.append(self.logistics_callback)
def ready_trainer(trainer): from mmf.common.registry import registry from mmf.utils.logger import Logger, TensorboardLogger trainer.run_type = trainer.config.get("run_type", "train") writer = registry.get("writer", no_warning=True) if writer: trainer.writer = writer else: trainer.writer = Logger(trainer.config) registry.register("writer", trainer.writer) trainer.configure_device() trainer.configure_seed() trainer.load_model() from mmf.trainers.callbacks.checkpoint import CheckpointCallback from mmf.trainers.callbacks.early_stopping import EarlyStoppingCallback trainer.checkpoint_callback = CheckpointCallback(trainer.config, trainer) trainer.early_stop_callback = EarlyStoppingCallback( trainer.config, trainer) trainer.callbacks.append(trainer.checkpoint_callback) trainer.on_init_start()
def configure_callbacks(self): self.checkpoint_callback = CheckpointCallback(self.config, self) self.early_stop_callback = EarlyStoppingCallback(self.config, self) self.logistics_callback = LogisticsCallback(self.config, self) self.lr_scheduler_callback = LRSchedulerCallback(self.config, self) # Reset callbacks as they are class variables and would be shared between # multiple interactive shell calls to `run` self.callbacks = [] # Add callbacks for execution during events self.callbacks.append(self.lr_scheduler_callback) # checkpoint_callback needs to be called after lr_scheduler_callback so that # lr_scheduler_callback._scheduler.step() happens before saving checkpoints # (otherwise the saved last_epoch in scheduler would be wrong) self.callbacks.append(self.checkpoint_callback) self.callbacks.append(self.logistics_callback) # Add all customized callbacks defined by users for callback in self.config.training.get("callbacks", []): callback_type = callback.type callback_param = callback.params callback_cls = registry.get_callback_class(callback_type) self.callbacks.append( callback_cls(self.config, self, **callback_param))