コード例 #1
0
ファイル: test_reporter.py プロジェクト: Mokashaa/mmf
    def __init__(self, multi_task_instance):
        self.test_task = multi_task_instance
        self.task_type = multi_task_instance.dataset_type
        self.config = registry.get("config")
        self.writer = registry.get("writer")
        self.report = []
        self.timer = Timer()
        self.training_config = self.config.training
        self.num_workers = self.training_config.num_workers
        self.batch_size = self.training_config.batch_size
        self.report_folder_arg = get_mmf_env(key="report_dir")
        self.experiment_name = self.training_config.experiment_name

        self.datasets = []

        for dataset in self.test_task.get_datasets():
            self.datasets.append(dataset)

        self.current_dataset_idx = -1
        self.current_dataset = self.datasets[self.current_dataset_idx]

        self.save_dir = get_mmf_env(key="save_dir")
        self.report_folder = ckpt_name_from_core_args(self.config)
        self.report_folder += foldername_from_config_override(self.config)

        self.report_folder = os.path.join(self.save_dir, self.report_folder)
        self.report_folder = os.path.join(self.report_folder, "reports")

        if self.report_folder_arg:
            self.report_folder = self.report_folder_arg

        PathManager.mkdirs(self.report_folder)
コード例 #2
0
def setup_output_folder(folder_only: bool = False):
    """Sets up and returns the output file where the logs will be placed
    based on the configuration passed. Usually "save_dir/logs/log_<timestamp>.txt".
    If env.log_dir is passed, logs will be directly saved in this folder.

    Args:
        folder_only (bool, optional): If folder should be returned and not the file.
            Defaults to False.

    Returns:
        str: folder or file path depending on folder_only flag
    """
    save_dir = get_mmf_env(key="save_dir")
    time_format = "%Y_%m_%dT%H_%M_%S"
    log_filename = "train_"
    log_filename += Timer().get_time_hhmmss(None, format=time_format)
    log_filename += ".log"

    log_folder = os.path.join(save_dir, "logs")

    env_log_dir = get_mmf_env(key="log_dir")
    if env_log_dir:
        log_folder = env_log_dir

    if not PathManager.exists(log_folder):
        PathManager.mkdirs(log_folder)

    if folder_only:
        return log_folder

    log_filename = os.path.join(log_folder, log_filename)

    return log_filename
コード例 #3
0
ファイル: download.py プロジェクト: zeta1999/mmf
def download_pretrained_model(model_name, *args, **kwargs):
    import omegaconf
    from omegaconf import OmegaConf

    from mmf.utils.configuration import load_yaml, get_mmf_env

    model_zoo = load_yaml(get_mmf_env(key="model_zoo"))
    OmegaConf.set_struct(model_zoo, True)
    OmegaConf.set_readonly(model_zoo, True)

    data_dir = get_absolute_path(get_mmf_env("data_dir"))
    model_data_dir = os.path.join(data_dir, "models")
    download_path = os.path.join(model_data_dir, model_name)

    try:
        model_config = OmegaConf.select(model_zoo, model_name)
    except omegaconf.errors.OmegaConfBaseException as e:
        print(f"No such model name {model_name} defined in mmf zoo")
        raise e

    if "version" not in model_config or "resources" not in model_config:
        # Version and Resources are not present time to try the defaults
        try:
            model_config = model_config.defaults
            download_path = os.path.join(model_data_dir, model_name + ".defaults")
        except omegaconf.errors.OmegaConfBaseException as e:
            print(
                f"Model name {model_name} doesn't specify 'resources' and 'version' "
                "while no defaults have been provided"
            )
            raise e

    # Download requirements if any specified by "zoo_requirements" field
    # This can either be a list or a string
    if "zoo_requirements" in model_config:
        requirements = model_config.zoo_requirements
        if isinstance(requirements, str):
            requirements = [requirements]
        for item in requirements:
            download_pretrained_model(item, *args, **kwargs)

    version = model_config.version
    resources = model_config.resources

    if is_master():
        download_resources(resources, download_path, version)
    synchronize()

    return download_path
コード例 #4
0
    def _get_model_folder(self):
        home = str(Path.home())
        data_dir = get_mmf_env(key="data_dir")
        model_folder = os.path.join(home, data_dir, "models",
                                    "mmbt.hateful_memes.images")

        return model_folder
