Exemplo n.º 1
0
    def test_reset(self):
        timer = Timer()
        time.sleep(2)
        timer.reset()
        expected = "000ms"

        self.assertEqual(timer.get_current(), expected)
Exemplo n.º 2
0
    def test_reset(self):
        timer = Timer()
        time.sleep(2)
        timer.reset()
        expected = 0

        self.assertEqual(int(timer.get_current().split("ms")[0]), expected)
Exemplo n.º 3
0
 def __init__(self, configuration):
     self.configuration = configuration
     self.config = self.configuration.get_config()
     self.profiler = Timer()
     self.total_timer = Timer()
     if self.configuration is not None:
         self.args = self.configuration.args
Exemplo n.º 4
0
    def test_get_time_since_start(self):
        timer = Timer()
        time.sleep(2)
        expected = 2

        self.assertEqual(expected,
                         int(timer.get_time_since_start().split("s")[0]))
Exemplo n.º 5
0
    def __init__(self, config, trainer):
        """
        Attr:
            config(mmf_typings.DictConfig): Config for the callback
            trainer(Type[BaseTrainer]): Trainer object
        """
        super().__init__(config, trainer)

        self.total_timer = Timer()
        self.log_interval = self.training_config.log_interval
        self.evaluation_interval = self.training_config.evaluation_interval
        self.checkpoint_interval = self.training_config.checkpoint_interval

        # Total iterations for snapshot
        self.snapshot_iterations = len(self.trainer.val_dataset)
        self.snapshot_iterations //= self.training_config.batch_size

        self.tb_writer = None

        if self.training_config.tensorboard:
            log_dir = setup_output_folder(folder_only=True)
            env_tb_logdir = get_mmf_env(key="tensorboard_logdir")
            if env_tb_logdir:
                log_dir = env_tb_logdir

            self.tb_writer = TensorboardLogger(log_dir, self.trainer.current_iteration)
Exemplo n.º 6
0
    def __init__(self, multi_task_instance):
        self.test_task = multi_task_instance
        self.task_type = multi_task_instance.dataset_type
        self.config = registry.get("config")
        self.writer = registry.get("writer")
        self.report = []
        self.timer = Timer()
        self.training_config = self.config.training
        self.num_workers = self.training_config.num_workers
        self.batch_size = self.training_config.batch_size
        self.report_folder_arg = get_mmf_env(key="report_dir")
        self.experiment_name = self.training_config.experiment_name

        self.datasets = []

        for dataset in self.test_task.get_datasets():
            self.datasets.append(dataset)

        self.current_dataset_idx = -1
        self.current_dataset = self.datasets[self.current_dataset_idx]

        self.save_dir = get_mmf_env(key="save_dir")
        self.report_folder = ckpt_name_from_core_args(self.config)
        self.report_folder += foldername_from_config_override(self.config)

        self.report_folder = os.path.join(self.save_dir, self.report_folder)
        self.report_folder = os.path.join(self.report_folder, "reports")

        if self.report_folder_arg:
            self.report_folder = self.report_folder_arg

        PathManager.mkdirs(self.report_folder)
Exemplo n.º 7
0
    def train(self):
        self.writer.write("===== Model =====")
        self.writer.write(self.model)

        print_model_parameters(self.model)

        if "train" not in self.run_type:
            self.inference()
            return

        should_break = False

        if self.max_epochs is None:
            self.max_epochs = math.inf
        else:
            self.max_updates = math.inf

        self.model.train()
        self.train_timer = Timer()
        self.snapshot_timer = Timer()

        self.profile("Setup Time")

        torch.autograd.set_detect_anomaly(True)
        self.writer.write("Starting training...")

        while self.num_updates < self.max_updates and not should_break:
            self.current_epoch += 1
            registry.register("current_epoch", self.current_epoch)

            # Seed the sampler in case if it is distributed
            self.dataset_loader.seed_sampler("train", self.current_epoch)

            if self.current_epoch > self.max_epochs:
                break

            for batch in self.train_loader:
                self.profile("Batch load time")
                self.current_iteration += 1
                self.writer.write(self.num_updates + 1, "debug")

                report = self._forward_pass(batch)
                loss = self._extract_loss(report)
                self._backward(loss)
                should_break = self._logistics(report)

                if self.num_updates > self.max_updates:
                    should_break = True

                if should_break:
                    break

            # In distributed, each worker will complete one epoch when we reach this
            # as each worker is an individual instance
            self.current_epoch += get_world_size() - 1
        self.finalize()
Exemplo n.º 8
0
 def __init__(self, log_folder="./logs", iteration=0):
     self._summary_writer = None
     self._is_main = is_main()
     self.timer = Timer()
     self.log_folder = log_folder
     self.time_format = "%Y-%m-%dT%H:%M:%S"
     current_time = self.timer.get_time_hhmmss(None,
                                               format=self.time_format)
     self.tensorboard_folder = os.path.join(self.log_folder,
                                            f"tensorboard_{current_time}")
Exemplo n.º 9
0
def setup_output_folder(folder_only: bool = False):
    """Sets up and returns the output file where the logs will be placed
    based on the configuration passed. Usually "save_dir/logs/log_<timestamp>.txt".
    If env.log_dir is passed, logs will be directly saved in this folder.

    Args:
        folder_only (bool, optional): If folder should be returned and not the file.
            Defaults to False.

    Returns:
        str: folder or file path depending on folder_only flag
    """
    save_dir = get_mmf_env(key="save_dir")
    time_format = "%Y_%m_%dT%H_%M_%S"
    log_filename = "train_"
    log_filename += Timer().get_time_hhmmss(None, format=time_format)
    log_filename += ".log"

    log_folder = os.path.join(save_dir, "logs")

    env_log_dir = get_mmf_env(key="log_dir")
    if env_log_dir:
        log_folder = env_log_dir

    if not PathManager.exists(log_folder):
        PathManager.mkdirs(log_folder)

    if folder_only:
        return log_folder

    log_filename = os.path.join(log_folder, log_filename)

    return log_filename
Exemplo n.º 10
0
    def __init__(self, log_folder="./logs", iteration=0):
        # This would handle warning of missing tensorboard
        from torch.utils.tensorboard import SummaryWriter

        self.summary_writer = None
        self._is_master = is_master()
        self.timer = Timer()
        self.log_folder = log_folder
        self.time_format = "%Y-%m-%dT%H:%M:%S"

        if self._is_master:
            current_time = self.timer.get_time_hhmmss(None,
                                                      format=self.time_format)
            tensorboard_folder = os.path.join(self.log_folder,
                                              f"tensorboard_{current_time}")
            self.summary_writer = SummaryWriter(tensorboard_folder)
Exemplo n.º 11
0
class TrainerProfilingMixin(ABC):
    profiler: Type[Timer] = Timer()

    def profile(self, text: str) -> None:
        if self.training_config.logger_level != "debug":
            return
        logging.debug(f"{text}: {self.profiler.get_time_since_start()}")
        self.profiler.reset()
