Exemple #1
0
    def _update_and_create_report(
        self,
        batch: Dict,
        batch_idx: int,
        step_output: Dict,
        pl_module: LightningModule,
        combined_report: Report = None,
        update_meter: Meter = None,
    ):
        report = Report(batch, step_output)

        if update_meter:
            update_meter.update_from_report(report)

        should_accumulate = not (
            batch_idx % self.trainer_config.accumulate_grad_batches == 0)

        final_report = report
        if should_accumulate and combined_report is not None:
            combined_report.accumulate_tensor_fields_and_loss(
                report, pl_module.metrics.required_params)
            combined_report.batch_size += report.batch_size
            final_report = combined_report

        return final_report
Exemple #2
0
def build_meters(run_type: str) -> List[Meter]:
    train_meter, val_meter, test_meter = None, None, None
    if "train" in run_type:
        train_meter = Meter()
        # val_meter used for validation after training loop
        val_meter = Meter()
    elif "val" in run_type or "inference" in run_type:
        val_meter = Meter()

    if "test" in run_type:
        test_meter = Meter()

    return train_meter, val_meter, test_meter
Exemple #3
0
    def test_meter_update_from_report(self):
        meter = Meter()
        prepared_batch = SampleList(
            {"targets": torch.tensor([1, 2, 3, 4]), "dataset_type": "val"}
        )
        for idx in range(5):
            model_output = {
                "scores": torch.tensor([0, 1, 2, 3]),
                "losses": {"loss": float(idx)},
            }
            report = Report(prepared_batch, model_output)
            meter.update_from_report(report)

        self.assertEqual(meter.loss.global_avg, 2.0)
        self.assertEqual(meter.loss.avg, 2.0)
Exemple #4
0
    def evaluate(self, loader, use_tqdm=False, single_batch=False):
        meter = Meter()

        with torch.no_grad():
            self.model.eval()
            disable_tqdm = not use_tqdm or not is_master()
            combined_report = None

            for batch in tqdm(loader, disable=disable_tqdm):
                report = self._forward_pass(batch)
                self._update_meter(report, meter)

                # accumulate necessary params for metric calculation
                if combined_report is None:
                    combined_report = report
                else:
                    combined_report.accumulate_tensor_fields(
                        report, self.metrics.required_params)
                    combined_report.batch_size += report.batch_size

                if single_batch is True:
                    break

            combined_report.metrics = self.metrics(combined_report,
                                                   combined_report)
            self._update_meter(combined_report, meter, eval_mode=True)

            self.model.train()

        return combined_report, meter
    def evaluation_loop(
            self,
            loader,
            use_tqdm: bool = False,
            single_batch: bool = False) -> Tuple[Dict[str, Any], Type[Meter]]:
        meter = Meter()

        with torch.no_grad():
            self.model.eval()
            disable_tqdm = not use_tqdm or not is_master()
            combined_report = None

            for batch in tqdm.tqdm(loader, disable=disable_tqdm):
                report = self._forward(batch)
                self.update_meter(report, meter)

                # accumulate necessary params for metric calculation
                if combined_report is None:
                    combined_report = report
                else:
                    combined_report.accumulate_tensor_fields(
                        report, self.metrics.required_params)
                    combined_report.batch_size += report.batch_size

                if single_batch is True:
                    break

            combined_report.metrics = self.metrics(combined_report,
                                                   combined_report)
            self.update_meter(combined_report, meter, eval_mode=True)

            # enable train mode again
            self.model.train()

        return combined_report, meter