コード例 #5
0
    def __init__(self, config, trainer):
        """
        Attr:
            config(mmf_typings.DictConfig): Config for the callback
            trainer(Type[BaseTrainer]): Trainer object
        """
        super().__init__(config, trainer)

        self.total_timer = Timer()
        self.log_interval = self.training_config.log_interval
        self.evaluation_interval = self.training_config.evaluation_interval
        self.checkpoint_interval = self.training_config.checkpoint_interval

        # Total iterations for snapshot
        self.snapshot_iterations = len(self.trainer.val_dataset)
        self.snapshot_iterations //= self.training_config.batch_size

        self.tb_writer = None

        if self.training_config.tensorboard:
            log_dir = setup_output_folder(folder_only=True)
            env_tb_logdir = get_mmf_env(key="tensorboard_logdir")
            if env_tb_logdir:
                log_dir = env_tb_logdir

            self.tb_writer = TensorboardLogger(log_dir, self.trainer.current_iteration)
コード例 #6
0
    def __init__(self, trainer):
        """
        Generates a path for saving model which can also be used for resuming
        from a checkpoint.
        """
        self.trainer = trainer

        self.config = self.trainer.config
        self.save_dir = get_mmf_env(key="save_dir")
        self.model_name = self.config.model

        self.ckpt_foldername = self.save_dir

        self.device = registry.get("current_device")

        self.ckpt_prefix = ""

        if hasattr(self.trainer.model, "get_ckpt_name"):
            self.ckpt_prefix = self.trainer.model.get_ckpt_name() + "_"

        self.pth_filepath = os.path.join(
            self.ckpt_foldername,
            self.ckpt_prefix + self.model_name + "_final.pth")

        self.models_foldername = os.path.join(self.ckpt_foldername, "models")
        if not PathManager.exists(self.models_foldername):
            PathManager.mkdirs(self.models_foldername)

        self.save_config()
        self.repo_path = updir(os.path.abspath(__file__), n=3)
        self.git_repo = None
        if git:
            self.git_repo = git.Repo(self.repo_path)
コード例 #7
0
    def __init__(self, config: BatchProcessorConfigType, *args, **kwargs):
        extra_params = {"data_dir": get_mmf_env(key="data_dir")}
        processors_dict = config.get("processors", {})

        # Since build_processors also imports processor, import it at runtime to
        # avoid circulat dependencies
        from mmf.utils.build import build_processors

        self.processors = build_processors(processors_dict, **extra_params)
コード例 #8
0
    def __init__(self,
                 multi_task_instance,
                 test_reporter_config: TestReporterConfigType = None):
        if not isinstance(test_reporter_config,
                          TestReporter.TestReporterConfigType):
            test_reporter_config = TestReporter.TestReporterConfigType(
                **test_reporter_config)
        self.test_task = multi_task_instance
        self.task_type = multi_task_instance.dataset_type
        self.config = registry.get("config")
        self.report = []
        self.timer = Timer()
        self.training_config = self.config.training
        self.num_workers = self.training_config.num_workers
        self.batch_size = self.training_config.batch_size
        self.report_folder_arg = get_mmf_env(key="report_dir")
        self.experiment_name = self.training_config.experiment_name
        self.test_reporter_config = test_reporter_config

        self.datasets = []

        for dataset in self.test_task.get_datasets():
            self.datasets.append(dataset)

        self.current_dataset_idx = -1
        self.current_dataset = self.datasets[self.current_dataset_idx]

        self.save_dir = get_mmf_env(key="save_dir")
        self.report_folder = ckpt_name_from_core_args(self.config)
        self.report_folder += foldername_from_config_override(self.config)

        self.report_folder = os.path.join(self.save_dir, self.report_folder)
        self.report_folder = os.path.join(self.report_folder, "reports")

        if self.report_folder_arg:
            self.report_folder = self.report_folder_arg

        self.candidate_fields = DEFAULT_CANDIDATE_FIELDS

        if not test_reporter_config.candidate_fields == MISSING:
            self.candidate_fields = test_reporter_config.candidate_fields

        PathManager.mkdirs(self.report_folder)
コード例 #9
0
    def _load_loggers(self) -> None:
        self.tb_writer = None
        if self.training_config.tensorboard:
            # TODO: @sash PL logger upgrade
            log_dir = setup_output_folder(folder_only=True)
            env_tb_logdir = get_mmf_env(key="tensorboard_logdir")
            if env_tb_logdir:
                log_dir = env_tb_logdir

            self.tb_writer = TensorboardLogger(log_dir)
