Beispiel #1
0
def main(configuration, init_distributed=False, predict=False):
    # A reload might be needed for imports
    setup_imports()
    configuration.import_user_dir()
    config = configuration.get_config()

    if torch.cuda.is_available():
        torch.cuda.set_device(config.device_id)
        torch.cuda.init()

    if init_distributed:
        distributed_init(config)

    seed = config.training.seed
    config.training.seed = set_seed(seed if seed == -1 else seed + get_rank())
    registry.register("seed", config.training.seed)

    config = build_config(configuration)

    setup_logger(color=config.training.colored_logs,
                 disable=config.training.should_not_log)
    logger = logging.getLogger("multimodelity_cli.run")
    # Log args for debugging purposes
    logger.info(configuration.args)
    logger.info(f"Torch version: {torch.__version__}")
    log_device_names()
    logger.info(f"Using seed {config.training.seed}")

    trainer = build_trainer(config)
    trainer.load()
    if predict:
        trainer.inference()
    else:
        trainer.train()
Beispiel #2
0
    def get_data_t(self, t, data, batch_size_t, prev_output):
        if self.teacher_forcing:
            # Modify batch_size for timestep t
            batch_size_t = sum([l > t for l in data["decode_lengths"]])
        elif prev_output is not None and self.config.inference.type == "greedy":
            # Adding t-1 output words to data["text"] for greedy decoding
            output_softmax = torch.log_softmax(prev_output, dim=1)
            _, indices = torch.max(output_softmax, dim=1, keepdim=True)
            data["texts"] = torch.cat(
                (data["texts"], indices.view(batch_size_t, 1)), dim=1
            )

        # Slice data based on batch_size at timestep t
        data["texts"] = data["texts"][:batch_size_t]
        if "state" in data:
            h1 = data["state"]["td_hidden"][0][:batch_size_t]
            c1 = data["state"]["td_hidden"][1][:batch_size_t]
            h2 = data["state"]["lm_hidden"][0][:batch_size_t]
            c2 = data["state"]["lm_hidden"][1][:batch_size_t]
        else:
            h1, c1 = self.init_hidden_state(data["texts"])
            h2, c2 = self.init_hidden_state(data["texts"])
        data["state"] = {"td_hidden": (h1, c1), "lm_hidden": (h2, c2)}
        registry.register(f"{h1.device}_lstm_state", data["state"])

        return data, batch_size_t
Beispiel #3
0
    def update_meter(self,
                     report: Dict[str, Any],
                     meter: Type[Meter] = None,
                     eval_mode: bool = False) -> None:
        if meter is None:
            meter = self.meter

        if hasattr(report, "metrics"):
            metrics_dict = report.metrics
            reduced_metrics_dict = reduce_dict(metrics_dict)

        if not eval_mode:
            loss_dict = report.losses
            reduced_loss_dict = reduce_dict(loss_dict)

        with torch.no_grad():
            # Add metrics to meter only when mode is `eval`
            meter_update_dict = {}
            if not eval_mode:
                total_loss_key = report.dataset_type + "/total_loss"
                meter_update_dict, total_loss = self.update_dict(
                    meter_update_dict, reduced_loss_dict)
                registry.register(total_loss_key, total_loss)
                meter_update_dict.update({total_loss_key: total_loss})

            if hasattr(report, "metrics"):
                meter_update_dict, _ = self.update_dict(
                    meter_update_dict, reduced_metrics_dict)

            meter.update(meter_update_dict, report.batch_size)
Beispiel #4
0
    def __call__(self, sample_list, model_output, *args, **kwargs):
        values = {}

        dataset_type = sample_list.dataset_type
        dataset_name = sample_list.dataset_name

        with torch.no_grad():
            for metric_name, metric_object in self.metrics.items():
                key = f"{dataset_type}/{dataset_name}/{metric_name}"
                values[key] = metric_object._calculate_with_checks(
                    sample_list, model_output, *args, **kwargs
                )

                if not isinstance(values[key], torch.Tensor):
                    values[key] = torch.tensor(values[key], dtype=torch.float)
                else:
                    values[key] = values[key].float()

                if values[key].dim() == 0:
                    values[key] = values[key].view(1)

        registry.register(
            "{}.{}.{}".format("metrics", sample_list.dataset_name, dataset_type), values
        )

        return values
