def test_reset(self): timer = Timer() time.sleep(2) timer.reset() expected = "000ms" self.assertEqual(timer.get_current(), expected)
def test_reset(self): timer = Timer() time.sleep(2) timer.reset() expected = 0 self.assertEqual(int(timer.get_current().split("ms")[0]), expected)
def __init__(self, configuration): self.configuration = configuration self.config = self.configuration.get_config() self.profiler = Timer() self.total_timer = Timer() if self.configuration is not None: self.args = self.configuration.args
def test_get_time_since_start(self): timer = Timer() time.sleep(2) expected = 2 self.assertEqual(expected, int(timer.get_time_since_start().split("s")[0]))
def __init__(self, config, trainer): """ Attr: config(mmf_typings.DictConfig): Config for the callback trainer(Type[BaseTrainer]): Trainer object """ super().__init__(config, trainer) self.total_timer = Timer() self.log_interval = self.training_config.log_interval self.evaluation_interval = self.training_config.evaluation_interval self.checkpoint_interval = self.training_config.checkpoint_interval # Total iterations for snapshot self.snapshot_iterations = len(self.trainer.val_dataset) self.snapshot_iterations //= self.training_config.batch_size self.tb_writer = None if self.training_config.tensorboard: log_dir = setup_output_folder(folder_only=True) env_tb_logdir = get_mmf_env(key="tensorboard_logdir") if env_tb_logdir: log_dir = env_tb_logdir self.tb_writer = TensorboardLogger(log_dir, self.trainer.current_iteration)
def __init__(self, multi_task_instance): self.test_task = multi_task_instance self.task_type = multi_task_instance.dataset_type self.config = registry.get("config") self.writer = registry.get("writer") self.report = [] self.timer = Timer() self.training_config = self.config.training self.num_workers = self.training_config.num_workers self.batch_size = self.training_config.batch_size self.report_folder_arg = get_mmf_env(key="report_dir") self.experiment_name = self.training_config.experiment_name self.datasets = [] for dataset in self.test_task.get_datasets(): self.datasets.append(dataset) self.current_dataset_idx = -1 self.current_dataset = self.datasets[self.current_dataset_idx] self.save_dir = get_mmf_env(key="save_dir") self.report_folder = ckpt_name_from_core_args(self.config) self.report_folder += foldername_from_config_override(self.config) self.report_folder = os.path.join(self.save_dir, self.report_folder) self.report_folder = os.path.join(self.report_folder, "reports") if self.report_folder_arg: self.report_folder = self.report_folder_arg PathManager.mkdirs(self.report_folder)
def train(self): self.writer.write("===== Model =====") self.writer.write(self.model) print_model_parameters(self.model) if "train" not in self.run_type: self.inference() return should_break = False if self.max_epochs is None: self.max_epochs = math.inf else: self.max_updates = math.inf self.model.train() self.train_timer = Timer() self.snapshot_timer = Timer() self.profile("Setup Time") torch.autograd.set_detect_anomaly(True) self.writer.write("Starting training...") while self.num_updates < self.max_updates and not should_break: self.current_epoch += 1 registry.register("current_epoch", self.current_epoch) # Seed the sampler in case if it is distributed self.dataset_loader.seed_sampler("train", self.current_epoch) if self.current_epoch > self.max_epochs: break for batch in self.train_loader: self.profile("Batch load time") self.current_iteration += 1 self.writer.write(self.num_updates + 1, "debug") report = self._forward_pass(batch) loss = self._extract_loss(report) self._backward(loss) should_break = self._logistics(report) if self.num_updates > self.max_updates: should_break = True if should_break: break # In distributed, each worker will complete one epoch when we reach this # as each worker is an individual instance self.current_epoch += get_world_size() - 1 self.finalize()
def __init__(self, log_folder="./logs", iteration=0): self._summary_writer = None self._is_main = is_main() self.timer = Timer() self.log_folder = log_folder self.time_format = "%Y-%m-%dT%H:%M:%S" current_time = self.timer.get_time_hhmmss(None, format=self.time_format) self.tensorboard_folder = os.path.join(self.log_folder, f"tensorboard_{current_time}")
def setup_output_folder(folder_only: bool = False): """Sets up and returns the output file where the logs will be placed based on the configuration passed. Usually "save_dir/logs/log_<timestamp>.txt". If env.log_dir is passed, logs will be directly saved in this folder. Args: folder_only (bool, optional): If folder should be returned and not the file. Defaults to False. Returns: str: folder or file path depending on folder_only flag """ save_dir = get_mmf_env(key="save_dir") time_format = "%Y_%m_%dT%H_%M_%S" log_filename = "train_" log_filename += Timer().get_time_hhmmss(None, format=time_format) log_filename += ".log" log_folder = os.path.join(save_dir, "logs") env_log_dir = get_mmf_env(key="log_dir") if env_log_dir: log_folder = env_log_dir if not PathManager.exists(log_folder): PathManager.mkdirs(log_folder) if folder_only: return log_folder log_filename = os.path.join(log_folder, log_filename) return log_filename
def __init__(self, log_folder="./logs", iteration=0): # This would handle warning of missing tensorboard from torch.utils.tensorboard import SummaryWriter self.summary_writer = None self._is_master = is_master() self.timer = Timer() self.log_folder = log_folder self.time_format = "%Y-%m-%dT%H:%M:%S" if self._is_master: current_time = self.timer.get_time_hhmmss(None, format=self.time_format) tensorboard_folder = os.path.join(self.log_folder, f"tensorboard_{current_time}") self.summary_writer = SummaryWriter(tensorboard_folder)
class TrainerProfilingMixin(ABC): profiler: Type[Timer] = Timer() def profile(self, text: str) -> None: if self.training_config.logger_level != "debug": return logging.debug(f"{text}: {self.profiler.get_time_since_start()}") self.profiler.reset()
def test_tensorboard_logging_parity( self, summary_writer, mmf, lightning, logistics, logistics_logs, report_logs, trainer_logs, mkdirs, ): # mmf trainer mmf_trainer = get_mmf_trainer( max_updates=8, batch_size=2, max_epochs=None, log_interval=3, tensorboard=True, ) def _add_scalars_mmf(log_dict, iteration): self.mmf_tensorboard_logs.append({iteration: log_dict}) mmf_trainer.load_metrics() logistics_callback = LogisticsCallback(mmf_trainer.config, mmf_trainer) logistics_callback.snapshot_timer = MagicMock(return_value=None) logistics_callback.train_timer = Timer() logistics_callback.tb_writer.add_scalars = _add_scalars_mmf mmf_trainer.logistics_callback = logistics_callback mmf_trainer.callbacks = [logistics_callback] mmf_trainer.early_stop_callback = MagicMock(return_value=None) mmf_trainer.on_update_end = logistics_callback.on_update_end mmf_trainer.training_loop() # lightning_trainer trainer = get_lightning_trainer( max_steps=8, batch_size=2, prepare_trainer=False, log_every_n_steps=3, val_check_interval=9, tensorboard=True, ) def _add_scalars_lightning(log_dict, iteration): self.lightning_tensorboard_logs.append({iteration: log_dict}) def _on_fit_start_callback(): trainer.tb_writer.add_scalars = _add_scalars_lightning callback = LightningLoopCallback(trainer) run_lightning_trainer_with_callback( trainer, callback, on_fit_start_callback=_on_fit_start_callback) self.assertEqual(len(self.mmf_tensorboard_logs), len(self.lightning_tensorboard_logs)) for mmf, lightning in zip(self.mmf_tensorboard_logs, self.lightning_tensorboard_logs): self.assertDictEqual(mmf, lightning)
class TrainerProfilingMixin(ABC): profiler: Type[Timer] = Timer() def profile(self, text: str) -> None: if self.training_config.logger_level != "debug": return self.writer.write(text + ": " + self.profiler.get_time_since_start(), "debug") self.profiler.reset()
def __init__(self, multi_task_instance, test_reporter_config: TestReporterConfigType = None): if not isinstance(test_reporter_config, TestReporter.TestReporterConfigType): test_reporter_config = TestReporter.TestReporterConfigType( **test_reporter_config) self.test_task = multi_task_instance self.task_type = multi_task_instance.dataset_type self.config = registry.get("config") self.report = [] self.timer = Timer() self.training_config = self.config.training self.num_workers = self.training_config.num_workers self.batch_size = self.training_config.batch_size self.report_folder_arg = get_mmf_env(key="report_dir") self.experiment_name = self.training_config.experiment_name self.test_reporter_config = test_reporter_config self.datasets = [] for dataset in self.test_task.get_datasets(): self.datasets.append(dataset) self.current_dataset_idx = -1 self.current_dataset = self.datasets[self.current_dataset_idx] self.save_dir = get_mmf_env(key="save_dir") self.report_folder = ckpt_name_from_core_args(self.config) self.report_folder += foldername_from_config_override(self.config) self.report_folder = os.path.join(self.save_dir, self.report_folder) self.report_folder = os.path.join(self.report_folder, "reports") if self.report_folder_arg: self.report_folder = self.report_folder_arg self.candidate_fields = DEFAULT_CANDIDATE_FIELDS if not test_reporter_config.candidate_fields == MISSING: self.candidate_fields = test_reporter_config.candidate_fields PathManager.mkdirs(self.report_folder)
def __init__( self, datamodules: List[pl.LightningDataModule], config: Config = None, dataset_type: str = "train", ): self.test_reporter_config = OmegaConf.merge( OmegaConf.structured(self.Config), config ) self.datamodules = datamodules self.dataset_type = dataset_type self.config = registry.get("config") self.report = [] self.timer = Timer() self.training_config = self.config.training self.num_workers = self.training_config.num_workers self.batch_size = self.training_config.batch_size self.report_folder_arg = get_mmf_env(key="report_dir") self.experiment_name = self.training_config.experiment_name self.current_datamodule_idx = -1 self.dataset_names = list(self.datamodules.keys()) self.current_datamodule = self.datamodules[ self.dataset_names[self.current_datamodule_idx] ] self.current_dataloader = None self.save_dir = get_mmf_env(key="save_dir") self.report_folder = ckpt_name_from_core_args(self.config) self.report_folder += foldername_from_config_override(self.config) self.report_folder = os.path.join(self.save_dir, self.report_folder) self.report_folder = os.path.join(self.report_folder, "reports") if self.report_folder_arg: self.report_folder = self.report_folder_arg self.candidate_fields = self.test_reporter_config.candidate_fields PathManager.mkdirs(self.report_folder) log_class_usage("TestReporter", self.__class__)
def test_validation_parity(self, summarize_report_fn, test_reporter, sw, mkdirs): mmf_trainer = get_mmf_trainer(max_updates=8, batch_size=2, max_epochs=None, evaluation_interval=3) mmf_trainer.load_metrics() logistics_callback = LogisticsCallback(mmf_trainer.config, mmf_trainer) logistics_callback.snapshot_timer = Timer() logistics_callback.train_timer = Timer() mmf_trainer.logistics_callback = logistics_callback mmf_trainer.callbacks.append(logistics_callback) mmf_trainer.early_stop_callback = MagicMock(return_value=None) mmf_trainer.on_validation_end = logistics_callback.on_validation_end mmf_trainer.training_loop() calls = summarize_report_fn.call_args_list self.assertEqual(3, len(calls)) self.assertEqual(len(self.ground_truths), len(calls)) self._check_values(calls)
def __init__(self, config, trainer): """ Attr: config(mmf_typings.DictConfig): Config for the callback trainer(Type[BaseTrainer]): Trainer object """ super().__init__(config, trainer) self.total_timer = Timer() self.log_interval = self.training_config.log_interval self.evaluation_interval = self.training_config.evaluation_interval self.checkpoint_interval = self.training_config.checkpoint_interval # Total iterations for snapshot # len would be number of batches per GPU == max updates self.snapshot_iterations = len(self.trainer.val_loader) self.tb_writer = None self.wandb_logger = None if self.training_config.tensorboard: log_dir = setup_output_folder(folder_only=True) env_tb_logdir = get_mmf_env(key="tensorboard_logdir") if env_tb_logdir: log_dir = env_tb_logdir self.tb_writer = TensorboardLogger(log_dir, self.trainer.current_iteration) if self.training_config.wandb.enabled: log_dir = setup_output_folder(folder_only=True) env_wandb_logdir = get_mmf_env(key="wandb_logdir") if env_wandb_logdir: log_dir = env_wandb_logdir self.wandb_logger = WandbLogger( entity=config.training.wandb.entity, config=config, project=config.training.wandb.project, )
class TensorboardLogger: def __init__(self, log_folder="./logs", iteration=0): # This would handle warning of missing tensorboard from torch.utils.tensorboard import SummaryWriter self.summary_writer = None self._is_master = is_master() self.timer = Timer() self.log_folder = log_folder self.time_format = "%Y-%m-%dT%H:%M:%S" if self._is_master: current_time = self.timer.get_time_hhmmss(None, format=self.time_format) tensorboard_folder = os.path.join( self.log_folder, f"tensorboard_{current_time}" ) self.summary_writer = SummaryWriter(tensorboard_folder) def __del__(self): if getattr(self, "summary_writer", None) is not None: self.summary_writer.close() def _should_log_tensorboard(self): if self.summary_writer is None or not self._is_master: return False else: return True def add_scalar(self, key, value, iteration): if not self._should_log_tensorboard(): return self.summary_writer.add_scalar(key, value, iteration) def add_scalars(self, scalar_dict, iteration): if not self._should_log_tensorboard(): return for key, val in scalar_dict.items(): self.summary_writer.add_scalar(key, val, iteration) def add_histogram_for_model(self, model, iteration): if not self._should_log_tensorboard(): return for name, param in model.named_parameters(): np_param = param.clone().cpu().data.numpy() self.summary_writer.add_histogram(name, np_param, iteration) def flush(self): if self._should_log_tensorboard(): self.summary_writer.flush()
class TensorboardLogger: def __init__(self, log_folder="./logs", iteration=0): self._summary_writer = None self._is_main = is_main() self.timer = Timer() self.log_folder = log_folder self.time_format = "%Y-%m-%dT%H:%M:%S" current_time = self.timer.get_time_hhmmss(None, format=self.time_format) self.tensorboard_folder = os.path.join(self.log_folder, f"tensorboard_{current_time}") @property def summary_writer(self): # Only on rank zero if not self._is_main: return None if self._summary_writer is None: # This would handle warning of missing tensorboard from torch.utils.tensorboard import SummaryWriter self._summary_writer = SummaryWriter(self.tensorboard_folder) return self._summary_writer @skip_if_tensorboard_inactive def close(self): """ Closes the tensorboard summary writer. """ self.summary_writer.close() @skip_if_tensorboard_inactive def add_scalar(self, key, value, iteration): self.summary_writer.add_scalar(key, value, iteration) @skip_if_tensorboard_inactive def add_scalars(self, scalar_dict, iteration): for key, val in scalar_dict.items(): self.summary_writer.add_scalar(key, val, iteration) @skip_if_tensorboard_inactive def add_histogram_for_model(self, model, iteration): for name, param in model.named_parameters(): np_param = param.clone().cpu().data.numpy() self.summary_writer.add_histogram(name, np_param, iteration)
def __init__(self, lightning_trainer: Any): super().__init__() self.lightning_trainer = lightning_trainer # this is lightning trainer's config self.trainer_config = lightning_trainer.trainer_config # training config configures training parameters. self.training_config = lightning_trainer.training_config self.run_type = lightning_trainer.run_type # for logging self.total_timer = Timer() self.snapshot_timer = Timer() self.snapshot_iterations = len(self.lightning_trainer.val_loader) self.train_timer = Timer()
class LogisticsCallback(Callback): """Callback for handling train/validation logistics, report summarization, logging etc. """ def __init__(self, config, trainer): """ Attr: config(mmf_typings.DictConfig): Config for the callback trainer(Type[BaseTrainer]): Trainer object """ super().__init__(config, trainer) self.total_timer = Timer() self.log_interval = self.training_config.log_interval self.evaluation_interval = self.training_config.evaluation_interval self.checkpoint_interval = self.training_config.checkpoint_interval # Total iterations for snapshot self.snapshot_iterations = len(self.trainer.val_dataset) self.snapshot_iterations //= self.training_config.batch_size self.tb_writer = None if self.training_config.tensorboard: log_dir = setup_output_folder(folder_only=True) env_tb_logdir = get_mmf_env(key="tensorboard_logdir") if env_tb_logdir: log_dir = env_tb_logdir self.tb_writer = TensorboardLogger(log_dir, self.trainer.current_iteration) def on_train_start(self): self.train_timer = Timer() self.snapshot_timer = Timer() def on_update_end(self, **kwargs): if not kwargs["should_log"]: return extra = {} if "cuda" in str(self.trainer.device): extra["max mem"] = torch.cuda.max_memory_allocated() / 1024 extra["max mem"] //= 1024 if self.training_config.experiment_name: extra["experiment"] = self.training_config.experiment_name extra.update( { "epoch": self.trainer.current_epoch, "num_updates": self.trainer.num_updates, "iterations": self.trainer.current_iteration, "max_updates": self.trainer.max_updates, "lr": "{:.5f}".format( self.trainer.optimizer.param_groups[0]["lr"] ).rstrip("0"), "ups": "{:.2f}".format( self.log_interval / self.train_timer.unix_time_since_start() ), "time": self.train_timer.get_time_since_start(), "time_since_start": self.total_timer.get_time_since_start(), "eta": self._calculate_time_left(), } ) self.train_timer.reset() self._summarize_report(kwargs["meter"], extra=extra) def on_validation_start(self, **kwargs): self.snapshot_timer.reset() def on_validation_end(self, **kwargs): extra = { "num_updates": self.trainer.num_updates, "epoch": self.trainer.current_epoch, "iterations": self.trainer.current_iteration, "max_updates": self.trainer.max_updates, "val_time": self.snapshot_timer.get_time_since_start(), } extra.update(self.trainer.early_stop_callback.early_stopping.get_info()) self.train_timer.reset() self._summarize_report(kwargs["meter"], extra=extra) def on_test_end(self, **kwargs): prefix = "{}: full {}".format( kwargs["report"].dataset_name, kwargs["report"].dataset_type ) self._summarize_report(kwargs["meter"], prefix) logger.info(f"Finished run in {self.total_timer.get_time_since_start()}") def _summarize_report(self, meter, should_print=True, extra=None): if extra is None: extra = {} if not is_master() and not is_xla(): return if self.training_config.tensorboard: scalar_dict = meter.get_scalar_dict() self.tb_writer.add_scalars(scalar_dict, self.trainer.current_iteration) if not should_print: return log_dict = {} if hasattr(self.trainer, "num_updates") and hasattr( self.trainer, "max_updates" ): log_dict.update( {"progress": f"{self.trainer.num_updates}/{self.trainer.max_updates}"} ) log_dict.update(meter.get_log_dict()) log_dict.update(extra) log_progress(log_dict) def _calculate_time_left(self): time_taken_for_log = time.time() * 1000 - self.train_timer.start iterations_left = self.trainer.max_updates - self.trainer.num_updates num_logs_left = iterations_left / self.log_interval time_left = num_logs_left * time_taken_for_log snapshot_iteration = self.snapshot_iterations / self.log_interval snapshot_iteration *= iterations_left / self.evaluation_interval time_left += snapshot_iteration * time_taken_for_log return self.train_timer.get_time_hhmmss(gap=time_left)
class BaseTrainer: def __init__(self, configuration): self.configuration = configuration self.config = self.configuration.get_config() self.profiler = Timer() self.total_timer = Timer() if self.configuration is not None: self.args = self.configuration.args def load(self): self._set_device() self.run_type = self.config.get("run_type", "train") self.dataset_loader = DatasetLoader(self.config) self._datasets = self.config.datasets # Check if loader is already defined, else init it writer = registry.get("writer", no_warning=True) if writer: self.writer = writer else: self.writer = Logger(self.config) registry.register("writer", self.writer) self.configuration.pretty_print() self.config_based_setup() self.load_datasets() self.load_model_and_optimizer() self.load_metrics() def _set_device(self): self.local_rank = self.config.device_id self.device = self.local_rank self.distributed = False # Will be updated later based on distributed setup registry.register("global_device", self.device) if self.config.distributed.init_method is not None: self.distributed = True self.device = torch.device("cuda", self.local_rank) elif torch.cuda.is_available(): self.device = torch.device("cuda") else: self.device = torch.device("cpu") registry.register("current_device", self.device) def load_datasets(self): self.writer.write("Loading datasets", "info") self.dataset_loader.load_datasets() self.train_dataset = self.dataset_loader.train_dataset self.val_dataset = self.dataset_loader.val_dataset # Total iterations for snapshot self.snapshot_iterations = len(self.val_dataset) self.snapshot_iterations //= self.config.training.batch_size self.test_dataset = self.dataset_loader.test_dataset self.train_loader = self.dataset_loader.train_loader self.val_loader = self.dataset_loader.val_loader self.test_loader = self.dataset_loader.test_loader def load_metrics(self): metrics = self.config.evaluation.get("metrics", []) self.metrics = Metrics(metrics) self.metrics_params = self.metrics.required_params def load_model_and_optimizer(self): attributes = self.config.model_config[self.config.model] # Easy way to point to config for other model if isinstance(attributes, str): attributes = self.config.model_config[attributes] with omegaconf.open_dict(attributes): attributes.model = self.config.model self.model = build_model(attributes) if "cuda" in str(self.device): device_info = "CUDA Device {} is: {}".format( self.config.distributed.rank, torch.cuda.get_device_name(self.local_rank), ) registry.register("global_device", self.config.distributed.rank) self.writer.write(device_info, log_all=True) self.model = self.model.to(self.device) self.optimizer = build_optimizer(self.model, self.config) registry.register("data_parallel", False) registry.register("distributed", False) self.load_extras() self.parallelize_model() def parallelize_model(self): training = self.config.training if ("cuda" in str(self.device) and torch.cuda.device_count() > 1 and not self.distributed): registry.register("data_parallel", True) self.model = torch.nn.DataParallel(self.model) if "cuda" in str(self.device) and self.distributed: registry.register("distributed", True) self.model = torch.nn.parallel.DistributedDataParallel( self.model, device_ids=[self.local_rank], output_device=self.local_rank, check_reduction=True, find_unused_parameters=training.find_unused_parameters, ) def load_extras(self): self.writer.write("Torch version is: " + torch.__version__) self.checkpoint = Checkpoint(self) self.meter = Meter() self.training_config = self.config.training early_stop_criteria = self.training_config.early_stop.criteria early_stop_minimize = self.training_config.early_stop.minimize early_stop_enabled = self.training_config.early_stop.enabled early_stop_patience = self.training_config.early_stop.patience self.log_interval = self.training_config.log_interval self.evaluation_interval = self.training_config.evaluation_interval self.checkpoint_interval = self.training_config.checkpoint_interval self.max_updates = self.training_config.max_updates self.should_clip_gradients = self.training_config.clip_gradients self.max_epochs = self.training_config.max_epochs self.early_stopping = EarlyStopping( self.model, self.checkpoint, early_stop_criteria, patience=early_stop_patience, minimize=early_stop_minimize, should_stop=early_stop_enabled, ) self.current_epoch = 0 self.current_iteration = 0 self.num_updates = 0 self.checkpoint.load_state_dict() self.not_debug = self.training_config.logger_level != "debug" self.lr_scheduler = None if self.training_config.lr_scheduler is True: self.lr_scheduler = build_scheduler(self.optimizer, self.config) self.tb_writer = None if self.training_config.tensorboard: log_dir = self.writer.log_dir env_tb_logdir = get_mmf_env(key="tensorboard_logdir") if env_tb_logdir: log_dir = env_tb_logdir self.tb_writer = TensorboardLogger(log_dir, self.current_iteration) def config_based_setup(self): seed = self.config.training.seed if seed is None: return torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def train(self): self.writer.write("===== Model =====") self.writer.write(self.model) print_model_parameters(self.model) if "train" not in self.run_type: self.inference() return should_break = False if self.max_epochs is None: self.max_epochs = math.inf else: self.max_updates = math.inf self.model.train() self.train_timer = Timer() self.snapshot_timer = Timer() self.profile("Setup Time") torch.autograd.set_detect_anomaly(True) self.writer.write("Starting training...") while self.num_updates < self.max_updates and not should_break: self.current_epoch += 1 registry.register("current_epoch", self.current_epoch) # Seed the sampler in case if it is distributed self.dataset_loader.seed_sampler("train", self.current_epoch) if self.current_epoch > self.max_epochs: break for batch in self.train_loader: self.profile("Batch load time") self.current_iteration += 1 self.writer.write(self.num_updates + 1, "debug") report = self._forward_pass(batch) loss = self._extract_loss(report) self._backward(loss) should_break = self._logistics(report) if self.num_updates > self.max_updates: should_break = True if should_break: break # In distributed, each worker will complete one epoch when we reach this # as each worker is an individual instance self.current_epoch += get_world_size() - 1 self.finalize() def _run_scheduler(self): if self.lr_scheduler is not None: self.lr_scheduler.step(self.num_updates) def _forward_pass(self, batch): prepared_batch = self.dataset_loader.prepare_batch(batch) self.profile("Batch prepare time") # Arguments should be a dict at this point model_output = self.model(prepared_batch) report = Report(prepared_batch, model_output) self.profile("Forward time") return report def _backward(self, loss): self.optimizer.zero_grad() loss.backward() if self.should_clip_gradients: clip_gradients(self.model, self.num_updates, self.tb_writer, self.config) self.optimizer.step() self._run_scheduler() self.num_updates += 1 self.profile("Backward time") def _extract_loss(self, report): loss_dict = report.losses loss = sum([loss.mean() for loss in loss_dict.values()]) return loss def finalize(self): self.writer.write("Stepping into final validation check") # Only do when run_type has train as it shouldn't happen on validation and # inference runs. Inference will take care of this anyways. Also, don't run # if current iteration is divisble by snapshot interval as it will just # be a repeat if ("train" in self.run_type and self.num_updates % self.evaluation_interval != 0): self._try_full_validation(force=True) self.checkpoint.restore() self.checkpoint.finalize() self.inference() self.writer.write( f"Finished run in {self.total_timer.get_time_since_start()}") def _update_meter(self, report, meter=None, eval_mode=False): if meter is None: meter = self.meter if hasattr(report, "metrics"): metrics_dict = report.metrics reduced_metrics_dict = reduce_dict(metrics_dict) if not eval_mode: loss_dict = report.losses reduced_loss_dict = reduce_dict(loss_dict) with torch.no_grad(): # Add metrics to meter only when mode is `eval` meter_update_dict = {} if not eval_mode: loss_key = report.dataset_type + "/total_loss" reduced_loss = sum( [loss.mean() for loss in reduced_loss_dict.values()]) if hasattr(reduced_loss, "item"): reduced_loss = reduced_loss.item() registry.register(loss_key, reduced_loss) meter_update_dict.update({loss_key: reduced_loss}) meter_update_dict.update(reduced_loss_dict) if hasattr(report, "metrics"): meter_update_dict.update(reduced_metrics_dict) meter.update(meter_update_dict, report.batch_size) def _logistics(self, report): registry.register("current_iteration", self.current_iteration) registry.register("num_updates", self.num_updates) should_print = self.num_updates % self.log_interval == 0 should_break = False extra = {} if should_print is True: if "cuda" in str(self.device): extra["max mem"] = torch.cuda.max_memory_allocated() / 1024 extra["max mem"] //= 1024 if self.training_config.experiment_name: extra["experiment"] = self.training_config.experiment_name extra.update({ "epoch": self.current_epoch, "num_updates": self.num_updates, "iterations": self.current_iteration, "max_updates": self.max_updates, "lr": "{:.5f}".format( self.optimizer.param_groups[0]["lr"]).rstrip("0"), "ups": "{:.2f}".format(self.log_interval / self.train_timer.unix_time_since_start()), "time": self.train_timer.get_time_since_start(), "time_since_start": self.total_timer.get_time_since_start(), "eta": self._calculate_time_left(), }) self.train_timer.reset() # Calculate metrics every log interval for debugging if self.training_config.evaluate_metrics: report.metrics = self.metrics(report, report) self._update_meter(report, self.meter) self._summarize_report(self.meter, should_print=should_print, extra=extra) self._try_snapshot() should_break = self._try_full_validation() return should_break def _try_snapshot(self): if self.num_updates % self.checkpoint_interval == 0: self.writer.write("Checkpoint time. Saving a checkpoint.") self.checkpoint.save(self.num_updates, self.current_iteration, update_best=False) def _try_full_validation(self, force=False): should_break = False if self.num_updates % self.evaluation_interval == 0 or force: self.snapshot_timer.reset() self.writer.write( "Evaluation time. Running on full validation set...") # Validation and Early stopping # Create a new meter for this case report, meter = self.evaluate(self.val_loader) extra = { "num_updates": self.num_updates, "epoch": self.current_epoch, "iterations": self.current_iteration, "max_updates": self.max_updates, "val_time": self.snapshot_timer.get_time_since_start(), } stop = self.early_stopping(self.num_updates, self.current_iteration, meter) stop = bool(broadcast_scalar(stop, src=0, device=self.device)) extra.update(self.early_stopping.get_info()) self._summarize_report(meter, extra=extra) gc.collect() if "cuda" in str(self.device): torch.cuda.empty_cache() if stop is True: self.writer.write("Early stopping activated") should_break = True self.train_timer.reset() return should_break def evaluate(self, loader, use_tqdm=False, single_batch=False): meter = Meter() with torch.no_grad(): self.model.eval() disable_tqdm = not use_tqdm or not is_master() combined_report = None for batch in tqdm(loader, disable=disable_tqdm): report = self._forward_pass(batch) self._update_meter(report, meter) # accumulate necessary params for metric calculation if combined_report is None: combined_report = report else: combined_report.accumulate_tensor_fields( report, self.metrics.required_params) combined_report.batch_size += report.batch_size if single_batch is True: break combined_report.metrics = self.metrics(combined_report, combined_report) self._update_meter(combined_report, meter, eval_mode=True) self.model.train() return combined_report, meter def _summarize_report(self, meter, should_print=True, extra=None): if extra is None: extra = {} if not is_master(): return if self.training_config.tensorboard: scalar_dict = meter.get_scalar_dict() self.tb_writer.add_scalars(scalar_dict, self.current_iteration) if not should_print: return log_dict = {"progress": f"{self.num_updates}/{self.max_updates}"} log_dict.update(meter.get_log_dict()) log_dict.update(extra) self.writer.log_progress(log_dict) def inference(self): if "val" in self.run_type: self._inference_run("val") if any(rt in self.run_type for rt in ["inference", "test", "predict"]): self._inference_run("test") def _inference_run(self, dataset_type): if self.config.evaluation.predict: self.predict(dataset_type) return self.writer.write(f"Starting inference on {dataset_type} set") report, meter = self.evaluate(getattr(self, f"{dataset_type}_loader"), use_tqdm=True) prefix = f"{report.dataset_name}: full {dataset_type}" self._summarize_report(meter, prefix) def _calculate_time_left(self): time_taken_for_log = time.time() * 1000 - self.train_timer.start iterations_left = self.max_updates - self.num_updates num_logs_left = iterations_left / self.log_interval time_left = num_logs_left * time_taken_for_log snapshot_iteration = self.snapshot_iterations / self.log_interval snapshot_iteration *= iterations_left / self.evaluation_interval time_left += snapshot_iteration * time_taken_for_log return self.train_timer.get_time_hhmmss(gap=time_left) def profile(self, text): if self.not_debug: return self.writer.write(text + ": " + self.profiler.get_time_since_start(), "debug") self.profiler.reset() def predict(self, dataset_type): reporter = self.dataset_loader.get_test_reporter(dataset_type) with torch.no_grad(): self.model.eval() message = f"Starting {dataset_type} inference predictions" self.writer.write(message) while reporter.next_dataset(): dataloader = reporter.get_dataloader() for batch in tqdm(dataloader): prepared_batch = reporter.prepare_batch(batch) model_output = self.model(prepared_batch) report = Report(prepared_batch, model_output) reporter.add_to_report(report, self.model) self.writer.write("Finished predicting") self.model.train()
def __init__(self, config): self.config = config self.profiler = Timer() self.total_timer = Timer()
def __init__(self, config, name=None): self.logger = None self._is_master = is_master() self.timer = Timer() self.config = config self.save_dir = get_mmf_env(key="save_dir") self.log_format = config.training.log_format self.time_format = "%Y-%m-%dT%H:%M:%S" self.log_filename = "train_" self.log_filename += self.timer.get_time_hhmmss(None, format=self.time_format) self.log_filename += ".log" self.log_folder = os.path.join(self.save_dir, "logs") env_log_dir = get_mmf_env(key="log_dir") if env_log_dir: self.log_folder = env_log_dir if not PathManager.exists(self.log_folder): PathManager.mkdirs(self.log_folder) self.log_filename = os.path.join(self.log_folder, self.log_filename) if not self._is_master: return if self._is_master: print("Logging to:", self.log_filename) logging.captureWarnings(True) if not name: name = __name__ self.logger = logging.getLogger(name) self._file_only_logger = logging.getLogger(name) warnings_logger = logging.getLogger("py.warnings") # Set level level = config.training.logger_level self.logger.setLevel(getattr(logging, level.upper())) self._file_only_logger.setLevel(getattr(logging, level.upper())) formatter = logging.Formatter( "%(asctime)s %(levelname)s: %(message)s", datefmt="%Y-%m-%dT%H:%M:%S" ) # Add handler to file channel = logging.FileHandler(filename=self.log_filename, mode="a") channel.setFormatter(formatter) self.logger.addHandler(channel) self._file_only_logger.addHandler(channel) warnings_logger.addHandler(channel) # Add handler to stdout channel = logging.StreamHandler(sys.stdout) channel.setFormatter(formatter) self.logger.addHandler(channel) warnings_logger.addHandler(channel) should_not_log = self.config.training.should_not_log self.should_log = not should_not_log # Single log wrapper map self._single_log_map = set()
class TestReporter(Dataset): @dataclass class Config: # A set of fields to be *considered* for exporting by the reporter # Note that `format_for_prediction` is what ultimtly detemrimes the # exported fields candidate_fields: List[str] = field( default_factory=lambda: DEFAULT_CANDIDATE_FIELDS) # csv or json predict_file_format: str = "json" def __init__( self, datamodules: List[pl.LightningDataModule], config: Config = None, dataset_type: str = "train", ): self.test_reporter_config = OmegaConf.merge( OmegaConf.structured(self.Config), config) self.datamodules = datamodules self.dataset_type = dataset_type self.config = registry.get("config") self.report = [] self.timer = Timer() self.training_config = self.config.training self.num_workers = self.training_config.num_workers self.batch_size = self.training_config.batch_size self.report_folder_arg = get_mmf_env(key="report_dir") self.experiment_name = self.training_config.experiment_name self.current_datamodule_idx = -1 self.dataset_names = list(self.datamodules.keys()) self.current_datamodule = self.datamodules[self.dataset_names[ self.current_datamodule_idx]] self.current_dataloader = None self.save_dir = get_mmf_env(key="save_dir") self.report_folder = ckpt_name_from_core_args(self.config) self.report_folder += foldername_from_config_override(self.config) self.report_folder = os.path.join(self.save_dir, self.report_folder) self.report_folder = os.path.join(self.report_folder, "reports") if self.report_folder_arg: self.report_folder = self.report_folder_arg self.candidate_fields = self.test_reporter_config.candidate_fields PathManager.mkdirs(self.report_folder) @property def current_dataset(self): self._check_current_dataloader() return self.current_dataloader.dataset def next_dataset(self, flush_report=True): if self.current_datamodule_idx >= 0: if flush_report: self.flush_report() else: self.report = [] self.current_datamodule_idx += 1 if self.current_datamodule_idx == len(self.datamodules): return False else: self.current_datamodule = self.datamodules[self.dataset_names[ self.current_datamodule_idx]] logger.info( f"Predicting for {self.dataset_names[self.current_datamodule_idx]}" ) return True def flush_report(self): if not is_master(): return name = self.current_datamodule.dataset_name time_format = "%Y-%m-%dT%H:%M:%S" time = self.timer.get_time_hhmmss(None, format=time_format) filename = name + "_" if len(self.experiment_name) > 0: filename += self.experiment_name + "_" filename += self.dataset_type + "_" filename += time use_csv_writer = (self.config.evaluation.predict_file_format == "csv" or self.test_reporter_config.predict_file_format == "csv") if use_csv_writer: filepath = os.path.join(self.report_folder, filename + ".csv") self.csv_dump(filepath) else: filepath = os.path.join(self.report_folder, filename + ".json") self.json_dump(filepath) logger.info( f"Wrote predictions for {name} to {os.path.abspath(filepath)}") self.report = [] def postprocess_dataset_report(self): self._check_current_dataloader() if hasattr(self.current_dataset, "on_prediction_end"): self.report = self.current_dataset.on_prediction_end(self.report) def csv_dump(self, filepath): with PathManager.open(filepath, "w") as f: title = self.report[0].keys() cw = csv.DictWriter(f, title, delimiter=",", quoting=csv.QUOTE_MINIMAL) cw.writeheader() cw.writerows(self.report) def json_dump(self, filepath): with PathManager.open(filepath, "w") as f: json.dump(self.report, f) def get_dataloader(self): self.current_dataloader = getattr(self.current_datamodule, f"{self.dataset_type}_dataloader")() # Make sure to assign dataset to dataloader object as # required by MMF if not hasattr(self.current_dataloader, "dataset"): self.current_dataloader.dataset = getattr( self.current_datamodule, f"{self.dataset_type}_dataset") return self.current_dataloader def prepare_batch(self, batch): self._check_current_dataloader() if hasattr(self.current_dataset, "prepare_batch"): batch = self.current_dataset.prepare_batch(batch) batch = convert_batch_to_sample_list(batch) batch.dataset_name = self.current_dataset.dataset_name batch.dataset_type = self.dataset_type return batch def __len__(self): self._check_current_dataloader() return len(self.current_dataloader) def _check_current_dataloader(self): assert self.current_dataloader is not None, ( "Please call `get_dataloader` before accessing any " + "'current_dataloader' based function") def add_to_report(self, report, model, *args, **kwargs): if "execute_on_master_only" in kwargs: warnings.warn( "'execute_on_master_only keyword is deprecated and isn't used anymore", DeprecationWarning, ) self._check_current_dataloader() for key in self.candidate_fields: report = self.reshape_and_gather(report, key) results = [] if hasattr(self.current_dataset, "format_for_prediction"): results = self.current_dataset.format_for_prediction(report) if hasattr(model, "format_for_prediction"): results = model.format_for_prediction(results, report) elif hasattr(model.module, "format_for_prediction"): results = model.module.format_for_prediction(results, report) self.report = self.report + results def reshape_and_gather(self, report, key): if key in report: num_dims = report[key].dim() if num_dims == 1: report[key] = gather_tensor(report[key]).view(-1) elif num_dims >= 2: # Collect dims other than batch other_dims = report[key].size()[1:] report[key] = gather_tensor(report[key]).view(-1, *other_dims) return report
def test_get_current(self): timer = Timer() expected = 0 self.assertEqual(int(timer.get_current().split("ms")[0]), expected)
class TestReporter(Dataset): def __init__(self, multi_task_instance): self.test_task = multi_task_instance self.task_type = multi_task_instance.dataset_type self.config = registry.get("config") self.report = [] self.timer = Timer() self.training_config = self.config.training self.num_workers = self.training_config.num_workers self.batch_size = self.training_config.batch_size self.report_folder_arg = get_mmf_env(key="report_dir") self.experiment_name = self.training_config.experiment_name self.datasets = [] for dataset in self.test_task.get_datasets(): self.datasets.append(dataset) self.current_dataset_idx = -1 self.current_dataset = self.datasets[self.current_dataset_idx] self.save_dir = get_mmf_env(key="save_dir") self.report_folder = ckpt_name_from_core_args(self.config) self.report_folder += foldername_from_config_override(self.config) self.report_folder = os.path.join(self.save_dir, self.report_folder) self.report_folder = os.path.join(self.report_folder, "reports") if self.report_folder_arg: self.report_folder = self.report_folder_arg PathManager.mkdirs(self.report_folder) def next_dataset(self): if self.current_dataset_idx >= 0: self.flush_report() self.current_dataset_idx += 1 if self.current_dataset_idx == len(self.datasets): return False else: self.current_dataset = self.datasets[self.current_dataset_idx] logger.info(f"Predicting for {self.current_dataset.dataset_name}") return True def flush_report(self): if not is_master(): return name = self.current_dataset.dataset_name time_format = "%Y-%m-%dT%H:%M:%S" time = self.timer.get_time_hhmmss(None, format=time_format) filename = name + "_" if len(self.experiment_name) > 0: filename += self.experiment_name + "_" filename += self.task_type + "_" filename += time if self.config.evaluation.predict_file_format == "csv": filepath = os.path.join(self.report_folder, filename + ".csv") self.csv_dump(filepath) else: filepath = os.path.join(self.report_folder, filename + ".json") self.json_dump(filepath) logger.info( f"Wrote evalai predictions for {name} to {os.path.abspath(filepath)}" ) self.report = [] def csv_dump(self, filepath): with PathManager.open(filepath, "w") as f: title = self.report[0].keys() cw = csv.DictWriter(f, title, delimiter=",", quoting=csv.QUOTE_MINIMAL) cw.writeheader() cw.writerows(self.report) def json_dump(self, filepath): with PathManager.open(filepath, "w") as f: json.dump(self.report, f) def get_dataloader(self): other_args = self._add_extra_args_for_dataloader() return DataLoader( dataset=self.current_dataset, collate_fn=BatchCollator( self.current_dataset.dataset_name, self.current_dataset.dataset_type ), num_workers=self.num_workers, pin_memory=self.config.training.pin_memory, **other_args, ) def _add_extra_args_for_dataloader(self, other_args=None): if other_args is None: other_args = {} if is_dist_initialized(): other_args["sampler"] = DistributedSampler( self.current_dataset, shuffle=False ) else: other_args["shuffle"] = False other_args["batch_size"] = get_batch_size() return other_args def prepare_batch(self, batch): return self.current_dataset.prepare_batch(batch) def __len__(self): return len(self.current_dataset) def __getitem__(self, idx): return self.current_dataset[idx] def add_to_report(self, report, model): keys = ["id", "question_id", "image_id", "context_tokens", "captions", "scores"] for key in keys: report = self.reshape_and_gather(report, key) if not is_master(): return results = self.current_dataset.format_for_prediction(report) if hasattr(model, "format_for_prediction"): results = model.format_for_prediction(results, report) elif hasattr(model.module, "format_for_prediction"): results = model.module.format_for_prediction(results, report) self.report = self.report + results def reshape_and_gather(self, report, key): if key in report: num_dims = report[key].dim() if num_dims == 1: report[key] = gather_tensor(report[key]).view(-1) elif num_dims >= 2: # Collect dims other than batch other_dims = report[key].size()[1:] report[key] = gather_tensor(report[key]).view(-1, *other_dims) return report
class TestReporter(Dataset): def __init__(self, multi_task_instance): self.test_task = multi_task_instance self.task_type = multi_task_instance.dataset_type self.config = registry.get("config") self.writer = registry.get("writer") self.report = [] self.timer = Timer() self.training_config = self.config.training self.num_workers = self.training_config.num_workers self.batch_size = self.training_config.batch_size self.report_folder_arg = get_mmf_env(key="report_dir") self.experiment_name = self.training_config.experiment_name self.datasets = [] for dataset in self.test_task.get_datasets(): self.datasets.append(dataset) self.current_dataset_idx = -1 self.current_dataset = self.datasets[self.current_dataset_idx] self.save_dir = get_mmf_env(key="save_dir") self.report_folder = ckpt_name_from_core_args(self.config) self.report_folder += foldername_from_config_override(self.config) self.report_folder = os.path.join(self.save_dir, self.report_folder) self.report_folder = os.path.join(self.report_folder, "reports") if self.report_folder_arg: self.report_folder = self.report_folder_arg PathManager.mkdirs(self.report_folder) def next_dataset(self): if self.current_dataset_idx >= 0: self.flush_report() self.current_dataset_idx += 1 if self.current_dataset_idx == len(self.datasets): return False else: self.current_dataset = self.datasets[self.current_dataset_idx] self.writer.write("Predicting for " + self.current_dataset.dataset_name) return True def flush_report(self): if not is_master(): return name = self.current_dataset.dataset_name time_format = "%Y-%m-%dT%H:%M:%S" time = self.timer.get_time_hhmmss(None, format=time_format) filename = name + "_" if len(self.experiment_name) > 0: filename += self.experiment_name + "_" filename += self.task_type + "_" filename += time + ".json" filepath = os.path.join(self.report_folder, filename) with PathManager.open(filepath, "w") as f: json.dump(self.report, f) self.writer.write( "Wrote evalai predictions for %s to %s" % (name, os.path.abspath(filepath)) ) self.report = [] def get_dataloader(self): other_args = self._add_extra_args_for_dataloader() return DataLoader( dataset=self.current_dataset, collate_fn=BatchCollator( self.current_dataset.dataset_name, self.current_dataset.dataset_type ), num_workers=self.num_workers, pin_memory=self.config.training.pin_memory, **other_args ) def _add_extra_args_for_dataloader(self, other_args=None): if other_args is None: other_args = {} if torch.distributed.is_initialized(): other_args["sampler"] = DistributedSampler( self.current_dataset, shuffle=False ) else: other_args["shuffle"] = False other_args["batch_size"] = get_batch_size() return other_args def prepare_batch(self, batch): return self.current_dataset.prepare_batch(batch) def __len__(self): return len(self.current_dataset) def __getitem__(self, idx): return self.current_dataset[idx] def add_to_report(self, report): # TODO: Later gather whole report for no opinions if self.current_dataset.dataset_name == "coco": report.captions = gather_tensor(report.captions) if isinstance(report.image_id, torch.Tensor): report.image_id = gather_tensor(report.image_id).view(-1) else: report.scores = gather_tensor(report.scores).view( -1, report.scores.size(-1) ) if "question_id" in report: report.question_id = gather_tensor(report.question_id).view(-1) if "image_id" in report: _, enc_size = report.image_id.size() report.image_id = gather_tensor(report.image_id) report.image_id = report.image_id.view(-1, enc_size) if "context_tokens" in report: _, enc_size = report.context_tokens.size() report.context_tokens = gather_tensor(report.context_tokens) report.context_tokens = report.context_tokens.view(-1, enc_size) if not is_master(): return results = self.current_dataset.format_for_evalai(report) self.report = self.report + results
def on_train_start(self): self.train_timer = Timer() self.snapshot_timer = Timer()
class LightningLoopCallback(Callback): def __init__(self, lightning_trainer: Any): super().__init__() self.lightning_trainer = lightning_trainer # this is lightning trainer's config self.trainer_config = lightning_trainer.trainer_config # training config configures training parameters. self.training_config = lightning_trainer.training_config self.run_type = lightning_trainer.run_type # for logging self.total_timer = Timer() self.snapshot_timer = Timer() self.snapshot_iterations = len(self.lightning_trainer.val_loader) self.train_timer = Timer() def on_train_start(self, trainer: Trainer, pl_module: LightningModule): registry.register("current_epoch", trainer.current_epoch) self.train_combined_report = None def on_train_batch_end( self, trainer: Trainer, pl_module: LightningModule, outputs: List, batch: SampleList, batch_idx: int, dataloader_idx: int, ): # prepare the next batch self.lightning_trainer.data_module.train_loader.change_dataloader() # aggregate train combined_report self.train_combined_report = self._update_and_create_report( SampleList(batch), batch_idx, outputs, pl_module, self.train_combined_report) # log if (trainer.global_step + 1) % self.trainer_config.log_every_n_steps == 0: self._train_log(trainer, pl_module) # save checkpoints - TODO: @sash def on_train_end(self, trainer: Trainer, pl_module: LightningModule): # Only do when run_type has train as it shouldn't happen on validation and # inference runs. Inference will take care of this anyways. Also, don't run # if current iteration is divisble by snapshot interval as it will just # be a repeat if ("train" in self.run_type and trainer.global_step % self.trainer_config.val_check_interval != 0): logger.info("Stepping into final validation check") # Pytorch Lightning upgrades PR4945 and PR4948 will be enabled in 1.2 # TODO: perform final validation check # Validation Callbacks def on_validation_start(self, trainer: Trainer, pl_module: LightningModule): logger.info("Evaluation time. Running on full validation set...") self.snapshot_timer.reset() self.val_combined_report = None pl_module.val_meter.reset() def on_validation_batch_end( self, trainer: Trainer, pl_module: LightningModule, outputs: List, batch: SampleList, batch_idx: int, dataloader_idx: int, ): # prepare the next batch self.lightning_trainer.data_module.val_loader.change_dataloader() # aggregate val_combined_report self.val_combined_report = self._update_and_create_report( batch, batch_idx, outputs, pl_module, self.val_combined_report, update_meter=pl_module.val_meter, ) self.val_combined_report.metrics = pl_module.metrics( self.val_combined_report, self.val_combined_report) pl_module.val_meter.update_from_report(self.val_combined_report, should_update_loss=False) def on_validation_end(self, trainer: Trainer, pl_module: LightningModule): iterations = self._get_iterations_for_logging(trainer) current_epochs = self._get_current_epoch_for_logging(trainer) num_updates = self._get_num_updates_for_logging(trainer) extra = { "num_updates": num_updates, "epoch": current_epochs, "iterations": iterations, "max_updates": trainer.max_steps, "val_time": self.snapshot_timer.get_time_since_start(), } # TODO: @sash populate early stop info for logging (next mvp) # extra.update(self.trainer.early_stop_callback.early_stopping.get_info()) self.train_timer.reset() summarize_report( current_iteration=iterations, num_updates=num_updates, max_updates=trainer.max_steps, meter=pl_module.val_meter, extra=extra, tb_writer=self.lightning_trainer.tb_writer, ) def _update_and_create_report( self, batch: Dict, batch_idx: int, step_output: Dict, pl_module: LightningModule, combined_report: Report = None, update_meter: Meter = None, ): report = Report(batch, step_output) if update_meter: update_meter.update_from_report(report) should_accumulate = not ( batch_idx % self.trainer_config.accumulate_grad_batches == 0) final_report = report if should_accumulate and combined_report is not None: combined_report.accumulate_tensor_fields_and_loss( report, pl_module.metrics.required_params) combined_report.batch_size += report.batch_size final_report = combined_report return final_report def get_optimizer(self, trainer: Trainer): assert ( len(trainer.optimizers) == 1 ), "mmf lightning_trainer supports 1 optimizer per model for now." optimizer = trainer.optimizers[0] return optimizer def _save_checkpoint(self, trainer: Trainer): logger.info("Checkpoint time. Saving a checkpoint.") return # TODO: sash Needs implementation - next mvp def _get_current_epoch_for_logging(self, trainer: Trainer): return trainer.current_epoch + 1 def _get_iterations_for_logging(self, trainer: Trainer): return trainer.train_loop.batch_idx + 1 def _get_num_updates_for_logging(self, trainer: Trainer): return trainer.global_step + 1 def _train_log(self, trainer: Trainer, pl_module: LightningModule): if self.training_config.evaluate_metrics: self.train_combined_report.metrics = pl_module.metrics( self.train_combined_report, self.train_combined_report) pl_module.train_meter.update_from_report(self.train_combined_report) extra = {} if "cuda" in str(trainer.model.device): extra["max mem"] = torch.cuda.max_memory_allocated() / 1024 extra["max mem"] //= 1024 if self.training_config.experiment_name: extra["experiment"] = self.training_config.experiment_name optimizer = self.get_optimizer(trainer) num_updates = self._get_num_updates_for_logging(trainer) current_iteration = self._get_iterations_for_logging(trainer) extra.update({ "epoch": self._get_current_epoch_for_logging(trainer), "iterations": current_iteration, "num_updates": num_updates, "max_updates": trainer.max_steps, "lr": "{:.5f}".format(optimizer.param_groups[0]["lr"]).rstrip("0"), "ups": "{:.2f}".format(self.trainer_config.log_every_n_steps / self.train_timer.unix_time_since_start()), "time": self.train_timer.get_time_since_start(), "time_since_start": self.total_timer.get_time_since_start(), "eta": calculate_time_left( max_updates=trainer.max_steps, num_updates=num_updates, timer=self.train_timer, num_snapshot_iterations=self.snapshot_iterations, log_interval=self.trainer_config.log_every_n_steps, eval_interval=self.trainer_config.val_check_interval, ), }) self.train_timer.reset() summarize_report( current_iteration=current_iteration, num_updates=num_updates, max_updates=trainer.max_steps, meter=pl_module.train_meter, extra=extra, tb_writer=self.lightning_trainer.tb_writer, )