Exemple #6
0
class TrainerReportingMixin(ABC):
    meter: Type[Meter] = Meter()

    def update_meter(self,
                     report: Dict[str, Any],
                     meter: Type[Meter] = None,
                     eval_mode: bool = False) -> None:
        if meter is None:
            meter = self.meter

        if hasattr(report, "metrics"):
            metrics_dict = report.metrics
            reduced_metrics_dict = reduce_dict(metrics_dict)

        if not eval_mode:
            loss_dict = report.losses
            reduced_loss_dict = reduce_dict(loss_dict)

        with torch.no_grad():
            # Add metrics to meter only when mode is `eval`
            meter_update_dict = {}
            if not eval_mode:
                loss_key = report.dataset_type + "/total_loss"
                reduced_loss = sum(
                    [loss.mean() for loss in reduced_loss_dict.values()])
                if hasattr(reduced_loss, "item"):
                    reduced_loss = reduced_loss.item()

                registry.register(loss_key, reduced_loss)
                meter_update_dict.update({loss_key: reduced_loss})
                meter_update_dict.update(reduced_loss_dict)
            if hasattr(report, "metrics"):
                meter_update_dict.update(reduced_metrics_dict)
            meter.update(meter_update_dict, report.batch_size)
Exemple #7
0
    def evaluation_loop(
            self,
            dataset_type: str,
            use_tqdm: bool = False,
            single_batch: bool = False) -> Tuple[Dict[str, Any], Type[Meter]]:
        meter = Meter()
        reporter = self.dataset_loader.get_test_reporter(dataset_type)

        with torch.no_grad():
            self.model.eval()
            disable_tqdm = not use_tqdm or not is_master()

            while reporter.next_dataset(flush_report=False):
                dataloader = reporter.get_dataloader()

                combined_report = None
                for batch in tqdm.tqdm(dataloader, disable=disable_tqdm):
                    prepared_batch = reporter.prepare_batch(batch)
                    prepared_batch = to_device(prepared_batch, self.device)
                    model_output = self.model(prepared_batch)
                    report = Report(prepared_batch, model_output)

                    self.update_meter(report, meter)

                    # accumulate necessary params for metric calculation
                    if combined_report is None:
                        # make a copy of report since `reporter.add_to_report` will
                        # change some of the report keys later
                        combined_report = Report(report)
                    else:
                        combined_report.accumulate_tensor_fields_and_loss(
                            report, self.metrics.required_params)
                        combined_report.batch_size += report.batch_size

                    # Each node generates a separate copy of predict JSON from the report,
                    # which will be used to evaluate dataset-level metrics
                    # (such as mAP in object detection or CIDEr in image captioning)
                    # Since `reporter.add_to_report` changes report keys (e.g. scores),
                    # do this after `combined_report.accumulate_tensor_fields_and_loss`
                    if "__prediction_report__" in self.metrics.required_params:
                        reporter.add_to_report(report,
                                               self.model,
                                               execute_on_master_only=False)

                    if single_batch is True:
                        break

                reporter.postprocess_dataset_report()
                # add prediction_report is used for set-level metrics
                combined_report.prediction_report = reporter.report

                combined_report.metrics = self.metrics(combined_report,
                                                       combined_report)
                self.update_meter(combined_report, meter, eval_mode=True)

            # enable train mode again
            self.model.train()

        return combined_report, meter
Exemple #8
0
    def load_extras(self):
        self.writer.write("Torch version is: " + torch.__version__)
        self.checkpoint = Checkpoint(self)
        self.meter = Meter()

        self.training_config = self.config.training

        early_stop_criteria = self.training_config.early_stop.criteria
        early_stop_minimize = self.training_config.early_stop.minimize
        early_stop_enabled = self.training_config.early_stop.enabled
        early_stop_patience = self.training_config.early_stop.patience

        self.log_interval = self.training_config.log_interval
        self.evaluation_interval = self.training_config.evaluation_interval
        self.checkpoint_interval = self.training_config.checkpoint_interval
        self.max_updates = self.training_config.max_updates
        self.should_clip_gradients = self.training_config.clip_gradients
        self.max_epochs = self.training_config.max_epochs

        self.early_stopping = EarlyStopping(
            self.model,
            self.checkpoint,
            early_stop_criteria,
            patience=early_stop_patience,
            minimize=early_stop_minimize,
            should_stop=early_stop_enabled,
        )
        self.current_epoch = 0
        self.current_iteration = 0
        self.num_updates = 0

        self.checkpoint.load_state_dict()

        self.not_debug = self.training_config.logger_level != "debug"

        self.lr_scheduler = None

        if self.training_config.lr_scheduler is True:
            self.lr_scheduler = build_scheduler(self.optimizer, self.config)

        self.tb_writer = None

        if self.training_config.tensorboard:
            log_dir = self.writer.log_dir
            env_tb_logdir = get_mmf_env(key="tensorboard_logdir")
            if env_tb_logdir:
                log_dir = env_tb_logdir

            self.tb_writer = TensorboardLogger(log_dir, self.current_iteration)