Beispiel #5
0
    def __init__(self, args=None, default_only=False):
        self.config = {}

        if not args:
            import argparse

            args = argparse.Namespace(opts=[])
            default_only = True

        self.args = args
        self._register_resolvers()

        self._default_config = self._build_default_config()

        if default_only:
            other_configs = {}
        else:
            other_configs = self._build_other_configs()

        self.config = OmegaConf.merge(self._default_config, other_configs)

        self.config = self._merge_with_dotlist(self.config, args.opts)
        self._update_specific(self.config)
        self.upgrade(self.config)
        # Resolve the config here itself after full creation so that spawned workers
        # don't face any issues
        self.config = OmegaConf.create(
            OmegaConf.to_container(self.config, resolve=True))
        registry.register("config", self.config)
Beispiel #6
0
    def forward(self, sample_list: Dict[str, Tensor],
                model_output: Dict[str, Tensor]):
        """Takes in the original ``SampleList`` returned from DataLoader
        and `model_output` returned from the model and returned a Dict containing
        loss for each of the losses in `losses`.

        Args:
            sample_list (SampleList): SampleList given be the dataloader.
            model_output (Dict): Dict returned from model as output.

        Returns:
            Dict: Dictionary containing loss value for each of the loss.

        """
        output = {}
        if "targets" not in sample_list:
            if not self._evaluation_predict:
                warnings.warn("Sample list has not field 'targets', are you "
                              "sure that your ImDB has labels? you may have "
                              "wanted to run with evaluation.predict=true")
            return output

        for loss in self.losses:
            output.update(loss(sample_list, model_output))

        if not torch.jit.is_scripting():
            registry_loss_key = "{}.{}.{}".format("losses",
                                                  sample_list["dataset_name"],
                                                  sample_list["dataset_type"])
            # Register the losses to registry
            registry.register(registry_loss_key, output)

        return output
Beispiel #7
0
 def update_registry_for_model(self, config):
     registry.register(
         self.dataset_name + "_text_vocab_size",
         self.dataset.text_processor.get_vocab_size(),
     )
     registry.register(
         self.dataset_name + "_num_final_outputs",
         self.dataset.answer_processor.get_vocab_size(),
     )
Beispiel #8
0
def get_multimodelity_root():
    from multimodelity.common.registry import registry

    multimodelity_root = registry.get("multimodelity_root", no_warning=True)
    if multimodelity_root is None:
        multimodelity_root = os.path.dirname(os.path.abspath(__file__))
        multimodelity_root = os.path.abspath(
            os.path.join(multimodelity_root, ".."))
        registry.register("multimodelity_root", multimodelity_root)
    return multimodelity_root
Beispiel #9
0
def get_global_config(key=None):
    config = registry.get("config")
    if config is None:
        configuration = Configuration()
        config = configuration.get_config()
        registry.register("config", config)

    if key:
        config = OmegaConf.select(config, key)

    return config
Beispiel #10
0
    def get_data_t(self, data, batch_size_t):
        data["texts"] = data["texts"][:batch_size_t]
        if "state" in data:
            h1 = data["state"]["td_hidden"][0][:batch_size_t]
            c1 = data["state"]["td_hidden"][1][:batch_size_t]
            h2 = data["state"]["lm_hidden"][0][:batch_size_t]
            c2 = data["state"]["lm_hidden"][1][:batch_size_t]
        else:
            h1, c1 = self.init_hidden_state(data["texts"])
            h2, c2 = self.init_hidden_state(data["texts"])
        data["state"] = {"td_hidden": (h1, c1), "lm_hidden": (h2, c2)}
        registry.register(f"{h1.device}_lstm_state", data["state"])

        return data, batch_size_t
