예제 #1
0
    def setUp(self):
        self.trainer = argparse.Namespace()
        self.config = OmegaConf.create(
            {
                "model": "simple",
                "model_config": {},
                "training": {
                    "lr_scheduler": True,
                    "lr_ratio": 0.1,
                    "lr_steps": [1, 2],
                    "use_warmup": False,
                },
            }
        )
        # Keep original copy for testing purposes
        self.trainer.config = deepcopy(self.config)
        registry.register("config", self.trainer.config)

        self.trainer.model = SimpleModule()
        self.trainer.val_dataset = NumbersDataset()

        self.trainer.optimizer = torch.optim.Adam(
            self.trainer.model.parameters(), lr=1e-01
        )
        self.trainer.lr_scheduler_callback = LRSchedulerCallback(
            self.config, self.trainer
        )
예제 #2
0
    def setUp(self):
        self.trainer = argparse.Namespace()
        self.config = load_yaml(os.path.join("configs", "defaults.yaml"))
        self.config = OmegaConf.merge(
            self.config,
            {
                "model": "simple",
                "model_config": {},
                "training": {
                    "lr_scheduler": True,
                    "lr_ratio": 0.1,
                    "lr_steps": [1, 2],
                    "use_warmup": False,
                },
            },
        )
        # Keep original copy for testing purposes
        self.trainer.config = deepcopy(self.config)
        registry.register("config", self.trainer.config)

        self.trainer.model = SimpleModule()
        self.trainer.val_loader = torch.utils.data.DataLoader(
            NumbersDataset(), batch_size=self.config.training.batch_size)

        self.trainer.optimizer = torch.optim.Adam(
            self.trainer.model.parameters(), lr=1e-01)
        self.trainer.lr_scheduler_callback = LRSchedulerCallback(
            self.config, self.trainer)
예제 #3
0
    def setUp(self):
        import argparse

        torch.manual_seed(1234)
        # An easy way to get a AttributeDict object
        self.trainer = argparse.Namespace()
        self.config = load_yaml(os.path.join("configs", "defaults.yaml"))
        self.config = OmegaConf.merge(
            self.config,
            {
                "model": "simple",
                "model_config": {},
                "checkpoint": {
                    "save_git_details": False,
                    "reset": {
                        "optimizer": False,
                        "counts": False,
                        "all": False,
                        "fp16_scaler": False,
                    },
                    "pretrained_state_mapping": {
                        "base_test": "base"
                    },
                    "max_to_keep": 5,
                },
                "config_override": None,
                "training": {
                    "checkpoint_interval": 1,
                    "early_stop": {
                        "criteria": "val/total_loss",
                        "minimize": True
                    },
                    "lr_scheduler": True,
                },
                "scheduler": {
                    "type": "multi_step",
                    "params": {
                        "use_warmup": False,
                        "lr_steps": [10, 20],
                        "lr_ratio": 0.1,
                        "warmup_factor": 1.0,
                    },
                },
            },
        )
        # Keep original copy for testing purposes
        self.trainer.config = deepcopy(self.config)

        self.trainer.model = SimpleModule()
        self.trainer.scaler = torch.cuda.amp.GradScaler()

        self.trainer.optimizer = torch.optim.Adam(
            self.trainer.model.parameters(), lr=1e-01)

        self.trainer.lr_scheduler_callback = LRSchedulerCallback(
            self.config, self.trainer)
