def main(configuration, init_distributed=False, predict=False): # A reload might be needed for imports setup_imports() configuration.import_user_dir() config = configuration.get_config() if torch.cuda.is_available(): torch.cuda.set_device(config.device_id) torch.cuda.init() if init_distributed: distributed_init(config) seed = config.training.seed config.training.seed = set_seed(seed if seed == -1 else seed + get_rank()) registry.register("seed", config.training.seed) config = build_config(configuration) setup_logger(color=config.training.colored_logs, disable=config.training.should_not_log) logger = logging.getLogger("mmf_cli.run") # Log args for debugging purposes logger.info(configuration.args) logger.info(f"Torch version: {torch.__version__}") log_device_names() logger.info(f"Using seed {config.training.seed}") trainer = build_trainer(config) trainer.load() if predict: trainer.inference() else: trainer.train()
def setUp(self): test_utils.setup_proxy() setup_imports() self.model_name = "mmf_transformer" args = test_utils.dummy_args(model=self.model_name) configuration = Configuration(args) self.config = configuration.get_config() self.config.model_config[self.model_name].model = self.model_name self.finetune_model = build_model(self.config.model_config[self.model_name])
def run(opts: typing.Optional[typing.List[str]] = None, predict: bool = False): """Run starts a job based on the command passed from the command line. You can optionally run the mmf job programmatically by passing an optlist as opts. Args: opts (typing.Optional[typing.List[str]], optional): Optlist which can be used. to override opts programmatically. For e.g. if you pass opts = ["training.batch_size=64", "checkpoint.resume=True"], this will set the batch size to 64 and resume from the checkpoint if present. Defaults to None. predict (bool, optional): If predict is passed True, then the program runs in prediction mode. Defaults to False. """ setup_imports() if opts is None: parser = flags.get_parser() args = parser.parse_args() else: args = argparse.Namespace(config_override=None) args.opts = opts perturbation_arguments.args = args configuration = Configuration(args) # Do set runtime args which can be changed by MMF configuration.args = args config = configuration.get_config() config.start_rank = 0 if config.distributed.init_method is None: infer_init_method(config) if config.distributed.init_method is not None: if torch.cuda.device_count() > 1 and not config.distributed.no_spawn: config.start_rank = config.distributed.rank config.distributed.rank = None torch.multiprocessing.spawn( fn=distributed_main, args=(configuration, predict), nprocs=torch.cuda.device_count(), ) else: distributed_main(0, configuration, predict) elif config.distributed.world_size > 1: assert config.distributed.world_size <= torch.cuda.device_count() port = random.randint(10000, 20000) config.distributed.init_method = f"tcp://localhost:{port}" config.distributed.rank = None torch.multiprocessing.spawn( fn=distributed_main, args=(configuration, predict), nprocs=config.distributed.world_size, ) else: config.device_id = 0 main(configuration, predict=predict)
def setUp(self): test_utils.setup_proxy() setup_imports() replace_with_jit() model_name = "visual_bert" args = test_utils.dummy_args(model=model_name) configuration = Configuration(args) config = configuration.get_config() model_config = config.model_config[model_name] model_config.model = model_name self.pretrain_model = build_model(model_config)
def setUp(self): test_utils.setup_proxy() setup_imports() model_name = "mmbt" args = test_utils.dummy_args(model=model_name) configuration = Configuration(args) config = configuration.get_config() model_config = config.model_config[model_name] model_config["training_head_type"] = "classification" model_config["num_labels"] = 2 model_config.model = model_name self.finetune_model = build_model(model_config)
def setUp(self): setup_imports() torch.manual_seed(1234) config_path = os.path.join( get_mmf_root(), "..", "projects", "butd", "configs", "coco", "beam_search.yaml", ) config_path = os.path.abspath(config_path) args = dummy_args(model="butd", dataset="coco") args.opts.append(f"config={config_path}") configuration = Configuration(args) configuration.config.datasets = "coco" configuration.freeze() self.config = configuration.config registry.register("config", self.config)
def setUp(self): test_utils.setup_proxy() setup_imports() model_name = "vilbert" args = test_utils.dummy_args(model=model_name) configuration = Configuration(args) config = configuration.get_config() self.vision_feature_size = 1024 self.vision_target_size = 1279 model_config = config.model_config[model_name] model_config["training_head_type"] = "pretraining" model_config["visual_embedding_dim"] = self.vision_feature_size model_config["v_feature_size"] = self.vision_feature_size model_config["v_target_size"] = self.vision_target_size model_config["dynamic_attention"] = False model_config.model = model_name self.pretrain_model = build_model(model_config) model_config["training_head_type"] = "classification" model_config["num_labels"] = 2 self.finetune_model = build_model(model_config)
def setUp(self): setup_imports() torch.manual_seed(1234) config_path = os.path.join( get_mmf_root(), "..", "projects", "butd", "configs", "coco", "nucleus_sampling.yaml", ) config_path = os.path.abspath(config_path) args = dummy_args(model="butd", dataset="coco") args.opts.append(f"config={config_path}") configuration = Configuration(args) configuration.config.datasets = "coco" configuration.config.model_config.butd.inference.params.sum_threshold = 0.5 configuration.freeze() self.config = configuration.config registry.register("config", self.config)
): cls.mapping["state"]["writer"].warning( "Key {} is not present in registry, returning default value " "of {}".format(original_name, default) ) return value @classmethod def unregister(cls, name): r"""Remove an item from registry with key 'name' Args: name: Key which needs to be removed. Usage:: from VisualBERT.mmf.common.registry import registry config = registry.unregister("config") """ return cls.mapping["state"].pop(name, None) registry = Registry() # Only setup imports in main process, this means registry won't be # fully available in spawned child processes (such as dataloader processes) # but instantiated. This is to prevent issues such as # https://github.com/facebookresearch/mmf/issues/355 if __name__ == "__main__": setup_imports()
def setUp(self): setup_imports()