class TrainerReportingMixin(ABC):
    meter: Type[Meter] = Meter()

    def update_meter(self,
                     report: Dict[str, Any],
                     meter: Type[Meter] = None,
                     eval_mode: bool = False) -> None:
        if meter is None:
            meter = self.meter

        if hasattr(report, "metrics"):
            metrics_dict = report.metrics
            reduced_metrics_dict = reduce_dict(metrics_dict)

        if not eval_mode:
            loss_dict = report.losses
            reduced_loss_dict = reduce_dict(loss_dict)

        with torch.no_grad():
            # Add metrics to meter only when mode is `eval`
            meter_update_dict = {}
            if not eval_mode:
                total_loss_key = report.dataset_type + "/total_loss"
                meter_update_dict, total_loss = self.update_dict(
                    meter_update_dict, reduced_loss_dict)
                registry.register(total_loss_key, total_loss)
                meter_update_dict.update({total_loss_key: total_loss})

            if hasattr(report, "metrics"):
                meter_update_dict, _ = self.update_dict(
                    meter_update_dict, reduced_metrics_dict)

            meter.update(meter_update_dict, report.batch_size)

    def update_dict(self, meter_update_dict, values_dict):
        total_val = 0
        for key, val in values_dict.items():
            if torch.is_tensor(val):
                if val.dim() == 1:
                    val = val.mean()

            if hasattr(val, "item"):
                val = val.item()

            meter_update_dict.update({key: val})
            total_val += val

        return meter_update_dict, total_val
    def setUp(self):
        self.tmpdir = tempfile.mkdtemp()
        self.trainer = argparse.Namespace()
        self.config = load_yaml(os.path.join("configs", "defaults.yaml"))
        self.config = OmegaConf.merge(
            self.config,
            {
                "model": "simple",
                "model_config": {},
                "training": {
                    "checkpoint_interval": 1,
                    "evaluation_interval": 10,
                    "early_stop": {
                        "criteria": "val/total_loss"
                    },
                    "batch_size": 16,
                    "log_interval": 10,
                    "logger_level": "info",
                },
                "env": {
                    "save_dir": self.tmpdir
                },
            },
        )
        # Keep original copy for testing purposes
        self.trainer.config = deepcopy(self.config)
        registry.register("config", self.trainer.config)
        setup_logger()
        self.report = Mock(spec=Report)
        self.report.dataset_name = "abcd"
        self.report.dataset_type = "test"

        self.trainer.model = SimpleModule()
        self.trainer.val_loader = torch.utils.data.DataLoader(
            NumbersDataset(), batch_size=self.config.training.batch_size)

        self.trainer.optimizer = torch.optim.Adam(
            self.trainer.model.parameters(), lr=1e-01)
        self.trainer.device = "cpu"
        self.trainer.num_updates = 0
        self.trainer.current_iteration = 0
        self.trainer.current_epoch = 0
        self.trainer.max_updates = 0
        self.trainer.meter = Meter()
        self.cb = LogisticsCallback(self.config, self.trainer)