Beispiel #11
0
 def setUpClass(cls) -> None:
     cls._tmpdir = tempfile.mkdtemp()
     args = argparse.Namespace()
     args.opts = [
         f"env.save_dir={cls._tmpdir}", f"model=cnn_lstm", f"dataset=clevr"
     ]
     args.config_override = None
     configuration = Configuration(args)
     configuration.freeze()
     cls.config = configuration.get_config()
     registry.register("config", cls.config)
     setup_output_folder.cache_clear()
     setup_logger.cache_clear()
     cls.writer = setup_logger()
Beispiel #12
0
    def configure_device(self) -> None:
        self.local_rank = self.config.device_id
        self.device = self.local_rank
        self.distributed = False

        # Will be updated later based on distributed setup
        registry.register("global_device", self.device)

        if self.config.distributed.init_method is not None:
            self.distributed = True
            self.device = torch.device("cuda", self.local_rank)
            torch.cuda.set_device(self.local_rank)
        elif torch.cuda.is_available():
            self.device = torch.device("cuda")
            torch.cuda.set_device(0)
        else:
            self.device = torch.device("cpu")

        registry.register("global_device", self.config.distributed.rank)
Beispiel #13
0
def build_config(
    configuration: Type[Configuration], *args, **kwargs
) -> multimodelity_typings.DictConfig:
    """Builder function for config. Freezes the configuration and registers
    configuration object and config DictConfig object to registry.

    Args:
        configuration (Configuration): Configuration object that will be
            used to create the config.

    Returns:
        (DictConfig): A config which is of type Omegaconf.DictConfig
    """
    configuration.freeze()
    config = configuration.get_config()
    registry.register("config", config)
    registry.register("configuration", configuration)

    return config
Beispiel #14
0
 def setUp(self):
     setup_imports()
     torch.manual_seed(1234)
     config_path = os.path.join(
         get_multimodelity_root(),
         "..",
         "projects",
         "butd",
         "configs",
         "coco",
         "beam_search.yaml",
     )
     config_path = os.path.abspath(config_path)
     args = dummy_args(model="butd", dataset="coco")
     args.opts.append(f"config={config_path}")
     configuration = Configuration(args)
     configuration.config.datasets = "coco"
     configuration.freeze()
     self.config = configuration.config
     registry.register("config", self.config)
Beispiel #15
0
    def setUp(self):
        self.tmpdir = tempfile.mkdtemp()
        self.trainer = argparse.Namespace()
        self.config = OmegaConf.create({
            "model": "simple",
            "model_config": {},
            "training": {
                "checkpoint_interval": 1,
                "evaluation_interval": 10,
                "early_stop": {
                    "criteria": "val/total_loss"
                },
                "batch_size": 16,
                "log_interval": 10,
                "logger_level": "info",
            },
            "env": {
                "save_dir": self.tmpdir
            },
        })
        # Keep original copy for testing purposes
        self.trainer.config = deepcopy(self.config)
        registry.register("config", self.trainer.config)
        setup_logger.cache_clear()
        setup_logger()
        self.report = Mock(spec=Report)
        self.report.dataset_name = "abcd"
        self.report.dataset_type = "test"

        self.trainer.model = SimpleModule()
        self.trainer.val_dataset = NumbersDataset()

        self.trainer.optimizer = torch.optim.Adam(
            self.trainer.model.parameters(), lr=1e-01)
        self.trainer.device = "cpu"
        self.trainer.num_updates = 0
        self.trainer.current_iteration = 0
        self.trainer.current_epoch = 0
        self.trainer.max_updates = 0
        self.trainer.meter = Meter()
        self.cb = LogisticsCallback(self.config, self.trainer)
