def remove(self, update): ckpt_filepath = os.path.join(self.models_foldername, "model_%d.ckpt" % update) if PathManager.isfile(ckpt_filepath): PathManager.rm(ckpt_filepath)
def setup_logger( output: str = None, color: bool = True, name: str = "mmf", disable: bool = False, clear_handlers=True, *args, **kwargs, ): """ Initialize the MMF logger and set its verbosity level to "INFO". Outside libraries shouldn't call this in case they have set there own logging handlers and setup. If they do, and don't want to clear handlers, pass clear_handlers options. The initial version of this function was taken from D2 and adapted for MMF. Args: output (str): a file name or a directory to save log. If ends with ".txt" or ".log", assumed to be a file name. Default: Saved to file <save_dir/logs/log_[timestamp].txt> color (bool): If false, won't log colored logs. Default: true name (str): the root module name of this logger. Defaults to "mmf". clear_handlers (bool): If false, won't clear existing handlers. Returns: logging.Logger: a logger """ if disable: return None logger = logging.getLogger(name) logger.propagate = False logging.captureWarnings(True) warnings_logger = logging.getLogger("py.warnings") plain_formatter = logging.Formatter( "%(asctime)s | %(levelname)s | %(name)s : %(message)s", datefmt="%Y-%m-%dT%H:%M:%S", ) distributed_rank = get_rank() handlers = [] if distributed_rank == 0: logger.setLevel(logging.INFO) ch = logging.StreamHandler(stream=sys.stdout) ch.setLevel(logging.INFO) if color: formatter = ColorfulFormatter( colored("%(asctime)s | %(name)s: ", "green") + "%(message)s", datefmt="%Y-%m-%dT%H:%M:%S", ) else: formatter = plain_formatter ch.setFormatter(formatter) logger.addHandler(ch) warnings_logger.addHandler(ch) handlers.append(ch) # file logging: all workers if output is None: output = setup_output_folder() if output is not None: if output.endswith(".txt") or output.endswith(".log"): filename = output else: filename = os.path.join(output, "train.log") if distributed_rank > 0: filename = filename + f".rank{distributed_rank}" PathManager.mkdirs(os.path.dirname(filename)) fh = logging.StreamHandler(_cached_log_stream(filename)) fh.setLevel(logging.INFO) fh.setFormatter(plain_formatter) logger.addHandler(fh) warnings_logger.addHandler(fh) handlers.append(fh) # Slurm/FB output, only log the main process if "train.log" not in filename and distributed_rank == 0: save_dir = get_mmf_env(key="save_dir") filename = os.path.join(save_dir, "train.log") sh = logging.StreamHandler(_cached_log_stream(filename)) sh.setLevel(logging.INFO) sh.setFormatter(plain_formatter) logger.addHandler(sh) warnings_logger.addHandler(sh) handlers.append(sh) logger.info(f"Logging to: {filename}") # Remove existing handlers to add MMF specific handlers if clear_handlers: for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) # Now, add our handlers. logging.basicConfig(level=logging.INFO, handlers=handlers) registry.register("writer", logger) return logger
def save(self, update, iteration=None, update_best=False): # Only save in main process # For xla we use xm.save method # Which ensures that actual checkpoint saving happens # only for the master node. # The method also takes care of all the necessary synchronization if not is_master() and not is_xla(): return logger.info("Checkpoint save operation started!") if not iteration: iteration = update ckpt_filepath = os.path.join(self.models_foldername, "model_%d.ckpt" % update) best_ckpt_filepath = os.path.join(self.ckpt_foldername, self.ckpt_prefix + "best.ckpt") current_ckpt_filepath = os.path.join(self.ckpt_foldername, self.ckpt_prefix + "current.ckpt") best_iteration = (self.trainer.early_stop_callback.early_stopping. best_monitored_iteration) best_update = (self.trainer.early_stop_callback.early_stopping. best_monitored_update) best_metric = (self.trainer.early_stop_callback.early_stopping. best_monitored_value) model = self.trainer.model data_parallel = registry.get("data_parallel") or registry.get( "distributed") fp16_scaler = getattr(self.trainer, "scaler", None) fp16_scaler_dict = None if fp16_scaler is not None: fp16_scaler_dict = fp16_scaler.state_dict() if data_parallel is True: model = model.module ckpt = { "model": model.state_dict(), "optimizer": self.trainer.optimizer.state_dict(), "best_iteration": best_iteration, "current_iteration": iteration, "current_epoch": self.trainer.current_epoch, "num_updates": update, "best_update": best_update, "best_metric_value": best_metric, "fp16_scaler": fp16_scaler_dict, # Convert to container to avoid any dependencies "config": OmegaConf.to_container(self.config, resolve=True), } lr_scheduler = self.trainer.lr_scheduler_callback._scheduler if lr_scheduler is not None: ckpt["lr_scheduler"] = lr_scheduler.state_dict() if self.git_repo: git_metadata_dict = self._get_vcs_fields() ckpt.update(git_metadata_dict) with PathManager.open(ckpt_filepath, "wb") as f: self.save_func(ckpt, f) if update_best: logger.info("Saving best checkpoint") with PathManager.open(best_ckpt_filepath, "wb") as f: self.save_func(ckpt, f) # Save current always logger.info("Saving current checkpoint") with PathManager.open(current_ckpt_filepath, "wb") as f: self.save_func(ckpt, f) # Remove old checkpoints if max_to_keep is set if self.max_to_keep > 0: if len(self.saved_iterations) == self.max_to_keep: self.remove(self.saved_iterations.pop(0)) self.saved_iterations.append(update) logger.info("Checkpoint save operation finished!")
def default_loader(path): with PathManager.open(path, "rb") as f: img = Image.open(f) return img.convert("RGB")
def _cached_log_stream(filename): return PathManager.open(filename, "a")
from mmf.utils.configuration import get_mmf_cache_dir from mmf.utils.file_io import PathManager from mmf.datasets.processors.processors import EvalAIAnswerProcessor root_dir = os.path.join(get_mmf_cache_dir(), "data", "datasets", "okvqa", "defaults", "annotations") out_dir = os.path.join(get_mmf_cache_dir(), "data", "datasets", "okvqa", "defaults", "extras", "vocabs") train_path = os.path.join(root_dir, "mscoco_train2014_annotations.json") val_path = os.path.join(root_dir, "mscoco_val2014_annotations.json") out_path = os.path.join(out_dir, "gt2raw_answers.json") evalai_answer_processor = EvalAIAnswerProcessor() with PathManager.open(train_path, "r") as f: annotations = json.load(f)["annotations"] with PathManager.open(val_path, "r") as f: annotations += json.load(f)["annotations"] gt2raw = {} for ann in tqdm(annotations): for ans in ann["answers"]: raw_ans = evalai_answer_processor(ans["raw_answer"]) gt_ans = evalai_answer_processor(ans["answer"]) if gt_ans in gt2raw: gt2raw[gt_ans].add(raw_ans) else: gt2raw[gt_ans] = set([raw_ans])
def test_log_writer(self) -> None: self.writer.write(self._tmpfile_write_contents) f = PathManager.open(os.path.join(self._tmpdir, "train.log")) self.assertTrue( any(self._tmpfile_write_contents in line for line in f.readlines()))
def test_file_io_mkdirs(self): dir_path = os.path.join(self._tmpdir, "test_dir") PathManager.mkdirs(dir_path) self.assertTrue(os.path.isdir(dir_path))
if __name__ == "__main__": src_dataset = 'vqa2' dst_dataset = 'okvqa' src_fname = "answers_vqa.txt" dst_fname = "answers_okvqa.txt" gt2raw_fname = "gt2raw_answers.json" use_raw = True use_raw_str = "_raw" if use_raw else "" out_fname = f"{src_dataset}2{dst_dataset}{use_raw_str}.json" src_dir = os.path.join(get_mmf_cache_dir(), "data", "datasets", src_dataset, "defaults", "extras", "vocabs") dst_dir = os.path.join(get_mmf_cache_dir(), "data", "datasets", dst_dataset, "defaults", "extras", "vocabs") with PathManager.open(os.path.join(src_dir, src_fname), "r") as f: src_vocab = f.read().splitlines() with PathManager.open(os.path.join(dst_dir, dst_fname), "r") as f: dst_vocab = f.read().splitlines() if use_raw: with PathManager.open(os.path.join(dst_dir, gt2raw_fname), "r") as f: gt2raw = json.load(f) src_dict = {w: i for i, w in enumerate(src_vocab)} qa_map = {} count = 0 for idx, word in enumerate(dst_vocab): if word in src_dict: qa_map[idx] = src_dict[word]
def test_file_io_copy(self): PathManager.copy(self._tmpfile, os.path.join(self._tmpdir, "test_copy.txt")) with open(os.path.join(self._tmpdir, "test_copy.txt"), "r") as f: s = f.read() self.assertEqual(s, self._tmpfile_contents)
def test_file_io_exists(self): self.assertEqual( PathManager.exists(self._tmpfile), os.path.exists(self._tmpfile) ) fake_path = os.path.join(self._tmpdir, uuid.uuid4().hex) self.assertEqual(PathManager.exists(fake_path), os.path.exists(fake_path))
def test_file_io_open(self): with PathManager.open(self._tmpfile, mode="r") as f: s = f.read() self.assertEqual(s, self._tmpfile_contents)
def convert(self): config = self.configuration.get_config() data_dir = config.env.data_dir if self.args.mmf_data_folder: data_dir = self.args.mmf_data_folder bypass_checksum = False if self.args.bypass_checksum: bypass_checksum = bool(self.args.bypass_checksum) print(f"Data folder is {data_dir}") print(f"Zip path is {self.args.zip_file}") base_path = data_dir images_path = os.path.join(base_path, "images") PathManager.mkdirs(images_path) move_dir = False if self.args.move: move_dir = bool(self.args.move) if not bypass_checksum: self.checksum(self.args.zip_file, self.POSSIBLE_CHECKSUMS) src = self.args.zip_file dest = images_path if move_dir: print(f"Moving {src}") move(src, dest) else: print(f"Copying {src}") copy(src, dest) print(f"Unzipping {src}") self.decompress_zip(dest, fname=os.path.basename(src), password=self.args.password) phase_one = self.assert_files(images_path) annotations_path = os.path.join(base_path, "annotations") PathManager.mkdirs(annotations_path) annotations = (self.JSONL_PHASE_ONE_FILES if phase_one is True else self.JSONL_PHASE_TWO_FILES) for annotation in annotations: print(f"Moving {annotation}") src = os.path.join(images_path, "data", annotation) dest = os.path.join(annotations_path, annotation) move(src, dest) images = self.IMAGE_FILES for image_file in images: src = os.path.join(images_path, "data", image_file) if PathManager.exists(src): print(f"Moving {image_file}") else: continue dest = os.path.join(images_path, image_file) move(src, dest) if src.endswith(".tar.gz"): decompress(dest, fname=image_file, delete_original=False)
def __init__(self, vocab_file=None, embedding_dim=300, data_dir=None, *args, **kwargs): """Vocab class to be used when you want to train word embeddings from scratch based on a custom vocab. This will initialize the random vectors for the vocabulary you pass. Get the vectors using `get_vectors` function. This will also create random embeddings for some predefined words like PAD - <pad>, SOS - <s>, EOS - </s>, UNK - <unk>. Parameters ---------- vocab_file : str Path of the vocabulary file containing one word per line embedding_dim : int Size of the embedding """ self.type = "base" self.word_dict = {} self.itos = {} self.itos[self.PAD_INDEX] = self.PAD_TOKEN self.itos[self.SOS_INDEX] = self.SOS_TOKEN self.itos[self.EOS_INDEX] = self.EOS_TOKEN self.itos[self.UNK_INDEX] = self.UNK_TOKEN self.word_dict[self.SOS_TOKEN] = self.SOS_INDEX self.word_dict[self.EOS_TOKEN] = self.EOS_INDEX self.word_dict[self.PAD_TOKEN] = self.PAD_INDEX self.word_dict[self.UNK_TOKEN] = self.UNK_INDEX index = len(self.itos.keys()) self.total_predefined = len(self.itos.keys()) if vocab_file is not None: if not os.path.isabs(vocab_file) and data_dir is not None: vocab_file = os.path.join(data_dir, vocab_file) vocab_file = get_absolute_path(vocab_file) if not PathManager.exists(vocab_file): raise RuntimeError("Vocab not found at " + vocab_file) with PathManager.open(vocab_file, "r") as f: for line in f: self.itos[index] = line.strip() self.word_dict[line.strip()] = index index += 1 self.word_dict[self.SOS_TOKEN] = self.SOS_INDEX self.word_dict[self.EOS_TOKEN] = self.EOS_INDEX self.word_dict[self.PAD_TOKEN] = self.PAD_INDEX self.word_dict[self.UNK_TOKEN] = self.UNK_INDEX # Return unk index by default self.stoi = defaultdict(self.get_unk_index) self.stoi.update(self.word_dict) self.vectors = torch.FloatTensor(self.get_size(), embedding_dim)
def finalize(self): if is_master() or is_xla(): with PathManager.open(self.pth_filepath, "wb") as f: self.save_func(self.trainer.model.state_dict(), f)
def download(url, path, fname, redownload=True, disable_tqdm=False): """ Download file using `requests`. If ``redownload`` is set to false, then will not download tar file again if it is present (default ``True``). Returns whether download actually happened or not """ outfile = os.path.join(path, fname) download = not PathManager.isfile(outfile) or redownload retry = 5 exp_backoff = [2**r for r in reversed(range(retry))] pbar = None if download: # First test if the link is actually downloadable check_header(url) if not disable_tqdm: print("[ Downloading: " + url + " to " + outfile + " ]") pbar = tqdm.tqdm(unit="B", unit_scale=True, desc=f"Downloading {fname}", disable=disable_tqdm) while download and retry >= 0: resume_file = outfile + ".part" resume = PathManager.isfile(resume_file) if resume: resume_pos = os.path.getsize(resume_file) mode = "ab" else: resume_pos = 0 mode = "wb" response = None with requests.Session() as session: try: header = ({ "Range": "bytes=%d-" % resume_pos, "Accept-Encoding": "identity" } if resume else {}) response = session.get(url, stream=True, timeout=5, headers=header) # negative reply could be 'none' or just missing if resume and response.headers.get("Accept-Ranges", "none") == "none": resume_pos = 0 mode = "wb" CHUNK_SIZE = 32768 total_size = int(response.headers.get("Content-Length", -1)) # server returns remaining size if resuming, so adjust total total_size += resume_pos pbar.total = total_size done = resume_pos with PathManager.open(resume_file, mode) as f: for chunk in response.iter_content(CHUNK_SIZE): if chunk: # filter out keep-alive new chunks f.write(chunk) if total_size > 0: done += len(chunk) if total_size < done: # don't freak out if content-length was too small total_size = done pbar.total = total_size pbar.update(len(chunk)) break except ( requests.exceptions.ConnectionError, requests.exceptions.ReadTimeout, ): retry -= 1 pbar.clear() if retry >= 0: print("Connection error, retrying. (%d retries left)" % retry) time.sleep(exp_backoff[retry]) else: print("Retried too many times, stopped retrying.") finally: if response: response.close() if retry < 0: raise RuntimeWarning( "Connection broken too many times. Stopped retrying.") if download and retry > 0: pbar.update(done - pbar.n) if done < total_size: raise RuntimeWarning("Received less data than specified in " + "Content-Length header for " + url + ". There may be a download problem.") move(resume_file, outfile) if pbar: pbar.close() return download
def csv_dump(self, filepath): with PathManager.open(filepath, "w") as f: title = self.report[0].keys() cw = csv.DictWriter(f, title, delimiter=",", quoting=csv.QUOTE_MINIMAL) cw.writeheader() cw.writerows(self.report)
def test_save_and_load_state_dict(self): with mock_env_with_temp() as d: checkpoint = Checkpoint(self.trainer) self._init_early_stopping(checkpoint) self._do_a_pass() # Test normal case checkpoint.save(1500) self.assertTrue( PathManager.exists(os.path.join(d, "models", "model_1500.ckpt"))) self.assertTrue(PathManager.exists(os.path.join(d, "current.ckpt"))) self.assertFalse(PathManager.exists(os.path.join(d, "best.ckpt"))) os.remove(os.path.join(d, "models", "model_1500.ckpt")) os.remove(os.path.join(d, "current.ckpt")) best_model = deepcopy(self.trainer.model) best_optimizer = deepcopy(self.trainer.optimizer) # Test with update_best checkpoint.save(2000, update_best=True) self.assertTrue( PathManager.exists(os.path.join(d, "models", "model_2000.ckpt"))) self.assertTrue(PathManager.exists(os.path.join(d, "best.ckpt"))) self.assertTrue(PathManager.exists(os.path.join(d, "current.ckpt"))) self._do_a_pass() checkpoint.save(2500) # Test resume self.trainer.config.checkpoint.resume = True current_model = deepcopy(self.trainer.model) current_optimizer = deepcopy(self.trainer.optimizer) checkpoint.load_state_dict() self.assertFalse( compare_state_dicts(self.trainer.model.state_dict(), best_model.state_dict())) self.assertTrue( compare_state_dicts(self.trainer.model.state_dict(), current_model.state_dict())) self.assertFalse( self._compare_optimizers(self.trainer.optimizer, best_optimizer)) self.assertFalse( self._compare_optimizers(self.trainer.optimizer, best_optimizer, skip_keys=True)) self.assertFalse( self._compare_optimizers(self.trainer.optimizer, current_optimizer)) self.assertTrue( self._compare_optimizers(self.trainer.optimizer, current_optimizer, skip_keys=True)) base_0_weight_current = self.trainer.model.base[ 0].weight.data.clone() # Test resume_best self.trainer.config.checkpoint.resume = True self.trainer.config.checkpoint.resume_best = True checkpoint.load_state_dict() self.assertTrue( compare_state_dicts(self.trainer.model.state_dict(), best_model.state_dict())) self.assertFalse( self._compare_optimizers(self.trainer.optimizer, best_optimizer)) self.assertTrue( self._compare_optimizers(self.trainer.optimizer, best_optimizer, skip_keys=True)) self.assertFalse( self._compare_optimizers(self.trainer.optimizer, current_optimizer)) self.assertFalse( self._compare_optimizers(self.trainer.optimizer, current_optimizer, skip_keys=True)) base_0_weight_best = self.trainer.model.base[0].weight.data.clone() self.trainer.config.checkpoint.resume_best = False # Test distributed settings self.trainer.model = torch.nn.DataParallel(self.trainer.model) checkpoint.load_state_dict() weight_to_be_tested = self.trainer.model.module.base[0].weight weight_device = weight_to_be_tested.device self.assertTrue( torch.equal(weight_to_be_tested, base_0_weight_current.to(weight_device))) self.assertFalse( torch.equal(weight_to_be_tested, base_0_weight_best.to(weight_device)))
def test_logger_files(self) -> None: self.assertTrue( PathManager.exists(os.path.join(self._tmpdir, "train.log"))) self.assertTrue(PathManager.exists(os.path.join(self._tmpdir, "logs")))
def test_on_test_end(self): self.cb.on_test_end(report=self.report, meter=self.trainer.meter) f = PathManager.open(os.path.join(self.tmpdir, "train.log")) self.assertTrue( any("Finished run in" in line for line in f.readlines()))
def json_dump(self, filepath): with PathManager.open(filepath, "w") as f: json.dump(self.report, f)
def __init__(self, config, name=None): self._logger = None self._is_master = is_master() self.timer = Timer() self.config = config self.save_dir = get_mmf_env(key="save_dir") self.log_format = config.training.log_format self.time_format = "%Y_%m_%dT%H_%M_%S" self.log_filename = "train_" self.log_filename += self.timer.get_time_hhmmss(None, format=self.time_format) self.log_filename += ".log" self.log_folder = os.path.join(self.save_dir, "logs") env_log_dir = get_mmf_env(key="log_dir") if env_log_dir: self.log_folder = env_log_dir if not PathManager.exists(self.log_folder): PathManager.mkdirs(self.log_folder) self.log_filename = os.path.join(self.log_folder, self.log_filename) if not self._is_master: return if self._is_master: print("Logging to:", self.log_filename) logging.captureWarnings(True) if not name: name = __name__ self._logger = logging.getLogger(name) self._file_only_logger = logging.getLogger(name) self._warnings_logger = logging.getLogger("py.warnings") # Set level level = config.training.logger_level self._logger.setLevel(getattr(logging, level.upper())) self._file_only_logger.setLevel(getattr(logging, level.upper())) # Capture stdout to logger self._stdout_logger = None if self.config.training.stdout_capture: self._stdout_logger = StreamToLogger( logging.getLogger("stdout"), getattr(logging, level.upper()) ) sys.stdout = self._stdout_logger formatter = logging.Formatter( "%(asctime)s | %(levelname)s | %(name)s : %(message)s", datefmt="%Y-%m-%dT%H:%M:%S", ) # Add handler to file channel = logging.StreamHandler(PathManager.open(self.log_filename, mode="a")) channel.setFormatter(formatter) self.add_handlers(channel) # Add handler to train.log. train.log is full log that is also used # by slurm/fbl output channel = logging.StreamHandler( PathManager.open(os.path.join(self.save_dir, "train.log"), mode="a") ) channel.setFormatter(formatter) self.add_handlers(channel) # Add handler to stdout. Only when we are not capturing stdout in # the logger if not self._stdout_logger: channel = logging.StreamHandler(sys.stdout) channel.setFormatter(formatter) self._logger.addHandler(channel) self._warnings_logger.addHandler(channel) should_not_log = self.config.training.should_not_log self.should_log = not should_not_log # Single log wrapper map self._single_log_map = set()