Exemplo n.º 12
0
    def test_tensorboard_logging_parity(
        self,
        summary_writer,
        mmf,
        lightning,
        logistics,
        logistics_logs,
        report_logs,
        trainer_logs,
        mkdirs,
    ):
        # mmf trainer
        mmf_trainer = get_mmf_trainer(
            max_updates=8,
            batch_size=2,
            max_epochs=None,
            log_interval=3,
            tensorboard=True,
        )

        def _add_scalars_mmf(log_dict, iteration):
            self.mmf_tensorboard_logs.append({iteration: log_dict})

        mmf_trainer.load_metrics()
        logistics_callback = LogisticsCallback(mmf_trainer.config, mmf_trainer)
        logistics_callback.snapshot_timer = MagicMock(return_value=None)
        logistics_callback.train_timer = Timer()
        logistics_callback.tb_writer.add_scalars = _add_scalars_mmf
        mmf_trainer.logistics_callback = logistics_callback
        mmf_trainer.callbacks = [logistics_callback]
        mmf_trainer.early_stop_callback = MagicMock(return_value=None)
        mmf_trainer.on_update_end = logistics_callback.on_update_end
        mmf_trainer.training_loop()

        # lightning_trainer
        trainer = get_lightning_trainer(
            max_steps=8,
            batch_size=2,
            prepare_trainer=False,
            log_every_n_steps=3,
            val_check_interval=9,
            tensorboard=True,
        )

        def _add_scalars_lightning(log_dict, iteration):
            self.lightning_tensorboard_logs.append({iteration: log_dict})

        def _on_fit_start_callback():
            trainer.tb_writer.add_scalars = _add_scalars_lightning

        callback = LightningLoopCallback(trainer)
        run_lightning_trainer_with_callback(
            trainer, callback, on_fit_start_callback=_on_fit_start_callback)
        self.assertEqual(len(self.mmf_tensorboard_logs),
                         len(self.lightning_tensorboard_logs))
        for mmf, lightning in zip(self.mmf_tensorboard_logs,
                                  self.lightning_tensorboard_logs):
            self.assertDictEqual(mmf, lightning)
Exemplo n.º 13
0
class TrainerProfilingMixin(ABC):
    profiler: Type[Timer] = Timer()

    def profile(self, text: str) -> None:
        if self.training_config.logger_level != "debug":
            return
        self.writer.write(text + ": " + self.profiler.get_time_since_start(),
                          "debug")
        self.profiler.reset()
Exemplo n.º 14
0
    def __init__(self,
                 multi_task_instance,
                 test_reporter_config: TestReporterConfigType = None):
        if not isinstance(test_reporter_config,
                          TestReporter.TestReporterConfigType):
            test_reporter_config = TestReporter.TestReporterConfigType(
                **test_reporter_config)
        self.test_task = multi_task_instance
        self.task_type = multi_task_instance.dataset_type
        self.config = registry.get("config")
        self.report = []
        self.timer = Timer()
        self.training_config = self.config.training
        self.num_workers = self.training_config.num_workers
        self.batch_size = self.training_config.batch_size
        self.report_folder_arg = get_mmf_env(key="report_dir")
        self.experiment_name = self.training_config.experiment_name
        self.test_reporter_config = test_reporter_config

        self.datasets = []

        for dataset in self.test_task.get_datasets():
            self.datasets.append(dataset)

        self.current_dataset_idx = -1
        self.current_dataset = self.datasets[self.current_dataset_idx]

        self.save_dir = get_mmf_env(key="save_dir")
        self.report_folder = ckpt_name_from_core_args(self.config)
        self.report_folder += foldername_from_config_override(self.config)

        self.report_folder = os.path.join(self.save_dir, self.report_folder)
        self.report_folder = os.path.join(self.report_folder, "reports")

        if self.report_folder_arg:
            self.report_folder = self.report_folder_arg

        self.candidate_fields = DEFAULT_CANDIDATE_FIELDS

        if not test_reporter_config.candidate_fields == MISSING:
            self.candidate_fields = test_reporter_config.candidate_fields

        PathManager.mkdirs(self.report_folder)
Exemplo n.º 15
0
    def __init__(
        self,
        datamodules: List[pl.LightningDataModule],
        config: Config = None,
        dataset_type: str = "train",
    ):
        self.test_reporter_config = OmegaConf.merge(
            OmegaConf.structured(self.Config), config
        )
        self.datamodules = datamodules
        self.dataset_type = dataset_type
        self.config = registry.get("config")
        self.report = []
        self.timer = Timer()
        self.training_config = self.config.training
        self.num_workers = self.training_config.num_workers
        self.batch_size = self.training_config.batch_size
        self.report_folder_arg = get_mmf_env(key="report_dir")
        self.experiment_name = self.training_config.experiment_name

        self.current_datamodule_idx = -1
        self.dataset_names = list(self.datamodules.keys())
        self.current_datamodule = self.datamodules[
            self.dataset_names[self.current_datamodule_idx]
        ]
        self.current_dataloader = None

        self.save_dir = get_mmf_env(key="save_dir")
        self.report_folder = ckpt_name_from_core_args(self.config)
        self.report_folder += foldername_from_config_override(self.config)

        self.report_folder = os.path.join(self.save_dir, self.report_folder)
        self.report_folder = os.path.join(self.report_folder, "reports")

        if self.report_folder_arg:
            self.report_folder = self.report_folder_arg

        self.candidate_fields = self.test_reporter_config.candidate_fields

        PathManager.mkdirs(self.report_folder)

        log_class_usage("TestReporter", self.__class__)
Exemplo n.º 16
0
    def test_validation_parity(self, summarize_report_fn, test_reporter, sw,
                               mkdirs):
        mmf_trainer = get_mmf_trainer(max_updates=8,
                                      batch_size=2,
                                      max_epochs=None,
                                      evaluation_interval=3)
        mmf_trainer.load_metrics()
        logistics_callback = LogisticsCallback(mmf_trainer.config, mmf_trainer)
        logistics_callback.snapshot_timer = Timer()
        logistics_callback.train_timer = Timer()
        mmf_trainer.logistics_callback = logistics_callback
        mmf_trainer.callbacks.append(logistics_callback)
        mmf_trainer.early_stop_callback = MagicMock(return_value=None)
        mmf_trainer.on_validation_end = logistics_callback.on_validation_end
        mmf_trainer.training_loop()

        calls = summarize_report_fn.call_args_list
        self.assertEqual(3, len(calls))
        self.assertEqual(len(self.ground_truths), len(calls))
        self._check_values(calls)
Exemplo n.º 17
0
    def __init__(self, config, trainer):
        """
        Attr:
            config(mmf_typings.DictConfig): Config for the callback
            trainer(Type[BaseTrainer]): Trainer object
        """
        super().__init__(config, trainer)

        self.total_timer = Timer()
        self.log_interval = self.training_config.log_interval
        self.evaluation_interval = self.training_config.evaluation_interval
        self.checkpoint_interval = self.training_config.checkpoint_interval

        # Total iterations for snapshot
        # len would be number of batches per GPU == max updates
        self.snapshot_iterations = len(self.trainer.val_loader)

        self.tb_writer = None

        self.wandb_logger = None

        if self.training_config.tensorboard:
            log_dir = setup_output_folder(folder_only=True)
            env_tb_logdir = get_mmf_env(key="tensorboard_logdir")
            if env_tb_logdir:
                log_dir = env_tb_logdir

            self.tb_writer = TensorboardLogger(log_dir, self.trainer.current_iteration)

        if self.training_config.wandb.enabled:
            log_dir = setup_output_folder(folder_only=True)

            env_wandb_logdir = get_mmf_env(key="wandb_logdir")
            if env_wandb_logdir:
                log_dir = env_wandb_logdir

            self.wandb_logger = WandbLogger(
                entity=config.training.wandb.entity,
                config=config,
                project=config.training.wandb.project,
            )