Beispiel #16
0
 def setUp(self):
     setup_imports()
     torch.manual_seed(1234)
     config_path = os.path.join(
         get_multimodelity_root(),
         "..",
         "projects",
         "butd",
         "configs",
         "coco",
         "nucleus_sampling.yaml",
     )
     config_path = os.path.abspath(config_path)
     args = dummy_args(model="butd", dataset="coco")
     args.opts.append(f"config={config_path}")
     configuration = Configuration(args)
     configuration.config.datasets = "coco"
     configuration.config.model_config.butd.inference.params.sum_threshold = 0.5
     configuration.freeze()
     self.config = configuration.config
     registry.register("config", self.config)
Beispiel #17
0
    def test_caption_bleu4(self):
        path = os.path.join(
            os.path.abspath(__file__),
            "../../../multimodelity/configs/datasets/coco/defaults.yaml",
        )
        config = load_yaml(os.path.abspath(path))
        captioning_config = config.dataset_config.coco
        caption_processor_config = captioning_config.processors.caption_processor
        vocab_path = os.path.join(os.path.abspath(__file__), "..", "..",
                                  "data", "vocab.txt")
        caption_processor_config.params.vocab.type = "random"
        caption_processor_config.params.vocab.vocab_file = os.path.abspath(
            vocab_path)
        caption_processor = CaptionProcessor(caption_processor_config.params)
        registry.register("coco_caption_processor", caption_processor)

        caption_bleu4 = metrics.CaptionBleu4Metric()
        expected = Sample()
        predicted = dict()

        # Test complete match
        expected.answers = torch.empty((5, 5, 10))
        expected.answers.fill_(4)
        predicted["scores"] = torch.zeros((5, 10, 19))
        predicted["scores"][:, :, 4] = 1.0

        self.assertEqual(
            caption_bleu4.calculate(expected, predicted).item(), 1.0)

        # Test partial match
        expected.answers = torch.empty((5, 5, 10))
        expected.answers.fill_(4)
        predicted["scores"] = torch.zeros((5, 10, 19))
        predicted["scores"][:, 0:5, 4] = 1.0
        predicted["scores"][:, 5:, 18] = 1.0

        self.assertAlmostEqual(
            caption_bleu4.calculate(expected, predicted).item(), 0.3928, 4)
Beispiel #18
0
    def setUp(self):
        self.trainer = argparse.Namespace()
        self.config = OmegaConf.create({
            "model": "simple",
            "model_config": {},
            "training": {
                "lr_scheduler": True,
                "lr_ratio": 0.1,
                "lr_steps": [1, 2],
                "use_warmup": False,
            },
        })
        # Keep original copy for testing purposes
        self.trainer.config = deepcopy(self.config)
        registry.register("config", self.trainer.config)

        self.trainer.model = SimpleModule()
        self.trainer.val_dataset = NumbersDataset()

        self.trainer.optimizer = torch.optim.Adam(
            self.trainer.model.parameters(), lr=1e-01)
        self.trainer.lr_scheduler_callback = LRSchedulerCallback(
            self.config, self.trainer)
Beispiel #19
0
    def parallelize_model(self) -> None:
        registry.register("data_parallel", False)
        registry.register("distributed", False)
        if ("cuda" in str(self.device) and torch.cuda.device_count() > 1
                and not self.distributed):
            registry.register("data_parallel", True)
            self.model = torch.nn.DataParallel(self.model)

        if "cuda" in str(self.device) and self.distributed:
            registry.register("distributed", True)
            self.model = torch.nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[self.local_rank],
                output_device=self.local_rank,
                check_reduction=True,
                find_unused_parameters=self.config.training.
                find_unused_parameters,
            )
