コード例 #1
0
def main(configuration, init_distributed=False, predict=False):
    # A reload might be needed for imports
    setup_imports()
    configuration.import_user_dir()
    config = configuration.get_config()

    if torch.cuda.is_available():
        torch.cuda.set_device(config.device_id)
        torch.cuda.init()

    if init_distributed:
        distributed_init(config)

    seed = config.training.seed
    config.training.seed = set_seed(seed if seed == -1 else seed + get_rank())
    registry.register("seed", config.training.seed)

    config = build_config(configuration)

    setup_logger(color=config.training.colored_logs,
                 disable=config.training.should_not_log)
    logger = logging.getLogger("multimodelity_cli.run")
    # Log args for debugging purposes
    logger.info(configuration.args)
    logger.info(f"Torch version: {torch.__version__}")
    log_device_names()
    logger.info(f"Using seed {config.training.seed}")

    trainer = build_trainer(config)
    trainer.load()
    if predict:
        trainer.inference()
    else:
        trainer.train()
コード例 #2
0
 def setUp(self):
     test_utils.setup_proxy()
     setup_imports()
     self.model_name = "multimodelity_transformer"
     args = test_utils.dummy_args(model=self.model_name)
     configuration = Configuration(args)
     self.config = configuration.get_config()
     self.config.model_config[self.model_name].model = self.model_name
     self.finetune_model = build_model(
         self.config.model_config[self.model_name])
コード例 #3
0
 def setUp(self):
     test_utils.setup_proxy()
     setup_imports()
     replace_with_jit()
     model_name = "visual_bert"
     args = test_utils.dummy_args(model=model_name)
     configuration = Configuration(args)
     config = configuration.get_config()
     model_config = config.model_config[model_name]
     model_config.model = model_name
     self.pretrain_model = build_model(model_config)
コード例 #4
0
def run(opts: typing.Optional[typing.List[str]] = None, predict: bool = False):
    """Run starts a job based on the command passed from the command line.
    You can optionally run the multimodelity job programmatically by passing an optlist as opts.

    Args:
        opts (typing.Optional[typing.List[str]], optional): Optlist which can be used.
            to override opts programmatically. For e.g. if you pass
            opts = ["training.batch_size=64", "checkpoint.resume=True"], this will
            set the batch size to 64 and resume from the checkpoint if present.
            Defaults to None.
        predict (bool, optional): If predict is passed True, then the program runs in
            prediction mode. Defaults to False.
    """
    setup_imports()

    if opts is None:
        parser = flags.get_parser()
        args = parser.parse_args()
    else:
        args = argparse.Namespace(config_override=None)
        args.opts = opts

    configuration = Configuration(args)
    # Do set runtime args which can be changed by multimodelity
    configuration.args = args
    config = configuration.get_config()
    config.start_rank = 0
    if config.distributed.init_method is None:
        infer_init_method(config)

    if config.distributed.init_method is not None:
        if torch.cuda.device_count() > 1 and not config.distributed.no_spawn:
            config.start_rank = config.distributed.rank
            config.distributed.rank = None
            torch.multiprocessing.spawn(
                fn=distributed_main,
                args=(configuration, predict),
                nprocs=torch.cuda.device_count(),
            )
        else:
            distributed_main(0, configuration, predict)
    elif config.distributed.world_size > 1:
        assert config.distributed.world_size <= torch.cuda.device_count()
        port = random.randint(10000, 20000)
        config.distributed.init_method = f"tcp://localhost:{port}"
        config.distributed.rank = None
        torch.multiprocessing.spawn(
            fn=distributed_main,
            args=(configuration, predict),
            nprocs=config.distributed.world_size,
        )
    else:
        config.device_id = 0
        main(configuration, predict=predict)
コード例 #5
0
ファイル: test_mmbt.py プロジェクト: hahaxun/mmf
 def setUp(self):
     test_utils.setup_proxy()
     setup_imports()
     model_name = "mmbt"
     args = test_utils.dummy_args(model=model_name)
     configuration = Configuration(args)
     config = configuration.get_config()
     model_config = config.model_config[model_name]
     model_config["training_head_type"] = "classification"
     model_config["num_labels"] = 2
     model_config.model = model_name
     self.finetune_model = build_model(model_config)
コード例 #6
0
 def setUp(self):
     setup_imports()
     torch.manual_seed(1234)
     config_path = os.path.join(
         get_multimodelity_root(),
         "..",
         "projects",
         "butd",
         "configs",
         "coco",
         "beam_search.yaml",
     )
     config_path = os.path.abspath(config_path)
     args = dummy_args(model="butd", dataset="coco")
     args.opts.append(f"config={config_path}")
     configuration = Configuration(args)
     configuration.config.datasets = "coco"
     configuration.freeze()
     self.config = configuration.config
     registry.register("config", self.config)
コード例 #7
0
 def setUp(self):
     setup_imports()
     torch.manual_seed(1234)
     config_path = os.path.join(
         get_multimodelity_root(),
         "..",
         "projects",
         "butd",
         "configs",
         "coco",
         "nucleus_sampling.yaml",
     )
     config_path = os.path.abspath(config_path)
     args = dummy_args(model="butd", dataset="coco")
     args.opts.append(f"config={config_path}")
     configuration = Configuration(args)
     configuration.config.datasets = "coco"
     configuration.config.model_config.butd.inference.params.sum_threshold = 0.5
     configuration.freeze()
     self.config = configuration.config
     registry.register("config", self.config)
コード例 #8
0
    def setUp(self):
        test_utils.setup_proxy()
        setup_imports()
        model_name = "vilbert"
        args = test_utils.dummy_args(model=model_name)
        configuration = Configuration(args)
        config = configuration.get_config()
        self.vision_feature_size = 1024
        self.vision_target_size = 1279
        model_config = config.model_config[model_name]
        model_config["training_head_type"] = "pretraining"
        model_config["visual_embedding_dim"] = self.vision_feature_size
        model_config["v_feature_size"] = self.vision_feature_size
        model_config["v_target_size"] = self.vision_target_size
        model_config["dynamic_attention"] = False
        model_config.model = model_name
        self.pretrain_model = build_model(model_config)

        model_config["training_head_type"] = "classification"
        model_config["num_labels"] = 2
        self.finetune_model = build_model(model_config)
コード例 #9
0
ファイル: registry.py プロジェクト: hahaxun/mmf
        if ("writer" in cls.mapping["state"] and value == default
                and no_warning is False):
            cls.mapping["state"]["writer"].warning(
                "Key {} is not present in registry, returning default value "
                "of {}".format(original_name, default))
        return value

    @classmethod
    def unregister(cls, name):
        r"""Remove an item from registry with key 'name'

        Args:
            name: Key which needs to be removed.
        Usage::

            from multimodelity.common.registry import registry

            config = registry.unregister("config")
        """
        return cls.mapping["state"].pop(name, None)


registry = Registry()

# Only setup imports in main process, this means registry won't be
# fully available in spawned child processes (such as dataloader processes)
# but instantiated. This is to prevent issues such as
# https://github.com/facebookresearch/multimodelity/issues/355
if __name__ == "__main__":
    setup_imports()
コード例 #10
0
ファイル: test_configs_for_keys.py プロジェクト: hahaxun/mmf
 def setUp(self):
     setup_imports()