コード例 #1
0
ファイル: build.py プロジェクト: xinyuliu828/mmf
def build_config(configuration: Configuration, *args, **kwargs) -> DictConfig:
    """Builder function for config. Freezes the configuration and registers
    configuration object and config DictConfig object to registry.

    Args:
        configuration (Configuration): Configuration object that will be
            used to create the config.

    Returns:
        (DictConfig): A config which is of type omegaconf.DictConfig
    """
    configuration.freeze()
    config = configuration.get_config()
    registry.register("config", config)
    registry.register("configuration", configuration)

    return config
コード例 #2
0
    def setUp(self):
        self.tmpdir = tempfile.mkdtemp()
        self.trainer = argparse.Namespace()
        self.config = load_yaml(os.path.join("configs", "defaults.yaml"))
        self.config = OmegaConf.merge(
            self.config,
            {
                "model": "simple",
                "model_config": {},
                "training": {
                    "checkpoint_interval": 1,
                    "evaluation_interval": 10,
                    "early_stop": {
                        "criteria": "val/total_loss"
                    },
                    "batch_size": 16,
                    "log_interval": 10,
                    "logger_level": "info",
                },
                "env": {
                    "save_dir": self.tmpdir
                },
            },
        )
        # Keep original copy for testing purposes
        self.trainer.config = deepcopy(self.config)
        registry.register("config", self.trainer.config)
        setup_logger()
        self.report = Mock(spec=Report)
        self.report.dataset_name = "abcd"
        self.report.dataset_type = "test"

        self.trainer.model = SimpleModule()
        self.trainer.val_loader = torch.utils.data.DataLoader(
            NumbersDataset(), batch_size=self.config.training.batch_size)

        self.trainer.optimizer = torch.optim.Adam(
            self.trainer.model.parameters(), lr=1e-01)
        self.trainer.device = "cpu"
        self.trainer.num_updates = 0
        self.trainer.current_iteration = 0
        self.trainer.current_epoch = 0
        self.trainer.max_updates = 0
        self.trainer.meter = Meter()
        self.cb = LogisticsCallback(self.config, self.trainer)
コード例 #3
0
    def __init__(self,
                 config,
                 max_steps,
                 max_epochs=None,
                 callback=None,
                 num_data_size=100,
                 batch_size=1,
                 accumulate_grad_batches=1,
                 lr_scheduler=False,
                 gradient_clip_val=0.0,
                 precision=32,
                 **kwargs):
        self.config = config
        self._callbacks = None
        if callback:
            self._callbacks = [callback]

        # data
        self.data_module = MultiDataModuleNumbersTestObject(
            num_data=num_data_size, batch_size=batch_size)

        # settings
        trainer_config = self.config.trainer.params
        trainer_config.accumulate_grad_batches = accumulate_grad_batches
        trainer_config.precision = precision
        trainer_config.max_steps = max_steps
        trainer_config.max_epochs = max_epochs
        trainer_config.gradient_clip_val = gradient_clip_val
        trainer_config.precision = precision

        for key, value in kwargs.items():
            trainer_config[key] = value

        self.trainer_config = trainer_config
        self.training_config = self.config.training
        self.training_config.batch_size = batch_size
        self.run_type = self.config.get("run_type", "train")
        self.config.training.lr_scheduler = lr_scheduler
        registry.register("config", self.config)

        self.data_module = MultiDataModuleNumbersTestObject(
            num_data=num_data_size, batch_size=batch_size)
        self.train_loader = self.data_module.train_dataloader()
        self.val_loader = self.data_module.val_dataloader()
        self.test_loader = self.data_module.test_dataloader()