Beispiel #20
0
def setup_imports():
    from multimodelity.common.registry import registry

    # First, check if imports are already setup
    has_already_setup = registry.get("imports_setup", no_warning=True)
    if has_already_setup:
        return
    # Automatically load all of the modules, so that
    # they register with registry
    root_folder = registry.get("multimodelity_root", no_warning=True)

    if root_folder is None:
        root_folder = os.path.dirname(os.path.abspath(__file__))
        root_folder = os.path.join(root_folder, "..")

        environment_multimodelity_path = os.environ.get(
            "multimodelity_PATH", os.environ.get("PYTHIA_PATH"))

        if environment_multimodelity_path is not None:
            root_folder = environment_multimodelity_path

        registry.register("pythia_path", root_folder)
        registry.register("multimodelity_path", root_folder)

    trainer_folder = os.path.join(root_folder, "trainers")
    trainer_pattern = os.path.join(trainer_folder, "**", "*.py")
    datasets_folder = os.path.join(root_folder, "datasets")
    datasets_pattern = os.path.join(datasets_folder, "**", "*.py")
    model_folder = os.path.join(root_folder, "models")
    model_pattern = os.path.join(model_folder, "**", "*.py")

    importlib.import_module("multimodelity.common.meter")

    files = (glob.glob(datasets_pattern, recursive=True) +
             glob.glob(model_pattern, recursive=True) +
             glob.glob(trainer_pattern, recursive=True))

    for f in files:
        f = os.path.realpath(f)
        if f.endswith(".py") and not f.endswith("__init__.py"):
            splits = f.split(os.sep)
            import_prefix_index = 0
            for idx, split in enumerate(splits):
                if split == "multimodelity":
                    import_prefix_index = idx + 1
            file_name = splits[-1]
            module_name = file_name[:file_name.find(".py")]
            module = ".".join(["multimodelity"] +
                              splits[import_prefix_index:-1] + [module_name])
            importlib.import_module(module)

    registry.register("imports_setup", True)
Beispiel #21
0
    def update_registry_for_pretrained(cls, config, checkpoint, full_output):
        from omegaconf import OmegaConf

        # Hack datasets using OmegaConf
        datasets = full_output["full_config"].datasets
        dataset = datasets.split(",")[0]
        config_mock = OmegaConf.create({"datasets": datasets})
        registry.register("config", config_mock)
        registry.register(
            f"{dataset}_num_final_outputs",
            # Need to add as it is subtracted
            checkpoint["classifier.module.weight"].size(0) +
            config.classifier.ocr_max_num,
        )
        # Fix this later, when processor pipeline is available
        answer_processor = OmegaConf.create({"BOS_IDX": 1})
        registry.register(f"{dataset}_answer_processor", answer_processor)
Beispiel #22
0
    def _load_counts_and_lr_scheduler(self, ckpt):
        ckpt_config = self.trainer.config.checkpoint
        if "best_update" in ckpt:
            if ckpt_config.resume_best:
                self.trainer.num_updates = ckpt.get("best_update",
                                                    self.trainer.num_updates)
                self.trainer.current_iteration = ckpt.get(
                    "best_iteration", self.trainer.current_iteration)
            else:
                self.trainer.num_updates = ckpt.get("num_updates",
                                                    self.trainer.num_updates)
                self.trainer.current_iteration = ckpt.get(
                    "current_iteration", self.trainer.current_iteration)

            self.trainer.current_epoch = ckpt.get("current_epoch",
                                                  self.trainer.current_epoch)
        elif "best_iteration" in ckpt:
            # Preserve old behavior for old checkpoints where we always
            # load best iteration
            if ckpt_config.resume_best and "current_iteration" in ckpt:
                self.trainer.current_iteration = ckpt["current_iteration"]
            else:
                self.trainer.current_iteration = ckpt.get(
                    "best_iteration", self.trainer.current_iteration)

            self.trainer.num_updates = self.trainer.current_iteration

        lr_scheduler = self.trainer.lr_scheduler_callback._scheduler
        if lr_scheduler is not None:
            if "lr_scheduler" in ckpt:
                lr_scheduler.load_state_dict(ckpt["lr_scheduler"])
            else:
                warnings.warn(
                    "'lr_scheduler' key is not present in the "
                    "checkpoint asked to be loaded. Setting lr_scheduler's "
                    "last_epoch to current_iteration.")
                lr_scheduler.last_epoch = self.trainer.current_iteration

        registry.register("current_iteration", self.trainer.current_iteration)
        registry.register("num_updates", self.trainer.num_updates)

        self.trainer.current_epoch = ckpt.get("best_epoch",
                                              self.trainer.current_epoch)
        registry.register("current_epoch", self.trainer.current_epoch)