예제 #4
0
    def configure_callbacks(self):
        self.checkpoint_callback = CheckpointCallback(self.config, self)
        self.early_stop_callback = EarlyStoppingCallback(self.config, self)
        self.logistics_callback = LogisticsCallback(self.config, self)
        self.lr_scheduler_callback = LRSchedulerCallback(self.config, self)

        # Add callbacks for execution during events
        self.callbacks.append(self.checkpoint_callback)
        self.callbacks.append(self.logistics_callback)
        self.callbacks.append(self.lr_scheduler_callback)
    def configure_callbacks(self):
        self.checkpoint_callback = CheckpointCallback(self.config, self)
        self.early_stop_callback = EarlyStoppingCallback(self.config, self)
        self.logistics_callback = LogisticsCallback(self.config, self)
        self.lr_scheduler_callback = LRSchedulerCallback(self.config, self)

        # Add callbacks for execution during events
        self.callbacks.append(self.lr_scheduler_callback)
        # checkpoint_callback needs to be called after lr_scheduler_callback so that
        # lr_scheduler_callback._scheduler.step() happens before saving checkpoints
        # (otherwise the saved last_epoch in scheduler would be wrong)
        self.callbacks.append(self.checkpoint_callback)
        self.callbacks.append(self.logistics_callback)
예제 #6
0
    def configure_callbacks(self):
        self.checkpoint_callback = CheckpointCallback(self.config, self)
        self.early_stop_callback = EarlyStoppingCallback(self.config, self)
        self.logistics_callback = LogisticsCallback(self.config, self)
        self.lr_scheduler_callback = LRSchedulerCallback(self.config, self)

        # Reset callbacks as they are class variables and would be shared between
        # multiple interactive shell calls to `run`
        self.callbacks = []
        # Add callbacks for execution during events
        self.callbacks.append(self.lr_scheduler_callback)
        # checkpoint_callback needs to be called after lr_scheduler_callback so that
        # lr_scheduler_callback._scheduler.step() happens before saving checkpoints
        # (otherwise the saved last_epoch in scheduler would be wrong)
        self.callbacks.append(self.checkpoint_callback)
        self.callbacks.append(self.logistics_callback)
예제 #7
0
    def setUp(self):
        self.trainer = argparse.Namespace()
        self.config = load_yaml(os.path.join("configs", "defaults.yaml"))
        self.config = OmegaConf.merge(
            self.config,
            {
                "model": "simple",
                "model_config": {},
                "training": {
                    "lr_scheduler": True,
                    "lr_ratio": 0.1,
                    "lr_steps": [1, 2],
                    "use_warmup": False,
                    "callbacks": [{
                        "type": "test_callback",
                        "params": {}
                    }],
                },
            },
        )
        # Keep original copy for testing purposes
        self.trainer.config = deepcopy(self.config)
        registry.register("config", self.trainer.config)

        model = SimpleModel(SimpleModel.Config())
        model.build()
        self.trainer.model = model
        self.trainer.val_loader = torch.utils.data.DataLoader(
            NumbersDataset(2), batch_size=self.config.training.batch_size)

        self.trainer.optimizer = torch.optim.Adam(
            self.trainer.model.parameters(), lr=1e-01)
        self.trainer.lr_scheduler_callback = LRSchedulerCallback(
            self.config, self.trainer)

        self.trainer.callbacks = []
        for callback in self.config.training.get("callbacks", []):
            callback_type = callback.type
            callback_param = callback.params
            callback_cls = registry.get_callback_class(callback_type)
            self.trainer.callbacks.append(
                callback_cls(self.trainer.config, self.trainer,
                             **callback_param))
예제 #8
0
    def test_lr_schedule_compared_to_mmf_is_same(self):
        config = get_config_with_defaults(
            {"training": {"max_updates": 8, "max_epochs": None, "lr_scheduler": True}}
        )

        mmf_trainer = get_mmf_trainer(config=config)
        mmf_trainer.lr_scheduler_callback = LRSchedulerCallback(config, mmf_trainer)
        mmf_trainer.callbacks.append(mmf_trainer.lr_scheduler_callback)
        mmf_trainer.on_update_end = mmf_trainer.lr_scheduler_callback.on_update_end
        mmf_trainer.evaluation_loop = MagicMock(return_value=(None, None))
        mmf_trainer.training_loop()

        with patch("mmf.trainers.lightning_trainer.get_mmf_env", return_value=""):
            config = self._get_config(max_steps=8, lr_scheduler=True)
            trainer = get_lightning_trainer(config=config)
            trainer.trainer.fit(trainer.model, trainer.data_module.train_loader)

            mmf_trainer.model.to(trainer.model.device)
            last_model_param1 = list(mmf_trainer.model.parameters())[-1]
            last_model_param2 = list(trainer.model.parameters())[-1]
            self.assertTrue(torch.allclose(last_model_param1, last_model_param2))
