Ejemplo n.º 1
0
 def remove(self, update):
     ckpt_filepath = os.path.join(self.models_foldername,
                                  "model_%d.ckpt" % update)
     if PathManager.isfile(ckpt_filepath):
         PathManager.rm(ckpt_filepath)
Ejemplo n.º 2
0
Archivo: logger.py Proyecto: naykun/mmf
def setup_logger(
    output: str = None,
    color: bool = True,
    name: str = "mmf",
    disable: bool = False,
    clear_handlers=True,
    *args,
    **kwargs,
):
    """
    Initialize the MMF logger and set its verbosity level to "INFO".
    Outside libraries shouldn't call this in case they have set there
    own logging handlers and setup. If they do, and don't want to
    clear handlers, pass clear_handlers options.

    The initial version of this function was taken from D2 and adapted
    for MMF.

    Args:
        output (str): a file name or a directory to save log.
            If ends with ".txt" or ".log", assumed to be a file name.
            Default: Saved to file <save_dir/logs/log_[timestamp].txt>
        color (bool): If false, won't log colored logs. Default: true
        name (str): the root module name of this logger. Defaults to "mmf".
        clear_handlers (bool): If false, won't clear existing handlers.

    Returns:
        logging.Logger: a logger
    """
    if disable:
        return None
    logger = logging.getLogger(name)
    logger.propagate = False

    logging.captureWarnings(True)
    warnings_logger = logging.getLogger("py.warnings")

    plain_formatter = logging.Formatter(
        "%(asctime)s | %(levelname)s | %(name)s : %(message)s",
        datefmt="%Y-%m-%dT%H:%M:%S",
    )

    distributed_rank = get_rank()
    handlers = []

    if distributed_rank == 0:
        logger.setLevel(logging.INFO)
        ch = logging.StreamHandler(stream=sys.stdout)
        ch.setLevel(logging.INFO)
        if color:
            formatter = ColorfulFormatter(
                colored("%(asctime)s | %(name)s: ", "green") + "%(message)s",
                datefmt="%Y-%m-%dT%H:%M:%S",
            )
        else:
            formatter = plain_formatter
        ch.setFormatter(formatter)
        logger.addHandler(ch)
        warnings_logger.addHandler(ch)
        handlers.append(ch)

    # file logging: all workers
    if output is None:
        output = setup_output_folder()

    if output is not None:
        if output.endswith(".txt") or output.endswith(".log"):
            filename = output
        else:
            filename = os.path.join(output, "train.log")
        if distributed_rank > 0:
            filename = filename + f".rank{distributed_rank}"
        PathManager.mkdirs(os.path.dirname(filename))

        fh = logging.StreamHandler(_cached_log_stream(filename))
        fh.setLevel(logging.INFO)
        fh.setFormatter(plain_formatter)
        logger.addHandler(fh)
        warnings_logger.addHandler(fh)
        handlers.append(fh)

        # Slurm/FB output, only log the main process
        if "train.log" not in filename and distributed_rank == 0:
            save_dir = get_mmf_env(key="save_dir")
            filename = os.path.join(save_dir, "train.log")
            sh = logging.StreamHandler(_cached_log_stream(filename))
            sh.setLevel(logging.INFO)
            sh.setFormatter(plain_formatter)
            logger.addHandler(sh)
            warnings_logger.addHandler(sh)
            handlers.append(sh)

        logger.info(f"Logging to: {filename}")

    # Remove existing handlers to add MMF specific handlers
    if clear_handlers:
        for handler in logging.root.handlers[:]:
            logging.root.removeHandler(handler)
    # Now, add our handlers.
    logging.basicConfig(level=logging.INFO, handlers=handlers)

    registry.register("writer", logger)

    return logger