コード例 #10
0
def gqa_feature_loader():
    global img, img_info
    if img is not None:
        return img, img_info
    path = os.path.join(get_mmf_env("data_dir"), "datasets", "gqa", "defaults",
                        "features")
    h = h5py.File(f'{path}/gqa_spatial.hdf5', 'r')
    img = h['features']
    img_info = json.load(open(f'{path}/gqa_spatial_merged_info.json', 'r'))
    return img, img_info
コード例 #11
0
    def __init__(
        self,
        datamodules: List[pl.LightningDataModule],
        config: Config = None,
        dataset_type: str = "train",
    ):
        self.test_reporter_config = OmegaConf.merge(
            OmegaConf.structured(self.Config), config
        )
        self.datamodules = datamodules
        self.dataset_type = dataset_type
        self.config = registry.get("config")
        self.report = []
        self.timer = Timer()
        self.training_config = self.config.training
        self.num_workers = self.training_config.num_workers
        self.batch_size = self.training_config.batch_size
        self.report_folder_arg = get_mmf_env(key="report_dir")
        self.experiment_name = self.training_config.experiment_name

        self.current_datamodule_idx = -1
        self.dataset_names = list(self.datamodules.keys())
        self.current_datamodule = self.datamodules[
            self.dataset_names[self.current_datamodule_idx]
        ]
        self.current_dataloader = None

        self.save_dir = get_mmf_env(key="save_dir")
        self.report_folder = ckpt_name_from_core_args(self.config)
        self.report_folder += foldername_from_config_override(self.config)

        self.report_folder = os.path.join(self.save_dir, self.report_folder)
        self.report_folder = os.path.join(self.report_folder, "reports")

        if self.report_folder_arg:
            self.report_folder = self.report_folder_arg

        self.candidate_fields = self.test_reporter_config.candidate_fields

        PathManager.mkdirs(self.report_folder)

        log_class_usage("TestReporter", self.__class__)
コード例 #12
0
    def __init__(self, config, trainer):
        """
        Attr:
            config(mmf_typings.DictConfig): Config for the callback
            trainer(Type[BaseTrainer]): Trainer object
        """
        super().__init__(config, trainer)

        self.total_timer = Timer()
        self.log_interval = self.training_config.log_interval
        self.evaluation_interval = self.training_config.evaluation_interval
        self.checkpoint_interval = self.training_config.checkpoint_interval

        # Total iterations for snapshot
        # len would be number of batches per GPU == max updates
        self.snapshot_iterations = len(self.trainer.val_loader)

        self.tb_writer = None

        self.wandb_logger = None

        if self.training_config.tensorboard:
            log_dir = setup_output_folder(folder_only=True)
            env_tb_logdir = get_mmf_env(key="tensorboard_logdir")
            if env_tb_logdir:
                log_dir = env_tb_logdir

            self.tb_writer = TensorboardLogger(log_dir, self.trainer.current_iteration)

        if self.training_config.wandb.enabled:
            log_dir = setup_output_folder(folder_only=True)

            env_wandb_logdir = get_mmf_env(key="wandb_logdir")
            if env_wandb_logdir:
                log_dir = env_wandb_logdir

            self.wandb_logger = WandbLogger(
                entity=config.training.wandb.entity,
                config=config,
                project=config.training.wandb.project,
            )
コード例 #13
0
ファイル: test_env.py プロジェクト: vishalbelsare/pythia
    def test_import_user_module_from_file(self):
        self.assertIsNone(registry.get_builder_class("always_one"))
        self.assertIsNone(registry.get_model_class("simple"))

        user_dir = self._get_user_dir()
        user_file = os.path.join(user_dir, "models", "simple.py")
        import_user_module(user_file)
        # Only model should be found and build should be none
        self.assertIsNone(registry.get_builder_class("always_one"))
        self.assertIsNotNone(registry.get_model_class("simple"))
        self.assertTrue("mmf_user_dir" in sys.modules)
        self.assertTrue(user_dir in get_mmf_env("user_dir"))
コード例 #14
0
 def configure_monitor_callbacks(self) -> List[ModelCheckpoint]:
     criteria, mode = self.monitor_criteria()
     monitor_callback = ModelCheckpoint(
         monitor=criteria,
         dirpath=get_mmf_env(key="save_dir"),
         filename="best",
         mode=mode,
         save_top_k=1,
         save_last=False,
         verbose=True,
     )
     return [monitor_callback]
コード例 #15
0
 def configure_checkpoint_callbacks(self) -> List[ModelCheckpoint]:
     train_callback = ModelCheckpoint(
         monitor=None,
         every_n_train_steps=self.config.training.checkpoint_interval,
         dirpath=get_mmf_env(key="save_dir"),
         filename="models/model_{step}",
         save_top_k=-1,
         save_last=True,
         verbose=True,
     )
     train_callback.CHECKPOINT_NAME_LAST = "current"
     return [train_callback]
