def __init__(self, multi_task_instance): self.test_task = multi_task_instance self.task_type = multi_task_instance.dataset_type self.config = registry.get("config") self.writer = registry.get("writer") self.report = [] self.timer = Timer() self.training_config = self.config.training self.num_workers = self.training_config.num_workers self.batch_size = self.training_config.batch_size self.report_folder_arg = get_mmf_env(key="report_dir") self.experiment_name = self.training_config.experiment_name self.datasets = [] for dataset in self.test_task.get_datasets(): self.datasets.append(dataset) self.current_dataset_idx = -1 self.current_dataset = self.datasets[self.current_dataset_idx] self.save_dir = get_mmf_env(key="save_dir") self.report_folder = ckpt_name_from_core_args(self.config) self.report_folder += foldername_from_config_override(self.config) self.report_folder = os.path.join(self.save_dir, self.report_folder) self.report_folder = os.path.join(self.report_folder, "reports") if self.report_folder_arg: self.report_folder = self.report_folder_arg PathManager.mkdirs(self.report_folder)
def setup_output_folder(folder_only: bool = False): """Sets up and returns the output file where the logs will be placed based on the configuration passed. Usually "save_dir/logs/log_<timestamp>.txt". If env.log_dir is passed, logs will be directly saved in this folder. Args: folder_only (bool, optional): If folder should be returned and not the file. Defaults to False. Returns: str: folder or file path depending on folder_only flag """ save_dir = get_mmf_env(key="save_dir") time_format = "%Y_%m_%dT%H_%M_%S" log_filename = "train_" log_filename += Timer().get_time_hhmmss(None, format=time_format) log_filename += ".log" log_folder = os.path.join(save_dir, "logs") env_log_dir = get_mmf_env(key="log_dir") if env_log_dir: log_folder = env_log_dir if not PathManager.exists(log_folder): PathManager.mkdirs(log_folder) if folder_only: return log_folder log_filename = os.path.join(log_folder, log_filename) return log_filename
def download_pretrained_model(model_name, *args, **kwargs): import omegaconf from omegaconf import OmegaConf from mmf.utils.configuration import load_yaml, get_mmf_env model_zoo = load_yaml(get_mmf_env(key="model_zoo")) OmegaConf.set_struct(model_zoo, True) OmegaConf.set_readonly(model_zoo, True) data_dir = get_absolute_path(get_mmf_env("data_dir")) model_data_dir = os.path.join(data_dir, "models") download_path = os.path.join(model_data_dir, model_name) try: model_config = OmegaConf.select(model_zoo, model_name) except omegaconf.errors.OmegaConfBaseException as e: print(f"No such model name {model_name} defined in mmf zoo") raise e if "version" not in model_config or "resources" not in model_config: # Version and Resources are not present time to try the defaults try: model_config = model_config.defaults download_path = os.path.join(model_data_dir, model_name + ".defaults") except omegaconf.errors.OmegaConfBaseException as e: print( f"Model name {model_name} doesn't specify 'resources' and 'version' " "while no defaults have been provided" ) raise e # Download requirements if any specified by "zoo_requirements" field # This can either be a list or a string if "zoo_requirements" in model_config: requirements = model_config.zoo_requirements if isinstance(requirements, str): requirements = [requirements] for item in requirements: download_pretrained_model(item, *args, **kwargs) version = model_config.version resources = model_config.resources if is_master(): download_resources(resources, download_path, version) synchronize() return download_path
def _get_model_folder(self): home = str(Path.home()) data_dir = get_mmf_env(key="data_dir") model_folder = os.path.join(home, data_dir, "models", "mmbt.hateful_memes.images") return model_folder
def __init__(self, config, trainer): """ Attr: config(mmf_typings.DictConfig): Config for the callback trainer(Type[BaseTrainer]): Trainer object """ super().__init__(config, trainer) self.total_timer = Timer() self.log_interval = self.training_config.log_interval self.evaluation_interval = self.training_config.evaluation_interval self.checkpoint_interval = self.training_config.checkpoint_interval # Total iterations for snapshot self.snapshot_iterations = len(self.trainer.val_dataset) self.snapshot_iterations //= self.training_config.batch_size self.tb_writer = None if self.training_config.tensorboard: log_dir = setup_output_folder(folder_only=True) env_tb_logdir = get_mmf_env(key="tensorboard_logdir") if env_tb_logdir: log_dir = env_tb_logdir self.tb_writer = TensorboardLogger(log_dir, self.trainer.current_iteration)
def __init__(self, trainer): """ Generates a path for saving model which can also be used for resuming from a checkpoint. """ self.trainer = trainer self.config = self.trainer.config self.save_dir = get_mmf_env(key="save_dir") self.model_name = self.config.model self.ckpt_foldername = self.save_dir self.device = registry.get("current_device") self.ckpt_prefix = "" if hasattr(self.trainer.model, "get_ckpt_name"): self.ckpt_prefix = self.trainer.model.get_ckpt_name() + "_" self.pth_filepath = os.path.join( self.ckpt_foldername, self.ckpt_prefix + self.model_name + "_final.pth") self.models_foldername = os.path.join(self.ckpt_foldername, "models") if not PathManager.exists(self.models_foldername): PathManager.mkdirs(self.models_foldername) self.save_config() self.repo_path = updir(os.path.abspath(__file__), n=3) self.git_repo = None if git: self.git_repo = git.Repo(self.repo_path)
def __init__(self, config: BatchProcessorConfigType, *args, **kwargs): extra_params = {"data_dir": get_mmf_env(key="data_dir")} processors_dict = config.get("processors", {}) # Since build_processors also imports processor, import it at runtime to # avoid circulat dependencies from mmf.utils.build import build_processors self.processors = build_processors(processors_dict, **extra_params)
def __init__(self, multi_task_instance, test_reporter_config: TestReporterConfigType = None): if not isinstance(test_reporter_config, TestReporter.TestReporterConfigType): test_reporter_config = TestReporter.TestReporterConfigType( **test_reporter_config) self.test_task = multi_task_instance self.task_type = multi_task_instance.dataset_type self.config = registry.get("config") self.report = [] self.timer = Timer() self.training_config = self.config.training self.num_workers = self.training_config.num_workers self.batch_size = self.training_config.batch_size self.report_folder_arg = get_mmf_env(key="report_dir") self.experiment_name = self.training_config.experiment_name self.test_reporter_config = test_reporter_config self.datasets = [] for dataset in self.test_task.get_datasets(): self.datasets.append(dataset) self.current_dataset_idx = -1 self.current_dataset = self.datasets[self.current_dataset_idx] self.save_dir = get_mmf_env(key="save_dir") self.report_folder = ckpt_name_from_core_args(self.config) self.report_folder += foldername_from_config_override(self.config) self.report_folder = os.path.join(self.save_dir, self.report_folder) self.report_folder = os.path.join(self.report_folder, "reports") if self.report_folder_arg: self.report_folder = self.report_folder_arg self.candidate_fields = DEFAULT_CANDIDATE_FIELDS if not test_reporter_config.candidate_fields == MISSING: self.candidate_fields = test_reporter_config.candidate_fields PathManager.mkdirs(self.report_folder)
def _load_loggers(self) -> None: self.tb_writer = None if self.training_config.tensorboard: # TODO: @sash PL logger upgrade log_dir = setup_output_folder(folder_only=True) env_tb_logdir = get_mmf_env(key="tensorboard_logdir") if env_tb_logdir: log_dir = env_tb_logdir self.tb_writer = TensorboardLogger(log_dir)
def gqa_feature_loader(): global img, img_info if img is not None: return img, img_info path = os.path.join(get_mmf_env("data_dir"), "datasets", "gqa", "defaults", "features") h = h5py.File(f'{path}/gqa_spatial.hdf5', 'r') img = h['features'] img_info = json.load(open(f'{path}/gqa_spatial_merged_info.json', 'r')) return img, img_info
def __init__( self, datamodules: List[pl.LightningDataModule], config: Config = None, dataset_type: str = "train", ): self.test_reporter_config = OmegaConf.merge( OmegaConf.structured(self.Config), config ) self.datamodules = datamodules self.dataset_type = dataset_type self.config = registry.get("config") self.report = [] self.timer = Timer() self.training_config = self.config.training self.num_workers = self.training_config.num_workers self.batch_size = self.training_config.batch_size self.report_folder_arg = get_mmf_env(key="report_dir") self.experiment_name = self.training_config.experiment_name self.current_datamodule_idx = -1 self.dataset_names = list(self.datamodules.keys()) self.current_datamodule = self.datamodules[ self.dataset_names[self.current_datamodule_idx] ] self.current_dataloader = None self.save_dir = get_mmf_env(key="save_dir") self.report_folder = ckpt_name_from_core_args(self.config) self.report_folder += foldername_from_config_override(self.config) self.report_folder = os.path.join(self.save_dir, self.report_folder) self.report_folder = os.path.join(self.report_folder, "reports") if self.report_folder_arg: self.report_folder = self.report_folder_arg self.candidate_fields = self.test_reporter_config.candidate_fields PathManager.mkdirs(self.report_folder) log_class_usage("TestReporter", self.__class__)
def __init__(self, config, trainer): """ Attr: config(mmf_typings.DictConfig): Config for the callback trainer(Type[BaseTrainer]): Trainer object """ super().__init__(config, trainer) self.total_timer = Timer() self.log_interval = self.training_config.log_interval self.evaluation_interval = self.training_config.evaluation_interval self.checkpoint_interval = self.training_config.checkpoint_interval # Total iterations for snapshot # len would be number of batches per GPU == max updates self.snapshot_iterations = len(self.trainer.val_loader) self.tb_writer = None self.wandb_logger = None if self.training_config.tensorboard: log_dir = setup_output_folder(folder_only=True) env_tb_logdir = get_mmf_env(key="tensorboard_logdir") if env_tb_logdir: log_dir = env_tb_logdir self.tb_writer = TensorboardLogger(log_dir, self.trainer.current_iteration) if self.training_config.wandb.enabled: log_dir = setup_output_folder(folder_only=True) env_wandb_logdir = get_mmf_env(key="wandb_logdir") if env_wandb_logdir: log_dir = env_wandb_logdir self.wandb_logger = WandbLogger( entity=config.training.wandb.entity, config=config, project=config.training.wandb.project, )
def test_import_user_module_from_file(self): self.assertIsNone(registry.get_builder_class("always_one")) self.assertIsNone(registry.get_model_class("simple")) user_dir = self._get_user_dir() user_file = os.path.join(user_dir, "models", "simple.py") import_user_module(user_file) # Only model should be found and build should be none self.assertIsNone(registry.get_builder_class("always_one")) self.assertIsNotNone(registry.get_model_class("simple")) self.assertTrue("mmf_user_dir" in sys.modules) self.assertTrue(user_dir in get_mmf_env("user_dir"))
def configure_monitor_callbacks(self) -> List[ModelCheckpoint]: criteria, mode = self.monitor_criteria() monitor_callback = ModelCheckpoint( monitor=criteria, dirpath=get_mmf_env(key="save_dir"), filename="best", mode=mode, save_top_k=1, save_last=False, verbose=True, ) return [monitor_callback]
def configure_checkpoint_callbacks(self) -> List[ModelCheckpoint]: train_callback = ModelCheckpoint( monitor=None, every_n_train_steps=self.config.training.checkpoint_interval, dirpath=get_mmf_env(key="save_dir"), filename="models/model_{step}", save_top_k=-1, save_last=True, verbose=True, ) train_callback.CHECKPOINT_NAME_LAST = "current" return [train_callback]
def is_zoo_path(self, path) -> bool: from mmf.utils.configuration import get_mmf_env, load_yaml model_zoo = load_yaml(get_mmf_env(key="model_zoo")) OmegaConf.set_struct(model_zoo, True) OmegaConf.set_readonly(model_zoo, True) try: model_config = OmegaConf.select(model_zoo, path) return model_config is not None except omegaconf.errors.OmegaConfBaseException: return False
def _create_checkpoint_file(self, path): home = str(Path.home()) data_dir = get_mmf_env(key="data_dir") model_folder = os.path.join(home, data_dir, "models", "mmbt.hateful_memes.images") model_file = os.path.join(model_folder, "model.pth") config_file = os.path.join(model_folder, "config.yaml") config = load_yaml(config_file) with PathManager.open(model_file, "rb") as f: ckpt = torch.load(f) ckpt["config"] = config torch.save(ckpt, path)
def _load_trainer(self): lightning_params = self.trainer_config with omegaconf.open_dict(lightning_params): lightning_params.pop("max_steps") lightning_params.pop("max_epochs") lightning_params_dict = OmegaConf.to_container(lightning_params, resolve=True) self.trainer = Trainer(callbacks=self._callbacks, max_steps=self._max_updates, default_root_dir=get_mmf_env(key="log_dir"), **lightning_params_dict)
def test_import_user_module_from_directory_absolute(self, abs_path=True): # Make sure the modules are not available first self.assertIsNone(registry.get_builder_class("always_one")) self.assertIsNone(registry.get_model_class("simple")) self.assertFalse("mmf_user_dir" in sys.modules) # Now, import and test user_dir = self._get_user_dir(abs_path) import_user_module(user_dir) self.assertIsNotNone(registry.get_builder_class("always_one")) self.assertIsNotNone(registry.get_model_class("simple")) self.assertTrue("mmf_user_dir" in sys.modules) self.assertTrue(user_dir in get_mmf_env("user_dir"))
def build(self, config, *args, **kwargs): # First, check whether manual downloads have been performed data_dir = get_mmf_env(key="data_dir") test_path = get_absolute_path( os.path.join( data_dir, "annotations", "train.jsonl", )) # NOTE: This doesn't check for files, but that is a fine assumption for now assert PathManager.exists(test_path), ( "Hateful Memes Dataset doesn't do automatic downloads; please " + "follow instructions at https://fb.me/hm_prerequisites") super().build(config, *args, **kwargs)
def load_extras(self): self.writer.write("Torch version is: " + torch.__version__) self.checkpoint = Checkpoint(self) self.meter = Meter() self.training_config = self.config.training early_stop_criteria = self.training_config.early_stop.criteria early_stop_minimize = self.training_config.early_stop.minimize early_stop_enabled = self.training_config.early_stop.enabled early_stop_patience = self.training_config.early_stop.patience self.log_interval = self.training_config.log_interval self.evaluation_interval = self.training_config.evaluation_interval self.checkpoint_interval = self.training_config.checkpoint_interval self.max_updates = self.training_config.max_updates self.should_clip_gradients = self.training_config.clip_gradients self.max_epochs = self.training_config.max_epochs self.early_stopping = EarlyStopping( self.model, self.checkpoint, early_stop_criteria, patience=early_stop_patience, minimize=early_stop_minimize, should_stop=early_stop_enabled, ) self.current_epoch = 0 self.current_iteration = 0 self.num_updates = 0 self.checkpoint.load_state_dict() self.not_debug = self.training_config.logger_level != "debug" self.lr_scheduler = None if self.training_config.lr_scheduler is True: self.lr_scheduler = build_scheduler(self.optimizer, self.config) self.tb_writer = None if self.training_config.tensorboard: log_dir = self.writer.log_dir env_tb_logdir = get_mmf_env(key="tensorboard_logdir") if env_tb_logdir: log_dir = env_tb_logdir self.tb_writer = TensorboardLogger(log_dir, self.current_iteration)
def build(self, config, *args, **kwargs): # First, check whether manual downloads have been performed data_dir = get_mmf_env(key="data_dir") test_path = get_absolute_path( os.path.join( data_dir, "datasets", self.dataset_name, "defaults", "annotations", "train.jsonl", )) # NOTE: This doesn't check for files, but that is a fine assumption for now assert PathManager.exists(test_path) super().build(config, *args, **kwargs)
def _download_requirement(self, config, requirement_key, requirement_variation="defaults"): version, resources = get_zoo_config(requirement_key, requirement_variation, self.zoo_config_path, self.zoo_type) if resources is None: return requirement_split = requirement_key.split(".") dataset_name = requirement_split[0] # The dataset variation has been directly passed in the key so use it instead if len(requirement_split) >= 2: dataset_variation = requirement_split[1] else: dataset_variation = requirement_variation # We want to use root env data_dir so that we don't mix up our download # root dir with the dataset ones download_path = os.path.join(get_mmf_env("data_dir"), "datasets", dataset_name, dataset_variation) download_path = get_absolute_path(download_path) if not isinstance(resources, collections.abc.Mapping): self._download_resources(resources, download_path, version) else: use_features = config.get("use_features", False) use_images = config.get("use_images", False) if use_features: self._download_based_on_attribute(resources, download_path, version, "features") if use_images: self._download_based_on_attribute(resources, download_path, version, "images") self._download_based_on_attribute(resources, download_path, version, "annotations") self._download_resources(resources.get("extras", []), download_path, version)
def get_absolute_path(paths): # String check should be first as Sequence would pass for string too if isinstance(paths, str): # If path is absolute return it directly if os.path.isabs(paths): return paths possible_paths = [ # Direct path paths ] # Now, try relative to user_dir if it exists from mmf.utils.configuration import get_mmf_env mmf_root = get_mmf_root() user_dir = get_mmf_env(key="user_dir") if user_dir: possible_paths.append(os.path.join(user_dir, paths)) # check in relative to mmf relative user dir possible_paths.append(os.path.join(mmf_root, "..", user_dir, paths)) # Relative to root folder of mmf install possible_paths.append(os.path.join(mmf_root, "..", paths)) # Relative to mmf root possible_paths.append(os.path.join(mmf_root, paths)) # Test all these paths, if any exists return for path in possible_paths: if PathManager.exists(path): # URIs if path.find("://") == -1: return os.path.abspath(path) else: return path # If nothing works, return original path so that it throws an error return paths elif isinstance(paths, collections.abc.Iterable): return [get_absolute_path(path) for path in paths] else: raise TypeError("Paths passed to dataset should either be " "string or list")
def __init__(self, trainer): """ Generates a path for saving model which can also be used for resuming from a checkpoint. """ self.trainer = trainer self.config = self.trainer.config self.save_dir = get_mmf_env(key="save_dir") self.model_name = self.config.model self.ckpt_foldername = self.save_dir self.device = registry.get("current_device") self.ckpt_prefix = "" if hasattr(self.trainer.model, "get_ckpt_name"): self.ckpt_prefix = self.trainer.model.get_ckpt_name() + "_" self.pth_filepath = os.path.join( self.ckpt_foldername, self.ckpt_prefix + self.model_name + "_final.pth") self.models_foldername = os.path.join(self.ckpt_foldername, "models") if not PathManager.exists(self.models_foldername): PathManager.mkdirs(self.models_foldername) self.save_config() self.repo_path = updir(os.path.abspath(__file__), n=3) self.git_repo = None if git and self.config.checkpoint.save_git_details: try: self.git_repo = git.Repo(self.repo_path) except git.exc.InvalidGitRepositoryError: # Not a git repo, don't do anything pass self.max_to_keep = self.config.checkpoint.max_to_keep self.saved_iterations = []
def __init__(self, config, dataset_type, imdb_file_index, *args, **kwargs): super().__init__(config, dataset_type, imdb_file_index, dataset_name="visual_genome", *args, **kwargs) self._return_scene_graph = config.return_scene_graph self._return_objects = config.return_objects self._return_relationships = config.return_relationships self._return_region_descriptions = config.return_region_descriptions self._no_unk = config.get("no_unk", False) self.scene_graph_db = None self.region_descriptions_db = None self.image_metadata_db = None self._max_feature = config.max_features build_scene_graph_db = (self._return_scene_graph or self._return_objects or self._return_relationships) # print("config", config) if self._return_region_descriptions: print("use_region_descriptions_true") self.region_descriptions_db = self.build_region_descriptions_db() self.image_metadata_db = self.build_image_metadata_db() if build_scene_graph_db: scene_graph_file = config.scene_graph_files[dataset_type][ imdb_file_index] print("scene_graph_file", scene_graph_file) # scene_graph_file = self._get_absolute_path(scene_graph_file) scene_graph_file = get_absolute_path( get_mmf_env("data_dir") + "/" + scene_graph_file) print("scene_graph_file", scene_graph_file) self.scene_graph_db = SceneGraphDatabase(config, scene_graph_file) print("use_scene_graph_true") self.scene_graph_db = self.build_scene_graph_db()
def get_checkpoint_data(self) -> Dict[str, Any]: """This function gets checkpoint file path on disk from config.trainer.params.resume_from_checkpoint. However if it not specified, it gets checkpoint path from config.checkpoint. If config.resume is specified it gets the latest checkpoint from the config's save directory (alternatively it gets the best checkpoint if config.resume_best is True). If config.resume is not specified, then it gets config.resume_file or the checkpoint file from config.resume_zoo (in that order). Returns: Dict[str, Any]: a dict containing the following keys, `checkpoint_path` (str) local file path for the checkpoint; `ckpt` (Dict[str, Any]) `is_zoo` (Bool) whether or not the checkpoint is specified through a zoo identifier `config` (Dict[str, Any]]) the config that is stored together with this checkpoint """ # get ckpt file path from config.trainer.params.resume_from_checkpoint path = self.config.trainer.params.get("resume_from_checkpoint", None) if path is not None: is_zoo = self.is_zoo_path(path) ckpt_filepath = path if is_zoo: folder = download_pretrained_model(path) ckpt_filepath = get_ckpt_path_from_folder(folder) ckpt = get_ckpt_from_path(ckpt_filepath) config = get_config_from_folder_or_ckpt(folder, ckpt) else: ckpt = get_ckpt_from_path(ckpt_filepath) config = None return { "ckpt": ckpt, "checkpoint_path": ckpt_filepath, "is_zoo": is_zoo, "config": config, } is_zoo = False config = None ckpt = None # get ckpt file path from config.checkpoint ckpt_config = self.config.checkpoint suffix = "best.ckpt" if ckpt_config.resume_best else "current.ckpt" path = os.path.join(get_mmf_env(key="save_dir"), suffix) ckpt_filepath = None resume_from_specified_path = (ckpt_config.resume_file is not None or ckpt_config.resume_zoo is not None ) and (not ckpt_config.resume or not PathManager.exists(path)) if resume_from_specified_path: if ckpt_config.resume_file and PathManager.exists( ckpt_config.resume_file): ckpt_filepath = ckpt_config.resume_file elif ckpt_config.resume_zoo is not None: is_zoo = True folder = download_pretrained_model(ckpt_config.resume_zoo) ckpt_filepath = get_ckpt_path_from_folder(folder) ckpt = get_ckpt_from_path(ckpt_filepath) config = get_config_from_folder_or_ckpt(folder, ckpt) else: raise RuntimeError(f"{ckpt_config.resume_file} doesn't exist") if ckpt_config.resume and PathManager.exists(path): ckpt_filepath = path if ckpt_filepath is not None: ckpt = get_ckpt_from_path(ckpt_filepath) return { "ckpt": ckpt, "checkpoint_path": ckpt_filepath, "is_zoo": is_zoo, "config": config, }
def __init__(self, config, name=None): self.logger = None self._is_master = is_master() self.timer = Timer() self.config = config self.save_dir = get_mmf_env(key="save_dir") self.log_format = config.training.log_format self.time_format = "%Y-%m-%dT%H:%M:%S" self.log_filename = "train_" self.log_filename += self.timer.get_time_hhmmss(None, format=self.time_format) self.log_filename += ".log" self.log_folder = os.path.join(self.save_dir, "logs") env_log_dir = get_mmf_env(key="log_dir") if env_log_dir: self.log_folder = env_log_dir if not PathManager.exists(self.log_folder): PathManager.mkdirs(self.log_folder) self.log_filename = os.path.join(self.log_folder, self.log_filename) if not self._is_master: return if self._is_master: print("Logging to:", self.log_filename) logging.captureWarnings(True) if not name: name = __name__ self.logger = logging.getLogger(name) self._file_only_logger = logging.getLogger(name) warnings_logger = logging.getLogger("py.warnings") # Set level level = config.training.logger_level self.logger.setLevel(getattr(logging, level.upper())) self._file_only_logger.setLevel(getattr(logging, level.upper())) formatter = logging.Formatter( "%(asctime)s %(levelname)s: %(message)s", datefmt="%Y-%m-%dT%H:%M:%S" ) # Add handler to file channel = logging.FileHandler(filename=self.log_filename, mode="a") channel.setFormatter(formatter) self.logger.addHandler(channel) self._file_only_logger.addHandler(channel) warnings_logger.addHandler(channel) # Add handler to stdout channel = logging.StreamHandler(sys.stdout) channel.setFormatter(formatter) self.logger.addHandler(channel) warnings_logger.addHandler(channel) should_not_log = self.config.training.should_not_log self.should_log = not should_not_log # Single log wrapper map self._single_log_map = set()
def __init__(self, config, name=None): self._logger = None self._is_master = is_master() self.timer = Timer() self.config = config self.save_dir = get_mmf_env(key="save_dir") self.log_format = config.training.log_format self.time_format = "%Y_%m_%dT%H_%M_%S" self.log_filename = "train_" self.log_filename += self.timer.get_time_hhmmss( None, format=self.time_format) self.log_filename += ".log" self.log_folder = os.path.join(self.save_dir, "logs") env_log_dir = get_mmf_env(key="log_dir") if env_log_dir: self.log_folder = env_log_dir if not PathManager.exists(self.log_folder): PathManager.mkdirs(self.log_folder) self.log_filename = os.path.join(self.log_folder, self.log_filename) if not self._is_master: return if self._is_master: print("Logging to:", self.log_filename) logging.captureWarnings(True) if not name: name = __name__ self._logger = logging.getLogger(name) self._file_only_logger = logging.getLogger(name) self._warnings_logger = logging.getLogger("py.warnings") # Set level level = config.training.logger_level self._logger.setLevel(getattr(logging, level.upper())) self._file_only_logger.setLevel(getattr(logging, level.upper())) # Capture stdout to logger self._stdout_logger = None if self.config.training.stdout_capture: self._stdout_logger = StreamToLogger( logging.getLogger("stdout"), getattr(logging, level.upper())) sys.stdout = self._stdout_logger formatter = logging.Formatter( "%(asctime)s | %(levelname)s | %(name)s : %(message)s", datefmt="%Y-%m-%dT%H:%M:%S", ) # Add handler to file channel = logging.FileHandler(filename=self.log_filename, mode="a") channel.setFormatter(formatter) self.add_handlers(channel) # Add handler to train.log. train.log is full log that is also used # by slurm/fbl output channel = logging.FileHandler(filename=os.path.join( self.save_dir, "train.log"), mode="a") channel.setFormatter(formatter) self.add_handlers(channel) # Add handler to stdout. Only when we are not capturing stdout in # the logger if not self._stdout_logger: channel = logging.StreamHandler(sys.stdout) channel.setFormatter(formatter) self._logger.addHandler(channel) self._warnings_logger.addHandler(channel) should_not_log = self.config.training.should_not_log self.should_log = not should_not_log # Single log wrapper map self._single_log_map = set()
def setup_logger( output: str = None, color: bool = True, name: str = "mmf", disable: bool = False, clear_handlers=True, *args, **kwargs, ): """ Initialize the MMF logger and set its verbosity level to "INFO". Outside libraries shouldn't call this in case they have set there own logging handlers and setup. If they do, and don't want to clear handlers, pass clear_handlers options. The initial version of this function was taken from D2 and adapted for MMF. Args: output (str): a file name or a directory to save log. If ends with ".txt" or ".log", assumed to be a file name. Default: Saved to file <save_dir/logs/log_[timestamp].txt> color (bool): If false, won't log colored logs. Default: true name (str): the root module name of this logger. Defaults to "mmf". clear_handlers (bool): If false, won't clear existing handlers. Returns: logging.Logger: a logger """ if disable: return None logger = logging.getLogger(name) logger.propagate = False logging.captureWarnings(True) warnings_logger = logging.getLogger("py.warnings") plain_formatter = logging.Formatter( "%(asctime)s | %(levelname)s | %(name)s : %(message)s", datefmt="%Y-%m-%dT%H:%M:%S", ) distributed_rank = get_rank() handlers = [] logging_level = registry.get("config").training.logger_level.upper() if distributed_rank == 0: logger.setLevel(logging_level) ch = logging.StreamHandler(stream=sys.stdout) ch.setLevel(logging_level) if color: formatter = ColorfulFormatter( colored("%(asctime)s | %(name)s: ", "green") + "%(message)s", datefmt="%Y-%m-%dT%H:%M:%S", ) else: formatter = plain_formatter ch.setFormatter(formatter) logger.addHandler(ch) warnings_logger.addHandler(ch) handlers.append(ch) # file logging: all workers if output is None: output = setup_output_folder() if output is not None: if output.endswith(".txt") or output.endswith(".log"): filename = output else: filename = os.path.join(output, "train.log") if distributed_rank > 0: filename = filename + f".rank{distributed_rank}" PathManager.mkdirs(os.path.dirname(filename)) fh = logging.StreamHandler(_cached_log_stream(filename)) fh.setLevel(logging_level) fh.setFormatter(plain_formatter) logger.addHandler(fh) warnings_logger.addHandler(fh) handlers.append(fh) # Slurm/FB output, only log the main process if "train.log" not in filename and distributed_rank == 0: save_dir = get_mmf_env(key="save_dir") filename = os.path.join(save_dir, "train.log") sh = logging.StreamHandler(_cached_log_stream(filename)) sh.setLevel(logging_level) sh.setFormatter(plain_formatter) logger.addHandler(sh) warnings_logger.addHandler(sh) handlers.append(sh) logger.info(f"Logging to: {filename}") # Remove existing handlers to add MMF specific handlers if clear_handlers: for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) # Now, add our handlers. logging.basicConfig(level=logging_level, handlers=handlers) registry.register("writer", logger) return logger