コード例 #4
0
    def configure_device(self) -> None:
        if self.config.training.get("device", "cuda") == "xla":
            import torch_xla.core.xla_model as xm

            self.device = xm.xla_device()
            self.distributed = True
            self.local_rank = xm.get_local_ordinal()
            is_xla = True
        else:
            is_xla = False
            if "device_id" not in self.config:
                warnings.warn(
                    "No 'device_id' in 'config', setting to -1. "
                    "This can cause issues later in training. Ensure that "
                    "distributed setup is properly initialized.")
                self.local_rank = -1
            else:
                self.local_rank = self.config.device_id
            self.device = self.local_rank
            self.distributed = False

        # Will be updated later based on distributed setup
        registry.register("global_device", self.device)

        if self.config.distributed.init_method is not None:
            self.distributed = True
            self.device = torch.device("cuda", self.local_rank)
            torch.cuda.set_device(self.local_rank)
        elif torch.cuda.is_available():
            self.device = torch.device("cuda")
            torch.cuda.set_device(0)
        elif not is_xla:
            self.device = torch.device("cpu")

        if "rank" not in self.config.distributed:
            if torch.distributed.is_available(
            ) and torch.distributed.is_initialized():
                global_rank = torch.distributed.get_rank()
            else:
                global_rank = -1
            with open_dict(self.config.distributed):
                self.config.distributed.rank = global_rank

        registry.register("global_device", self.config.distributed.rank)
コード例 #5
0
    def configure_device(self) -> None:
        self.local_rank = self.config.device_id
        self.device = self.local_rank
        self.distributed = False

        # Will be updated later based on distributed setup
        registry.register("global_device", self.device)

        if self.config.distributed.init_method is not None:
            self.distributed = True
            self.device = torch.device("cuda", self.local_rank)
            torch.cuda.set_device(self.local_rank)
        elif torch.cuda.is_available():
            self.device = torch.device("cuda")
            torch.cuda.set_device(0)
        else:
            self.device = torch.device("cpu")

        registry.register("global_device", self.config.distributed.rank)
コード例 #6
0
    def setUp(self):
        self.trainer = argparse.Namespace()
        self.config = load_yaml(os.path.join("configs", "defaults.yaml"))
        self.config = OmegaConf.merge(
            self.config,
            {
                "model": "simple",
                "model_config": {},
                "training": {
                    "lr_scheduler": True,
                    "lr_ratio": 0.1,
                    "lr_steps": [1, 2],
                    "use_warmup": False,
                    "callbacks": [{
                        "type": "test_callback",
                        "params": {}
                    }],
                },
            },
        )
        # Keep original copy for testing purposes
        self.trainer.config = deepcopy(self.config)
        registry.register("config", self.trainer.config)

        model = SimpleModel(SimpleModel.Config())
        model.build()
        self.trainer.model = model
        self.trainer.val_loader = torch.utils.data.DataLoader(
            NumbersDataset(2), batch_size=self.config.training.batch_size)

        self.trainer.optimizer = torch.optim.Adam(
            self.trainer.model.parameters(), lr=1e-01)
        self.trainer.lr_scheduler_callback = LRSchedulerCallback(
            self.config, self.trainer)

        self.trainer.callbacks = []
        for callback in self.config.training.get("callbacks", []):
            callback_type = callback.type
            callback_param = callback.params
            callback_cls = registry.get_callback_class(callback_type)
            self.trainer.callbacks.append(
                callback_cls(self.trainer.config, self.trainer,
                             **callback_param))
コード例 #7
0
    def parallelize_model(self):
        training = self.config.training
        if (
            "cuda" in str(self.device)
            and torch.cuda.device_count() > 1
            and not self.distributed
        ):
            registry.register("data_parallel", True)
            self.model = torch.nn.DataParallel(self.model)

        if "cuda" in str(self.device) and self.distributed:
            registry.register("distributed", True)
            self.model = torch.nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[self.local_rank],
                output_device=self.local_rank,
                check_reduction=True,
                find_unused_parameters=training.find_unused_parameters,
            )
コード例 #8
0
 def setUp(self):
     setup_imports()
     torch.manual_seed(1234)
     config_path = os.path.join(
         get_mmf_root(),
         "..",
         "projects",
         "butd",
         "configs",
         "coco",
         "beam_search.yaml",
     )
     config_path = os.path.abspath(config_path)
     args = dummy_args(model="butd", dataset="coco")
     args.opts.append(f"config={config_path}")
     configuration = Configuration(args)
     configuration.config.datasets = "coco"
     configuration.freeze()
     self.config = configuration.config
     registry.register("config", self.config)