コード例 #16
0
    def is_zoo_path(self, path) -> bool:
        from mmf.utils.configuration import get_mmf_env, load_yaml

        model_zoo = load_yaml(get_mmf_env(key="model_zoo"))
        OmegaConf.set_struct(model_zoo, True)
        OmegaConf.set_readonly(model_zoo, True)

        try:
            model_config = OmegaConf.select(model_zoo, path)
            return model_config is not None
        except omegaconf.errors.OmegaConfBaseException:
            return False
コード例 #17
0
 def _create_checkpoint_file(self, path):
     home = str(Path.home())
     data_dir = get_mmf_env(key="data_dir")
     model_folder = os.path.join(home, data_dir, "models",
                                 "mmbt.hateful_memes.images")
     model_file = os.path.join(model_folder, "model.pth")
     config_file = os.path.join(model_folder, "config.yaml")
     config = load_yaml(config_file)
     with PathManager.open(model_file, "rb") as f:
         ckpt = torch.load(f)
     ckpt["config"] = config
     torch.save(ckpt, path)
コード例 #18
0
    def _load_trainer(self):
        lightning_params = self.trainer_config

        with omegaconf.open_dict(lightning_params):
            lightning_params.pop("max_steps")
            lightning_params.pop("max_epochs")

        lightning_params_dict = OmegaConf.to_container(lightning_params,
                                                       resolve=True)
        self.trainer = Trainer(callbacks=self._callbacks,
                               max_steps=self._max_updates,
                               default_root_dir=get_mmf_env(key="log_dir"),
                               **lightning_params_dict)
コード例 #19
0
ファイル: test_env.py プロジェクト: vishalbelsare/pythia
    def test_import_user_module_from_directory_absolute(self, abs_path=True):
        # Make sure the modules are not available first
        self.assertIsNone(registry.get_builder_class("always_one"))
        self.assertIsNone(registry.get_model_class("simple"))
        self.assertFalse("mmf_user_dir" in sys.modules)

        # Now, import and test
        user_dir = self._get_user_dir(abs_path)
        import_user_module(user_dir)
        self.assertIsNotNone(registry.get_builder_class("always_one"))
        self.assertIsNotNone(registry.get_model_class("simple"))
        self.assertTrue("mmf_user_dir" in sys.modules)
        self.assertTrue(user_dir in get_mmf_env("user_dir"))
コード例 #20
0
 def build(self, config, *args, **kwargs):
     # First, check whether manual downloads have been performed
     data_dir = get_mmf_env(key="data_dir")
     test_path = get_absolute_path(
         os.path.join(
             data_dir,
             "annotations",
             "train.jsonl",
         ))
     # NOTE: This doesn't check for files, but that is a fine assumption for now
     assert PathManager.exists(test_path), (
         "Hateful Memes Dataset doesn't do automatic downloads; please " +
         "follow instructions at https://fb.me/hm_prerequisites")
     super().build(config, *args, **kwargs)
コード例 #21
0
ファイル: base_trainer.py プロジェクト: zeta1999/mmf
    def load_extras(self):
        self.writer.write("Torch version is: " + torch.__version__)
        self.checkpoint = Checkpoint(self)
        self.meter = Meter()

        self.training_config = self.config.training

        early_stop_criteria = self.training_config.early_stop.criteria
        early_stop_minimize = self.training_config.early_stop.minimize
        early_stop_enabled = self.training_config.early_stop.enabled
        early_stop_patience = self.training_config.early_stop.patience

        self.log_interval = self.training_config.log_interval
        self.evaluation_interval = self.training_config.evaluation_interval
        self.checkpoint_interval = self.training_config.checkpoint_interval
        self.max_updates = self.training_config.max_updates
        self.should_clip_gradients = self.training_config.clip_gradients
        self.max_epochs = self.training_config.max_epochs

        self.early_stopping = EarlyStopping(
            self.model,
            self.checkpoint,
            early_stop_criteria,
            patience=early_stop_patience,
            minimize=early_stop_minimize,
            should_stop=early_stop_enabled,
        )
        self.current_epoch = 0
        self.current_iteration = 0
        self.num_updates = 0

        self.checkpoint.load_state_dict()

        self.not_debug = self.training_config.logger_level != "debug"

        self.lr_scheduler = None

        if self.training_config.lr_scheduler is True:
            self.lr_scheduler = build_scheduler(self.optimizer, self.config)

        self.tb_writer = None

        if self.training_config.tensorboard:
            log_dir = self.writer.log_dir
            env_tb_logdir = get_mmf_env(key="tensorboard_logdir")
            if env_tb_logdir:
                log_dir = env_tb_logdir

            self.tb_writer = TensorboardLogger(log_dir, self.current_iteration)
