class TensorboardLogger: def __init__(self, log_folder="./logs", iteration=0): # This would handle warning of missing tensorboard from torch.utils.tensorboard import SummaryWriter self.summary_writer = None self._is_master = is_master() self.timer = Timer() self.log_folder = log_folder self.time_format = "%Y-%m-%dT%H:%M:%S" if self._is_master: current_time = self.timer.get_time_hhmmss(None, format=self.time_format) tensorboard_folder = os.path.join(self.log_folder, f"tensorboard_{current_time}") self.summary_writer = SummaryWriter(tensorboard_folder) def __del__(self): if getattr(self, "summary_writer", None) is not None: self.summary_writer.close() def _should_log_tensorboard(self): if self.summary_writer is None or not self._is_master: return False else: return True def add_scalar(self, key, value, iteration): if not self._should_log_tensorboard(): return self.summary_writer.add_scalar(key, value, iteration) def add_scalars(self, scalar_dict, iteration): if not self._should_log_tensorboard(): return for key, val in scalar_dict.items(): self.summary_writer.add_scalar(key, val, iteration) def add_histogram_for_model(self, model, iteration): if not self._should_log_tensorboard(): return for name, param in model.named_parameters(): np_param = param.clone().cpu().data.numpy() self.summary_writer.add_histogram(name, np_param, iteration)
class TestReporter(Dataset): def __init__(self, multi_task_instance): self.test_task = multi_task_instance self.task_type = multi_task_instance.dataset_type self.config = registry.get("config") self.report = [] self.timer = Timer() self.training_config = self.config.training self.num_workers = self.training_config.num_workers self.batch_size = self.training_config.batch_size self.report_folder_arg = get_mmf_env(key="report_dir") self.experiment_name = self.training_config.experiment_name self.datasets = [] for dataset in self.test_task.get_datasets(): self.datasets.append(dataset) self.current_dataset_idx = -1 self.current_dataset = self.datasets[self.current_dataset_idx] self.save_dir = get_mmf_env(key="save_dir") self.report_folder = ckpt_name_from_core_args(self.config) self.report_folder += foldername_from_config_override(self.config) self.report_folder = os.path.join(self.save_dir, self.report_folder) self.report_folder = os.path.join(self.report_folder, "reports") if self.report_folder_arg: self.report_folder = self.report_folder_arg PathManager.mkdirs(self.report_folder) def next_dataset(self): if self.current_dataset_idx >= 0: self.flush_report() self.current_dataset_idx += 1 if self.current_dataset_idx == len(self.datasets): return False else: self.current_dataset = self.datasets[self.current_dataset_idx] logger.info(f"Predicting for {self.current_dataset.dataset_name}") return True def flush_report(self): if not is_master(): return name = self.current_dataset.dataset_name time_format = "%Y-%m-%dT%H:%M:%S" time = self.timer.get_time_hhmmss(None, format=time_format) filename = name + "_" if len(self.experiment_name) > 0: filename += self.experiment_name + "_" filename += self.task_type + "_" filename += time if self.config.evaluation.predict_file_format == "csv": filepath = os.path.join(self.report_folder, filename + ".csv") self.csv_dump(filepath) else: filepath = os.path.join(self.report_folder, filename + ".json") self.json_dump(filepath) logger.info( f"Wrote predictions for {name} to {os.path.abspath(filepath)}") self.report = [] def csv_dump(self, filepath): with PathManager.open(filepath, "w") as f: title = self.report[0].keys() cw = csv.DictWriter(f, title, delimiter=",", quoting=csv.QUOTE_MINIMAL) cw.writeheader() cw.writerows(self.report) def json_dump(self, filepath): with PathManager.open(filepath, "w") as f: json.dump(self.report, f) def get_dataloader(self): dataloader, _ = build_dataloader_and_sampler(self.current_dataset, self.training_config) return dataloader def prepare_batch(self, batch): if hasattr(self.current_dataset, "prepare_batch"): batch = self.current_dataset.prepare_batch(batch) return batch def __len__(self): return len(self.current_dataset) def __getitem__(self, idx): return self.current_dataset[idx] def add_to_report(self, report, model): keys = [ "id", "question_id", "image_id", "context_tokens", "captions", "scores" ] for key in keys: report = self.reshape_and_gather(report, key) if not is_master(): return results = self.current_dataset.format_for_prediction(report) if hasattr(model, "format_for_prediction"): results = model.format_for_prediction(results, report) elif hasattr(model.module, "format_for_prediction"): results = model.module.format_for_prediction(results, report) self.report = self.report + results def reshape_and_gather(self, report, key): if key in report: num_dims = report[key].dim() if num_dims == 1: report[key] = gather_tensor(report[key]).view(-1) elif num_dims >= 2: # Collect dims other than batch other_dims = report[key].size()[1:] report[key] = gather_tensor(report[key]).view(-1, *other_dims) return report
class LogisticsCallback(Callback): """Callback for handling train/validation logistics, report summarization, logging etc. """ def __init__(self, config, trainer): """ Attr: config(mmf_typings.DictConfig): Config for the callback trainer(Type[BaseTrainer]): Trainer object """ super().__init__(config, trainer) self.total_timer = Timer() self.log_interval = self.training_config.log_interval self.evaluation_interval = self.training_config.evaluation_interval self.checkpoint_interval = self.training_config.checkpoint_interval # Total iterations for snapshot self.snapshot_iterations = len(self.trainer.val_dataset) self.snapshot_iterations //= self.training_config.batch_size self.tb_writer = None if self.training_config.tensorboard: log_dir = setup_output_folder(folder_only=True) env_tb_logdir = get_mmf_env(key="tensorboard_logdir") if env_tb_logdir: log_dir = env_tb_logdir self.tb_writer = TensorboardLogger(log_dir, self.trainer.current_iteration) def on_train_start(self): self.train_timer = Timer() self.snapshot_timer = Timer() def on_update_end(self, **kwargs): if not kwargs["should_log"]: return extra = {} if "cuda" in str(self.trainer.device): extra["max mem"] = torch.cuda.max_memory_allocated() / 1024 extra["max mem"] //= 1024 if self.training_config.experiment_name: extra["experiment"] = self.training_config.experiment_name extra.update({ "epoch": self.trainer.current_epoch, "num_updates": self.trainer.num_updates, "iterations": self.trainer.current_iteration, "max_updates": self.trainer.max_updates, "lr": "{:.5f}".format( self.trainer.optimizer.param_groups[0]["lr"]).rstrip("0"), "ups": "{:.2f}".format(self.log_interval / self.train_timer.unix_time_since_start()), "time": self.train_timer.get_time_since_start(), "time_since_start": self.total_timer.get_time_since_start(), "eta": self._calculate_time_left(), }) self.train_timer.reset() self._summarize_report(kwargs["meter"], extra=extra) def on_validation_start(self, **kwargs): self.snapshot_timer.reset() def on_validation_end(self, **kwargs): extra = { "num_updates": self.trainer.num_updates, "epoch": self.trainer.current_epoch, "iterations": self.trainer.current_iteration, "max_updates": self.trainer.max_updates, "val_time": self.snapshot_timer.get_time_since_start(), } extra.update( self.trainer.early_stop_callback.early_stopping.get_info()) self.train_timer.reset() self._summarize_report(kwargs["meter"], extra=extra) def on_test_end(self, **kwargs): prefix = "{}: full {}".format(kwargs["report"].dataset_name, kwargs["report"].dataset_type) self._summarize_report(kwargs["meter"], prefix) logger.info( f"Finished run in {self.total_timer.get_time_since_start()}") def _summarize_report(self, meter, should_print=True, extra=None): if extra is None: extra = {} if not is_master(): return if self.training_config.tensorboard: scalar_dict = meter.get_scalar_dict() self.tb_writer.add_scalars(scalar_dict, self.trainer.current_iteration) if not should_print: return log_dict = {} if hasattr(self.trainer, "num_updates") and hasattr( self.trainer, "max_updates"): log_dict.update({ "progress": f"{self.trainer.num_updates}/{self.trainer.max_updates}" }) log_dict.update(meter.get_log_dict()) log_dict.update(extra) log_progress(log_dict) def _calculate_time_left(self): time_taken_for_log = time.time() * 1000 - self.train_timer.start iterations_left = self.trainer.max_updates - self.trainer.num_updates num_logs_left = iterations_left / self.log_interval time_left = num_logs_left * time_taken_for_log snapshot_iteration = self.snapshot_iterations / self.log_interval snapshot_iteration *= iterations_left / self.evaluation_interval time_left += snapshot_iteration * time_taken_for_log return self.train_timer.get_time_hhmmss(gap=time_left)