Exemplo n.º 18
0
Arquivo: logger.py Projeto: naykun/mmf
class TensorboardLogger:
    def __init__(self, log_folder="./logs", iteration=0):
        # This would handle warning of missing tensorboard
        from torch.utils.tensorboard import SummaryWriter

        self.summary_writer = None
        self._is_master = is_master()
        self.timer = Timer()
        self.log_folder = log_folder
        self.time_format = "%Y-%m-%dT%H:%M:%S"

        if self._is_master:
            current_time = self.timer.get_time_hhmmss(None, format=self.time_format)
            tensorboard_folder = os.path.join(
                self.log_folder, f"tensorboard_{current_time}"
            )
            self.summary_writer = SummaryWriter(tensorboard_folder)

    def __del__(self):
        if getattr(self, "summary_writer", None) is not None:
            self.summary_writer.close()

    def _should_log_tensorboard(self):
        if self.summary_writer is None or not self._is_master:
            return False
        else:
            return True

    def add_scalar(self, key, value, iteration):
        if not self._should_log_tensorboard():
            return

        self.summary_writer.add_scalar(key, value, iteration)

    def add_scalars(self, scalar_dict, iteration):
        if not self._should_log_tensorboard():
            return

        for key, val in scalar_dict.items():
            self.summary_writer.add_scalar(key, val, iteration)

    def add_histogram_for_model(self, model, iteration):
        if not self._should_log_tensorboard():
            return

        for name, param in model.named_parameters():
            np_param = param.clone().cpu().data.numpy()
            self.summary_writer.add_histogram(name, np_param, iteration)

    def flush(self):
        if self._should_log_tensorboard():
            self.summary_writer.flush()
Exemplo n.º 19
0
class TensorboardLogger:
    def __init__(self, log_folder="./logs", iteration=0):
        self._summary_writer = None
        self._is_main = is_main()
        self.timer = Timer()
        self.log_folder = log_folder
        self.time_format = "%Y-%m-%dT%H:%M:%S"
        current_time = self.timer.get_time_hhmmss(None,
                                                  format=self.time_format)
        self.tensorboard_folder = os.path.join(self.log_folder,
                                               f"tensorboard_{current_time}")

    @property
    def summary_writer(self):
        # Only on rank zero
        if not self._is_main:
            return None

        if self._summary_writer is None:
            # This would handle warning of missing tensorboard
            from torch.utils.tensorboard import SummaryWriter

            self._summary_writer = SummaryWriter(self.tensorboard_folder)

        return self._summary_writer

    @skip_if_tensorboard_inactive
    def close(self):
        """
        Closes the tensorboard summary writer.
        """
        self.summary_writer.close()

    @skip_if_tensorboard_inactive
    def add_scalar(self, key, value, iteration):
        self.summary_writer.add_scalar(key, value, iteration)

    @skip_if_tensorboard_inactive
    def add_scalars(self, scalar_dict, iteration):
        for key, val in scalar_dict.items():
            self.summary_writer.add_scalar(key, val, iteration)

    @skip_if_tensorboard_inactive
    def add_histogram_for_model(self, model, iteration):
        for name, param in model.named_parameters():
            np_param = param.clone().cpu().data.numpy()
            self.summary_writer.add_histogram(name, np_param, iteration)
Exemplo n.º 20
0
    def __init__(self, lightning_trainer: Any):
        super().__init__()
        self.lightning_trainer = lightning_trainer
        # this is lightning trainer's config
        self.trainer_config = lightning_trainer.trainer_config
        # training config configures training parameters.
        self.training_config = lightning_trainer.training_config
        self.run_type = lightning_trainer.run_type

        # for logging
        self.total_timer = Timer()
        self.snapshot_timer = Timer()
        self.snapshot_iterations = len(self.lightning_trainer.val_loader)
        self.train_timer = Timer()
Exemplo n.º 21
0
class LogisticsCallback(Callback):
    """Callback for handling train/validation logistics, report summarization,
    logging etc.
    """

    def __init__(self, config, trainer):
        """
        Attr:
            config(mmf_typings.DictConfig): Config for the callback
            trainer(Type[BaseTrainer]): Trainer object
        """
        super().__init__(config, trainer)

        self.total_timer = Timer()
        self.log_interval = self.training_config.log_interval
        self.evaluation_interval = self.training_config.evaluation_interval
        self.checkpoint_interval = self.training_config.checkpoint_interval

        # Total iterations for snapshot
        self.snapshot_iterations = len(self.trainer.val_dataset)
        self.snapshot_iterations //= self.training_config.batch_size

        self.tb_writer = None

        if self.training_config.tensorboard:
            log_dir = setup_output_folder(folder_only=True)
            env_tb_logdir = get_mmf_env(key="tensorboard_logdir")
            if env_tb_logdir:
                log_dir = env_tb_logdir

            self.tb_writer = TensorboardLogger(log_dir, self.trainer.current_iteration)

    def on_train_start(self):
        self.train_timer = Timer()
        self.snapshot_timer = Timer()

    def on_update_end(self, **kwargs):
        if not kwargs["should_log"]:
            return
        extra = {}
        if "cuda" in str(self.trainer.device):
            extra["max mem"] = torch.cuda.max_memory_allocated() / 1024
            extra["max mem"] //= 1024

        if self.training_config.experiment_name:
            extra["experiment"] = self.training_config.experiment_name

        extra.update(
            {
                "epoch": self.trainer.current_epoch,
                "num_updates": self.trainer.num_updates,
                "iterations": self.trainer.current_iteration,
                "max_updates": self.trainer.max_updates,
                "lr": "{:.5f}".format(
                    self.trainer.optimizer.param_groups[0]["lr"]
                ).rstrip("0"),
                "ups": "{:.2f}".format(
                    self.log_interval / self.train_timer.unix_time_since_start()
                ),
                "time": self.train_timer.get_time_since_start(),
                "time_since_start": self.total_timer.get_time_since_start(),
                "eta": self._calculate_time_left(),
            }
        )
        self.train_timer.reset()
        self._summarize_report(kwargs["meter"], extra=extra)

    def on_validation_start(self, **kwargs):
        self.snapshot_timer.reset()

    def on_validation_end(self, **kwargs):
        extra = {
            "num_updates": self.trainer.num_updates,
            "epoch": self.trainer.current_epoch,
            "iterations": self.trainer.current_iteration,
            "max_updates": self.trainer.max_updates,
            "val_time": self.snapshot_timer.get_time_since_start(),
        }
        extra.update(self.trainer.early_stop_callback.early_stopping.get_info())
        self.train_timer.reset()
        self._summarize_report(kwargs["meter"], extra=extra)

    def on_test_end(self, **kwargs):
        prefix = "{}: full {}".format(
            kwargs["report"].dataset_name, kwargs["report"].dataset_type
        )
        self._summarize_report(kwargs["meter"], prefix)
        logger.info(f"Finished run in {self.total_timer.get_time_since_start()}")

    def _summarize_report(self, meter, should_print=True, extra=None):
        if extra is None:
            extra = {}
        if not is_master() and not is_xla():
            return

        if self.training_config.tensorboard:
            scalar_dict = meter.get_scalar_dict()
            self.tb_writer.add_scalars(scalar_dict, self.trainer.current_iteration)

        if not should_print:
            return
        log_dict = {}
        if hasattr(self.trainer, "num_updates") and hasattr(
            self.trainer, "max_updates"
        ):
            log_dict.update(
                {"progress": f"{self.trainer.num_updates}/{self.trainer.max_updates}"}
            )
        log_dict.update(meter.get_log_dict())
        log_dict.update(extra)

        log_progress(log_dict)

    def _calculate_time_left(self):
        time_taken_for_log = time.time() * 1000 - self.train_timer.start
        iterations_left = self.trainer.max_updates - self.trainer.num_updates
        num_logs_left = iterations_left / self.log_interval
        time_left = num_logs_left * time_taken_for_log

        snapshot_iteration = self.snapshot_iterations / self.log_interval
        snapshot_iteration *= iterations_left / self.evaluation_interval
        time_left += snapshot_iteration * time_taken_for_log

        return self.train_timer.get_time_hhmmss(gap=time_left)