Beispiel #23
0
 def setUp(self):
     torch.manual_seed(1234)
     registry.register("clevr_text_vocab_size", 80)
     registry.register("clevr_num_final_outputs", 32)
     config_path = os.path.join(
         get_multimodelity_root(),
         "..",
         "projects",
         "others",
         "cnn_lstm",
         "clevr",
         "defaults.yaml",
     )
     config_path = os.path.abspath(config_path)
     args = dummy_args(model="cnn_lstm", dataset="clevr")
     args.opts.append(f"config={config_path}")
     configuration = Configuration(args)
     configuration.config.datasets = "clevr"
     configuration.freeze()
     self.config = configuration.config
     registry.register("config", self.config)
Beispiel #24
0
 def update_registry_for_model(self, config):
     registry.register(
         self.dataset_name + "_text_vocab_size",
         self.dataset.masked_token_processor.get_vocab_size(),
     )
Beispiel #25
0
    def run_training_epoch(self) -> None:
        should_break = False
        while self.num_updates < self.max_updates and not should_break:
            self.current_epoch += 1
            registry.register("current_epoch", self.current_epoch)

            # Seed the sampler in case if it is distributed
            self.dataset_loader.seed_sampler("train", self.current_epoch)

            # For iterable datasets we cannot determine length of dataset properly.
            # For those cases we set num_remaining_batches to be the (number of
            # updates remaining x update_frequency)
            num_remaining_batches = (((self.max_updates - self.num_updates) *
                                      self.training_config.update_frequency)
                                     if isinstance(
                                         self.train_loader.current_dataset,
                                         torch.utils.data.IterableDataset) else
                                     len(self.train_loader))

            combined_report = None
            num_batches_for_this_update = 1
            for idx, batch in enumerate(self.train_loader):

                if (idx + 1) % self.training_config.update_frequency == 0:
                    combined_report = None
                    num_batches_for_this_update = min(
                        self.training_config.update_frequency,
                        num_remaining_batches)

                    self._start_update()

                # batch execution starts here
                self.on_batch_start()
                self.profile("Batch load time")

                report = self.run_training_batch(batch,
                                                 num_batches_for_this_update)

                # accumulate necessary params for metric calculation
                if combined_report is None:
                    combined_report = report
                else:
                    combined_report.accumulate_tensor_fields(
                        report, self.metrics.required_params)
                    combined_report.batch_size += report.batch_size

                # batch execution ends here
                self.on_batch_end(report=combined_report, meter=self.meter)

                # check if an update has finished, if no continue
                if (idx + 1) % self.training_config.update_frequency:
                    continue

                self._finish_update()

                should_log = False
                if self.num_updates % self.logistics_callback.log_interval == 0:
                    should_log = True
                    # Calculate metrics every log interval for debugging
                    if self.training_config.evaluate_metrics:
                        combined_report.metrics = self.metrics(
                            combined_report, combined_report)
                    self.update_meter(combined_report, self.meter)

                self.on_update_end(report=combined_report,
                                   meter=self.meter,
                                   should_log=should_log)

                num_remaining_batches -= num_batches_for_this_update

                # Check if training should be stopped
                should_break = False

                if self.num_updates % self.training_config.evaluation_interval == 0:
                    # Validation begin callbacks
                    self.on_validation_start()

                    logger.info(
                        "Evaluation time. Running on full validation set...")
                    # Validation and Early stopping
                    # Create a new meter for this case
                    report, meter = self.evaluation_loop(self.val_loader)

                    # Validation end callbacks
                    stop = self.early_stop_callback.on_validation_end(
                        report=report, meter=meter)
                    self.on_validation_end(report=report, meter=meter)

                    gc.collect()

                    if "cuda" in str(self.device):
                        torch.cuda.empty_cache()

                    if stop is True:
                        logger.info("Early stopping activated")
                        should_break = True
                if self.num_updates >= self.max_updates:
                    should_break = True

                if should_break:
                    break
