def __init__(self, trainer): """ Generates a path for saving model which can also be used for resuming from a checkpoint. """ self.trainer = trainer self.config = self.trainer.config self.save_dir = self.config.training_parameters.save_dir self.model_name = self.config.model self.ckpt_foldername = ckpt_name_from_core_args(self.config) self.ckpt_foldername += foldername_from_config_override(self.trainer.args) self.device = registry.get("current_device") self.ckpt_prefix = "" if hasattr(self.trainer.model, "get_ckpt_name"): self.ckpt_prefix = self.trainer.model.get_ckpt_name() + "_" self.config["log_foldername"] = self.ckpt_foldername self.ckpt_foldername = os.path.join(self.save_dir, self.ckpt_foldername) self.pth_filepath = os.path.join( self.ckpt_foldername, self.ckpt_prefix + self.model_name + getattr(self.config.model_attributes, self.model_name).code_name + "_final.pth" ) self.models_foldername = os.path.join(self.ckpt_foldername, "models") if not os.path.exists(self.models_foldername): os.makedirs(self.models_foldername) self.save_config() self.repo_path = updir(os.path.abspath(__file__), n=3)
def __init__(self, multi_task_instance): self.test_task = multi_task_instance self.task_type = multi_task_instance.dataset_type self.config = registry.get("config") self.writer = registry.get("writer") self.report = [] self.timer = Timer() self.training_parameters = self.config["training_parameters"] self.num_workers = self.training_parameters["num_workers"] self.batch_size = self.training_parameters["batch_size"] self.report_folder_arg = self.config.get("report_folder", None) self.experiment_name = self.training_parameters.get("experiment_name", "") self.datasets = [] for task in self.test_task.get_tasks(): for dataset in task.get_datasets(): self.datasets.append(dataset) self.current_dataset_idx = -1 self.current_dataset = self.datasets[self.current_dataset_idx] self.save_dir = self.config.get("save_dir", "./save") self.report_folder = ckpt_name_from_core_args(self.config) self.report_folder += foldername_from_config_override(self.config) self.report_folder = os.path.join(self.save_dir, self.report_folder) self.report_folder = os.path.join(self.report_folder, "reports") if self.report_folder_arg is not None: self.report_folder = self.report_folder_arg if not os.path.exists(self.report_folder): os.makedirs(self.report_folder)
def __init__(self, config): self.logger = None self.summary_writer = None if not is_main_process(): return self.timer = Timer() self.config = config self.save_dir = config.training_parameters.save_dir self.log_folder = ckpt_name_from_core_args(config) self.log_folder += foldername_from_config_override(config) time_format = "%Y-%m-%dT%H:%M:%S" self.log_filename = ckpt_name_from_core_args(config) + "_" self.log_filename += self.timer.get_time_hhmmss(None, format=time_format) self.log_filename += ".log" self.log_folder = os.path.join(self.save_dir, self.log_folder, "logs") arg_log_dir = self.config.get("log_dir", None) if arg_log_dir: self.log_folder = arg_log_dir if not os.path.exists(self.log_folder): os.makedirs(self.log_folder) tensorboard_folder = os.path.join(self.log_folder, "tensorboard") self.summary_writer = SummaryWriter(tensorboard_folder) self.log_filename = os.path.join(self.log_folder, self.log_filename) print("Logging to:", self.log_filename) logging.captureWarnings(True) self.logger = logging.getLogger(__name__) self._file_only_logger = logging.getLogger(__name__) warnings_logger = logging.getLogger("py.warnings") # Set level level = config["training_parameters"].get("logger_level", "info") self.logger.setLevel(getattr(logging, level.upper())) self._file_only_logger.setLevel(getattr(logging, level.upper())) formatter = logging.Formatter("%(asctime)s %(levelname)s: %(message)s", datefmt="%Y-%m-%dT%H:%M:%S") # Add handler to file channel = logging.FileHandler(filename=self.log_filename, mode="a") channel.setFormatter(formatter) self.logger.addHandler(channel) self._file_only_logger.addHandler(channel) warnings_logger.addHandler(channel) # Add handler to stdout channel = logging.StreamHandler(sys.stdout) channel.setFormatter(formatter) self.logger.addHandler(channel) warnings_logger.addHandler(channel) should_not_log = self.config["training_parameters"]["should_not_log"] self.should_log = not should_not_log # Single log wrapper map self._single_log_map = set()