Ejemplo n.º 3
0
    def save(self, update, iteration=None, update_best=False):
        # Only save in main process
        # For xla we use xm.save method
        # Which ensures that actual checkpoint saving happens
        # only for the master node.
        # The method also takes care of all the necessary synchronization
        if not is_master() and not is_xla():
            return

        logger.info("Checkpoint save operation started!")
        if not iteration:
            iteration = update

        ckpt_filepath = os.path.join(self.models_foldername,
                                     "model_%d.ckpt" % update)
        best_ckpt_filepath = os.path.join(self.ckpt_foldername,
                                          self.ckpt_prefix + "best.ckpt")
        current_ckpt_filepath = os.path.join(self.ckpt_foldername,
                                             self.ckpt_prefix + "current.ckpt")

        best_iteration = (self.trainer.early_stop_callback.early_stopping.
                          best_monitored_iteration)
        best_update = (self.trainer.early_stop_callback.early_stopping.
                       best_monitored_update)
        best_metric = (self.trainer.early_stop_callback.early_stopping.
                       best_monitored_value)
        model = self.trainer.model
        data_parallel = registry.get("data_parallel") or registry.get(
            "distributed")
        fp16_scaler = getattr(self.trainer, "scaler", None)
        fp16_scaler_dict = None

        if fp16_scaler is not None:
            fp16_scaler_dict = fp16_scaler.state_dict()

        if data_parallel is True:
            model = model.module

        ckpt = {
            "model": model.state_dict(),
            "optimizer": self.trainer.optimizer.state_dict(),
            "best_iteration": best_iteration,
            "current_iteration": iteration,
            "current_epoch": self.trainer.current_epoch,
            "num_updates": update,
            "best_update": best_update,
            "best_metric_value": best_metric,
            "fp16_scaler": fp16_scaler_dict,
            # Convert to container to avoid any dependencies
            "config": OmegaConf.to_container(self.config, resolve=True),
        }

        lr_scheduler = self.trainer.lr_scheduler_callback._scheduler
        if lr_scheduler is not None:
            ckpt["lr_scheduler"] = lr_scheduler.state_dict()

        if self.git_repo:
            git_metadata_dict = self._get_vcs_fields()
            ckpt.update(git_metadata_dict)

        with PathManager.open(ckpt_filepath, "wb") as f:
            self.save_func(ckpt, f)

        if update_best:
            logger.info("Saving best checkpoint")
            with PathManager.open(best_ckpt_filepath, "wb") as f:
                self.save_func(ckpt, f)

        # Save current always

        logger.info("Saving current checkpoint")
        with PathManager.open(current_ckpt_filepath, "wb") as f:
            self.save_func(ckpt, f)

        # Remove old checkpoints if max_to_keep is set
        if self.max_to_keep > 0:
            if len(self.saved_iterations) == self.max_to_keep:
                self.remove(self.saved_iterations.pop(0))
            self.saved_iterations.append(update)

        logger.info("Checkpoint save operation finished!")
Ejemplo n.º 4
0
def default_loader(path):
    with PathManager.open(path, "rb") as f:
        img = Image.open(f)
        return img.convert("RGB")
Ejemplo n.º 5
0
Archivo: logger.py Proyecto: naykun/mmf
def _cached_log_stream(filename):
    return PathManager.open(filename, "a")
Ejemplo n.º 6
0
from mmf.utils.configuration import get_mmf_cache_dir
from mmf.utils.file_io import PathManager
from mmf.datasets.processors.processors import EvalAIAnswerProcessor

root_dir = os.path.join(get_mmf_cache_dir(), "data", "datasets", "okvqa",
                        "defaults", "annotations")
out_dir = os.path.join(get_mmf_cache_dir(), "data", "datasets", "okvqa",
                       "defaults", "extras", "vocabs")
train_path = os.path.join(root_dir, "mscoco_train2014_annotations.json")
val_path = os.path.join(root_dir, "mscoco_val2014_annotations.json")
out_path = os.path.join(out_dir, "gt2raw_answers.json")

evalai_answer_processor = EvalAIAnswerProcessor()

with PathManager.open(train_path, "r") as f:
    annotations = json.load(f)["annotations"]

with PathManager.open(val_path, "r") as f:
    annotations += json.load(f)["annotations"]

gt2raw = {}
for ann in tqdm(annotations):
    for ans in ann["answers"]:
        raw_ans = evalai_answer_processor(ans["raw_answer"])
        gt_ans = evalai_answer_processor(ans["answer"])

        if gt_ans in gt2raw:
            gt2raw[gt_ans].add(raw_ans)
        else:
            gt2raw[gt_ans] = set([raw_ans])
Ejemplo n.º 7
0
 def test_log_writer(self) -> None:
     self.writer.write(self._tmpfile_write_contents)
     f = PathManager.open(os.path.join(self._tmpdir, "train.log"))
     self.assertTrue(
         any(self._tmpfile_write_contents in line
             for line in f.readlines()))
