コード例 #1
0
    def _get_mmf_trainer(self,
                         ckpt_config=None,
                         model_config=None,
                         seed=2,
                         max_updates=6):
        config = self._get_ckpt_config(ckpt_config=ckpt_config,
                                       max_steps=max_updates)

        load_model_from_config = False
        if model_config:
            config.model_config = model_config
            config.model = list(model_config.keys())[0]
            load_model_from_config = True

        mmf_trainer = get_mmf_trainer(
            config=config,
            load_model_from_config=load_model_from_config,
            seed=seed)
        mmf_trainer.load_metrics()

        checkpoint_callback = CheckpointCallback(config, mmf_trainer)
        mmf_trainer.on_init_start = checkpoint_callback.on_init_start
        mmf_trainer.on_train_end = checkpoint_callback.on_train_end
        mmf_trainer.callbacks.append(checkpoint_callback)
        mmf_trainer.checkpoint_callback = checkpoint_callback

        mmf_trainer.lr_scheduler_callback = None

        early_stop_callback = EarlyStoppingCallback(config, mmf_trainer)
        mmf_trainer.early_stop_callback = early_stop_callback
        mmf_trainer.callbacks.append(early_stop_callback)

        return mmf_trainer
コード例 #2
0
ファイル: mmf_trainer.py プロジェクト: slbinilkumar/mmf
    def configure_callbacks(self):
        self.checkpoint_callback = CheckpointCallback(self.config, self)
        self.early_stop_callback = EarlyStoppingCallback(self.config, self)
        self.logistics_callback = LogisticsCallback(self.config, self)
        self.lr_scheduler_callback = LRSchedulerCallback(self.config, self)

        # Add callbacks for execution during events
        self.callbacks.append(self.checkpoint_callback)
        self.callbacks.append(self.logistics_callback)
        self.callbacks.append(self.lr_scheduler_callback)
コード例 #3
0
ファイル: test_checkpoint.py プロジェクト: srag21/mmf
 def _init_early_stopping(self, checkpoint):
     self.trainer.num_updates = 0
     self.trainer.current_iteration = 0
     self.trainer.current_epoch = 0
     self.trainer.checkpoint_callback = CheckpointCallback(
         self.config, self.trainer)
     self.trainer.early_stop_callback = EarlyStoppingCallback(
         self.config, self.trainer)
     self.trainer.early_stop_callback.early_stopping.best_monitored_iteration = 1000
     self.trainer.early_stop_callback.early_stopping.best_monitored_update = 1000
     self.trainer.early_stop_callback.early_stopping.best_monitored_value = 0.1
     self.trainer.current_epoch = 2
コード例 #4
0
    def configure_callbacks(self):
        self.checkpoint_callback = CheckpointCallback(self.config, self)
        self.early_stop_callback = EarlyStoppingCallback(self.config, self)
        self.logistics_callback = LogisticsCallback(self.config, self)
        self.lr_scheduler_callback = LRSchedulerCallback(self.config, self)

        # Add callbacks for execution during events
        self.callbacks.append(self.lr_scheduler_callback)
        # checkpoint_callback needs to be called after lr_scheduler_callback so that
        # lr_scheduler_callback._scheduler.step() happens before saving checkpoints
        # (otherwise the saved last_epoch in scheduler would be wrong)
        self.callbacks.append(self.checkpoint_callback)
        self.callbacks.append(self.logistics_callback)
コード例 #5
0
    def configure_callbacks(self):
        self.checkpoint_callback = CheckpointCallback(self.config, self)
        self.early_stop_callback = EarlyStoppingCallback(self.config, self)
        self.logistics_callback = LogisticsCallback(self.config, self)
        self.lr_scheduler_callback = LRSchedulerCallback(self.config, self)

        # Reset callbacks as they are class variables and would be shared between
        # multiple interactive shell calls to `run`
        self.callbacks = []
        # Add callbacks for execution during events
        self.callbacks.append(self.lr_scheduler_callback)
        # checkpoint_callback needs to be called after lr_scheduler_callback so that
        # lr_scheduler_callback._scheduler.step() happens before saving checkpoints
        # (otherwise the saved last_epoch in scheduler would be wrong)
        self.callbacks.append(self.checkpoint_callback)
        self.callbacks.append(self.logistics_callback)
コード例 #6
0
def ready_trainer(trainer):
    from mmf.common.registry import registry
    from mmf.utils.logger import Logger, TensorboardLogger
    trainer.run_type = trainer.config.get("run_type", "train")
    writer = registry.get("writer", no_warning=True)
    if writer:
        trainer.writer = writer
    else:
        trainer.writer = Logger(trainer.config)
        registry.register("writer", trainer.writer)

    trainer.configure_device()
    trainer.configure_seed()
    trainer.load_model()
    from mmf.trainers.callbacks.checkpoint import CheckpointCallback
    from mmf.trainers.callbacks.early_stopping import EarlyStoppingCallback
    trainer.checkpoint_callback = CheckpointCallback(trainer.config, trainer)
    trainer.early_stop_callback = EarlyStoppingCallback(
        trainer.config, trainer)
    trainer.callbacks.append(trainer.checkpoint_callback)
    trainer.on_init_start()
コード例 #7
0
ファイル: mmf_trainer.py プロジェクト: facebookresearch/mmf
    def configure_callbacks(self):
        self.checkpoint_callback = CheckpointCallback(self.config, self)
        self.early_stop_callback = EarlyStoppingCallback(self.config, self)
        self.logistics_callback = LogisticsCallback(self.config, self)
        self.lr_scheduler_callback = LRSchedulerCallback(self.config, self)

        # Reset callbacks as they are class variables and would be shared between
        # multiple interactive shell calls to `run`
        self.callbacks = []
        # Add callbacks for execution during events
        self.callbacks.append(self.lr_scheduler_callback)
        # checkpoint_callback needs to be called after lr_scheduler_callback so that
        # lr_scheduler_callback._scheduler.step() happens before saving checkpoints
        # (otherwise the saved last_epoch in scheduler would be wrong)
        self.callbacks.append(self.checkpoint_callback)
        self.callbacks.append(self.logistics_callback)
        # Add all customized callbacks defined by users
        for callback in self.config.training.get("callbacks", []):
            callback_type = callback.type
            callback_param = callback.params
            callback_cls = registry.get_callback_class(callback_type)
            self.callbacks.append(
                callback_cls(self.config, self, **callback_param))