コード例 #22
0
ファイル: builder.py プロジェクト: lilyli2004/mmf
 def build(self, config, *args, **kwargs):
     # First, check whether manual downloads have been performed
     data_dir = get_mmf_env(key="data_dir")
     test_path = get_absolute_path(
         os.path.join(
             data_dir,
             "datasets",
             self.dataset_name,
             "defaults",
             "annotations",
             "train.jsonl",
         ))
     # NOTE: This doesn't check for files, but that is a fine assumption for now
     assert PathManager.exists(test_path)
     super().build(config, *args, **kwargs)
コード例 #23
0
    def _download_requirement(self,
                              config,
                              requirement_key,
                              requirement_variation="defaults"):
        version, resources = get_zoo_config(requirement_key,
                                            requirement_variation,
                                            self.zoo_config_path,
                                            self.zoo_type)

        if resources is None:
            return

        requirement_split = requirement_key.split(".")
        dataset_name = requirement_split[0]

        # The dataset variation has been directly passed in the key so use it instead
        if len(requirement_split) >= 2:
            dataset_variation = requirement_split[1]
        else:
            dataset_variation = requirement_variation

        # We want to use root env data_dir so that we don't mix up our download
        # root dir with the dataset ones
        download_path = os.path.join(get_mmf_env("data_dir"), "datasets",
                                     dataset_name, dataset_variation)
        download_path = get_absolute_path(download_path)

        if not isinstance(resources, collections.abc.Mapping):
            self._download_resources(resources, download_path, version)
        else:
            use_features = config.get("use_features", False)
            use_images = config.get("use_images", False)

            if use_features:
                self._download_based_on_attribute(resources, download_path,
                                                  version, "features")

            if use_images:
                self._download_based_on_attribute(resources, download_path,
                                                  version, "images")

            self._download_based_on_attribute(resources, download_path,
                                              version, "annotations")
            self._download_resources(resources.get("extras", []),
                                     download_path, version)
コード例 #24
0
ファイル: general.py プロジェクト: snie2012/mmf
def get_absolute_path(paths):
    # String check should be first as Sequence would pass for string too
    if isinstance(paths, str):
        # If path is absolute return it directly
        if os.path.isabs(paths):
            return paths

        possible_paths = [
            # Direct path
            paths
        ]
        # Now, try relative to user_dir if it exists
        from mmf.utils.configuration import get_mmf_env

        mmf_root = get_mmf_root()
        user_dir = get_mmf_env(key="user_dir")
        if user_dir:
            possible_paths.append(os.path.join(user_dir, paths))
            # check in relative to mmf relative user dir
            possible_paths.append(os.path.join(mmf_root, "..", user_dir,
                                               paths))

        # Relative to root folder of mmf install
        possible_paths.append(os.path.join(mmf_root, "..", paths))
        # Relative to mmf root
        possible_paths.append(os.path.join(mmf_root, paths))

        # Test all these paths, if any exists return
        for path in possible_paths:
            if PathManager.exists(path):
                # URIs
                if path.find("://") == -1:
                    return os.path.abspath(path)
                else:
                    return path

        # If nothing works, return original path so that it throws an error
        return paths
    elif isinstance(paths, collections.abc.Iterable):
        return [get_absolute_path(path) for path in paths]
    else:
        raise TypeError("Paths passed to dataset should either be "
                        "string or list")
コード例 #25
0
    def __init__(self, trainer):
        """
        Generates a path for saving model which can also be used for resuming
        from a checkpoint.
        """
        self.trainer = trainer

        self.config = self.trainer.config
        self.save_dir = get_mmf_env(key="save_dir")
        self.model_name = self.config.model

        self.ckpt_foldername = self.save_dir

        self.device = registry.get("current_device")

        self.ckpt_prefix = ""

        if hasattr(self.trainer.model, "get_ckpt_name"):
            self.ckpt_prefix = self.trainer.model.get_ckpt_name() + "_"

        self.pth_filepath = os.path.join(
            self.ckpt_foldername,
            self.ckpt_prefix + self.model_name + "_final.pth")

        self.models_foldername = os.path.join(self.ckpt_foldername, "models")
        if not PathManager.exists(self.models_foldername):
            PathManager.mkdirs(self.models_foldername)

        self.save_config()

        self.repo_path = updir(os.path.abspath(__file__), n=3)
        self.git_repo = None
        if git and self.config.checkpoint.save_git_details:
            try:
                self.git_repo = git.Repo(self.repo_path)
            except git.exc.InvalidGitRepositoryError:
                # Not a git repo, don't do anything
                pass

        self.max_to_keep = self.config.checkpoint.max_to_keep
        self.saved_iterations = []
