Example #1
0
    def test_reset(self):
        timer = Timer()
        time.sleep(2)
        timer.reset()
        expected = 0

        self.assertEqual(int(timer.get_current().split("ms")[0]), expected)
Example #2
0
    def test_reset(self):
        timer = Timer()
        time.sleep(2)
        timer.reset()
        expected = "000ms"

        self.assertEqual(timer.get_current(), expected)
Example #3
0
class LogisticsCallback(Callback):
    """Callback for handling train/validation logistics, report summarization,
    logging etc.
    """

    def __init__(self, config, trainer):
        """
        Attr:
            config(mmf_typings.DictConfig): Config for the callback
            trainer(Type[BaseTrainer]): Trainer object
        """
        super().__init__(config, trainer)

        self.total_timer = Timer()
        self.log_interval = self.training_config.log_interval
        self.evaluation_interval = self.training_config.evaluation_interval
        self.checkpoint_interval = self.training_config.checkpoint_interval

        # Total iterations for snapshot
        self.snapshot_iterations = len(self.trainer.val_dataset)
        self.snapshot_iterations //= self.training_config.batch_size

        self.tb_writer = None

        if self.training_config.tensorboard:
            log_dir = setup_output_folder(folder_only=True)
            env_tb_logdir = get_mmf_env(key="tensorboard_logdir")
            if env_tb_logdir:
                log_dir = env_tb_logdir

            self.tb_writer = TensorboardLogger(log_dir, self.trainer.current_iteration)

    def on_train_start(self):
        self.train_timer = Timer()
        self.snapshot_timer = Timer()

    def on_update_end(self, **kwargs):
        if not kwargs["should_log"]:
            return
        extra = {}
        if "cuda" in str(self.trainer.device):
            extra["max mem"] = torch.cuda.max_memory_allocated() / 1024
            extra["max mem"] //= 1024

        if self.training_config.experiment_name:
            extra["experiment"] = self.training_config.experiment_name

        extra.update(
            {
                "epoch": self.trainer.current_epoch,
                "num_updates": self.trainer.num_updates,
                "iterations": self.trainer.current_iteration,
                "max_updates": self.trainer.max_updates,
                "lr": "{:.5f}".format(
                    self.trainer.optimizer.param_groups[0]["lr"]
                ).rstrip("0"),
                "ups": "{:.2f}".format(
                    self.log_interval / self.train_timer.unix_time_since_start()
                ),
                "time": self.train_timer.get_time_since_start(),
                "time_since_start": self.total_timer.get_time_since_start(),
                "eta": self._calculate_time_left(),
            }
        )
        self.train_timer.reset()
        self._summarize_report(kwargs["meter"], extra=extra)

    def on_validation_start(self, **kwargs):
        self.snapshot_timer.reset()

    def on_validation_end(self, **kwargs):
        extra = {
            "num_updates": self.trainer.num_updates,
            "epoch": self.trainer.current_epoch,
            "iterations": self.trainer.current_iteration,
            "max_updates": self.trainer.max_updates,
            "val_time": self.snapshot_timer.get_time_since_start(),
        }
        extra.update(self.trainer.early_stop_callback.early_stopping.get_info())
        self.train_timer.reset()
        self._summarize_report(kwargs["meter"], extra=extra)

    def on_test_end(self, **kwargs):
        prefix = "{}: full {}".format(
            kwargs["report"].dataset_name, kwargs["report"].dataset_type
        )
        self._summarize_report(kwargs["meter"], prefix)
        logger.info(f"Finished run in {self.total_timer.get_time_since_start()}")

    def _summarize_report(self, meter, should_print=True, extra=None):
        if extra is None:
            extra = {}
        if not is_master() and not is_xla():
            return

        if self.training_config.tensorboard:
            scalar_dict = meter.get_scalar_dict()
            self.tb_writer.add_scalars(scalar_dict, self.trainer.current_iteration)

        if not should_print:
            return
        log_dict = {}
        if hasattr(self.trainer, "num_updates") and hasattr(
            self.trainer, "max_updates"
        ):
            log_dict.update(
                {"progress": f"{self.trainer.num_updates}/{self.trainer.max_updates}"}
            )
        log_dict.update(meter.get_log_dict())
        log_dict.update(extra)

        log_progress(log_dict)

    def _calculate_time_left(self):
        time_taken_for_log = time.time() * 1000 - self.train_timer.start
        iterations_left = self.trainer.max_updates - self.trainer.num_updates
        num_logs_left = iterations_left / self.log_interval
        time_left = num_logs_left * time_taken_for_log

        snapshot_iteration = self.snapshot_iterations / self.log_interval
        snapshot_iteration *= iterations_left / self.evaluation_interval
        time_left += snapshot_iteration * time_taken_for_log

        return self.train_timer.get_time_hhmmss(gap=time_left)