예제 #9
0
    def configure_callbacks(self):
        self.checkpoint_callback = CheckpointCallback(self.config, self)
        self.early_stop_callback = EarlyStoppingCallback(self.config, self)
        self.logistics_callback = LogisticsCallback(self.config, self)
        self.lr_scheduler_callback = LRSchedulerCallback(self.config, self)

        # Reset callbacks as they are class variables and would be shared between
        # multiple interactive shell calls to `run`
        self.callbacks = []
        # Add callbacks for execution during events
        self.callbacks.append(self.lr_scheduler_callback)
        # checkpoint_callback needs to be called after lr_scheduler_callback so that
        # lr_scheduler_callback._scheduler.step() happens before saving checkpoints
        # (otherwise the saved last_epoch in scheduler would be wrong)
        self.callbacks.append(self.checkpoint_callback)
        self.callbacks.append(self.logistics_callback)
        # Add all customized callbacks defined by users
        for callback in self.config.training.get("callbacks", []):
            callback_type = callback.type
            callback_param = callback.params
            callback_cls = registry.get_callback_class(callback_type)
            self.callbacks.append(
                callback_cls(self.config, self, **callback_param))
예제 #10
0
    def __init__(
        self,
        num_train_data,
        max_updates,
        max_epochs,
        config=None,
        optimizer=None,
        update_frequency=1,
        batch_size=1,
        batch_size_per_device=None,
        fp16=False,
        on_update_end_fn=None,
        scheduler_config=None,
        grad_clipping_config=None,
    ):
        if config is None:
            self.config = OmegaConf.create(
                {
                    "training": {
                        "detect_anomaly": False,
                        "evaluation_interval": 10000,
                        "update_frequency": update_frequency,
                        "fp16": fp16,
                        "batch_size": batch_size,
                        "batch_size_per_device": batch_size_per_device,
                    }
                }
            )
            self.training_config = self.config.training
        else:
            self.training_config = config.training
            self.config = config

        # Load batch size with custom config and cleanup
        original_config = registry.get("config")
        registry.register("config", self.config)
        batch_size = get_batch_size()
        registry.register("config", original_config)

        if max_updates is not None:
            self.training_config["max_updates"] = max_updates
        if max_epochs is not None:
            self.training_config["max_epochs"] = max_epochs

        self.model = SimpleModel({"in_dim": 1})
        self.model.build()
        if torch.cuda.is_available():
            self.model = self.model.cuda()
            self.device = "cuda"
        else:
            self.device = "cpu"
        self.distributed = False

        self.dataset_loader = MagicMock()
        self.dataset_loader.seed_sampler = MagicMock(return_value=None)
        self.dataset_loader.prepare_batch = lambda x: SampleList(x)
        if optimizer is None:
            self.optimizer = MagicMock()
            self.optimizer.step = MagicMock(return_value=None)
            self.optimizer.zero_grad = MagicMock(return_value=None)
        else:
            self.optimizer = optimizer

        if scheduler_config:
            config.training.lr_scheduler = True
            config.scheduler = scheduler_config
            self.lr_scheduler_callback = LRSchedulerCallback(config, self)
            self.callbacks.append(self.lr_scheduler_callback)
            on_update_end_fn = (
                on_update_end_fn
                if on_update_end_fn
                else self.lr_scheduler_callback.on_update_end
            )

        if grad_clipping_config:
            self.training_config.clip_gradients = True
            self.training_config.max_grad_l2_norm = grad_clipping_config[
                "max_grad_l2_norm"
            ]
            self.training_config.clip_norm_mode = grad_clipping_config["clip_norm_mode"]

        dataset = NumbersDataset(num_train_data)
        self.train_loader = torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=1,
            drop_last=False,
        )
        self.train_loader.current_dataset = dataset
        self.on_batch_start = MagicMock(return_value=None)
        self.on_update_start = MagicMock(return_value=None)
        self.logistics_callback = MagicMock(return_value=None)
        self.logistics_callback.log_interval = MagicMock(return_value=None)
        self.on_batch_end = MagicMock(return_value=None)
        self.on_update_end = (
            on_update_end_fn if on_update_end_fn else MagicMock(return_value=None)
        )
        self.meter = Meter()
        self.after_training_loop = MagicMock(return_value=None)
        self.on_validation_start = MagicMock(return_value=None)
        self.evaluation_loop = MagicMock(return_value=(None, None))
        self.scaler = torch.cuda.amp.GradScaler(enabled=False)
        self.val_loader = MagicMock(return_value=None)
        self.early_stop_callback = MagicMock(return_value=None)
        self.on_validation_end = MagicMock(return_value=None)
        self.metrics = MagicMock(return_value=None)
