def __init__(self, multi_task_instance): self.test_task = multi_task_instance self.task_type = multi_task_instance.dataset_type self.config = registry.get("config") self.report = [] self.timer = Timer() self.training_config = self.config.training self.num_workers = self.training_config.num_workers self.batch_size = self.training_config.batch_size self.report_folder_arg = get_mmf_env(key="report_dir") self.experiment_name = self.training_config.experiment_name self.datasets = [] for dataset in self.test_task.get_datasets(): self.datasets.append(dataset) self.current_dataset_idx = -1 self.current_dataset = self.datasets[self.current_dataset_idx] self.save_dir = get_mmf_env(key="save_dir") self.report_folder = ckpt_name_from_core_args(self.config) self.report_folder += foldername_from_config_override(self.config) self.report_folder = os.path.join(self.save_dir, self.report_folder) self.report_folder = os.path.join(self.report_folder, "reports") if self.report_folder_arg: self.report_folder = self.report_folder_arg PathManager.mkdirs(self.report_folder)
def setup_output_folder(folder_only: bool = False): """Sets up and returns the output file where the logs will be placed based on the configuration passed. Usually "save_dir/logs/log_<timestamp>.txt". If env.log_dir is passed, logs will be directly saved in this folder. Args: folder_only (bool, optional): If folder should be returned and not the file. Defaults to False. Returns: str: folder or file path depending on folder_only flag """ save_dir = get_mmf_env(key="save_dir") time_format = "%Y_%m_%dT%H_%M_%S" log_filename = "train_" log_filename += Timer().get_time_hhmmss(None, format=time_format) log_filename += ".log" log_folder = os.path.join(save_dir, "logs") env_log_dir = get_mmf_env(key="log_dir") if env_log_dir: log_folder = env_log_dir if not PathManager.exists(log_folder): PathManager.mkdirs(log_folder) if folder_only: return log_folder log_filename = os.path.join(log_folder, log_filename) return log_filename
def make_dir(path): """ Make the directory and any nonexistent parent directories (`mkdir -p`). """ # the current working directory is a fine path if path != "": PathManager.mkdirs(path)
def resolve_cache_dir(env_variable="MMF_CACHE_DIR", default="mmf"): # Some of this follow what "transformers" does for there cache resolving try: from torch.hub import _get_torch_home torch_cache_home = _get_torch_home() except ImportError: torch_cache_home = os.path.expanduser( os.getenv( "TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"), ) ) default_cache_path = os.path.join(torch_cache_home, default) cache_path = os.getenv(env_variable, default_cache_path) if not PathManager.exists(cache_path): try: PathManager.mkdirs(cache_path) except PermissionError: cache_path = os.path.join(get_mmf_root(), ".mmf_cache") PathManager.mkdirs(cache_path) return cache_path
def resolve_dir(env_variable, default="data"): default_dir = os.path.join(resolve_cache_dir(), default) dir_path = os.getenv(env_variable, default_dir) if not PathManager.exists(dir_path): PathManager.mkdirs(dir_path) return dir_path
def _download_model(self): _is_master = is_master() model_file_path = os.path.join(get_mmf_cache_dir(), "wiki.en.bin") if not _is_master: return model_file_path if PathManager.exists(model_file_path): logger.info(f"Vectors already present at {model_file_path}.") return model_file_path import requests from tqdm import tqdm from VisualBERT.mmf.common.constants import FASTTEXT_WIKI_URL PathManager.mkdirs(os.path.dirname(model_file_path)) response = requests.get(FASTTEXT_WIKI_URL, stream=True) with PathManager.open(model_file_path, "wb") as f: pbar = tqdm( total=int(response.headers["Content-Length"]) / 4096, miniters=50, disable=not _is_master, ) idx = 0 for data in response.iter_content(chunk_size=4096): if data: if idx % 50 == 0: pbar.update(len(data)) f.write(data) idx += 1 pbar.close() logger.info(f"fastText bin downloaded at {model_file_path}.") return model_file_path
def __init__(self, trainer): """ Generates a path for saving model which can also be used for resuming from a checkpoint. """ self.trainer = trainer self.config = self.trainer.config self.save_dir = get_mmf_env(key="save_dir") self.model_name = self.config.model self.ckpt_foldername = self.save_dir self.device = get_current_device() self.ckpt_prefix = "" if hasattr(self.trainer.model, "get_ckpt_name"): self.ckpt_prefix = self.trainer.model.get_ckpt_name() + "_" self.pth_filepath = os.path.join( self.ckpt_foldername, self.ckpt_prefix + self.model_name + "_final.pth") self.models_foldername = os.path.join(self.ckpt_foldername, "models") if not PathManager.exists(self.models_foldername): PathManager.mkdirs(self.models_foldername) self.save_config() self.repo_path = updir(os.path.abspath(__file__), n=3) self.git_repo = None if git and self.config.checkpoint.save_git_details: try: self.git_repo = git.Repo(self.repo_path) except git.exc.InvalidGitRepositoryError: # Not a git repo, don't do anything pass self.max_to_keep = self.config.checkpoint.max_to_keep self.saved_iterations = []
def setup_logger( output: str = None, color: bool = True, name: str = "mmf", disable: bool = False, clear_handlers=True, *args, **kwargs, ): """ Initialize the MMF logger and set its verbosity level to "INFO". Outside libraries shouldn't call this in case they have set there own logging handlers and setup. If they do, and don't want to clear handlers, pass clear_handlers options. The initial version of this function was taken from D2 and adapted for MMF. Args: output (str): a file name or a directory to save log. If ends with ".txt" or ".log", assumed to be a file name. Default: Saved to file <save_dir/logs/log_[timestamp].txt> color (bool): If false, won't log colored logs. Default: true name (str): the root module name of this logger. Defaults to "mmf". clear_handlers (bool): If false, won't clear existing handlers. Returns: logging.Logger: a logger """ if disable: return None logger = logging.getLogger(name) logger.propagate = False logging.captureWarnings(True) warnings_logger = logging.getLogger("py.warnings") plain_formatter = logging.Formatter( "%(asctime)s | %(levelname)s | %(name)s : %(message)s", datefmt="%Y-%m-%dT%H:%M:%S", ) distributed_rank = get_rank() handlers = [] if distributed_rank == 0: logger.setLevel(logging.INFO) ch = logging.StreamHandler(stream=sys.stdout) ch.setLevel(logging.INFO) if color: formatter = ColorfulFormatter( colored("%(asctime)s | %(name)s: ", "green") + "%(message)s", datefmt="%Y-%m-%dT%H:%M:%S", ) else: formatter = plain_formatter ch.setFormatter(formatter) logger.addHandler(ch) warnings_logger.addHandler(ch) handlers.append(ch) # file logging: all workers if output is None: output = setup_output_folder() if output is not None: if output.endswith(".txt") or output.endswith(".log"): filename = output else: filename = os.path.join(output, "train.log") if distributed_rank > 0: filename = filename + f".rank{distributed_rank}" PathManager.mkdirs(os.path.dirname(filename)) fh = logging.StreamHandler(_cached_log_stream(filename)) fh.setLevel(logging.INFO) fh.setFormatter(plain_formatter) logger.addHandler(fh) warnings_logger.addHandler(fh) handlers.append(fh) # Slurm/FB output, only log the main process if "train.log" not in filename and distributed_rank == 0: save_dir = get_mmf_env(key="save_dir") filename = os.path.join(save_dir, "train.log") sh = logging.StreamHandler(_cached_log_stream(filename)) sh.setLevel(logging.INFO) sh.setFormatter(plain_formatter) logger.addHandler(sh) warnings_logger.addHandler(sh) handlers.append(sh) logger.info(f"Logging to: {filename}") # Remove existing handlers to add MMF specific handlers if clear_handlers: for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) # Now, add our handlers. logging.basicConfig(level=logging.INFO, handlers=handlers) registry.register("writer", logger) return logger
def test_file_io_mkdirs(self): dir_path = os.path.join(self._tmpdir, "test_dir") PathManager.mkdirs(dir_path) self.assertTrue(os.path.isdir(dir_path))
def convert(self): config = self.configuration.get_config() data_dir = config.env.data_dir if self.args.mmf_data_folder: data_dir = self.args.mmf_data_folder bypass_checksum = False if self.args.bypass_checksum: bypass_checksum = bool(self.args.bypass_checksum) print(f"Data folder is {data_dir}") print(f"Zip path is {self.args.zip_file}") base_path = os.path.join(data_dir, "datasets", "hateful_memes", "defaults") images_path = os.path.join(base_path, "images") PathManager.mkdirs(images_path) move_dir = False if self.args.move: move_dir = bool(self.args.move) if not bypass_checksum: self.checksum(self.args.zip_file, self.POSSIBLE_CHECKSUMS) src = self.args.zip_file dest = images_path if move_dir: print(f"Moving {src}") move(src, dest) else: print(f"Copying {src}") copy(src, dest) print(f"Unzipping {src}") self.decompress_zip( dest, fname=os.path.basename(src), password=self.args.password ) phase_one = self.assert_files(images_path) annotations_path = os.path.join(base_path, "annotations") PathManager.mkdirs(annotations_path) annotations = ( self.JSONL_PHASE_ONE_FILES if phase_one is True else self.JSONL_PHASE_TWO_FILES ) for annotation in annotations: print(f"Moving {annotation}") src = os.path.join(images_path, "data", annotation) dest = os.path.join(annotations_path, annotation) move(src, dest) images = self.IMAGE_FILES for image_file in images: src = os.path.join(images_path, "data", image_file) if PathManager.exists(src): print(f"Moving {image_file}") else: continue dest = os.path.join(images_path, image_file) move(src, dest) if src.endswith(".tar.gz"): decompress(dest, fname=image_file, delete_original=False)