Exemplo n.º 22
0
class BaseTrainer:
    def __init__(self, configuration):
        self.configuration = configuration
        self.config = self.configuration.get_config()
        self.profiler = Timer()
        self.total_timer = Timer()
        if self.configuration is not None:
            self.args = self.configuration.args

    def load(self):
        self._set_device()

        self.run_type = self.config.get("run_type", "train")
        self.dataset_loader = DatasetLoader(self.config)
        self._datasets = self.config.datasets

        # Check if loader is already defined, else init it
        writer = registry.get("writer", no_warning=True)
        if writer:
            self.writer = writer
        else:
            self.writer = Logger(self.config)
            registry.register("writer", self.writer)

        self.configuration.pretty_print()

        self.config_based_setup()

        self.load_datasets()
        self.load_model_and_optimizer()
        self.load_metrics()

    def _set_device(self):
        self.local_rank = self.config.device_id
        self.device = self.local_rank
        self.distributed = False

        # Will be updated later based on distributed setup
        registry.register("global_device", self.device)

        if self.config.distributed.init_method is not None:
            self.distributed = True
            self.device = torch.device("cuda", self.local_rank)
        elif torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")

        registry.register("current_device", self.device)

    def load_datasets(self):
        self.writer.write("Loading datasets", "info")
        self.dataset_loader.load_datasets()

        self.train_dataset = self.dataset_loader.train_dataset
        self.val_dataset = self.dataset_loader.val_dataset

        # Total iterations for snapshot
        self.snapshot_iterations = len(self.val_dataset)
        self.snapshot_iterations //= self.config.training.batch_size

        self.test_dataset = self.dataset_loader.test_dataset

        self.train_loader = self.dataset_loader.train_loader
        self.val_loader = self.dataset_loader.val_loader
        self.test_loader = self.dataset_loader.test_loader

    def load_metrics(self):
        metrics = self.config.evaluation.get("metrics", [])
        self.metrics = Metrics(metrics)
        self.metrics_params = self.metrics.required_params

    def load_model_and_optimizer(self):
        attributes = self.config.model_config[self.config.model]
        # Easy way to point to config for other model
        if isinstance(attributes, str):
            attributes = self.config.model_config[attributes]

        with omegaconf.open_dict(attributes):
            attributes.model = self.config.model

        self.model = build_model(attributes)

        if "cuda" in str(self.device):
            device_info = "CUDA Device {} is: {}".format(
                self.config.distributed.rank,
                torch.cuda.get_device_name(self.local_rank),
            )
            registry.register("global_device", self.config.distributed.rank)
            self.writer.write(device_info, log_all=True)

        self.model = self.model.to(self.device)
        self.optimizer = build_optimizer(self.model, self.config)

        registry.register("data_parallel", False)
        registry.register("distributed", False)

        self.load_extras()
        self.parallelize_model()

    def parallelize_model(self):
        training = self.config.training
        if ("cuda" in str(self.device) and torch.cuda.device_count() > 1
                and not self.distributed):
            registry.register("data_parallel", True)
            self.model = torch.nn.DataParallel(self.model)

        if "cuda" in str(self.device) and self.distributed:
            registry.register("distributed", True)
            self.model = torch.nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[self.local_rank],
                output_device=self.local_rank,
                check_reduction=True,
                find_unused_parameters=training.find_unused_parameters,
            )

    def load_extras(self):
        self.writer.write("Torch version is: " + torch.__version__)
        self.checkpoint = Checkpoint(self)
        self.meter = Meter()

        self.training_config = self.config.training

        early_stop_criteria = self.training_config.early_stop.criteria
        early_stop_minimize = self.training_config.early_stop.minimize
        early_stop_enabled = self.training_config.early_stop.enabled
        early_stop_patience = self.training_config.early_stop.patience

        self.log_interval = self.training_config.log_interval
        self.evaluation_interval = self.training_config.evaluation_interval
        self.checkpoint_interval = self.training_config.checkpoint_interval
        self.max_updates = self.training_config.max_updates
        self.should_clip_gradients = self.training_config.clip_gradients
        self.max_epochs = self.training_config.max_epochs

        self.early_stopping = EarlyStopping(
            self.model,
            self.checkpoint,
            early_stop_criteria,
            patience=early_stop_patience,
            minimize=early_stop_minimize,
            should_stop=early_stop_enabled,
        )
        self.current_epoch = 0
        self.current_iteration = 0
        self.num_updates = 0

        self.checkpoint.load_state_dict()

        self.not_debug = self.training_config.logger_level != "debug"

        self.lr_scheduler = None

        if self.training_config.lr_scheduler is True:
            self.lr_scheduler = build_scheduler(self.optimizer, self.config)

        self.tb_writer = None

        if self.training_config.tensorboard:
            log_dir = self.writer.log_dir
            env_tb_logdir = get_mmf_env(key="tensorboard_logdir")
            if env_tb_logdir:
                log_dir = env_tb_logdir

            self.tb_writer = TensorboardLogger(log_dir, self.current_iteration)

    def config_based_setup(self):
        seed = self.config.training.seed
        if seed is None:
            return

        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    def train(self):
        self.writer.write("===== Model =====")
        self.writer.write(self.model)

        print_model_parameters(self.model)

        if "train" not in self.run_type:
            self.inference()
            return

        should_break = False

        if self.max_epochs is None:
            self.max_epochs = math.inf
        else:
            self.max_updates = math.inf

        self.model.train()
        self.train_timer = Timer()
        self.snapshot_timer = Timer()

        self.profile("Setup Time")

        torch.autograd.set_detect_anomaly(True)
        self.writer.write("Starting training...")

        while self.num_updates < self.max_updates and not should_break:
            self.current_epoch += 1
            registry.register("current_epoch", self.current_epoch)

            # Seed the sampler in case if it is distributed
            self.dataset_loader.seed_sampler("train", self.current_epoch)

            if self.current_epoch > self.max_epochs:
                break

            for batch in self.train_loader:
                self.profile("Batch load time")
                self.current_iteration += 1
                self.writer.write(self.num_updates + 1, "debug")

                report = self._forward_pass(batch)
                loss = self._extract_loss(report)
                self._backward(loss)
                should_break = self._logistics(report)

                if self.num_updates > self.max_updates:
                    should_break = True

                if should_break:
                    break

            # In distributed, each worker will complete one epoch when we reach this
            # as each worker is an individual instance
            self.current_epoch += get_world_size() - 1
        self.finalize()

    def _run_scheduler(self):
        if self.lr_scheduler is not None:
            self.lr_scheduler.step(self.num_updates)

    def _forward_pass(self, batch):
        prepared_batch = self.dataset_loader.prepare_batch(batch)
        self.profile("Batch prepare time")
        # Arguments should be a dict at this point
        model_output = self.model(prepared_batch)
        report = Report(prepared_batch, model_output)
        self.profile("Forward time")

        return report

    def _backward(self, loss):
        self.optimizer.zero_grad()
        loss.backward()

        if self.should_clip_gradients:
            clip_gradients(self.model, self.num_updates, self.tb_writer,
                           self.config)

        self.optimizer.step()
        self._run_scheduler()
        self.num_updates += 1
        self.profile("Backward time")

    def _extract_loss(self, report):
        loss_dict = report.losses
        loss = sum([loss.mean() for loss in loss_dict.values()])
        return loss

    def finalize(self):
        self.writer.write("Stepping into final validation check")

        # Only do when run_type has train as it shouldn't happen on validation and
        # inference runs. Inference will take care of this anyways. Also, don't run
        # if current iteration is divisble by snapshot interval as it will just
        # be a repeat
        if ("train" in self.run_type
                and self.num_updates % self.evaluation_interval != 0):
            self._try_full_validation(force=True)

        self.checkpoint.restore()
        self.checkpoint.finalize()
        self.inference()

        self.writer.write(
            f"Finished run in {self.total_timer.get_time_since_start()}")

    def _update_meter(self, report, meter=None, eval_mode=False):
        if meter is None:
            meter = self.meter

        if hasattr(report, "metrics"):
            metrics_dict = report.metrics
            reduced_metrics_dict = reduce_dict(metrics_dict)

        if not eval_mode:
            loss_dict = report.losses
            reduced_loss_dict = reduce_dict(loss_dict)

        with torch.no_grad():
            # Add metrics to meter only when mode is `eval`
            meter_update_dict = {}
            if not eval_mode:
                loss_key = report.dataset_type + "/total_loss"
                reduced_loss = sum(
                    [loss.mean() for loss in reduced_loss_dict.values()])
                if hasattr(reduced_loss, "item"):
                    reduced_loss = reduced_loss.item()

                registry.register(loss_key, reduced_loss)
                meter_update_dict.update({loss_key: reduced_loss})
                meter_update_dict.update(reduced_loss_dict)
            if hasattr(report, "metrics"):
                meter_update_dict.update(reduced_metrics_dict)
            meter.update(meter_update_dict, report.batch_size)

    def _logistics(self, report):
        registry.register("current_iteration", self.current_iteration)
        registry.register("num_updates", self.num_updates)

        should_print = self.num_updates % self.log_interval == 0
        should_break = False
        extra = {}

        if should_print is True:
            if "cuda" in str(self.device):
                extra["max mem"] = torch.cuda.max_memory_allocated() / 1024
                extra["max mem"] //= 1024

            if self.training_config.experiment_name:
                extra["experiment"] = self.training_config.experiment_name

            extra.update({
                "epoch":
                self.current_epoch,
                "num_updates":
                self.num_updates,
                "iterations":
                self.current_iteration,
                "max_updates":
                self.max_updates,
                "lr":
                "{:.5f}".format(
                    self.optimizer.param_groups[0]["lr"]).rstrip("0"),
                "ups":
                "{:.2f}".format(self.log_interval /
                                self.train_timer.unix_time_since_start()),
                "time":
                self.train_timer.get_time_since_start(),
                "time_since_start":
                self.total_timer.get_time_since_start(),
                "eta":
                self._calculate_time_left(),
            })

            self.train_timer.reset()
            # Calculate metrics every log interval for debugging
            if self.training_config.evaluate_metrics:
                report.metrics = self.metrics(report, report)
            self._update_meter(report, self.meter)

            self._summarize_report(self.meter,
                                   should_print=should_print,
                                   extra=extra)

        self._try_snapshot()
        should_break = self._try_full_validation()

        return should_break

    def _try_snapshot(self):
        if self.num_updates % self.checkpoint_interval == 0:
            self.writer.write("Checkpoint time. Saving a checkpoint.")
            self.checkpoint.save(self.num_updates,
                                 self.current_iteration,
                                 update_best=False)

    def _try_full_validation(self, force=False):
        should_break = False

        if self.num_updates % self.evaluation_interval == 0 or force:
            self.snapshot_timer.reset()
            self.writer.write(
                "Evaluation time. Running on full validation set...")
            # Validation and Early stopping
            # Create a new meter for this case
            report, meter = self.evaluate(self.val_loader)

            extra = {
                "num_updates": self.num_updates,
                "epoch": self.current_epoch,
                "iterations": self.current_iteration,
                "max_updates": self.max_updates,
                "val_time": self.snapshot_timer.get_time_since_start(),
            }

            stop = self.early_stopping(self.num_updates,
                                       self.current_iteration, meter)
            stop = bool(broadcast_scalar(stop, src=0, device=self.device))

            extra.update(self.early_stopping.get_info())

            self._summarize_report(meter, extra=extra)
            gc.collect()

            if "cuda" in str(self.device):
                torch.cuda.empty_cache()

            if stop is True:
                self.writer.write("Early stopping activated")
                should_break = True

            self.train_timer.reset()

        return should_break

    def evaluate(self, loader, use_tqdm=False, single_batch=False):
        meter = Meter()

        with torch.no_grad():
            self.model.eval()
            disable_tqdm = not use_tqdm or not is_master()
            combined_report = None

            for batch in tqdm(loader, disable=disable_tqdm):
                report = self._forward_pass(batch)
                self._update_meter(report, meter)

                # accumulate necessary params for metric calculation
                if combined_report is None:
                    combined_report = report
                else:
                    combined_report.accumulate_tensor_fields(
                        report, self.metrics.required_params)
                    combined_report.batch_size += report.batch_size

                if single_batch is True:
                    break

            combined_report.metrics = self.metrics(combined_report,
                                                   combined_report)
            self._update_meter(combined_report, meter, eval_mode=True)

            self.model.train()

        return combined_report, meter

    def _summarize_report(self, meter, should_print=True, extra=None):
        if extra is None:
            extra = {}
        if not is_master():
            return

        if self.training_config.tensorboard:
            scalar_dict = meter.get_scalar_dict()
            self.tb_writer.add_scalars(scalar_dict, self.current_iteration)

        if not should_print:
            return
        log_dict = {"progress": f"{self.num_updates}/{self.max_updates}"}
        log_dict.update(meter.get_log_dict())
        log_dict.update(extra)

        self.writer.log_progress(log_dict)

    def inference(self):
        if "val" in self.run_type:
            self._inference_run("val")

        if any(rt in self.run_type for rt in ["inference", "test", "predict"]):
            self._inference_run("test")

    def _inference_run(self, dataset_type):
        if self.config.evaluation.predict:
            self.predict(dataset_type)
            return

        self.writer.write(f"Starting inference on {dataset_type} set")

        report, meter = self.evaluate(getattr(self, f"{dataset_type}_loader"),
                                      use_tqdm=True)
        prefix = f"{report.dataset_name}: full {dataset_type}"
        self._summarize_report(meter, prefix)

    def _calculate_time_left(self):
        time_taken_for_log = time.time() * 1000 - self.train_timer.start
        iterations_left = self.max_updates - self.num_updates
        num_logs_left = iterations_left / self.log_interval
        time_left = num_logs_left * time_taken_for_log

        snapshot_iteration = self.snapshot_iterations / self.log_interval
        snapshot_iteration *= iterations_left / self.evaluation_interval
        time_left += snapshot_iteration * time_taken_for_log

        return self.train_timer.get_time_hhmmss(gap=time_left)

    def profile(self, text):
        if self.not_debug:
            return
        self.writer.write(text + ": " + self.profiler.get_time_since_start(),
                          "debug")
        self.profiler.reset()

    def predict(self, dataset_type):
        reporter = self.dataset_loader.get_test_reporter(dataset_type)
        with torch.no_grad():
            self.model.eval()
            message = f"Starting {dataset_type} inference predictions"
            self.writer.write(message)

            while reporter.next_dataset():
                dataloader = reporter.get_dataloader()

                for batch in tqdm(dataloader):
                    prepared_batch = reporter.prepare_batch(batch)
                    model_output = self.model(prepared_batch)
                    report = Report(prepared_batch, model_output)
                    reporter.add_to_report(report, self.model)

            self.writer.write("Finished predicting")
            self.model.train()