Ejemplo n.º 8
0
 def test_file_io_mkdirs(self):
     dir_path = os.path.join(self._tmpdir, "test_dir")
     PathManager.mkdirs(dir_path)
     self.assertTrue(os.path.isdir(dir_path))
Ejemplo n.º 9
0
if __name__ == "__main__":
    src_dataset = 'vqa2'
    dst_dataset = 'okvqa'
    src_fname = "answers_vqa.txt"
    dst_fname = "answers_okvqa.txt"
    gt2raw_fname = "gt2raw_answers.json"
    use_raw = True
    use_raw_str = "_raw" if use_raw else ""
    out_fname = f"{src_dataset}2{dst_dataset}{use_raw_str}.json"
    src_dir = os.path.join(get_mmf_cache_dir(), "data", "datasets",
                           src_dataset, "defaults", "extras", "vocabs")
    dst_dir = os.path.join(get_mmf_cache_dir(), "data", "datasets",
                           dst_dataset, "defaults", "extras", "vocabs")

    with PathManager.open(os.path.join(src_dir, src_fname), "r") as f:
        src_vocab = f.read().splitlines()

    with PathManager.open(os.path.join(dst_dir, dst_fname), "r") as f:
        dst_vocab = f.read().splitlines()

    if use_raw:
        with PathManager.open(os.path.join(dst_dir, gt2raw_fname), "r") as f:
            gt2raw = json.load(f)

    src_dict = {w: i for i, w in enumerate(src_vocab)}
    qa_map = {}
    count = 0
    for idx, word in enumerate(dst_vocab):
        if word in src_dict:
            qa_map[idx] = src_dict[word]
Ejemplo n.º 10
0
 def test_file_io_copy(self):
     PathManager.copy(self._tmpfile, os.path.join(self._tmpdir, "test_copy.txt"))
     with open(os.path.join(self._tmpdir, "test_copy.txt"), "r") as f:
         s = f.read()
     self.assertEqual(s, self._tmpfile_contents)
Ejemplo n.º 11
0
 def test_file_io_exists(self):
     self.assertEqual(
         PathManager.exists(self._tmpfile), os.path.exists(self._tmpfile)
     )
     fake_path = os.path.join(self._tmpdir, uuid.uuid4().hex)
     self.assertEqual(PathManager.exists(fake_path), os.path.exists(fake_path))
Ejemplo n.º 12
0
 def test_file_io_open(self):
     with PathManager.open(self._tmpfile, mode="r") as f:
         s = f.read()
     self.assertEqual(s, self._tmpfile_contents)
Ejemplo n.º 13
0
    def convert(self):
        config = self.configuration.get_config()
        data_dir = config.env.data_dir

        if self.args.mmf_data_folder:
            data_dir = self.args.mmf_data_folder

        bypass_checksum = False
        if self.args.bypass_checksum:
            bypass_checksum = bool(self.args.bypass_checksum)

        print(f"Data folder is {data_dir}")
        print(f"Zip path is {self.args.zip_file}")

        base_path = data_dir

        images_path = os.path.join(base_path, "images")
        PathManager.mkdirs(images_path)

        move_dir = False
        if self.args.move:
            move_dir = bool(self.args.move)

        if not bypass_checksum:
            self.checksum(self.args.zip_file, self.POSSIBLE_CHECKSUMS)

        src = self.args.zip_file
        dest = images_path
        if move_dir:
            print(f"Moving {src}")
            move(src, dest)
        else:
            print(f"Copying {src}")
            copy(src, dest)

        print(f"Unzipping {src}")
        self.decompress_zip(dest,
                            fname=os.path.basename(src),
                            password=self.args.password)

        phase_one = self.assert_files(images_path)

        annotations_path = os.path.join(base_path, "annotations")
        PathManager.mkdirs(annotations_path)
        annotations = (self.JSONL_PHASE_ONE_FILES
                       if phase_one is True else self.JSONL_PHASE_TWO_FILES)

        for annotation in annotations:
            print(f"Moving {annotation}")
            src = os.path.join(images_path, "data", annotation)
            dest = os.path.join(annotations_path, annotation)
            move(src, dest)

        images = self.IMAGE_FILES

        for image_file in images:
            src = os.path.join(images_path, "data", image_file)
            if PathManager.exists(src):
                print(f"Moving {image_file}")
            else:
                continue
            dest = os.path.join(images_path, image_file)
            move(src, dest)
            if src.endswith(".tar.gz"):
                decompress(dest, fname=image_file, delete_original=False)
