def __init__(self, config, trainer): """ Attr: config(mmf_typings.DictConfig): Config for the callback trainer(Type[BaseTrainer]): Trainer object """ super().__init__(config, trainer) self._scheduler = None if self.training_config.lr_scheduler is True: self._scheduler = build_scheduler(trainer.optimizer, self.config)
def load_extras(self): self.writer.write("Torch version is: " + torch.__version__) self.checkpoint = Checkpoint(self) self.meter = Meter() self.training_config = self.config.training early_stop_criteria = self.training_config.early_stop.criteria early_stop_minimize = self.training_config.early_stop.minimize early_stop_enabled = self.training_config.early_stop.enabled early_stop_patience = self.training_config.early_stop.patience self.log_interval = self.training_config.log_interval self.evaluation_interval = self.training_config.evaluation_interval self.checkpoint_interval = self.training_config.checkpoint_interval self.max_updates = self.training_config.max_updates self.should_clip_gradients = self.training_config.clip_gradients self.max_epochs = self.training_config.max_epochs self.early_stopping = EarlyStopping( self.model, self.checkpoint, early_stop_criteria, patience=early_stop_patience, minimize=early_stop_minimize, should_stop=early_stop_enabled, ) self.current_epoch = 0 self.current_iteration = 0 self.num_updates = 0 self.checkpoint.load_state_dict() self.not_debug = self.training_config.logger_level != "debug" self.lr_scheduler = None if self.training_config.lr_scheduler is True: self.lr_scheduler = build_scheduler(self.optimizer, self.config) self.tb_writer = None if self.training_config.tensorboard: log_dir = self.writer.log_dir env_tb_logdir = get_mmf_env(key="tensorboard_logdir") if env_tb_logdir: log_dir = env_tb_logdir self.tb_writer = TensorboardLogger(log_dir, self.current_iteration)