Exemple #11
0
    def evaluation_loop(
            self,
            dataset_type: str,
            use_tqdm: bool = False,
            single_batch: bool = False) -> Tuple[Dict[str, Any], Type[Meter]]:
        meter = Meter()
        reporter = self.dataset_loader.get_test_reporter(dataset_type)
        use_cpu = self.config.evaluation.get("use_cpu", False)
        loaded_batches = 0
        skipped_batches = 0

        with torch.no_grad():
            self.model.eval()
            disable_tqdm = not use_tqdm or not is_master()
            while reporter.next_dataset(flush_report=False):
                dataloader = reporter.get_dataloader()
                combined_report = None

                if self._can_use_tqdm(dataloader):
                    dataloader = tqdm.tqdm(dataloader, disable=disable_tqdm)
                for batch in dataloader:
                    # Do not timeout quickly on first batch, as workers might start at
                    # very different times.
                    with CompleteInTimeOrDie(600 if loaded_batches else 3600 *
                                             24):
                        loaded_batches += 1
                        prepared_batch = reporter.prepare_batch(batch)
                        prepared_batch = to_device(prepared_batch, self.device)
                        if not validate_batch_sizes(
                                prepared_batch.get_batch_size()):
                            logger.info(
                                "Skip batch due to uneven batch sizes.")
                            skipped_batches += 1
                            continue
                        model_output = self.model(prepared_batch)
                        report = Report(prepared_batch, model_output)
                        report = report.detach()

                        meter.update_from_report(report)

                        moved_report = report
                        # Move to CPU for metrics calculation later if needed
                        # Explicitly use `non_blocking=False` as this can cause
                        # race conditions in next accumulate
                        if use_cpu:
                            moved_report = report.copy().to("cpu",
                                                            non_blocking=False)

                        # accumulate necessary params for metric calculation
                        if combined_report is None:
                            # make a copy of report since `reporter.add_to_report` will
                            # change some of the report keys later
                            combined_report = moved_report.copy()
                        else:
                            combined_report.accumulate_tensor_fields_and_loss(
                                moved_report, self.metrics.required_params)
                            combined_report.batch_size += moved_report.batch_size

                        # Each node generates a separate copy of predict JSON from the
                        # report, which will be used to evaluate dataset-level metrics
                        # (such as mAP in object detection or CIDEr in image captioning)
                        # Since `reporter.add_to_report` changes report keys,
                        # (e.g scores) do this after
                        # `combined_report.accumulate_tensor_fields_and_loss`
                        if "__prediction_report__" in self.metrics.required_params:
                            # Still need to use original report here on GPU/TPU since
                            # it will be gathered
                            reporter.add_to_report(
                                report,
                                self.model,
                                execute_on_master_only=False)

                        if single_batch is True:
                            break

                logger.info(f"Finished training. Loaded {loaded_batches}")
                logger.info(f" -- skipped {skipped_batches} batches.")

                reporter.postprocess_dataset_report()
                assert (combined_report is not None
                        ), "Please check if your validation set is empty!"
                # add prediction_report is used for set-level metrics
                combined_report.prediction_report = reporter.report

                combined_report.metrics = self.metrics(combined_report,
                                                       combined_report)

                # Since update_meter will reduce the metrics over GPUs, we need to
                # move them back to GPU but we will only move metrics and losses
                # which are needed by update_meter to avoid OOM
                # Furthermore, do it in a non_blocking way to avoid any issues
                # in device to host or host to device transfer
                if use_cpu:
                    combined_report = combined_report.to(
                        self.device,
                        fields=["metrics", "losses"],
                        non_blocking=False)

                meter.update_from_report(combined_report,
                                         should_update_loss=False)

            # enable train mode again
            self.model.train()

        return combined_report, meter