Ejemplo n.º 14
0
    def __init__(self,
                 vocab_file=None,
                 embedding_dim=300,
                 data_dir=None,
                 *args,
                 **kwargs):
        """Vocab class to be used when you want to train word embeddings from
        scratch based on a custom vocab. This will initialize the random
        vectors for the vocabulary you pass. Get the vectors using
        `get_vectors` function. This will also create random embeddings for
        some predefined words like PAD - <pad>, SOS - <s>, EOS - </s>,
        UNK - <unk>.

        Parameters
        ----------
        vocab_file : str
            Path of the vocabulary file containing one word per line
        embedding_dim : int
            Size of the embedding

        """
        self.type = "base"
        self.word_dict = {}
        self.itos = {}

        self.itos[self.PAD_INDEX] = self.PAD_TOKEN
        self.itos[self.SOS_INDEX] = self.SOS_TOKEN
        self.itos[self.EOS_INDEX] = self.EOS_TOKEN
        self.itos[self.UNK_INDEX] = self.UNK_TOKEN

        self.word_dict[self.SOS_TOKEN] = self.SOS_INDEX
        self.word_dict[self.EOS_TOKEN] = self.EOS_INDEX
        self.word_dict[self.PAD_TOKEN] = self.PAD_INDEX
        self.word_dict[self.UNK_TOKEN] = self.UNK_INDEX

        index = len(self.itos.keys())

        self.total_predefined = len(self.itos.keys())

        if vocab_file is not None:
            if not os.path.isabs(vocab_file) and data_dir is not None:
                vocab_file = os.path.join(data_dir, vocab_file)
                vocab_file = get_absolute_path(vocab_file)

            if not PathManager.exists(vocab_file):
                raise RuntimeError("Vocab not found at " + vocab_file)

            with PathManager.open(vocab_file, "r") as f:
                for line in f:
                    self.itos[index] = line.strip()
                    self.word_dict[line.strip()] = index
                    index += 1

        self.word_dict[self.SOS_TOKEN] = self.SOS_INDEX
        self.word_dict[self.EOS_TOKEN] = self.EOS_INDEX
        self.word_dict[self.PAD_TOKEN] = self.PAD_INDEX
        self.word_dict[self.UNK_TOKEN] = self.UNK_INDEX
        # Return unk index by default
        self.stoi = defaultdict(self.get_unk_index)
        self.stoi.update(self.word_dict)

        self.vectors = torch.FloatTensor(self.get_size(), embedding_dim)
Ejemplo n.º 15
0
 def finalize(self):
     if is_master() or is_xla():
         with PathManager.open(self.pth_filepath, "wb") as f:
             self.save_func(self.trainer.model.state_dict(), f)
Ejemplo n.º 16
0
def download(url, path, fname, redownload=True, disable_tqdm=False):
    """
    Download file using `requests`.

    If ``redownload`` is set to false, then will not download tar file again if it is
    present (default ``True``).

    Returns whether download actually happened or not
    """
    outfile = os.path.join(path, fname)
    download = not PathManager.isfile(outfile) or redownload
    retry = 5
    exp_backoff = [2**r for r in reversed(range(retry))]

    pbar = None
    if download:
        # First test if the link is actually downloadable
        check_header(url)
        if not disable_tqdm:
            print("[ Downloading: " + url + " to " + outfile + " ]")
        pbar = tqdm.tqdm(unit="B",
                         unit_scale=True,
                         desc=f"Downloading {fname}",
                         disable=disable_tqdm)

    while download and retry >= 0:
        resume_file = outfile + ".part"
        resume = PathManager.isfile(resume_file)
        if resume:
            resume_pos = os.path.getsize(resume_file)
            mode = "ab"
        else:
            resume_pos = 0
            mode = "wb"
        response = None

        with requests.Session() as session:
            try:
                header = ({
                    "Range": "bytes=%d-" % resume_pos,
                    "Accept-Encoding": "identity"
                } if resume else {})
                response = session.get(url,
                                       stream=True,
                                       timeout=5,
                                       headers=header)

                # negative reply could be 'none' or just missing
                if resume and response.headers.get("Accept-Ranges",
                                                   "none") == "none":
                    resume_pos = 0
                    mode = "wb"

                CHUNK_SIZE = 32768
                total_size = int(response.headers.get("Content-Length", -1))
                # server returns remaining size if resuming, so adjust total
                total_size += resume_pos
                pbar.total = total_size
                done = resume_pos

                with PathManager.open(resume_file, mode) as f:
                    for chunk in response.iter_content(CHUNK_SIZE):
                        if chunk:  # filter out keep-alive new chunks
                            f.write(chunk)
                        if total_size > 0:
                            done += len(chunk)
                            if total_size < done:
                                # don't freak out if content-length was too small
                                total_size = done
                                pbar.total = total_size
                            pbar.update(len(chunk))
                    break
            except (
                    requests.exceptions.ConnectionError,
                    requests.exceptions.ReadTimeout,
            ):
                retry -= 1
                pbar.clear()
                if retry >= 0:
                    print("Connection error, retrying. (%d retries left)" %
                          retry)
                    time.sleep(exp_backoff[retry])
                else:
                    print("Retried too many times, stopped retrying.")
            finally:
                if response:
                    response.close()
    if retry < 0:
        raise RuntimeWarning(
            "Connection broken too many times. Stopped retrying.")

    if download and retry > 0:
        pbar.update(done - pbar.n)
        if done < total_size:
            raise RuntimeWarning("Received less data than specified in " +
                                 "Content-Length header for " + url +
                                 ". There may be a download problem.")
        move(resume_file, outfile)

    if pbar:
        pbar.close()

    return download