Exemplo n.º 23
0
 def __init__(self, config):
     self.config = config
     self.profiler = Timer()
     self.total_timer = Timer()
Exemplo n.º 24
0
    def __init__(self, config, name=None):
        self.logger = None
        self._is_master = is_master()

        self.timer = Timer()
        self.config = config
        self.save_dir = get_mmf_env(key="save_dir")
        self.log_format = config.training.log_format
        self.time_format = "%Y-%m-%dT%H:%M:%S"
        self.log_filename = "train_"
        self.log_filename += self.timer.get_time_hhmmss(None, format=self.time_format)
        self.log_filename += ".log"

        self.log_folder = os.path.join(self.save_dir, "logs")

        env_log_dir = get_mmf_env(key="log_dir")
        if env_log_dir:
            self.log_folder = env_log_dir

        if not PathManager.exists(self.log_folder):
            PathManager.mkdirs(self.log_folder)

        self.log_filename = os.path.join(self.log_folder, self.log_filename)

        if not self._is_master:
            return
        if self._is_master:
            print("Logging to:", self.log_filename)

        logging.captureWarnings(True)

        if not name:
            name = __name__
        self.logger = logging.getLogger(name)
        self._file_only_logger = logging.getLogger(name)
        warnings_logger = logging.getLogger("py.warnings")

        # Set level
        level = config.training.logger_level
        self.logger.setLevel(getattr(logging, level.upper()))
        self._file_only_logger.setLevel(getattr(logging, level.upper()))

        formatter = logging.Formatter(
            "%(asctime)s %(levelname)s: %(message)s", datefmt="%Y-%m-%dT%H:%M:%S"
        )

        # Add handler to file
        channel = logging.FileHandler(filename=self.log_filename, mode="a")
        channel.setFormatter(formatter)

        self.logger.addHandler(channel)
        self._file_only_logger.addHandler(channel)
        warnings_logger.addHandler(channel)

        # Add handler to stdout
        channel = logging.StreamHandler(sys.stdout)
        channel.setFormatter(formatter)

        self.logger.addHandler(channel)
        warnings_logger.addHandler(channel)

        should_not_log = self.config.training.should_not_log
        self.should_log = not should_not_log

        # Single log wrapper map
        self._single_log_map = set()