コード例 #26
0
    def __init__(self, config, dataset_type, imdb_file_index, *args, **kwargs):
        super().__init__(config,
                         dataset_type,
                         imdb_file_index,
                         dataset_name="visual_genome",
                         *args,
                         **kwargs)

        self._return_scene_graph = config.return_scene_graph
        self._return_objects = config.return_objects
        self._return_relationships = config.return_relationships
        self._return_region_descriptions = config.return_region_descriptions
        self._no_unk = config.get("no_unk", False)
        self.scene_graph_db = None
        self.region_descriptions_db = None
        self.image_metadata_db = None
        self._max_feature = config.max_features

        build_scene_graph_db = (self._return_scene_graph
                                or self._return_objects
                                or self._return_relationships)
        # print("config", config)
        if self._return_region_descriptions:
            print("use_region_descriptions_true")
            self.region_descriptions_db = self.build_region_descriptions_db()
            self.image_metadata_db = self.build_image_metadata_db()

        if build_scene_graph_db:
            scene_graph_file = config.scene_graph_files[dataset_type][
                imdb_file_index]
            print("scene_graph_file", scene_graph_file)
            # scene_graph_file = self._get_absolute_path(scene_graph_file)
            scene_graph_file = get_absolute_path(
                get_mmf_env("data_dir") + "/" + scene_graph_file)
            print("scene_graph_file", scene_graph_file)
            self.scene_graph_db = SceneGraphDatabase(config, scene_graph_file)
            print("use_scene_graph_true")
            self.scene_graph_db = self.build_scene_graph_db()
コード例 #27
0
    def get_checkpoint_data(self) -> Dict[str, Any]:
        """This function gets checkpoint file path on disk from
        config.trainer.params.resume_from_checkpoint. However if it not specified,
        it gets checkpoint path from config.checkpoint. If config.resume is specified
        it gets the latest checkpoint from the config's save directory (alternatively it
        gets the best checkpoint if config.resume_best is True). If config.resume is not
        specified, then it gets config.resume_file or the checkpoint file from
        config.resume_zoo (in that order).

        Returns:
            Dict[str, Any]: a dict containing the following keys,
            `checkpoint_path` (str) local file path for the checkpoint;
            `ckpt` (Dict[str, Any])
            `is_zoo` (Bool) whether or not the checkpoint is specified through a
                zoo identifier
            `config` (Dict[str, Any]]) the config that is stored together with this
                checkpoint
        """
        # get ckpt file path from config.trainer.params.resume_from_checkpoint
        path = self.config.trainer.params.get("resume_from_checkpoint", None)
        if path is not None:
            is_zoo = self.is_zoo_path(path)
            ckpt_filepath = path
            if is_zoo:
                folder = download_pretrained_model(path)
                ckpt_filepath = get_ckpt_path_from_folder(folder)
                ckpt = get_ckpt_from_path(ckpt_filepath)
                config = get_config_from_folder_or_ckpt(folder, ckpt)
            else:
                ckpt = get_ckpt_from_path(ckpt_filepath)
                config = None

            return {
                "ckpt": ckpt,
                "checkpoint_path": ckpt_filepath,
                "is_zoo": is_zoo,
                "config": config,
            }

        is_zoo = False
        config = None
        ckpt = None
        # get ckpt file path from config.checkpoint
        ckpt_config = self.config.checkpoint
        suffix = "best.ckpt" if ckpt_config.resume_best else "current.ckpt"
        path = os.path.join(get_mmf_env(key="save_dir"), suffix)
        ckpt_filepath = None
        resume_from_specified_path = (ckpt_config.resume_file is not None
                                      or ckpt_config.resume_zoo is not None
                                      ) and (not ckpt_config.resume
                                             or not PathManager.exists(path))
        if resume_from_specified_path:
            if ckpt_config.resume_file and PathManager.exists(
                    ckpt_config.resume_file):
                ckpt_filepath = ckpt_config.resume_file
            elif ckpt_config.resume_zoo is not None:
                is_zoo = True
                folder = download_pretrained_model(ckpt_config.resume_zoo)
                ckpt_filepath = get_ckpt_path_from_folder(folder)
                ckpt = get_ckpt_from_path(ckpt_filepath)
                config = get_config_from_folder_or_ckpt(folder, ckpt)
            else:
                raise RuntimeError(f"{ckpt_config.resume_file} doesn't exist")

        if ckpt_config.resume and PathManager.exists(path):
            ckpt_filepath = path

        if ckpt_filepath is not None:
            ckpt = get_ckpt_from_path(ckpt_filepath)

        return {
            "ckpt": ckpt,
            "checkpoint_path": ckpt_filepath,
            "is_zoo": is_zoo,
            "config": config,
        }