Exemple #12
0
class TrainerTrainingLoopMixin(ABC):
    current_epoch: int = 0
    current_iteration: int = 0
    num_updates: int = 0
    meter: Meter = Meter()

    def training_loop(self) -> None:
        self.max_updates = self._calculate_max_updates()
        torch.autograd.set_detect_anomaly(self.training_config.detect_anomaly)

        logger.info("Starting training...")
        self.model.train()
        self.run_training_epoch()
        self.after_training_loop()

    def after_training_loop(self) -> None:
        logger.info("Stepping into final validation check")
        # Only do when run_type has train as it shouldn't happen on validation and
        # inference runs. Inference will take care of this anyways. Also, don't run
        # if current iteration is divisble by snapshot interval as it will just
        # be a repeat
        if ("train" in self.run_type and "val" in self.run_type
                and self.num_updates % self.training_config.evaluation_interval
                != 0):
            # Create a new meter for this case
            report, meter = self.evaluation_loop("val")

            # Validation end callbacks
            self.on_validation_end(report=report, meter=meter)

    def run_training_epoch(self) -> None:
        should_break = False
        while self.num_updates < self.max_updates and not should_break:
            self.current_epoch += 1
            registry.register("current_epoch", self.current_epoch)

            # Seed the sampler in case if it is distributed
            self.dataset_loader.seed_sampler("train", self.current_epoch)

            # For iterable datasets we cannot determine length of dataset properly.
            # For those cases we set num_remaining_batches to be the (number of
            # updates remaining x update_frequency)
            num_remaining_batches = (((self.max_updates - self.num_updates) *
                                      self.training_config.update_frequency)
                                     if isinstance(
                                         self.train_loader.current_dataset,
                                         torch.utils.data.IterableDataset) else
                                     len(self.train_loader))

            should_start_update = True
            for idx, batch in enumerate(self.train_loader):
                if should_start_update:
                    combined_report = None
                    self._start_update()
                    num_batches_for_this_update = min(
                        self.training_config.update_frequency,
                        num_remaining_batches)
                    should_start_update = False

                self.current_iteration += 1
                # batch execution starts here
                self.on_batch_start()
                self.profile("Batch load time")

                report = self.run_training_batch(batch,
                                                 num_batches_for_this_update)
                report = report.detach()

                # accumulate necessary params (including loss) for metric calculation
                if combined_report is None:
                    combined_report = report
                else:
                    combined_report.accumulate_tensor_fields_and_loss(
                        report, self.metrics.required_params)
                    combined_report.batch_size += report.batch_size

                # batch execution ends here
                self.on_batch_end(report=combined_report, meter=self.meter)

                # check if an update has finished or if it is the last, if no continue
                if ((idx + 1) % self.training_config.update_frequency and
                        num_remaining_batches != num_batches_for_this_update):
                    continue

                self._finish_update()
                should_start_update = True

                should_log = False
                if self.num_updates % self.logistics_callback.log_interval == 0:
                    should_log = True
                    # Calculate metrics every log interval for debugging
                    if self.training_config.evaluate_metrics:
                        combined_report.metrics = self.metrics(
                            combined_report, combined_report)
                    self.meter.update_from_report(combined_report)

                self.on_update_end(report=combined_report,
                                   meter=self.meter,
                                   should_log=should_log)

                num_remaining_batches -= num_batches_for_this_update

                # Check if training should be stopped
                should_break = False

                if self.num_updates % self.training_config.evaluation_interval == 0:
                    # Validation begin callbacks
                    self.on_validation_start()

                    logger.info(
                        "Evaluation time. Running on full validation set...")
                    # Validation and Early stopping
                    # Create a new meter for this case
                    report, meter = self.evaluation_loop("val")

                    # Validation end callbacks
                    stop = self.early_stop_callback.on_validation_end(
                        report=report, meter=meter)
                    self.on_validation_end(report=report, meter=meter)

                    gc.collect()

                    if "cuda" in str(self.device):
                        torch.cuda.empty_cache()

                    if stop is True:
                        logger.info("Early stopping activated")
                        should_break = True

                if self.num_updates >= self.max_updates:
                    should_break = True

                if should_break:
                    break

    def run_training_batch(self, batch: Dict[str, Tensor],
                           loss_divisor: int) -> None:
        report = self._forward(batch)
        if self.training_config.exit_on_nan_losses:
            self._check_nan_losses(report)
        loss = extract_loss(report, loss_divisor)
        self._backward(loss)
        return report

    def _check_nan_losses(self, report):
        # skip this check in XLA mode as calling .item() in forward pass
        # greatly slows down the training
        if not is_xla():
            # check whether NaN has occurred in the losses, and exit the training
            # when NaN happens
            loss_dict = report.losses
            nan_loss_keys = []
            for key, value in loss_dict.items():
                if torch.any(torch.isnan(value)).item():
                    nan_loss_keys.append(key)
            if len(nan_loss_keys) > 0:
                keys_str = ", ".join(nan_loss_keys)
                error_msg = (
                    f"NaN occurred in the following loss(es): {keys_str}; "
                    f"exiting the training")
                logger.info(error_msg)
                raise RuntimeError(error_msg)

    def _forward(self, batch: Dict[str, Tensor]) -> Dict[str, Any]:
        # Move the sample list to device if it isn't as of now.
        prepared_batch = to_device(batch, self.device)
        self.profile("Batch prepare time")
        # Arguments should be a dict at this point

        with torch.cuda.amp.autocast(enabled=self.training_config.fp16):
            model_output = self.model(prepared_batch)
            report = Report(prepared_batch, model_output)

        self.profile("Forward time")
        return report

    def _start_update(self):
        logger.debug(self.num_updates + 1)
        self.on_update_start()
        self.optimizer.zero_grad()

    def _backward(self, loss: Tensor) -> None:
        self.scaler.scale(loss).backward()
        self.profile("Backward time")

    def _finish_update(self):
        if self.training_config.clip_gradients:
            clip_gradients(
                self.model,
                self.optimizer,
                self.num_updates,
                self.logistics_callback.tb_writer,
                self.config,
                scale=self.scaler.get_scale(),
            )
        if is_xla():
            import torch_xla.core.xla_model as xm

            # Assumes no model parallel
            xm.reduce_gradients(self.optimizer)

        self.scaler.step(self.optimizer)
        self.scaler.update()
        self.num_updates += 1
        self.profile("Finished update")

    def _calculate_max_updates(self):
        config_max_updates = self.training_config.max_updates
        config_max_epochs = self.training_config.max_epochs
        max_updates, _ = get_max_updates(
            config_max_updates,
            config_max_epochs,
            self.train_loader,
            self.training_config.update_frequency,
        )

        return max_updates
    def __init__(
        self,
        num_train_data,
        max_updates,
        max_epochs,
        config=None,
        optimizer=None,
        update_frequency=1,
        batch_size=1,
        batch_size_per_device=None,
        fp16=False,
        on_update_end_fn=None,
        scheduler_config=None,
        grad_clipping_config=None,
    ):
        if config is None:
            self.config = OmegaConf.create(
                {
                    "training": {
                        "detect_anomaly": False,
                        "evaluation_interval": 10000,
                        "update_frequency": update_frequency,
                        "fp16": fp16,
                        "batch_size": batch_size,
                        "batch_size_per_device": batch_size_per_device,
                    }
                }
            )
            self.training_config = self.config.training
        else:
            self.training_config = config.training
            self.config = config

        # Load batch size with custom config and cleanup
        original_config = registry.get("config")
        registry.register("config", self.config)
        batch_size = get_batch_size()
        registry.register("config", original_config)

        if max_updates is not None:
            self.training_config["max_updates"] = max_updates
        if max_epochs is not None:
            self.training_config["max_epochs"] = max_epochs

        self.model = SimpleModel({"in_dim": 1})
        self.model.build()
        if torch.cuda.is_available():
            self.model = self.model.cuda()
            self.device = "cuda"
        else:
            self.device = "cpu"
        self.distributed = False

        self.dataset_loader = MagicMock()
        self.dataset_loader.seed_sampler = MagicMock(return_value=None)
        self.dataset_loader.prepare_batch = lambda x: SampleList(x)
        if optimizer is None:
            self.optimizer = MagicMock()
            self.optimizer.step = MagicMock(return_value=None)
            self.optimizer.zero_grad = MagicMock(return_value=None)
        else:
            self.optimizer = optimizer

        if scheduler_config:
            config.training.lr_scheduler = True
            config.scheduler = scheduler_config
            self.lr_scheduler_callback = LRSchedulerCallback(config, self)
            self.callbacks.append(self.lr_scheduler_callback)
            on_update_end_fn = (
                on_update_end_fn
                if on_update_end_fn
                else self.lr_scheduler_callback.on_update_end
            )

        if grad_clipping_config:
            self.training_config.clip_gradients = True
            self.training_config.max_grad_l2_norm = grad_clipping_config[
                "max_grad_l2_norm"
            ]
            self.training_config.clip_norm_mode = grad_clipping_config["clip_norm_mode"]

        dataset = NumbersDataset(num_train_data)
        self.train_loader = torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=1,
            drop_last=False,
        )
        self.train_loader.current_dataset = dataset
        self.on_batch_start = MagicMock(return_value=None)
        self.on_update_start = MagicMock(return_value=None)
        self.logistics_callback = MagicMock(return_value=None)
        self.logistics_callback.log_interval = MagicMock(return_value=None)
        self.on_batch_end = MagicMock(return_value=None)
        self.on_update_end = (
            on_update_end_fn if on_update_end_fn else MagicMock(return_value=None)
        )
        self.meter = Meter()
        self.after_training_loop = MagicMock(return_value=None)
        self.on_validation_start = MagicMock(return_value=None)
        self.evaluation_loop = MagicMock(return_value=(None, None))
        self.scaler = torch.cuda.amp.GradScaler(enabled=False)
        self.val_loader = MagicMock(return_value=None)
        self.early_stop_callback = MagicMock(return_value=None)
        self.on_validation_end = MagicMock(return_value=None)
        self.metrics = MagicMock(return_value=None)