Example #4
0
class BaseTrainer:
    def __init__(self, configuration):
        self.configuration = configuration
        self.config = self.configuration.get_config()
        self.profiler = Timer()
        self.total_timer = Timer()
        if self.configuration is not None:
            self.args = self.configuration.args

    def load(self):
        self._set_device()

        self.run_type = self.config.get("run_type", "train")
        self.dataset_loader = DatasetLoader(self.config)
        self._datasets = self.config.datasets

        # Check if loader is already defined, else init it
        writer = registry.get("writer", no_warning=True)
        if writer:
            self.writer = writer
        else:
            self.writer = Logger(self.config)
            registry.register("writer", self.writer)

        self.configuration.pretty_print()

        self.config_based_setup()

        self.load_datasets()
        self.load_model_and_optimizer()
        self.load_metrics()

    def _set_device(self):
        self.local_rank = self.config.device_id
        self.device = self.local_rank
        self.distributed = False

        # Will be updated later based on distributed setup
        registry.register("global_device", self.device)

        if self.config.distributed.init_method is not None:
            self.distributed = True
            self.device = torch.device("cuda", self.local_rank)
        elif torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")

        registry.register("current_device", self.device)

    def load_datasets(self):
        self.writer.write("Loading datasets", "info")
        self.dataset_loader.load_datasets()

        self.train_dataset = self.dataset_loader.train_dataset
        self.val_dataset = self.dataset_loader.val_dataset

        # Total iterations for snapshot
        self.snapshot_iterations = len(self.val_dataset)
        self.snapshot_iterations //= self.config.training.batch_size

        self.test_dataset = self.dataset_loader.test_dataset

        self.train_loader = self.dataset_loader.train_loader
        self.val_loader = self.dataset_loader.val_loader
        self.test_loader = self.dataset_loader.test_loader

    def load_metrics(self):
        metrics = self.config.evaluation.get("metrics", [])
        self.metrics = Metrics(metrics)
        self.metrics_params = self.metrics.required_params

    def load_model_and_optimizer(self):
        attributes = self.config.model_config[self.config.model]
        # Easy way to point to config for other model
        if isinstance(attributes, str):
            attributes = self.config.model_config[attributes]

        with omegaconf.open_dict(attributes):
            attributes.model = self.config.model

        self.model = build_model(attributes)

        if "cuda" in str(self.device):
            device_info = "CUDA Device {} is: {}".format(
                self.config.distributed.rank,
                torch.cuda.get_device_name(self.local_rank),
            )
            registry.register("global_device", self.config.distributed.rank)
            self.writer.write(device_info, log_all=True)

        self.model = self.model.to(self.device)
        self.optimizer = build_optimizer(self.model, self.config)

        registry.register("data_parallel", False)
        registry.register("distributed", False)

        self.load_extras()
        self.parallelize_model()

    def parallelize_model(self):
        training = self.config.training
        if ("cuda" in str(self.device) and torch.cuda.device_count() > 1
                and not self.distributed):
            registry.register("data_parallel", True)
            self.model = torch.nn.DataParallel(self.model)

        if "cuda" in str(self.device) and self.distributed:
            registry.register("distributed", True)
            self.model = torch.nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[self.local_rank],
                output_device=self.local_rank,
                check_reduction=True,
                find_unused_parameters=training.find_unused_parameters,
            )

    def load_extras(self):
        self.writer.write("Torch version is: " + torch.__version__)
        self.checkpoint = Checkpoint(self)
        self.meter = Meter()

        self.training_config = self.config.training

        early_stop_criteria = self.training_config.early_stop.criteria
        early_stop_minimize = self.training_config.early_stop.minimize
        early_stop_enabled = self.training_config.early_stop.enabled
        early_stop_patience = self.training_config.early_stop.patience

        self.log_interval = self.training_config.log_interval
        self.evaluation_interval = self.training_config.evaluation_interval
        self.checkpoint_interval = self.training_config.checkpoint_interval
        self.max_updates = self.training_config.max_updates
        self.should_clip_gradients = self.training_config.clip_gradients
        self.max_epochs = self.training_config.max_epochs

        self.early_stopping = EarlyStopping(
            self.model,
            self.checkpoint,
            early_stop_criteria,
            patience=early_stop_patience,
            minimize=early_stop_minimize,
            should_stop=early_stop_enabled,
        )
        self.current_epoch = 0
        self.current_iteration = 0
        self.num_updates = 0

        self.checkpoint.load_state_dict()

        self.not_debug = self.training_config.logger_level != "debug"

        self.lr_scheduler = None

        if self.training_config.lr_scheduler is True:
            self.lr_scheduler = build_scheduler(self.optimizer, self.config)

        self.tb_writer = None

        if self.training_config.tensorboard:
            log_dir = self.writer.log_dir
            env_tb_logdir = get_mmf_env(key="tensorboard_logdir")
            if env_tb_logdir:
                log_dir = env_tb_logdir

            self.tb_writer = TensorboardLogger(log_dir, self.current_iteration)

    def config_based_setup(self):
        seed = self.config.training.seed
        if seed is None:
            return

        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    def train(self):
        self.writer.write("===== Model =====")
        self.writer.write(self.model)

        print_model_parameters(self.model)

        if "train" not in self.run_type:
            self.inference()
            return

        should_break = False

        if self.max_epochs is None:
            self.max_epochs = math.inf
        else:
            self.max_updates = math.inf

        self.model.train()
        self.train_timer = Timer()
        self.snapshot_timer = Timer()

        self.profile("Setup Time")

        torch.autograd.set_detect_anomaly(True)
        self.writer.write("Starting training...")

        while self.num_updates < self.max_updates and not should_break:
            self.current_epoch += 1
            registry.register("current_epoch", self.current_epoch)

            # Seed the sampler in case if it is distributed
            self.dataset_loader.seed_sampler("train", self.current_epoch)

            if self.current_epoch > self.max_epochs:
                break

            for batch in self.train_loader:
                self.profile("Batch load time")
                self.current_iteration += 1
                self.writer.write(self.num_updates + 1, "debug")

                report = self._forward_pass(batch)
                loss = self._extract_loss(report)
                self._backward(loss)
                should_break = self._logistics(report)

                if self.num_updates > self.max_updates:
                    should_break = True

                if should_break:
                    break

            # In distributed, each worker will complete one epoch when we reach this
            # as each worker is an individual instance
            self.current_epoch += get_world_size() - 1
        self.finalize()

    def _run_scheduler(self):
        if self.lr_scheduler is not None:
            self.lr_scheduler.step(self.num_updates)

    def _forward_pass(self, batch):
        prepared_batch = self.dataset_loader.prepare_batch(batch)
        self.profile("Batch prepare time")
        # Arguments should be a dict at this point
        model_output = self.model(prepared_batch)
        report = Report(prepared_batch, model_output)
        self.profile("Forward time")

        return report

    def _backward(self, loss):
        self.optimizer.zero_grad()
        loss.backward()

        if self.should_clip_gradients:
            clip_gradients(self.model, self.num_updates, self.tb_writer,
                           self.config)

        self.optimizer.step()
        self._run_scheduler()
        self.num_updates += 1
        self.profile("Backward time")

    def _extract_loss(self, report):
        loss_dict = report.losses
        loss = sum([loss.mean() for loss in loss_dict.values()])
        return loss

    def finalize(self):
        self.writer.write("Stepping into final validation check")

        # Only do when run_type has train as it shouldn't happen on validation and
        # inference runs. Inference will take care of this anyways. Also, don't run
        # if current iteration is divisble by snapshot interval as it will just
        # be a repeat
        if ("train" in self.run_type
                and self.num_updates % self.evaluation_interval != 0):
            self._try_full_validation(force=True)

        self.checkpoint.restore()
        self.checkpoint.finalize()
        self.inference()

        self.writer.write(
            f"Finished run in {self.total_timer.get_time_since_start()}")

    def _update_meter(self, report, meter=None, eval_mode=False):
        if meter is None:
            meter = self.meter

        if hasattr(report, "metrics"):
            metrics_dict = report.metrics
            reduced_metrics_dict = reduce_dict(metrics_dict)

        if not eval_mode:
            loss_dict = report.losses
            reduced_loss_dict = reduce_dict(loss_dict)

        with torch.no_grad():
            # Add metrics to meter only when mode is `eval`
            meter_update_dict = {}
            if not eval_mode:
                loss_key = report.dataset_type + "/total_loss"
                reduced_loss = sum(
                    [loss.mean() for loss in reduced_loss_dict.values()])
                if hasattr(reduced_loss, "item"):
                    reduced_loss = reduced_loss.item()

                registry.register(loss_key, reduced_loss)
                meter_update_dict.update({loss_key: reduced_loss})
                meter_update_dict.update(reduced_loss_dict)
            if hasattr(report, "metrics"):
                meter_update_dict.update(reduced_metrics_dict)
            meter.update(meter_update_dict, report.batch_size)

    def _logistics(self, report):
        registry.register("current_iteration", self.current_iteration)
        registry.register("num_updates", self.num_updates)

        should_print = self.num_updates % self.log_interval == 0
        should_break = False
        extra = {}

        if should_print is True:
            if "cuda" in str(self.device):
                extra["max mem"] = torch.cuda.max_memory_allocated() / 1024
                extra["max mem"] //= 1024

            if self.training_config.experiment_name:
                extra["experiment"] = self.training_config.experiment_name

            extra.update({
                "epoch":
                self.current_epoch,
                "num_updates":
                self.num_updates,
                "iterations":
                self.current_iteration,
                "max_updates":
                self.max_updates,
                "lr":
                "{:.5f}".format(
                    self.optimizer.param_groups[0]["lr"]).rstrip("0"),
                "ups":
                "{:.2f}".format(self.log_interval /
                                self.train_timer.unix_time_since_start()),
                "time":
                self.train_timer.get_time_since_start(),
                "time_since_start":
                self.total_timer.get_time_since_start(),
                "eta":
                self._calculate_time_left(),
            })

            self.train_timer.reset()
            # Calculate metrics every log interval for debugging
            if self.training_config.evaluate_metrics:
                report.metrics = self.metrics(report, report)
            self._update_meter(report, self.meter)

            self._summarize_report(self.meter,
                                   should_print=should_print,
                                   extra=extra)

        self._try_snapshot()
        should_break = self._try_full_validation()

        return should_break

    def _try_snapshot(self):
        if self.num_updates % self.checkpoint_interval == 0:
            self.writer.write("Checkpoint time. Saving a checkpoint.")
            self.checkpoint.save(self.num_updates,
                                 self.current_iteration,
                                 update_best=False)

    def _try_full_validation(self, force=False):
        should_break = False

        if self.num_updates % self.evaluation_interval == 0 or force:
            self.snapshot_timer.reset()
            self.writer.write(
                "Evaluation time. Running on full validation set...")
            # Validation and Early stopping
            # Create a new meter for this case
            report, meter = self.evaluate(self.val_loader)

            extra = {
                "num_updates": self.num_updates,
                "epoch": self.current_epoch,
                "iterations": self.current_iteration,
                "max_updates": self.max_updates,
                "val_time": self.snapshot_timer.get_time_since_start(),
            }

            stop = self.early_stopping(self.num_updates,
                                       self.current_iteration, meter)
            stop = bool(broadcast_scalar(stop, src=0, device=self.device))

            extra.update(self.early_stopping.get_info())

            self._summarize_report(meter, extra=extra)
            gc.collect()

            if "cuda" in str(self.device):
                torch.cuda.empty_cache()

            if stop is True:
                self.writer.write("Early stopping activated")
                should_break = True

            self.train_timer.reset()

        return should_break

    def evaluate(self, loader, use_tqdm=False, single_batch=False):
        meter = Meter()

        with torch.no_grad():
            self.model.eval()
            disable_tqdm = not use_tqdm or not is_master()
            combined_report = None

            for batch in tqdm(loader, disable=disable_tqdm):
                report = self._forward_pass(batch)
                self._update_meter(report, meter)

                # accumulate necessary params for metric calculation
                if combined_report is None:
                    combined_report = report
                else:
                    combined_report.accumulate_tensor_fields(
                        report, self.metrics.required_params)
                    combined_report.batch_size += report.batch_size

                if single_batch is True:
                    break

            combined_report.metrics = self.metrics(combined_report,
                                                   combined_report)
            self._update_meter(combined_report, meter, eval_mode=True)

            self.model.train()

        return combined_report, meter

    def _summarize_report(self, meter, should_print=True, extra=None):
        if extra is None:
            extra = {}
        if not is_master():
            return

        if self.training_config.tensorboard:
            scalar_dict = meter.get_scalar_dict()
            self.tb_writer.add_scalars(scalar_dict, self.current_iteration)

        if not should_print:
            return
        log_dict = {"progress": f"{self.num_updates}/{self.max_updates}"}
        log_dict.update(meter.get_log_dict())
        log_dict.update(extra)

        self.writer.log_progress(log_dict)

    def inference(self):
        if "val" in self.run_type:
            self._inference_run("val")

        if any(rt in self.run_type for rt in ["inference", "test", "predict"]):
            self._inference_run("test")

    def _inference_run(self, dataset_type):
        if self.config.evaluation.predict:
            self.predict(dataset_type)
            return

        self.writer.write(f"Starting inference on {dataset_type} set")

        report, meter = self.evaluate(getattr(self, f"{dataset_type}_loader"),
                                      use_tqdm=True)
        prefix = f"{report.dataset_name}: full {dataset_type}"
        self._summarize_report(meter, prefix)

    def _calculate_time_left(self):
        time_taken_for_log = time.time() * 1000 - self.train_timer.start
        iterations_left = self.max_updates - self.num_updates
        num_logs_left = iterations_left / self.log_interval
        time_left = num_logs_left * time_taken_for_log

        snapshot_iteration = self.snapshot_iterations / self.log_interval
        snapshot_iteration *= iterations_left / self.evaluation_interval
        time_left += snapshot_iteration * time_taken_for_log

        return self.train_timer.get_time_hhmmss(gap=time_left)

    def profile(self, text):
        if self.not_debug:
            return
        self.writer.write(text + ": " + self.profiler.get_time_since_start(),
                          "debug")
        self.profiler.reset()

    def predict(self, dataset_type):
        reporter = self.dataset_loader.get_test_reporter(dataset_type)
        with torch.no_grad():
            self.model.eval()
            message = f"Starting {dataset_type} inference predictions"
            self.writer.write(message)

            while reporter.next_dataset():
                dataloader = reporter.get_dataloader()

                for batch in tqdm(dataloader):
                    prepared_batch = reporter.prepare_batch(batch)
                    model_output = self.model(prepared_batch)
                    report = Report(prepared_batch, model_output)
                    reporter.add_to_report(report, self.model)

            self.writer.write("Finished predicting")
            self.model.train()