Exemplo n.º 25
0
class TestReporter(Dataset):
    @dataclass
    class Config:
        # A set of fields to be *considered* for exporting by the reporter
        # Note that `format_for_prediction` is what ultimtly detemrimes the
        # exported fields
        candidate_fields: List[str] = field(
            default_factory=lambda: DEFAULT_CANDIDATE_FIELDS)
        # csv or json
        predict_file_format: str = "json"

    def __init__(
        self,
        datamodules: List[pl.LightningDataModule],
        config: Config = None,
        dataset_type: str = "train",
    ):
        self.test_reporter_config = OmegaConf.merge(
            OmegaConf.structured(self.Config), config)
        self.datamodules = datamodules
        self.dataset_type = dataset_type
        self.config = registry.get("config")
        self.report = []
        self.timer = Timer()
        self.training_config = self.config.training
        self.num_workers = self.training_config.num_workers
        self.batch_size = self.training_config.batch_size
        self.report_folder_arg = get_mmf_env(key="report_dir")
        self.experiment_name = self.training_config.experiment_name

        self.current_datamodule_idx = -1
        self.dataset_names = list(self.datamodules.keys())
        self.current_datamodule = self.datamodules[self.dataset_names[
            self.current_datamodule_idx]]
        self.current_dataloader = None

        self.save_dir = get_mmf_env(key="save_dir")
        self.report_folder = ckpt_name_from_core_args(self.config)
        self.report_folder += foldername_from_config_override(self.config)

        self.report_folder = os.path.join(self.save_dir, self.report_folder)
        self.report_folder = os.path.join(self.report_folder, "reports")

        if self.report_folder_arg:
            self.report_folder = self.report_folder_arg

        self.candidate_fields = self.test_reporter_config.candidate_fields

        PathManager.mkdirs(self.report_folder)

    @property
    def current_dataset(self):
        self._check_current_dataloader()
        return self.current_dataloader.dataset

    def next_dataset(self, flush_report=True):
        if self.current_datamodule_idx >= 0:
            if flush_report:
                self.flush_report()
            else:
                self.report = []

        self.current_datamodule_idx += 1

        if self.current_datamodule_idx == len(self.datamodules):
            return False
        else:
            self.current_datamodule = self.datamodules[self.dataset_names[
                self.current_datamodule_idx]]
            logger.info(
                f"Predicting for {self.dataset_names[self.current_datamodule_idx]}"
            )
            return True

    def flush_report(self):
        if not is_master():
            return

        name = self.current_datamodule.dataset_name
        time_format = "%Y-%m-%dT%H:%M:%S"
        time = self.timer.get_time_hhmmss(None, format=time_format)

        filename = name + "_"

        if len(self.experiment_name) > 0:
            filename += self.experiment_name + "_"

        filename += self.dataset_type + "_"
        filename += time

        use_csv_writer = (self.config.evaluation.predict_file_format == "csv"
                          or self.test_reporter_config.predict_file_format
                          == "csv")

        if use_csv_writer:
            filepath = os.path.join(self.report_folder, filename + ".csv")
            self.csv_dump(filepath)
        else:
            filepath = os.path.join(self.report_folder, filename + ".json")
            self.json_dump(filepath)

        logger.info(
            f"Wrote predictions for {name} to {os.path.abspath(filepath)}")
        self.report = []

    def postprocess_dataset_report(self):
        self._check_current_dataloader()
        if hasattr(self.current_dataset, "on_prediction_end"):
            self.report = self.current_dataset.on_prediction_end(self.report)

    def csv_dump(self, filepath):
        with PathManager.open(filepath, "w") as f:
            title = self.report[0].keys()
            cw = csv.DictWriter(f,
                                title,
                                delimiter=",",
                                quoting=csv.QUOTE_MINIMAL)
            cw.writeheader()
            cw.writerows(self.report)

    def json_dump(self, filepath):
        with PathManager.open(filepath, "w") as f:
            json.dump(self.report, f)

    def get_dataloader(self):
        self.current_dataloader = getattr(self.current_datamodule,
                                          f"{self.dataset_type}_dataloader")()
        # Make sure to assign dataset to dataloader object as
        # required by MMF
        if not hasattr(self.current_dataloader, "dataset"):
            self.current_dataloader.dataset = getattr(
                self.current_datamodule, f"{self.dataset_type}_dataset")
        return self.current_dataloader

    def prepare_batch(self, batch):
        self._check_current_dataloader()
        if hasattr(self.current_dataset, "prepare_batch"):
            batch = self.current_dataset.prepare_batch(batch)

        batch = convert_batch_to_sample_list(batch)
        batch.dataset_name = self.current_dataset.dataset_name
        batch.dataset_type = self.dataset_type
        return batch

    def __len__(self):
        self._check_current_dataloader()
        return len(self.current_dataloader)

    def _check_current_dataloader(self):
        assert self.current_dataloader is not None, (
            "Please call `get_dataloader` before accessing any " +
            "'current_dataloader' based function")

    def add_to_report(self, report, model, *args, **kwargs):
        if "execute_on_master_only" in kwargs:
            warnings.warn(
                "'execute_on_master_only keyword is deprecated and isn't used anymore",
                DeprecationWarning,
            )
        self._check_current_dataloader()
        for key in self.candidate_fields:
            report = self.reshape_and_gather(report, key)

        results = []

        if hasattr(self.current_dataset, "format_for_prediction"):
            results = self.current_dataset.format_for_prediction(report)

        if hasattr(model, "format_for_prediction"):
            results = model.format_for_prediction(results, report)
        elif hasattr(model.module, "format_for_prediction"):
            results = model.module.format_for_prediction(results, report)

        self.report = self.report + results

    def reshape_and_gather(self, report, key):
        if key in report:
            num_dims = report[key].dim()
            if num_dims == 1:
                report[key] = gather_tensor(report[key]).view(-1)
            elif num_dims >= 2:
                # Collect dims other than batch
                other_dims = report[key].size()[1:]
                report[key] = gather_tensor(report[key]).view(-1, *other_dims)

        return report
Exemplo n.º 26
0
    def test_get_current(self):
        timer = Timer()
        expected = 0

        self.assertEqual(int(timer.get_current().split("ms")[0]), expected)