Exemple #14
0
    def __init__(
        self,
        num_train_data,
        max_updates,
        max_epochs,
        config=None,
        optimizer=None,
        update_frequency=1,
        batch_size=1,
        fp16=False,
        on_update_end_fn=None,
    ):
        if config is None:
            self.training_config = OmegaConf.create({
                "detect_anomaly": False,
                "evaluation_interval": 10000,
                "update_frequency": update_frequency,
                "fp16": fp16,
                "batch_size": batch_size,
            })
        else:
            self.training_config = config.training

        if max_updates is not None:
            self.training_config["max_updates"] = max_updates
        if max_epochs is not None:
            self.training_config["max_epochs"] = max_epochs

        self.model = SimpleModel(1)
        if torch.cuda.is_available():
            self.model = self.model.cuda()
            self.device = "cuda"
        else:
            self.device = "cpu"

        self.dataset_loader = MagicMock()
        self.dataset_loader.seed_sampler = MagicMock(return_value=None)
        self.dataset_loader.prepare_batch = lambda x: SampleList(x)
        if optimizer is None:
            self.optimizer = MagicMock()
            self.optimizer.step = MagicMock(return_value=None)
            self.optimizer.zero_grad = MagicMock(return_value=None)
        else:
            self.optimizer = optimizer

        dataset = NumbersDataset(num_train_data)
        self.train_loader = torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=1,
            drop_last=False,
        )
        self.train_loader.current_dataset = dataset
        self.on_batch_start = MagicMock(return_value=None)
        self.on_update_start = MagicMock(return_value=None)
        self.logistics_callback = MagicMock(return_value=None)
        self.logistics_callback.log_interval = MagicMock(return_value=None)
        self.on_batch_end = MagicMock(return_value=None)
        self.on_update_end = (on_update_end_fn if on_update_end_fn else
                              MagicMock(return_value=None))
        self.meter = Meter()
        self.after_training_loop = MagicMock(return_value=None)
        self.on_validation_start = MagicMock(return_value=None)
        self.evaluation_loop = MagicMock(return_value=(None, None))
        self.scaler = torch.cuda.amp.GradScaler(enabled=False)
        self.val_loader = MagicMock(return_value=None)
        self.early_stop_callback = MagicMock(return_value=None)
        self.on_validation_end = MagicMock(return_value=None)
        self.metrics = MagicMock(return_value=None)