コード例 #9
0
    def _init_processors(self):
        args = Namespace()
        args.opts = [
            "config=projects/pythia/configs/vqa2/defaults.yaml",
            "datasets=vqa2", "model=visual_bert", "evaluation.predict=True"
        ]
        args.config_override = None

        configuration = Configuration(args=args)

        config = self.config = configuration.config
        vqa2_config = config.dataset_config.vqa2
        text_processor_config = vqa2_config.processors.text_processor

        text_processor_config.params.vocab.vocab_file = "../model_data/vocabulary_100k.txt"

        # Add preprocessor as that will needed when we are getting questions from user
        self.text_processor = VocabProcessor(text_processor_config.params)

        registry.register("coco_text_processor", self.text_processor)
コード例 #10
0
 def setUp(self):
     setup_imports()
     torch.manual_seed(1234)
     config_path = os.path.join(
         get_mmf_root(),
         "..",
         "projects",
         "butd",
         "configs",
         "coco",
         "nucleus_sampling.yaml",
     )
     config_path = os.path.abspath(config_path)
     args = dummy_args(model="butd", dataset="coco")
     args.opts.append(f"config={config_path}")
     configuration = Configuration(args)
     configuration.config.datasets = "coco"
     configuration.config.model_config.butd.inference.params.sum_threshold = 0.5
     configuration.freeze()
     self.config = configuration.config
     registry.register("config", self.config)
コード例 #11
0
def ready_trainer(trainer):
    from mmf.common.registry import registry
    from mmf.utils.logger import Logger, TensorboardLogger
    trainer.run_type = trainer.config.get("run_type", "train")
    writer = registry.get("writer", no_warning=True)
    if writer:
        trainer.writer = writer
    else:
        trainer.writer = Logger(trainer.config)
        registry.register("writer", trainer.writer)

    trainer.configure_device()
    trainer.configure_seed()
    trainer.load_model()
    from mmf.trainers.callbacks.checkpoint import CheckpointCallback
    from mmf.trainers.callbacks.early_stopping import EarlyStoppingCallback
    trainer.checkpoint_callback = CheckpointCallback(trainer.config, trainer)
    trainer.early_stop_callback = EarlyStoppingCallback(
        trainer.config, trainer)
    trainer.callbacks.append(trainer.checkpoint_callback)
    trainer.on_init_start()
コード例 #12
0
ファイル: base_trainer.py プロジェクト: zeta1999/mmf
    def load_model_and_optimizer(self):
        attributes = self.config.model_config[self.config.model]
        # Easy way to point to config for other model
        if isinstance(attributes, str):
            attributes = self.config.model_config[attributes]

        with omegaconf.open_dict(attributes):
            attributes.model = self.config.model

        self.model = build_model(attributes)

        if "cuda" in str(self.device):
            device_info = "CUDA Device {} is: {}".format(
                self.config.distributed.rank,
                torch.cuda.get_device_name(self.local_rank),
            )
            registry.register("global_device", self.config.distributed.rank)
            self.writer.write(device_info, log_all=True)

        self.model = self.model.to(self.device)
        self.optimizer = build_optimizer(self.model, self.config)

        registry.register("data_parallel", False)
        registry.register("distributed", False)

        self.load_extras()
        self.parallelize_model()
コード例 #13
0
ファイル: device.py プロジェクト: EXYNOS-999/DeepMeMes
    def configure_device(self) -> None:
        self.local_rank = self.config.device_id
        self.device = self.local_rank
        self.distributed = False

        # Will be updated later based on distributed setup
        registry.register("global_device", self.device)

        if self.config.distributed.init_method is not None:
            self.distributed = True
            self.device = torch.device("cuda", self.local_rank)
        elif torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")

        registry.register("current_device", self.device)

        if "cuda" in str(self.device):
            device_info = "CUDA Device {} is: {}".format(
                self.config.distributed.rank,
                torch.cuda.get_device_name(self.local_rank),
            )
            registry.register("global_device", self.config.distributed.rank)
            self.writer.write(device_info, log_all=True)