コード例 #28
0
ファイル: logger.py プロジェクト: zeta1999/mmf
    def __init__(self, config, name=None):
        self.logger = None
        self._is_master = is_master()

        self.timer = Timer()
        self.config = config
        self.save_dir = get_mmf_env(key="save_dir")
        self.log_format = config.training.log_format
        self.time_format = "%Y-%m-%dT%H:%M:%S"
        self.log_filename = "train_"
        self.log_filename += self.timer.get_time_hhmmss(None, format=self.time_format)
        self.log_filename += ".log"

        self.log_folder = os.path.join(self.save_dir, "logs")

        env_log_dir = get_mmf_env(key="log_dir")
        if env_log_dir:
            self.log_folder = env_log_dir

        if not PathManager.exists(self.log_folder):
            PathManager.mkdirs(self.log_folder)

        self.log_filename = os.path.join(self.log_folder, self.log_filename)

        if not self._is_master:
            return
        if self._is_master:
            print("Logging to:", self.log_filename)

        logging.captureWarnings(True)

        if not name:
            name = __name__
        self.logger = logging.getLogger(name)
        self._file_only_logger = logging.getLogger(name)
        warnings_logger = logging.getLogger("py.warnings")

        # Set level
        level = config.training.logger_level
        self.logger.setLevel(getattr(logging, level.upper()))
        self._file_only_logger.setLevel(getattr(logging, level.upper()))

        formatter = logging.Formatter(
            "%(asctime)s %(levelname)s: %(message)s", datefmt="%Y-%m-%dT%H:%M:%S"
        )

        # Add handler to file
        channel = logging.FileHandler(filename=self.log_filename, mode="a")
        channel.setFormatter(formatter)

        self.logger.addHandler(channel)
        self._file_only_logger.addHandler(channel)
        warnings_logger.addHandler(channel)

        # Add handler to stdout
        channel = logging.StreamHandler(sys.stdout)
        channel.setFormatter(formatter)

        self.logger.addHandler(channel)
        warnings_logger.addHandler(channel)

        should_not_log = self.config.training.should_not_log
        self.should_log = not should_not_log

        # Single log wrapper map
        self._single_log_map = set()
コード例 #29
0
    def __init__(self, config, name=None):
        self._logger = None
        self._is_master = is_master()

        self.timer = Timer()
        self.config = config
        self.save_dir = get_mmf_env(key="save_dir")
        self.log_format = config.training.log_format
        self.time_format = "%Y_%m_%dT%H_%M_%S"
        self.log_filename = "train_"
        self.log_filename += self.timer.get_time_hhmmss(
            None, format=self.time_format)
        self.log_filename += ".log"

        self.log_folder = os.path.join(self.save_dir, "logs")

        env_log_dir = get_mmf_env(key="log_dir")
        if env_log_dir:
            self.log_folder = env_log_dir

        if not PathManager.exists(self.log_folder):
            PathManager.mkdirs(self.log_folder)

        self.log_filename = os.path.join(self.log_folder, self.log_filename)

        if not self._is_master:
            return
        if self._is_master:
            print("Logging to:", self.log_filename)

        logging.captureWarnings(True)

        if not name:
            name = __name__
        self._logger = logging.getLogger(name)
        self._file_only_logger = logging.getLogger(name)
        self._warnings_logger = logging.getLogger("py.warnings")

        # Set level
        level = config.training.logger_level
        self._logger.setLevel(getattr(logging, level.upper()))
        self._file_only_logger.setLevel(getattr(logging, level.upper()))

        # Capture stdout to logger
        self._stdout_logger = None
        if self.config.training.stdout_capture:
            self._stdout_logger = StreamToLogger(
                logging.getLogger("stdout"), getattr(logging, level.upper()))
            sys.stdout = self._stdout_logger

        formatter = logging.Formatter(
            "%(asctime)s | %(levelname)s | %(name)s : %(message)s",
            datefmt="%Y-%m-%dT%H:%M:%S",
        )

        # Add handler to file
        channel = logging.FileHandler(filename=self.log_filename, mode="a")
        channel.setFormatter(formatter)
        self.add_handlers(channel)

        # Add handler to train.log. train.log is full log that is also used
        # by slurm/fbl output
        channel = logging.FileHandler(filename=os.path.join(
            self.save_dir, "train.log"),
                                      mode="a")
        channel.setFormatter(formatter)
        self.add_handlers(channel)

        # Add handler to stdout. Only when we are not capturing stdout in
        # the logger
        if not self._stdout_logger:
            channel = logging.StreamHandler(sys.stdout)
            channel.setFormatter(formatter)

            self._logger.addHandler(channel)
            self._warnings_logger.addHandler(channel)

        should_not_log = self.config.training.should_not_log
        self.should_log = not should_not_log

        # Single log wrapper map
        self._single_log_map = set()