Exemplo n.º 27
0
class TestReporter(Dataset):
    def __init__(self, multi_task_instance):
        self.test_task = multi_task_instance
        self.task_type = multi_task_instance.dataset_type
        self.config = registry.get("config")
        self.report = []
        self.timer = Timer()
        self.training_config = self.config.training
        self.num_workers = self.training_config.num_workers
        self.batch_size = self.training_config.batch_size
        self.report_folder_arg = get_mmf_env(key="report_dir")
        self.experiment_name = self.training_config.experiment_name

        self.datasets = []

        for dataset in self.test_task.get_datasets():
            self.datasets.append(dataset)

        self.current_dataset_idx = -1
        self.current_dataset = self.datasets[self.current_dataset_idx]

        self.save_dir = get_mmf_env(key="save_dir")
        self.report_folder = ckpt_name_from_core_args(self.config)
        self.report_folder += foldername_from_config_override(self.config)

        self.report_folder = os.path.join(self.save_dir, self.report_folder)
        self.report_folder = os.path.join(self.report_folder, "reports")

        if self.report_folder_arg:
            self.report_folder = self.report_folder_arg

        PathManager.mkdirs(self.report_folder)

    def next_dataset(self):
        if self.current_dataset_idx >= 0:
            self.flush_report()

        self.current_dataset_idx += 1

        if self.current_dataset_idx == len(self.datasets):
            return False
        else:
            self.current_dataset = self.datasets[self.current_dataset_idx]
            logger.info(f"Predicting for {self.current_dataset.dataset_name}")
            return True

    def flush_report(self):
        if not is_master():
            return

        name = self.current_dataset.dataset_name
        time_format = "%Y-%m-%dT%H:%M:%S"
        time = self.timer.get_time_hhmmss(None, format=time_format)

        filename = name + "_"

        if len(self.experiment_name) > 0:
            filename += self.experiment_name + "_"

        filename += self.task_type + "_"
        filename += time

        if self.config.evaluation.predict_file_format == "csv":
            filepath = os.path.join(self.report_folder, filename + ".csv")
            self.csv_dump(filepath)
        else:
            filepath = os.path.join(self.report_folder, filename + ".json")
            self.json_dump(filepath)

        logger.info(
            f"Wrote evalai predictions for {name} to {os.path.abspath(filepath)}"
        )
        self.report = []

    def csv_dump(self, filepath):
        with PathManager.open(filepath, "w") as f:
            title = self.report[0].keys()
            cw = csv.DictWriter(f, title, delimiter=",", quoting=csv.QUOTE_MINIMAL)
            cw.writeheader()
            cw.writerows(self.report)

    def json_dump(self, filepath):
        with PathManager.open(filepath, "w") as f:
            json.dump(self.report, f)

    def get_dataloader(self):
        other_args = self._add_extra_args_for_dataloader()
        return DataLoader(
            dataset=self.current_dataset,
            collate_fn=BatchCollator(
                self.current_dataset.dataset_name, self.current_dataset.dataset_type
            ),
            num_workers=self.num_workers,
            pin_memory=self.config.training.pin_memory,
            **other_args,
        )

    def _add_extra_args_for_dataloader(self, other_args=None):
        if other_args is None:
            other_args = {}

        if is_dist_initialized():
            other_args["sampler"] = DistributedSampler(
                self.current_dataset, shuffle=False
            )
        else:
            other_args["shuffle"] = False

        other_args["batch_size"] = get_batch_size()

        return other_args

    def prepare_batch(self, batch):
        return self.current_dataset.prepare_batch(batch)

    def __len__(self):
        return len(self.current_dataset)

    def __getitem__(self, idx):
        return self.current_dataset[idx]

    def add_to_report(self, report, model):
        keys = ["id", "question_id", "image_id", "context_tokens", "captions", "scores"]
        for key in keys:
            report = self.reshape_and_gather(report, key)

        if not is_master():
            return

        results = self.current_dataset.format_for_prediction(report)

        if hasattr(model, "format_for_prediction"):
            results = model.format_for_prediction(results, report)
        elif hasattr(model.module, "format_for_prediction"):
            results = model.module.format_for_prediction(results, report)

        self.report = self.report + results

    def reshape_and_gather(self, report, key):
        if key in report:
            num_dims = report[key].dim()
            if num_dims == 1:
                report[key] = gather_tensor(report[key]).view(-1)
            elif num_dims >= 2:
                # Collect dims other than batch
                other_dims = report[key].size()[1:]
                report[key] = gather_tensor(report[key]).view(-1, *other_dims)

        return report
Exemplo n.º 28
0
class TestReporter(Dataset):
    def __init__(self, multi_task_instance):
        self.test_task = multi_task_instance
        self.task_type = multi_task_instance.dataset_type
        self.config = registry.get("config")
        self.writer = registry.get("writer")
        self.report = []
        self.timer = Timer()
        self.training_config = self.config.training
        self.num_workers = self.training_config.num_workers
        self.batch_size = self.training_config.batch_size
        self.report_folder_arg = get_mmf_env(key="report_dir")
        self.experiment_name = self.training_config.experiment_name

        self.datasets = []

        for dataset in self.test_task.get_datasets():
            self.datasets.append(dataset)

        self.current_dataset_idx = -1
        self.current_dataset = self.datasets[self.current_dataset_idx]

        self.save_dir = get_mmf_env(key="save_dir")
        self.report_folder = ckpt_name_from_core_args(self.config)
        self.report_folder += foldername_from_config_override(self.config)

        self.report_folder = os.path.join(self.save_dir, self.report_folder)
        self.report_folder = os.path.join(self.report_folder, "reports")

        if self.report_folder_arg:
            self.report_folder = self.report_folder_arg

        PathManager.mkdirs(self.report_folder)

    def next_dataset(self):
        if self.current_dataset_idx >= 0:
            self.flush_report()

        self.current_dataset_idx += 1

        if self.current_dataset_idx == len(self.datasets):
            return False
        else:
            self.current_dataset = self.datasets[self.current_dataset_idx]
            self.writer.write("Predicting for " + self.current_dataset.dataset_name)
            return True

    def flush_report(self):
        if not is_master():
            return

        name = self.current_dataset.dataset_name
        time_format = "%Y-%m-%dT%H:%M:%S"
        time = self.timer.get_time_hhmmss(None, format=time_format)

        filename = name + "_"

        if len(self.experiment_name) > 0:
            filename += self.experiment_name + "_"

        filename += self.task_type + "_"

        filename += time + ".json"
        filepath = os.path.join(self.report_folder, filename)

        with PathManager.open(filepath, "w") as f:
            json.dump(self.report, f)

        self.writer.write(
            "Wrote evalai predictions for %s to %s" % (name, os.path.abspath(filepath))
        )
        self.report = []

    def get_dataloader(self):
        other_args = self._add_extra_args_for_dataloader()
        return DataLoader(
            dataset=self.current_dataset,
            collate_fn=BatchCollator(
                self.current_dataset.dataset_name, self.current_dataset.dataset_type
            ),
            num_workers=self.num_workers,
            pin_memory=self.config.training.pin_memory,
            **other_args
        )

    def _add_extra_args_for_dataloader(self, other_args=None):
        if other_args is None:
            other_args = {}

        if torch.distributed.is_initialized():
            other_args["sampler"] = DistributedSampler(
                self.current_dataset, shuffle=False
            )
        else:
            other_args["shuffle"] = False

        other_args["batch_size"] = get_batch_size()

        return other_args

    def prepare_batch(self, batch):
        return self.current_dataset.prepare_batch(batch)

    def __len__(self):
        return len(self.current_dataset)

    def __getitem__(self, idx):
        return self.current_dataset[idx]

    def add_to_report(self, report):
        # TODO: Later gather whole report for no opinions
        if self.current_dataset.dataset_name == "coco":
            report.captions = gather_tensor(report.captions)
            if isinstance(report.image_id, torch.Tensor):
                report.image_id = gather_tensor(report.image_id).view(-1)
        else:
            report.scores = gather_tensor(report.scores).view(
                -1, report.scores.size(-1)
            )
            if "question_id" in report:
                report.question_id = gather_tensor(report.question_id).view(-1)
            if "image_id" in report:
                _, enc_size = report.image_id.size()
                report.image_id = gather_tensor(report.image_id)
                report.image_id = report.image_id.view(-1, enc_size)
            if "context_tokens" in report:
                _, enc_size = report.context_tokens.size()
                report.context_tokens = gather_tensor(report.context_tokens)
                report.context_tokens = report.context_tokens.view(-1, enc_size)

        if not is_master():
            return

        results = self.current_dataset.format_for_evalai(report)

        self.report = self.report + results