예제 #11
0
    def __init__(
        self,
        num_train_data,
        max_updates,
        max_epochs,
        config=None,
        optimizer=None,
        update_frequency=1,
        batch_size=1,
        batch_size_per_device=None,
        fp16=False,
        on_update_end_fn=None,
        scheduler_config=None,
        grad_clipping_config=None,
        tensorboard=False,
    ):
        if config is None:
            self.config = OmegaConf.create({
                "training": {
                    "detect_anomaly": False,
                    "evaluation_interval": 10000,
                    "update_frequency": update_frequency,
                    "fp16": fp16,
                    "batch_size": batch_size,
                    "batch_size_per_device": batch_size_per_device,
                    "tensorboard": tensorboard,
                }
            })
            self.training_config = self.config.training
        else:
            config.training.batch_size = batch_size
            config.training.fp16 = fp16
            config.training.update_frequency = update_frequency
            config.training.tensorboard = tensorboard
            self.training_config = config.training
            self.config = config

        registry.register("config", self.config)

        if max_updates is not None:
            self.training_config["max_updates"] = max_updates
        if max_epochs is not None:
            self.training_config["max_epochs"] = max_epochs
        self.model = SimpleModel({"in_dim": 1})
        self.model.build()
        if torch.cuda.is_available():
            self.model = self.model.cuda()
            self.device = "cuda"
        else:
            self.device = "cpu"
        self.distributed = False

        if optimizer is None:
            self.optimizer = MagicMock()
            self.optimizer.step = MagicMock(return_value=None)
            self.optimizer.zero_grad = MagicMock(return_value=None)
        else:
            self.optimizer = optimizer

        if scheduler_config:
            config.training.lr_scheduler = True
            config.scheduler = scheduler_config
            self.lr_scheduler_callback = LRSchedulerCallback(config, self)
            self.callbacks.append(self.lr_scheduler_callback)
            on_update_end_fn = (on_update_end_fn if on_update_end_fn else
                                self.lr_scheduler_callback.on_update_end)
        if grad_clipping_config:
            self.training_config.clip_gradients = True
            self.training_config.max_grad_l2_norm = grad_clipping_config[
                "max_grad_l2_norm"]
            self.training_config.clip_norm_mode = grad_clipping_config[
                "clip_norm_mode"]

        self.on_batch_start = MagicMock(return_value=None)
        self.on_update_start = MagicMock(return_value=None)
        self.logistics_callback = MagicMock(return_value=None)
        self.logistics_callback.log_interval = MagicMock(return_value=None)
        self.on_batch_end = MagicMock(return_value=None)
        self.on_update_end = (on_update_end_fn if on_update_end_fn else
                              MagicMock(return_value=None))
        self.after_training_loop = MagicMock(return_value=None)
        self.on_validation_start = MagicMock(return_value=None)
        self.scaler = torch.cuda.amp.GradScaler(enabled=False)
        self.early_stop_callback = MagicMock(return_value=None)
        self.on_validation_end = MagicMock(return_value=None)
        self.metrics = MagicMock(return_value={})

        self.num_data = num_train_data