Example #5
0
class LightningLoopCallback(Callback):
    def __init__(self, lightning_trainer: Any):
        super().__init__()
        self.lightning_trainer = lightning_trainer
        # this is lightning trainer's config
        self.trainer_config = lightning_trainer.trainer_config
        # training config configures training parameters.
        self.training_config = lightning_trainer.training_config
        self.run_type = lightning_trainer.run_type

        # for logging
        self.total_timer = Timer()
        self.snapshot_timer = Timer()
        self.snapshot_iterations = len(self.lightning_trainer.val_loader)
        self.train_timer = Timer()

    def on_train_start(self, trainer: Trainer, pl_module: LightningModule):
        registry.register("current_epoch", trainer.current_epoch)
        self.train_combined_report = None

    def on_train_batch_end(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        outputs: List,
        batch: SampleList,
        batch_idx: int,
        dataloader_idx: int,
    ):
        # prepare the next batch
        self.lightning_trainer.data_module.train_loader.change_dataloader()

        # aggregate train combined_report
        self.train_combined_report = self._update_and_create_report(
            SampleList(batch), batch_idx, outputs, pl_module,
            self.train_combined_report)

        # log
        if (trainer.global_step +
                1) % self.trainer_config.log_every_n_steps == 0:
            self._train_log(trainer, pl_module)

        # save checkpoints - TODO: @sash

    def on_train_end(self, trainer: Trainer, pl_module: LightningModule):
        # Only do when run_type has train as it shouldn't happen on validation and
        # inference runs. Inference will take care of this anyways. Also, don't run
        # if current iteration is divisble by snapshot interval as it will just
        # be a repeat
        if ("train" in self.run_type and trainer.global_step %
                self.trainer_config.val_check_interval != 0):
            logger.info("Stepping into final validation check")
            # Pytorch Lightning upgrades PR4945 and PR4948 will be enabled in 1.2
            # TODO: perform final validation check

    # Validation Callbacks
    def on_validation_start(self, trainer: Trainer,
                            pl_module: LightningModule):
        logger.info("Evaluation time. Running on full validation set...")
        self.snapshot_timer.reset()
        self.val_combined_report = None
        pl_module.val_meter.reset()

    def on_validation_batch_end(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        outputs: List,
        batch: SampleList,
        batch_idx: int,
        dataloader_idx: int,
    ):
        # prepare the next batch
        self.lightning_trainer.data_module.val_loader.change_dataloader()

        # aggregate val_combined_report
        self.val_combined_report = self._update_and_create_report(
            batch,
            batch_idx,
            outputs,
            pl_module,
            self.val_combined_report,
            update_meter=pl_module.val_meter,
        )
        self.val_combined_report.metrics = pl_module.metrics(
            self.val_combined_report, self.val_combined_report)
        pl_module.val_meter.update_from_report(self.val_combined_report,
                                               should_update_loss=False)

    def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
        iterations = self._get_iterations_for_logging(trainer)
        current_epochs = self._get_current_epoch_for_logging(trainer)
        num_updates = self._get_num_updates_for_logging(trainer)
        extra = {
            "num_updates": num_updates,
            "epoch": current_epochs,
            "iterations": iterations,
            "max_updates": trainer.max_steps,
            "val_time": self.snapshot_timer.get_time_since_start(),
        }
        # TODO: @sash populate early stop info for logging (next mvp)
        # extra.update(self.trainer.early_stop_callback.early_stopping.get_info())
        self.train_timer.reset()
        summarize_report(
            current_iteration=iterations,
            num_updates=num_updates,
            max_updates=trainer.max_steps,
            meter=pl_module.val_meter,
            extra=extra,
            tb_writer=self.lightning_trainer.tb_writer,
        )

    def _update_and_create_report(
        self,
        batch: Dict,
        batch_idx: int,
        step_output: Dict,
        pl_module: LightningModule,
        combined_report: Report = None,
        update_meter: Meter = None,
    ):
        report = Report(batch, step_output)

        if update_meter:
            update_meter.update_from_report(report)

        should_accumulate = not (
            batch_idx % self.trainer_config.accumulate_grad_batches == 0)

        final_report = report
        if should_accumulate and combined_report is not None:
            combined_report.accumulate_tensor_fields_and_loss(
                report, pl_module.metrics.required_params)
            combined_report.batch_size += report.batch_size
            final_report = combined_report

        return final_report

    def get_optimizer(self, trainer: Trainer):
        assert (
            len(trainer.optimizers) == 1
        ), "mmf lightning_trainer supports 1 optimizer per model for now."
        optimizer = trainer.optimizers[0]
        return optimizer

    def _save_checkpoint(self, trainer: Trainer):
        logger.info("Checkpoint time. Saving a checkpoint.")
        return
        # TODO: sash Needs implementation - next mvp

    def _get_current_epoch_for_logging(self, trainer: Trainer):
        return trainer.current_epoch + 1

    def _get_iterations_for_logging(self, trainer: Trainer):
        return trainer.train_loop.batch_idx + 1

    def _get_num_updates_for_logging(self, trainer: Trainer):
        return trainer.global_step + 1

    def _train_log(self, trainer: Trainer, pl_module: LightningModule):
        if self.training_config.evaluate_metrics:
            self.train_combined_report.metrics = pl_module.metrics(
                self.train_combined_report, self.train_combined_report)

        pl_module.train_meter.update_from_report(self.train_combined_report)

        extra = {}
        if "cuda" in str(trainer.model.device):
            extra["max mem"] = torch.cuda.max_memory_allocated() / 1024
            extra["max mem"] //= 1024

        if self.training_config.experiment_name:
            extra["experiment"] = self.training_config.experiment_name

        optimizer = self.get_optimizer(trainer)
        num_updates = self._get_num_updates_for_logging(trainer)
        current_iteration = self._get_iterations_for_logging(trainer)
        extra.update({
            "epoch":
            self._get_current_epoch_for_logging(trainer),
            "iterations":
            current_iteration,
            "num_updates":
            num_updates,
            "max_updates":
            trainer.max_steps,
            "lr":
            "{:.5f}".format(optimizer.param_groups[0]["lr"]).rstrip("0"),
            "ups":
            "{:.2f}".format(self.trainer_config.log_every_n_steps /
                            self.train_timer.unix_time_since_start()),
            "time":
            self.train_timer.get_time_since_start(),
            "time_since_start":
            self.total_timer.get_time_since_start(),
            "eta":
            calculate_time_left(
                max_updates=trainer.max_steps,
                num_updates=num_updates,
                timer=self.train_timer,
                num_snapshot_iterations=self.snapshot_iterations,
                log_interval=self.trainer_config.log_every_n_steps,
                eval_interval=self.trainer_config.val_check_interval,
            ),
        })
        self.train_timer.reset()
        summarize_report(
            current_iteration=current_iteration,
            num_updates=num_updates,
            max_updates=trainer.max_steps,
            meter=pl_module.train_meter,
            extra=extra,
            tb_writer=self.lightning_trainer.tb_writer,
        )