Ejemplo n.º 17
0
 def csv_dump(self, filepath):
     with PathManager.open(filepath, "w") as f:
         title = self.report[0].keys()
         cw = csv.DictWriter(f, title, delimiter=",", quoting=csv.QUOTE_MINIMAL)
         cw.writeheader()
         cw.writerows(self.report)
Ejemplo n.º 18
0
    def test_save_and_load_state_dict(self):
        with mock_env_with_temp() as d:
            checkpoint = Checkpoint(self.trainer)
            self._init_early_stopping(checkpoint)
            self._do_a_pass()
            # Test normal case
            checkpoint.save(1500)

            self.assertTrue(
                PathManager.exists(os.path.join(d, "models",
                                                "model_1500.ckpt")))
            self.assertTrue(PathManager.exists(os.path.join(d,
                                                            "current.ckpt")))
            self.assertFalse(PathManager.exists(os.path.join(d, "best.ckpt")))
            os.remove(os.path.join(d, "models", "model_1500.ckpt"))
            os.remove(os.path.join(d, "current.ckpt"))

            best_model = deepcopy(self.trainer.model)
            best_optimizer = deepcopy(self.trainer.optimizer)
            # Test with update_best
            checkpoint.save(2000, update_best=True)

            self.assertTrue(
                PathManager.exists(os.path.join(d, "models",
                                                "model_2000.ckpt")))
            self.assertTrue(PathManager.exists(os.path.join(d, "best.ckpt")))
            self.assertTrue(PathManager.exists(os.path.join(d,
                                                            "current.ckpt")))

            self._do_a_pass()
            checkpoint.save(2500)

            # Test resume
            self.trainer.config.checkpoint.resume = True

            current_model = deepcopy(self.trainer.model)
            current_optimizer = deepcopy(self.trainer.optimizer)
            checkpoint.load_state_dict()

            self.assertFalse(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    best_model.state_dict()))
            self.assertTrue(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    current_model.state_dict()))
            self.assertFalse(
                self._compare_optimizers(self.trainer.optimizer,
                                         best_optimizer))
            self.assertFalse(
                self._compare_optimizers(self.trainer.optimizer,
                                         best_optimizer,
                                         skip_keys=True))
            self.assertFalse(
                self._compare_optimizers(self.trainer.optimizer,
                                         current_optimizer))
            self.assertTrue(
                self._compare_optimizers(self.trainer.optimizer,
                                         current_optimizer,
                                         skip_keys=True))

            base_0_weight_current = self.trainer.model.base[
                0].weight.data.clone()

            # Test resume_best
            self.trainer.config.checkpoint.resume = True
            self.trainer.config.checkpoint.resume_best = True

            checkpoint.load_state_dict()

            self.assertTrue(
                compare_state_dicts(self.trainer.model.state_dict(),
                                    best_model.state_dict()))
            self.assertFalse(
                self._compare_optimizers(self.trainer.optimizer,
                                         best_optimizer))
            self.assertTrue(
                self._compare_optimizers(self.trainer.optimizer,
                                         best_optimizer,
                                         skip_keys=True))
            self.assertFalse(
                self._compare_optimizers(self.trainer.optimizer,
                                         current_optimizer))
            self.assertFalse(
                self._compare_optimizers(self.trainer.optimizer,
                                         current_optimizer,
                                         skip_keys=True))
            base_0_weight_best = self.trainer.model.base[0].weight.data.clone()

            self.trainer.config.checkpoint.resume_best = False
            # Test distributed settings
            self.trainer.model = torch.nn.DataParallel(self.trainer.model)
            checkpoint.load_state_dict()

            weight_to_be_tested = self.trainer.model.module.base[0].weight
            weight_device = weight_to_be_tested.device

            self.assertTrue(
                torch.equal(weight_to_be_tested,
                            base_0_weight_current.to(weight_device)))
            self.assertFalse(
                torch.equal(weight_to_be_tested,
                            base_0_weight_best.to(weight_device)))