コード例 #14
0
    def _load_counts(self, ckpt):
        ckpt_config = self.trainer.config.checkpoint
        if "best_update" in ckpt:
            if ckpt_config.resume_best:
                self.trainer.num_updates = ckpt.get("best_update",
                                                    self.trainer.num_updates)
                self.trainer.current_iteration = ckpt.get(
                    "best_iteration", self.trainer.current_iteration)
            else:
                self.trainer.num_updates = ckpt.get("num_updates",
                                                    self.trainer.num_updates)
                self.trainer.current_iteration = ckpt.get(
                    "current_iteration", self.trainer.current_iteration)

            self.trainer.current_epoch = ckpt.get("current_epoch",
                                                  self.trainer.current_epoch)
        elif "best_iteration" in ckpt:
            # Preserve old behavior for old checkpoints where we always
            # load best iteration
            if ckpt_config.resume_best and "current_iteration" in ckpt:
                self.trainer.current_iteration = ckpt["current_iteration"]
            else:
                self.trainer.current_iteration = ckpt.get(
                    "best_iteration", self.trainer.current_iteration)

            self.trainer.num_updates = self.trainer.current_iteration

        registry.register("current_iteration", self.trainer.current_iteration)
        registry.register("num_updates", self.trainer.num_updates)

        self.trainer.current_epoch = ckpt.get("best_epoch",
                                              self.trainer.current_epoch)
        registry.register("current_epoch", self.trainer.current_epoch)
コード例 #15
0
ファイル: test_bert.py プロジェクト: MetaVQA/MetaVQA
    def _init_processors(self):
        args = Namespace()
        args.opts = [
            "config=projects/visual_bert/configs/vqa2/defaults.yaml",
            "datasets=vqa2",
            "model=visual_bert",
            "evaluation.predict=True"
        ]
        args.config_override = None

        configuration = Configuration(args=args)

        config = self.config = configuration.config
        vqa_config = config.dataset_config.vqa2
        text_processor_config = vqa_config.processors.text_processor
        answer_processor_config = vqa_config.processors.answer_processor

        text_processor_config.params.vocab.vocab_file = self.root + "/content/model_data/vocabulary_100k.txt"
        answer_processor_config.params.vocab_file = self.root + "/content/model_data/answers_vqa.txt"
        # Add preprocessor as that will needed when we are getting questions from user
        self.text_processor = BertTokenizer(text_processor_config.params)
        self.answer_processor = VQAAnswerProcessor(answer_processor_config.params)

        registry.register("vqa2_text_processor", self.text_processor)
        registry.register("vqa2_answer_processor", self.answer_processor)
        registry.register("vqa2_num_final_outputs", 
                          self.answer_processor.get_vocab_size())
コード例 #16
0
ファイル: base_trainer.py プロジェクト: zeta1999/mmf
    def load(self):
        self._set_device()

        self.run_type = self.config.get("run_type", "train")
        self.dataset_loader = DatasetLoader(self.config)
        self._datasets = self.config.datasets

        # Check if loader is already defined, else init it
        writer = registry.get("writer", no_warning=True)
        if writer:
            self.writer = writer
        else:
            self.writer = Logger(self.config)
            registry.register("writer", self.writer)

        self.configuration.pretty_print()

        self.config_based_setup()

        self.load_datasets()
        self.load_model_and_optimizer()
        self.load_metrics()
コード例 #17
0
def main(configuration, init_distributed=False):
    # A reload might be needed for imports
    setup_imports()
    configuration.import_user_dir()
    config = configuration.get_config()

    if torch.cuda.is_available():
        torch.cuda.set_device(config.device_id)
        torch.cuda.init()

    if init_distributed:
        distributed_init(config)

    config.training.seed = set_seed(config.training.seed)
    registry.register("seed", config.training.seed)
    print("Using seed {}".format(config.training.seed))

    registry.register("writer", Logger(config, name="mmf.train"))

    trainer = build_trainer(configuration)
    trainer.load()
    trainer.train()