コード例 #30
0
def setup_logger(
    output: str = None,
    color: bool = True,
    name: str = "mmf",
    disable: bool = False,
    clear_handlers=True,
    *args,
    **kwargs,
):
    """
    Initialize the MMF logger and set its verbosity level to "INFO".
    Outside libraries shouldn't call this in case they have set there
    own logging handlers and setup. If they do, and don't want to
    clear handlers, pass clear_handlers options.

    The initial version of this function was taken from D2 and adapted
    for MMF.

    Args:
        output (str): a file name or a directory to save log.
            If ends with ".txt" or ".log", assumed to be a file name.
            Default: Saved to file <save_dir/logs/log_[timestamp].txt>
        color (bool): If false, won't log colored logs. Default: true
        name (str): the root module name of this logger. Defaults to "mmf".
        clear_handlers (bool): If false, won't clear existing handlers.

    Returns:
        logging.Logger: a logger
    """
    if disable:
        return None
    logger = logging.getLogger(name)
    logger.propagate = False

    logging.captureWarnings(True)
    warnings_logger = logging.getLogger("py.warnings")

    plain_formatter = logging.Formatter(
        "%(asctime)s | %(levelname)s | %(name)s : %(message)s",
        datefmt="%Y-%m-%dT%H:%M:%S",
    )

    distributed_rank = get_rank()
    handlers = []

    logging_level = registry.get("config").training.logger_level.upper()
    if distributed_rank == 0:
        logger.setLevel(logging_level)
        ch = logging.StreamHandler(stream=sys.stdout)
        ch.setLevel(logging_level)
        if color:
            formatter = ColorfulFormatter(
                colored("%(asctime)s | %(name)s: ", "green") + "%(message)s",
                datefmt="%Y-%m-%dT%H:%M:%S",
            )
        else:
            formatter = plain_formatter
        ch.setFormatter(formatter)
        logger.addHandler(ch)
        warnings_logger.addHandler(ch)
        handlers.append(ch)

    # file logging: all workers
    if output is None:
        output = setup_output_folder()

    if output is not None:
        if output.endswith(".txt") or output.endswith(".log"):
            filename = output
        else:
            filename = os.path.join(output, "train.log")
        if distributed_rank > 0:
            filename = filename + f".rank{distributed_rank}"
        PathManager.mkdirs(os.path.dirname(filename))

        fh = logging.StreamHandler(_cached_log_stream(filename))
        fh.setLevel(logging_level)
        fh.setFormatter(plain_formatter)
        logger.addHandler(fh)
        warnings_logger.addHandler(fh)
        handlers.append(fh)

        # Slurm/FB output, only log the main process
        if "train.log" not in filename and distributed_rank == 0:
            save_dir = get_mmf_env(key="save_dir")
            filename = os.path.join(save_dir, "train.log")
            sh = logging.StreamHandler(_cached_log_stream(filename))
            sh.setLevel(logging_level)
            sh.setFormatter(plain_formatter)
            logger.addHandler(sh)
            warnings_logger.addHandler(sh)
            handlers.append(sh)

        logger.info(f"Logging to: {filename}")

    # Remove existing handlers to add MMF specific handlers
    if clear_handlers:
        for handler in logging.root.handlers[:]:
            logging.root.removeHandler(handler)
    # Now, add our handlers.
    logging.basicConfig(level=logging_level, handlers=handlers)

    registry.register("writer", logger)

    return logger