Ejemplo n.º 19
0
 def test_logger_files(self) -> None:
     self.assertTrue(
         PathManager.exists(os.path.join(self._tmpdir, "train.log")))
     self.assertTrue(PathManager.exists(os.path.join(self._tmpdir, "logs")))
Ejemplo n.º 20
0
 def test_on_test_end(self):
     self.cb.on_test_end(report=self.report, meter=self.trainer.meter)
     f = PathManager.open(os.path.join(self.tmpdir, "train.log"))
     self.assertTrue(
         any("Finished run in" in line for line in f.readlines()))
Ejemplo n.º 21
0
 def json_dump(self, filepath):
     with PathManager.open(filepath, "w") as f:
         json.dump(self.report, f)
Ejemplo n.º 22
0
Archivo: logger.py Proyecto: srag21/mmf
    def __init__(self, config, name=None):
        self._logger = None
        self._is_master = is_master()

        self.timer = Timer()
        self.config = config
        self.save_dir = get_mmf_env(key="save_dir")
        self.log_format = config.training.log_format
        self.time_format = "%Y_%m_%dT%H_%M_%S"
        self.log_filename = "train_"
        self.log_filename += self.timer.get_time_hhmmss(None, format=self.time_format)
        self.log_filename += ".log"

        self.log_folder = os.path.join(self.save_dir, "logs")

        env_log_dir = get_mmf_env(key="log_dir")
        if env_log_dir:
            self.log_folder = env_log_dir

        if not PathManager.exists(self.log_folder):
            PathManager.mkdirs(self.log_folder)

        self.log_filename = os.path.join(self.log_folder, self.log_filename)

        if not self._is_master:
            return
        if self._is_master:
            print("Logging to:", self.log_filename)

        logging.captureWarnings(True)

        if not name:
            name = __name__
        self._logger = logging.getLogger(name)
        self._file_only_logger = logging.getLogger(name)
        self._warnings_logger = logging.getLogger("py.warnings")

        # Set level
        level = config.training.logger_level
        self._logger.setLevel(getattr(logging, level.upper()))
        self._file_only_logger.setLevel(getattr(logging, level.upper()))

        # Capture stdout to logger
        self._stdout_logger = None
        if self.config.training.stdout_capture:
            self._stdout_logger = StreamToLogger(
                logging.getLogger("stdout"), getattr(logging, level.upper())
            )
            sys.stdout = self._stdout_logger

        formatter = logging.Formatter(
            "%(asctime)s | %(levelname)s | %(name)s : %(message)s",
            datefmt="%Y-%m-%dT%H:%M:%S",
        )

        # Add handler to file
        channel = logging.StreamHandler(PathManager.open(self.log_filename, mode="a"))
        channel.setFormatter(formatter)
        self.add_handlers(channel)

        # Add handler to train.log. train.log is full log that is also used
        # by slurm/fbl output
        channel = logging.StreamHandler(
            PathManager.open(os.path.join(self.save_dir, "train.log"), mode="a")
        )
        channel.setFormatter(formatter)
        self.add_handlers(channel)

        # Add handler to stdout. Only when we are not capturing stdout in
        # the logger
        if not self._stdout_logger:
            channel = logging.StreamHandler(sys.stdout)
            channel.setFormatter(formatter)

            self._logger.addHandler(channel)
            self._warnings_logger.addHandler(channel)

        should_not_log = self.config.training.should_not_log
        self.should_log = not should_not_log

        # Single log wrapper map
        self._single_log_map = set()