コード例 #18
0
def setup_imports():
    from mmf.common.registry import registry

    # First, check if imports are already setup
    has_already_setup = registry.get("imports_setup", no_warning=True)
    if has_already_setup:
        return
    # Automatically load all of the modules, so that
    # they register with registry
    root_folder = registry.get("mmf_root", no_warning=True)

    if root_folder is None:
        root_folder = os.path.dirname(os.path.abspath(__file__))
        root_folder = os.path.join(root_folder, "..")

        environment_mmf_path = os.environ.get("MMF_PATH",
                                              os.environ.get("PYTHIA_PATH"))

        if environment_mmf_path is not None:
            root_folder = environment_mmf_path

        registry.register("pythia_path", root_folder)
        registry.register("mmf_path", root_folder)

    trainer_folder = os.path.join(root_folder, "trainers")
    trainer_pattern = os.path.join(trainer_folder, "**", "*.py")
    datasets_folder = os.path.join(root_folder, "datasets")
    datasets_pattern = os.path.join(datasets_folder, "**", "*.py")
    model_folder = os.path.join(root_folder, "models")
    common_folder = os.path.join(root_folder, "common")
    modules_folder = os.path.join(root_folder, "modules")
    print(modules_folder)
    model_pattern = os.path.join(model_folder, "**", "*.py")
    common_pattern = os.path.join(common_folder, "**", "*.py")
    modules_pattern = os.path.join(modules_folder, "**", "*.py")

    importlib.import_module("mmf.common.meter")

    files = (glob.glob(datasets_pattern, recursive=True) +
             glob.glob(model_pattern, recursive=True) +
             glob.glob(trainer_pattern, recursive=True) +
             glob.glob(common_pattern, recursive=True) +
             glob.glob(modules_pattern, recursive=True))

    for f in files:
        f = os.path.realpath(f)
        if f.endswith(".py") and not f.endswith("__init__.py"):
            splits = f.split(os.sep)
            import_prefix_index = 0
            for idx, split in enumerate(splits):
                if split == "mmf":
                    import_prefix_index = idx + 1
            file_name = splits[-1]
            module_name = file_name[:file_name.find(".py")]
            module = ".".join(["mmf"] + splits[import_prefix_index:-1] +
                              [module_name])
            importlib.import_module(module)

    registry.register("imports_setup", True)
コード例 #19
0
ファイル: test_metrics.py プロジェクト: zhang703652632/mmf
    def test_caption_bleu4(self):
        path = os.path.join(
            os.path.abspath(__file__),
            "../../../mmf/configs/datasets/coco/defaults.yaml",
        )
        config = load_yaml(os.path.abspath(path))
        captioning_config = config.dataset_config.coco
        caption_processor_config = captioning_config.processors.caption_processor
        vocab_path = os.path.join(os.path.abspath(__file__), "..", "..",
                                  "data", "vocab.txt")
        caption_processor_config.params.vocab.type = "random"
        caption_processor_config.params.vocab.vocab_file = os.path.abspath(
            vocab_path)
        caption_processor = CaptionProcessor(caption_processor_config.params)
        registry.register("coco_caption_processor", caption_processor)

        caption_bleu4 = metrics.CaptionBleu4Metric()
        expected = Sample()
        predicted = dict()

        # Test complete match
        expected.answers = torch.empty((5, 5, 10))
        expected.answers.fill_(4)
        predicted["scores"] = torch.zeros((5, 10, 19))
        predicted["scores"][:, :, 4] = 1.0

        self.assertEqual(
            caption_bleu4.calculate(expected, predicted).item(), 1.0)

        # Test partial match
        expected.answers = torch.empty((5, 5, 10))
        expected.answers.fill_(4)
        predicted["scores"] = torch.zeros((5, 10, 19))
        predicted["scores"][:, 0:5, 4] = 1.0
        predicted["scores"][:, 5:, 18] = 1.0

        self.assertAlmostEqual(
            caption_bleu4.calculate(expected, predicted).item(), 0.3928, 4)