Exemplo n.º 29
0
 def on_train_start(self):
     self.train_timer = Timer()
     self.snapshot_timer = Timer()
Exemplo n.º 30
0
class LightningLoopCallback(Callback):
    def __init__(self, lightning_trainer: Any):
        super().__init__()
        self.lightning_trainer = lightning_trainer
        # this is lightning trainer's config
        self.trainer_config = lightning_trainer.trainer_config
        # training config configures training parameters.
        self.training_config = lightning_trainer.training_config
        self.run_type = lightning_trainer.run_type

        # for logging
        self.total_timer = Timer()
        self.snapshot_timer = Timer()
        self.snapshot_iterations = len(self.lightning_trainer.val_loader)
        self.train_timer = Timer()

    def on_train_start(self, trainer: Trainer, pl_module: LightningModule):
        registry.register("current_epoch", trainer.current_epoch)
        self.train_combined_report = None

    def on_train_batch_end(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        outputs: List,
        batch: SampleList,
        batch_idx: int,
        dataloader_idx: int,
    ):
        # prepare the next batch
        self.lightning_trainer.data_module.train_loader.change_dataloader()

        # aggregate train combined_report
        self.train_combined_report = self._update_and_create_report(
            SampleList(batch), batch_idx, outputs, pl_module,
            self.train_combined_report)

        # log
        if (trainer.global_step +
                1) % self.trainer_config.log_every_n_steps == 0:
            self._train_log(trainer, pl_module)

        # save checkpoints - TODO: @sash

    def on_train_end(self, trainer: Trainer, pl_module: LightningModule):
        # Only do when run_type has train as it shouldn't happen on validation and
        # inference runs. Inference will take care of this anyways. Also, don't run
        # if current iteration is divisble by snapshot interval as it will just
        # be a repeat
        if ("train" in self.run_type and trainer.global_step %
                self.trainer_config.val_check_interval != 0):
            logger.info("Stepping into final validation check")
            # Pytorch Lightning upgrades PR4945 and PR4948 will be enabled in 1.2
            # TODO: perform final validation check

    # Validation Callbacks
    def on_validation_start(self, trainer: Trainer,
                            pl_module: LightningModule):
        logger.info("Evaluation time. Running on full validation set...")
        self.snapshot_timer.reset()
        self.val_combined_report = None
        pl_module.val_meter.reset()

    def on_validation_batch_end(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        outputs: List,
        batch: SampleList,
        batch_idx: int,
        dataloader_idx: int,
    ):
        # prepare the next batch
        self.lightning_trainer.data_module.val_loader.change_dataloader()

        # aggregate val_combined_report
        self.val_combined_report = self._update_and_create_report(
            batch,
            batch_idx,
            outputs,
            pl_module,
            self.val_combined_report,
            update_meter=pl_module.val_meter,
        )
        self.val_combined_report.metrics = pl_module.metrics(
            self.val_combined_report, self.val_combined_report)
        pl_module.val_meter.update_from_report(self.val_combined_report,
                                               should_update_loss=False)

    def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
        iterations = self._get_iterations_for_logging(trainer)
        current_epochs = self._get_current_epoch_for_logging(trainer)
        num_updates = self._get_num_updates_for_logging(trainer)
        extra = {
            "num_updates": num_updates,
            "epoch": current_epochs,
            "iterations": iterations,
            "max_updates": trainer.max_steps,
            "val_time": self.snapshot_timer.get_time_since_start(),
        }
        # TODO: @sash populate early stop info for logging (next mvp)
        # extra.update(self.trainer.early_stop_callback.early_stopping.get_info())
        self.train_timer.reset()
        summarize_report(
            current_iteration=iterations,
            num_updates=num_updates,
            max_updates=trainer.max_steps,
            meter=pl_module.val_meter,
            extra=extra,
            tb_writer=self.lightning_trainer.tb_writer,
        )

    def _update_and_create_report(
        self,
        batch: Dict,
        batch_idx: int,
        step_output: Dict,
        pl_module: LightningModule,
        combined_report: Report = None,
        update_meter: Meter = None,
    ):
        report = Report(batch, step_output)

        if update_meter:
            update_meter.update_from_report(report)

        should_accumulate = not (
            batch_idx % self.trainer_config.accumulate_grad_batches == 0)

        final_report = report
        if should_accumulate and combined_report is not None:
            combined_report.accumulate_tensor_fields_and_loss(
                report, pl_module.metrics.required_params)
            combined_report.batch_size += report.batch_size
            final_report = combined_report

        return final_report

    def get_optimizer(self, trainer: Trainer):
        assert (
            len(trainer.optimizers) == 1
        ), "mmf lightning_trainer supports 1 optimizer per model for now."
        optimizer = trainer.optimizers[0]
        return optimizer

    def _save_checkpoint(self, trainer: Trainer):
        logger.info("Checkpoint time. Saving a checkpoint.")
        return
        # TODO: sash Needs implementation - next mvp

    def _get_current_epoch_for_logging(self, trainer: Trainer):
        return trainer.current_epoch + 1

    def _get_iterations_for_logging(self, trainer: Trainer):
        return trainer.train_loop.batch_idx + 1

    def _get_num_updates_for_logging(self, trainer: Trainer):
        return trainer.global_step + 1

    def _train_log(self, trainer: Trainer, pl_module: LightningModule):
        if self.training_config.evaluate_metrics:
            self.train_combined_report.metrics = pl_module.metrics(
                self.train_combined_report, self.train_combined_report)

        pl_module.train_meter.update_from_report(self.train_combined_report)

        extra = {}
        if "cuda" in str(trainer.model.device):
            extra["max mem"] = torch.cuda.max_memory_allocated() / 1024
            extra["max mem"] //= 1024

        if self.training_config.experiment_name:
            extra["experiment"] = self.training_config.experiment_name

        optimizer = self.get_optimizer(trainer)
        num_updates = self._get_num_updates_for_logging(trainer)
        current_iteration = self._get_iterations_for_logging(trainer)
        extra.update({
            "epoch":
            self._get_current_epoch_for_logging(trainer),
            "iterations":
            current_iteration,
            "num_updates":
            num_updates,
            "max_updates":
            trainer.max_steps,
            "lr":
            "{:.5f}".format(optimizer.param_groups[0]["lr"]).rstrip("0"),
            "ups":
            "{:.2f}".format(self.trainer_config.log_every_n_steps /
                            self.train_timer.unix_time_since_start()),
            "time":
            self.train_timer.get_time_since_start(),
            "time_since_start":
            self.total_timer.get_time_since_start(),
            "eta":
            calculate_time_left(
                max_updates=trainer.max_steps,
                num_updates=num_updates,
                timer=self.train_timer,
                num_snapshot_iterations=self.snapshot_iterations,
                log_interval=self.trainer_config.log_every_n_steps,
                eval_interval=self.trainer_config.val_check_interval,
            ),
        })
        self.train_timer.reset()
        summarize_report(
            current_iteration=current_iteration,
            num_updates=num_updates,
            max_updates=trainer.max_steps,
            meter=pl_module.train_meter,
            extra=extra,
            tb_writer=self.lightning_trainer.tb_writer,
        )