Beispiel #26
0
def setup_logger(
    output: str = None,
    color: bool = True,
    name: str = "multimodelity",
    disable: bool = False,
    clear_handlers=True,
    *args,
    **kwargs,
):
    """
    Initialize the multimodelity logger and set its verbosity level to "INFO".
    Outside libraries shouldn't call this in case they have set there
    own logging handlers and setup. If they do, and don't want to
    clear handlers, pass clear_handlers options.

    The initial version of this function was taken from D2 and adapted
    for multimodelity.

    Args:
        output (str): a file name or a directory to save log.
            If ends with ".txt" or ".log", assumed to be a file name.
            Default: Saved to file <save_dir/logs/log_[timestamp].txt>
        color (bool): If false, won't log colored logs. Default: true
        name (str): the root module name of this logger. Defaults to "multimodelity".
        clear_handlers (bool): If false, won't clear existing handlers.

    Returns:
        logging.Logger: a logger
    """
    if disable:
        return None
    logger = logging.getLogger(name)
    logger.propagate = False

    logging.captureWarnings(True)
    warnings_logger = logging.getLogger("py.warnings")

    plain_formatter = logging.Formatter(
        "%(asctime)s | %(levelname)s | %(name)s : %(message)s",
        datefmt="%Y-%m-%dT%H:%M:%S",
    )

    distributed_rank = get_rank()
    handlers = []

    if distributed_rank == 0:
        logger.setLevel(logging.INFO)
        ch = logging.StreamHandler(stream=sys.stdout)
        ch.setLevel(logging.INFO)
        if color:
            formatter = ColorfulFormatter(
                colored("%(asctime)s | %(name)s: ", "green") + "%(message)s",
                datefmt="%Y-%m-%dT%H:%M:%S",
            )
        else:
            formatter = plain_formatter
        ch.setFormatter(formatter)
        logger.addHandler(ch)
        warnings_logger.addHandler(ch)
        handlers.append(ch)

    # file logging: all workers
    if output is None:
        output = setup_output_folder()

    if output is not None:
        if output.endswith(".txt") or output.endswith(".log"):
            filename = output
        else:
            filename = os.path.join(output, "train.log")
        if distributed_rank > 0:
            filename = filename + f".rank{distributed_rank}"
        PathManager.mkdirs(os.path.dirname(filename))

        fh = logging.StreamHandler(_cached_log_stream(filename))
        fh.setLevel(logging.INFO)
        fh.setFormatter(plain_formatter)
        logger.addHandler(fh)
        warnings_logger.addHandler(fh)
        handlers.append(fh)

        # Slurm/FB output, only log the main process
        if "train.log" not in filename and distributed_rank == 0:
            save_dir = get_multimodelity_env(key="save_dir")
            filename = os.path.join(save_dir, "train.log")
            sh = logging.StreamHandler(_cached_log_stream(filename))
            sh.setLevel(logging.INFO)
            sh.setFormatter(plain_formatter)
            logger.addHandler(sh)
            warnings_logger.addHandler(sh)
            handlers.append(sh)

        logger.info(f"Logging to: {filename}")

    # Remove existing handlers to add multimodelity specific handlers
    if clear_handlers:
        for handler in logging.root.handlers[:]:
            logging.root.removeHandler(handler)
    # Now, add our handlers.
    logging.basicConfig(level=logging.INFO, handlers=handlers)

    registry.register("writer", logger)

    return logger