コード例 #20
0
    def setUp(self):
        self.trainer = argparse.Namespace()
        self.config = OmegaConf.create({
            "model": "simple",
            "model_config": {},
            "training": {
                "lr_scheduler": True,
                "lr_ratio": 0.1,
                "lr_steps": [1, 2],
                "use_warmup": False,
            },
        })
        # Keep original copy for testing purposes
        self.trainer.config = deepcopy(self.config)
        registry.register("config", self.trainer.config)

        self.trainer.model = SimpleModule()
        self.trainer.val_dataset = NumbersDataset()

        self.trainer.optimizer = torch.optim.Adam(
            self.trainer.model.parameters(), lr=1e-01)
        self.trainer.lr_scheduler_callback = LRSchedulerCallback(
            self.config, self.trainer)
コード例 #21
0
ファイル: meter.py プロジェクト: SunYanCN/pythia
    def update_from_report(self, report, should_update_loss=True):
        """
        this method updates the provided meter with report info.
        this method by default handles reducing metrics.

        Args:
            report (Report): report object which content is used to populate
            the current meter

        Usage::

        >>> meter = Meter()
        >>> report = Report(prepared_batch, model_output)
        >>> meter.update_from_report(report)
        """
        if hasattr(report, "metrics"):
            metrics_dict = report.metrics
            reduced_metrics_dict = reduce_dict(metrics_dict)

        if should_update_loss:
            loss_dict = report.losses
            reduced_loss_dict = reduce_dict(loss_dict)

        with torch.no_grad():
            meter_update_dict = {}
            if should_update_loss:
                meter_update_dict = scalarize_dict_values(reduced_loss_dict)
                total_loss_key = report.dataset_type + "/total_loss"
                total_loss = sum(meter_update_dict.values())
                registry.register(total_loss_key, total_loss)
                meter_update_dict.update({total_loss_key: total_loss})

            if hasattr(report, "metrics"):
                metrics_dict = scalarize_dict_values(reduced_metrics_dict)
                meter_update_dict.update(**metrics_dict)

            self._update(meter_update_dict, report.batch_size)
コード例 #22
0
 def update_registry_for_model(self, config):
     if hasattr(self.dataset, "text_processor"):
         registry.register(
             self.dataset_name + "_text_vocab_size",
             self.dataset.text_processor.get_vocab_size(),
         )
         registry.register(
             f"{self.dataset_name}_text_processor", self.dataset.text_processor
         )
     if hasattr(self.dataset, "answer_processor"):
         registry.register(
             self.dataset_name + "_num_final_outputs",
             self.dataset.answer_processor.get_vocab_size(),
         )
         registry.register(
             f"{self.dataset_name}_answer_processor", self.dataset.answer_processor
         )
コード例 #23
0
    def _load_counts_and_lr_scheduler(self, ckpt):
        ckpt_config = self.trainer.config.checkpoint
        if "best_update" in ckpt:
            if ckpt_config.resume_best:
                self.trainer.num_updates = ckpt.get(
                    "best_update", self.trainer.num_updates
                )
                self.trainer.current_iteration = ckpt.get(
                    "best_iteration", self.trainer.current_iteration
                )
            else:
                self.trainer.num_updates = ckpt.get(
                    "num_updates", self.trainer.num_updates
                )
                self.trainer.current_iteration = ckpt.get(
                    "current_iteration", self.trainer.current_iteration
                )

            self.trainer.current_epoch = ckpt.get(
                "current_epoch", self.trainer.current_epoch
            )
        elif "best_iteration" in ckpt:
            # Preserve old behavior for old checkpoints where we always
            # load best iteration
            if ckpt_config.resume_best and "current_iteration" in ckpt:
                self.trainer.current_iteration = ckpt["current_iteration"]
            else:
                self.trainer.current_iteration = ckpt.get(
                    "best_iteration", self.trainer.current_iteration
                )

            self.trainer.num_updates = self.trainer.current_iteration

        lr_scheduler = self.trainer.lr_scheduler_callback

        if (
            lr_scheduler is not None
            and getattr(lr_scheduler, "_scheduler", None) is not None
        ):
            lr_scheduler = lr_scheduler._scheduler

            if "lr_scheduler" in ckpt:
                lr_scheduler.load_state_dict(ckpt["lr_scheduler"])
            else:
                warnings.warn(
                    "'lr_scheduler' key is not present in the "
                    "checkpoint asked to be loaded. Setting lr_scheduler's "
                    "last_epoch to current_iteration."
                )
                lr_scheduler.last_epoch = self.trainer.current_iteration

        registry.register("current_iteration", self.trainer.current_iteration)
        registry.register("num_updates", self.trainer.num_updates)

        self.trainer.current_epoch = ckpt.get("best_epoch", self.trainer.current_epoch)
        registry.register("current_epoch", self.trainer.current_epoch)