Example #6
0
class LogisticsCallback(Callback):
    """Callback for handling train/validation logistics, report summarization,
    logging etc.
    """
    def __init__(self, config, trainer):
        """
        Attr:
            config(mmf_typings.DictConfig): Config for the callback
            trainer(Type[BaseTrainer]): Trainer object
        """
        super().__init__(config, trainer)

        self.total_timer = Timer()
        self.log_interval = self.training_config.log_interval
        self.evaluation_interval = self.training_config.evaluation_interval
        self.checkpoint_interval = self.training_config.checkpoint_interval

        # Total iterations for snapshot
        # len would be number of batches per GPU == max updates
        self.snapshot_iterations = len(self.trainer.val_loader)

        self.tb_writer = None

        if self.training_config.tensorboard:
            log_dir = setup_output_folder(folder_only=True)
            env_tb_logdir = get_mmf_env(key="tensorboard_logdir")
            if env_tb_logdir:
                log_dir = env_tb_logdir

            self.tb_writer = TensorboardLogger(log_dir,
                                               self.trainer.current_iteration)

    def on_train_start(self):
        self.train_timer = Timer()
        self.snapshot_timer = Timer()

    def on_update_end(self, **kwargs):
        if not kwargs["should_log"]:
            return
        extra = {}
        if "cuda" in str(self.trainer.device):
            extra["max mem"] = torch.cuda.max_memory_allocated() / 1024
            extra["max mem"] //= 1024

        if self.training_config.experiment_name:
            extra["experiment"] = self.training_config.experiment_name

        extra.update({
            "epoch":
            self.trainer.current_epoch,
            "num_updates":
            self.trainer.num_updates,
            "iterations":
            self.trainer.current_iteration,
            "max_updates":
            self.trainer.max_updates,
            "lr":
            "{:.5f}".format(
                self.trainer.optimizer.param_groups[0]["lr"]).rstrip("0"),
            "ups":
            "{:.2f}".format(self.log_interval /
                            self.train_timer.unix_time_since_start()),
            "time":
            self.train_timer.get_time_since_start(),
            "time_since_start":
            self.total_timer.get_time_since_start(),
            "eta":
            calculate_time_left(
                max_updates=self.trainer.max_updates,
                num_updates=self.trainer.num_updates,
                timer=self.train_timer,
                num_snapshot_iterations=self.snapshot_iterations,
                log_interval=self.log_interval,
                eval_interval=self.evaluation_interval,
            ),
        })
        self.train_timer.reset()
        summarize_report(
            current_iteration=self.trainer.current_iteration,
            num_updates=self.trainer.num_updates,
            max_updates=self.trainer.max_updates,
            meter=kwargs["meter"],
            extra=extra,
            tb_writer=self.tb_writer,
        )

    def on_validation_start(self, **kwargs):
        self.snapshot_timer.reset()

    def on_validation_end(self, **kwargs):
        extra = {
            "num_updates": self.trainer.num_updates,
            "epoch": self.trainer.current_epoch,
            "iterations": self.trainer.current_iteration,
            "max_updates": self.trainer.max_updates,
            "val_time": self.snapshot_timer.get_time_since_start(),
        }
        extra.update(
            self.trainer.early_stop_callback.early_stopping.get_info())
        self.train_timer.reset()
        summarize_report(
            current_iteration=self.trainer.current_iteration,
            num_updates=self.trainer.num_updates,
            max_updates=self.trainer.max_updates,
            meter=kwargs["meter"],
            extra=extra,
            tb_writer=self.tb_writer,
        )

    def on_test_end(self, **kwargs):
        prefix = "{}: full {}".format(kwargs["report"].dataset_name,
                                      kwargs["report"].dataset_type)
        summarize_report(
            current_iteration=self.trainer.current_iteration,
            num_updates=self.trainer.num_updates,
            max_updates=self.trainer.max_updates,
            meter=kwargs["meter"],
            should_print=prefix,
            tb_writer=self.tb_writer,
        )
        logger.info(
            f"Finished run in {self.total_timer.get_time_since_start()}")