コード例 #24
0
    def __init__(self, config, num_data_size=100, **kwargs):
        super().__init__(config)

        self.config = config
        self.callbacks = []

        # settings
        trainer_config = self.config.trainer.params
        self.trainer_config = trainer_config
        self.training_config = self.config.training

        for key, value in kwargs.items():
            trainer_config[key] = value

        # data
        self.data_module = MultiDataModuleNumbersTestObject(
            config=config, num_data=num_data_size)

        self.run_type = self.config.get("run_type", "train")
        registry.register("config", self.config)

        self.train_loader = self.data_module.train_dataloader()
        self.val_loader = self.data_module.val_dataloader()
        self.test_loader = self.data_module.test_dataloader()
コード例 #25
0
ファイル: configuration.py プロジェクト: Mokashaa/mmf
    def __init__(self, args=None, default_only=False):
        self.config = {}

        if not args:
            import argparse

            args = argparse.Namespace(opts=[])
            default_only = True

        self.args = args
        self._register_resolvers()

        self._default_config = self._build_default_config()

        if default_only:
            other_configs = {}
        else:
            other_configs = self._build_other_configs()

        self.config = OmegaConf.merge(self._default_config, other_configs)

        self.config = self._merge_with_dotlist(self.config, args.opts)
        self._update_specific(self.config)
        registry.register("config", self.config)
コード例 #26
0
    def test_batch_size_per_device(self, a):
        # Need to patch the mmf.utils.general's world size not mmf.utils.distributed
        # as the first one is what will be used
        with patch("mmf.utils.general.get_world_size", return_value=2):
            config = self._get_config(max_updates=2,
                                      max_epochs=None,
                                      batch_size=4)
            trainer = TrainerTrainingLoopMock(config=config)
            add_model(trainer, SimpleModel({"in_dim": 1}))
            add_optimizer(trainer, config)
            registry.register("config", trainer.config)
            batch_size = get_batch_size()
            trainer.config.training.batch_size = batch_size
            trainer.load_datasets()
            # Train loader has batch size per device, for global batch size 4
            # with world size 2, batch size per device should 4 // 2 = 2
            self.assertEqual(trainer.train_loader.current_loader.batch_size, 2)
            # This is per device, so should stay same
            config = self._get_config(max_updates=2,
                                      max_epochs=None,
                                      batch_size_per_device=4)
            trainer = TrainerTrainingLoopMock(config=config)
            add_model(trainer, SimpleModel({"in_dim": 1}))
            add_optimizer(trainer, config)
            registry.register("config", trainer.config)
            batch_size = get_batch_size()
            trainer.config.training.batch_size = batch_size
            trainer.load_datasets()
            self.assertEqual(trainer.train_loader.current_loader.batch_size, 4)

        max_updates = trainer._calculate_max_updates()
        self.assertEqual(max_updates, 2)

        self.check_values(trainer, 0, 0, 0)
        trainer.training_loop()
        self.check_values(trainer, 2, 1, 2)
コード例 #27
0
    def parallelize_model(self) -> None:
        registry.register("data_parallel", False)
        registry.register("distributed", False)
        if ("cuda" in str(self.device) and torch.cuda.device_count() > 1
                and not self.distributed):
            registry.register("data_parallel", True)
            self.model = torch.nn.DataParallel(self.model)

        if "cuda" in str(self.device) and self.distributed:
            registry.register("distributed", True)
            set_torch_ddp = True
            try:
                from fairscale.nn.data_parallel import ShardedDataParallel
                from fairscale.optim.oss import OSS

                if isinstance(self.optimizer, OSS):
                    self.model = ShardedDataParallel(self.model,
                                                     self.optimizer)
                    set_torch_ddp = False
                    logger.info("Using FairScale ShardedDataParallel")
            except ImportError:
                logger.info("Using PyTorch DistributedDataParallel")
                warnings.warn(
                    "You can enable ZeRO and Sharded DDP, by installing fairscale "
                    + "and setting optimizer.enable_state_sharding=True.")

            if set_torch_ddp:
                self.model = torch.nn.parallel.DistributedDataParallel(
                    self.model,
                    device_ids=[self.local_rank],
                    output_device=self.local_rank,
                    find_unused_parameters=self.config.training.
                    find_unused_parameters,
                )

        if is_xla() and get_world_size() > 1:
            broadcast_xla_master_model_param(self.model)
コード例 #28
0
ファイル: m4c.py プロジェクト: zhangshengHust/mmf
    def update_registry_for_pretrained(cls, config, checkpoint, full_output):
        from omegaconf import OmegaConf

        # Hack datasets using OmegaConf
        datasets = full_output["full_config"].datasets
        dataset = datasets.split(",")[0]
        config_mock = OmegaConf.create({"datasets": datasets})
        registry.register("config", config_mock)
        registry.register(
            f"{dataset}_num_final_outputs",
            # Need to add as it is subtracted
            checkpoint["classifier.module.weight"].size(0) +
            config.classifier.ocr_max_num,
        )
        # Fix this later, when processor pipeline is available
        answer_processor = OmegaConf.create({"BOS_IDX": 1})
        registry.register(f"{dataset}_answer_processor", answer_processor)
コード例 #29
0
ファイル: configuration.py プロジェクト: hivestrung/mmf
    def __init__(self, args=None, default_only=False, load_dataset=True):
        self.config = {}

        if not args:
            import argparse

            args = argparse.Namespace(opts=[])
            default_only = True

        self.args = args
        self._register_resolvers()

        self._default_config = self._build_default_config()

        # Initially, silently add opts so that some of the overrides for the defaults
        # from command line required for setup can be honored
        self._default_config = _merge_with_dotlist(self._default_config,
                                                   args.opts,
                                                   skip_missing=True,
                                                   log_info=False)
        # Register the config and configuration for setup
        registry.register("config", self._default_config)
        registry.register("configuration", self)

        if default_only:
            other_configs = {}
        else:
            other_configs = self._build_other_configs(
                load_dataset=load_dataset)

        self.config = OmegaConf.merge(self._default_config, other_configs)

        self.config = _merge_with_dotlist(self.config, args.opts)
        self._update_specific(self.config)
        self.upgrade(self.config)
        # Resolve the config here itself after full creation so that spawned workers
        # don't face any issues
        self.config = OmegaConf.create(
            OmegaConf.to_container(self.config, resolve=True))

        # Update the registry with final config
        registry.register("config", self.config)
コード例 #30
0
ファイル: test_cnn_lstm.py プロジェクト: vishalbelsare/pythia
 def setUp(self):
     torch.manual_seed(1234)
     registry.register("clevr_text_vocab_size", 80)
     registry.register("clevr_num_final_outputs", 32)
     config_path = os.path.join(
         get_mmf_root(),
         "..",
         "projects",
         "others",
         "cnn_lstm",
         "clevr",
         "defaults.yaml",
     )
     config_path = os.path.abspath(config_path)
     args = dummy_args(model="cnn_lstm", dataset="clevr")
     args.opts.append(f"config={config_path}")
     configuration = Configuration(args)
     configuration.config.datasets = "clevr"
     configuration.freeze()
     self.config = configuration.config
     registry.register("config", self.config)