def build_config(configuration: Configuration, *args, **kwargs) -> DictConfig: """Builder function for config. Freezes the configuration and registers configuration object and config DictConfig object to registry. Args: configuration (Configuration): Configuration object that will be used to create the config. Returns: (DictConfig): A config which is of type omegaconf.DictConfig """ configuration.freeze() config = configuration.get_config() registry.register("config", config) registry.register("configuration", configuration) return config
def setUp(self): self.tmpdir = tempfile.mkdtemp() self.trainer = argparse.Namespace() self.config = load_yaml(os.path.join("configs", "defaults.yaml")) self.config = OmegaConf.merge( self.config, { "model": "simple", "model_config": {}, "training": { "checkpoint_interval": 1, "evaluation_interval": 10, "early_stop": { "criteria": "val/total_loss" }, "batch_size": 16, "log_interval": 10, "logger_level": "info", }, "env": { "save_dir": self.tmpdir }, }, ) # Keep original copy for testing purposes self.trainer.config = deepcopy(self.config) registry.register("config", self.trainer.config) setup_logger() self.report = Mock(spec=Report) self.report.dataset_name = "abcd" self.report.dataset_type = "test" self.trainer.model = SimpleModule() self.trainer.val_loader = torch.utils.data.DataLoader( NumbersDataset(), batch_size=self.config.training.batch_size) self.trainer.optimizer = torch.optim.Adam( self.trainer.model.parameters(), lr=1e-01) self.trainer.device = "cpu" self.trainer.num_updates = 0 self.trainer.current_iteration = 0 self.trainer.current_epoch = 0 self.trainer.max_updates = 0 self.trainer.meter = Meter() self.cb = LogisticsCallback(self.config, self.trainer)
def __init__(self, config, max_steps, max_epochs=None, callback=None, num_data_size=100, batch_size=1, accumulate_grad_batches=1, lr_scheduler=False, gradient_clip_val=0.0, precision=32, **kwargs): self.config = config self._callbacks = None if callback: self._callbacks = [callback] # data self.data_module = MultiDataModuleNumbersTestObject( num_data=num_data_size, batch_size=batch_size) # settings trainer_config = self.config.trainer.params trainer_config.accumulate_grad_batches = accumulate_grad_batches trainer_config.precision = precision trainer_config.max_steps = max_steps trainer_config.max_epochs = max_epochs trainer_config.gradient_clip_val = gradient_clip_val trainer_config.precision = precision for key, value in kwargs.items(): trainer_config[key] = value self.trainer_config = trainer_config self.training_config = self.config.training self.training_config.batch_size = batch_size self.run_type = self.config.get("run_type", "train") self.config.training.lr_scheduler = lr_scheduler registry.register("config", self.config) self.data_module = MultiDataModuleNumbersTestObject( num_data=num_data_size, batch_size=batch_size) self.train_loader = self.data_module.train_dataloader() self.val_loader = self.data_module.val_dataloader() self.test_loader = self.data_module.test_dataloader()
def configure_device(self) -> None: if self.config.training.get("device", "cuda") == "xla": import torch_xla.core.xla_model as xm self.device = xm.xla_device() self.distributed = True self.local_rank = xm.get_local_ordinal() is_xla = True else: is_xla = False if "device_id" not in self.config: warnings.warn( "No 'device_id' in 'config', setting to -1. " "This can cause issues later in training. Ensure that " "distributed setup is properly initialized.") self.local_rank = -1 else: self.local_rank = self.config.device_id self.device = self.local_rank self.distributed = False # Will be updated later based on distributed setup registry.register("global_device", self.device) if self.config.distributed.init_method is not None: self.distributed = True self.device = torch.device("cuda", self.local_rank) torch.cuda.set_device(self.local_rank) elif torch.cuda.is_available(): self.device = torch.device("cuda") torch.cuda.set_device(0) elif not is_xla: self.device = torch.device("cpu") if "rank" not in self.config.distributed: if torch.distributed.is_available( ) and torch.distributed.is_initialized(): global_rank = torch.distributed.get_rank() else: global_rank = -1 with open_dict(self.config.distributed): self.config.distributed.rank = global_rank registry.register("global_device", self.config.distributed.rank)
def configure_device(self) -> None: self.local_rank = self.config.device_id self.device = self.local_rank self.distributed = False # Will be updated later based on distributed setup registry.register("global_device", self.device) if self.config.distributed.init_method is not None: self.distributed = True self.device = torch.device("cuda", self.local_rank) torch.cuda.set_device(self.local_rank) elif torch.cuda.is_available(): self.device = torch.device("cuda") torch.cuda.set_device(0) else: self.device = torch.device("cpu") registry.register("global_device", self.config.distributed.rank)
def setUp(self): self.trainer = argparse.Namespace() self.config = load_yaml(os.path.join("configs", "defaults.yaml")) self.config = OmegaConf.merge( self.config, { "model": "simple", "model_config": {}, "training": { "lr_scheduler": True, "lr_ratio": 0.1, "lr_steps": [1, 2], "use_warmup": False, "callbacks": [{ "type": "test_callback", "params": {} }], }, }, ) # Keep original copy for testing purposes self.trainer.config = deepcopy(self.config) registry.register("config", self.trainer.config) model = SimpleModel(SimpleModel.Config()) model.build() self.trainer.model = model self.trainer.val_loader = torch.utils.data.DataLoader( NumbersDataset(2), batch_size=self.config.training.batch_size) self.trainer.optimizer = torch.optim.Adam( self.trainer.model.parameters(), lr=1e-01) self.trainer.lr_scheduler_callback = LRSchedulerCallback( self.config, self.trainer) self.trainer.callbacks = [] for callback in self.config.training.get("callbacks", []): callback_type = callback.type callback_param = callback.params callback_cls = registry.get_callback_class(callback_type) self.trainer.callbacks.append( callback_cls(self.trainer.config, self.trainer, **callback_param))
def parallelize_model(self): training = self.config.training if ( "cuda" in str(self.device) and torch.cuda.device_count() > 1 and not self.distributed ): registry.register("data_parallel", True) self.model = torch.nn.DataParallel(self.model) if "cuda" in str(self.device) and self.distributed: registry.register("distributed", True) self.model = torch.nn.parallel.DistributedDataParallel( self.model, device_ids=[self.local_rank], output_device=self.local_rank, check_reduction=True, find_unused_parameters=training.find_unused_parameters, )
def setUp(self): setup_imports() torch.manual_seed(1234) config_path = os.path.join( get_mmf_root(), "..", "projects", "butd", "configs", "coco", "beam_search.yaml", ) config_path = os.path.abspath(config_path) args = dummy_args(model="butd", dataset="coco") args.opts.append(f"config={config_path}") configuration = Configuration(args) configuration.config.datasets = "coco" configuration.freeze() self.config = configuration.config registry.register("config", self.config)
def _init_processors(self): args = Namespace() args.opts = [ "config=projects/pythia/configs/vqa2/defaults.yaml", "datasets=vqa2", "model=visual_bert", "evaluation.predict=True" ] args.config_override = None configuration = Configuration(args=args) config = self.config = configuration.config vqa2_config = config.dataset_config.vqa2 text_processor_config = vqa2_config.processors.text_processor text_processor_config.params.vocab.vocab_file = "../model_data/vocabulary_100k.txt" # Add preprocessor as that will needed when we are getting questions from user self.text_processor = VocabProcessor(text_processor_config.params) registry.register("coco_text_processor", self.text_processor)
def setUp(self): setup_imports() torch.manual_seed(1234) config_path = os.path.join( get_mmf_root(), "..", "projects", "butd", "configs", "coco", "nucleus_sampling.yaml", ) config_path = os.path.abspath(config_path) args = dummy_args(model="butd", dataset="coco") args.opts.append(f"config={config_path}") configuration = Configuration(args) configuration.config.datasets = "coco" configuration.config.model_config.butd.inference.params.sum_threshold = 0.5 configuration.freeze() self.config = configuration.config registry.register("config", self.config)
def ready_trainer(trainer): from mmf.common.registry import registry from mmf.utils.logger import Logger, TensorboardLogger trainer.run_type = trainer.config.get("run_type", "train") writer = registry.get("writer", no_warning=True) if writer: trainer.writer = writer else: trainer.writer = Logger(trainer.config) registry.register("writer", trainer.writer) trainer.configure_device() trainer.configure_seed() trainer.load_model() from mmf.trainers.callbacks.checkpoint import CheckpointCallback from mmf.trainers.callbacks.early_stopping import EarlyStoppingCallback trainer.checkpoint_callback = CheckpointCallback(trainer.config, trainer) trainer.early_stop_callback = EarlyStoppingCallback( trainer.config, trainer) trainer.callbacks.append(trainer.checkpoint_callback) trainer.on_init_start()
def load_model_and_optimizer(self): attributes = self.config.model_config[self.config.model] # Easy way to point to config for other model if isinstance(attributes, str): attributes = self.config.model_config[attributes] with omegaconf.open_dict(attributes): attributes.model = self.config.model self.model = build_model(attributes) if "cuda" in str(self.device): device_info = "CUDA Device {} is: {}".format( self.config.distributed.rank, torch.cuda.get_device_name(self.local_rank), ) registry.register("global_device", self.config.distributed.rank) self.writer.write(device_info, log_all=True) self.model = self.model.to(self.device) self.optimizer = build_optimizer(self.model, self.config) registry.register("data_parallel", False) registry.register("distributed", False) self.load_extras() self.parallelize_model()
def configure_device(self) -> None: self.local_rank = self.config.device_id self.device = self.local_rank self.distributed = False # Will be updated later based on distributed setup registry.register("global_device", self.device) if self.config.distributed.init_method is not None: self.distributed = True self.device = torch.device("cuda", self.local_rank) elif torch.cuda.is_available(): self.device = torch.device("cuda") else: self.device = torch.device("cpu") registry.register("current_device", self.device) if "cuda" in str(self.device): device_info = "CUDA Device {} is: {}".format( self.config.distributed.rank, torch.cuda.get_device_name(self.local_rank), ) registry.register("global_device", self.config.distributed.rank) self.writer.write(device_info, log_all=True)
def _load_counts(self, ckpt): ckpt_config = self.trainer.config.checkpoint if "best_update" in ckpt: if ckpt_config.resume_best: self.trainer.num_updates = ckpt.get("best_update", self.trainer.num_updates) self.trainer.current_iteration = ckpt.get( "best_iteration", self.trainer.current_iteration) else: self.trainer.num_updates = ckpt.get("num_updates", self.trainer.num_updates) self.trainer.current_iteration = ckpt.get( "current_iteration", self.trainer.current_iteration) self.trainer.current_epoch = ckpt.get("current_epoch", self.trainer.current_epoch) elif "best_iteration" in ckpt: # Preserve old behavior for old checkpoints where we always # load best iteration if ckpt_config.resume_best and "current_iteration" in ckpt: self.trainer.current_iteration = ckpt["current_iteration"] else: self.trainer.current_iteration = ckpt.get( "best_iteration", self.trainer.current_iteration) self.trainer.num_updates = self.trainer.current_iteration registry.register("current_iteration", self.trainer.current_iteration) registry.register("num_updates", self.trainer.num_updates) self.trainer.current_epoch = ckpt.get("best_epoch", self.trainer.current_epoch) registry.register("current_epoch", self.trainer.current_epoch)
def _init_processors(self): args = Namespace() args.opts = [ "config=projects/visual_bert/configs/vqa2/defaults.yaml", "datasets=vqa2", "model=visual_bert", "evaluation.predict=True" ] args.config_override = None configuration = Configuration(args=args) config = self.config = configuration.config vqa_config = config.dataset_config.vqa2 text_processor_config = vqa_config.processors.text_processor answer_processor_config = vqa_config.processors.answer_processor text_processor_config.params.vocab.vocab_file = self.root + "/content/model_data/vocabulary_100k.txt" answer_processor_config.params.vocab_file = self.root + "/content/model_data/answers_vqa.txt" # Add preprocessor as that will needed when we are getting questions from user self.text_processor = BertTokenizer(text_processor_config.params) self.answer_processor = VQAAnswerProcessor(answer_processor_config.params) registry.register("vqa2_text_processor", self.text_processor) registry.register("vqa2_answer_processor", self.answer_processor) registry.register("vqa2_num_final_outputs", self.answer_processor.get_vocab_size())
def load(self): self._set_device() self.run_type = self.config.get("run_type", "train") self.dataset_loader = DatasetLoader(self.config) self._datasets = self.config.datasets # Check if loader is already defined, else init it writer = registry.get("writer", no_warning=True) if writer: self.writer = writer else: self.writer = Logger(self.config) registry.register("writer", self.writer) self.configuration.pretty_print() self.config_based_setup() self.load_datasets() self.load_model_and_optimizer() self.load_metrics()
def main(configuration, init_distributed=False): # A reload might be needed for imports setup_imports() configuration.import_user_dir() config = configuration.get_config() if torch.cuda.is_available(): torch.cuda.set_device(config.device_id) torch.cuda.init() if init_distributed: distributed_init(config) config.training.seed = set_seed(config.training.seed) registry.register("seed", config.training.seed) print("Using seed {}".format(config.training.seed)) registry.register("writer", Logger(config, name="mmf.train")) trainer = build_trainer(configuration) trainer.load() trainer.train()
def setup_imports(): from mmf.common.registry import registry # First, check if imports are already setup has_already_setup = registry.get("imports_setup", no_warning=True) if has_already_setup: return # Automatically load all of the modules, so that # they register with registry root_folder = registry.get("mmf_root", no_warning=True) if root_folder is None: root_folder = os.path.dirname(os.path.abspath(__file__)) root_folder = os.path.join(root_folder, "..") environment_mmf_path = os.environ.get("MMF_PATH", os.environ.get("PYTHIA_PATH")) if environment_mmf_path is not None: root_folder = environment_mmf_path registry.register("pythia_path", root_folder) registry.register("mmf_path", root_folder) trainer_folder = os.path.join(root_folder, "trainers") trainer_pattern = os.path.join(trainer_folder, "**", "*.py") datasets_folder = os.path.join(root_folder, "datasets") datasets_pattern = os.path.join(datasets_folder, "**", "*.py") model_folder = os.path.join(root_folder, "models") common_folder = os.path.join(root_folder, "common") modules_folder = os.path.join(root_folder, "modules") print(modules_folder) model_pattern = os.path.join(model_folder, "**", "*.py") common_pattern = os.path.join(common_folder, "**", "*.py") modules_pattern = os.path.join(modules_folder, "**", "*.py") importlib.import_module("mmf.common.meter") files = (glob.glob(datasets_pattern, recursive=True) + glob.glob(model_pattern, recursive=True) + glob.glob(trainer_pattern, recursive=True) + glob.glob(common_pattern, recursive=True) + glob.glob(modules_pattern, recursive=True)) for f in files: f = os.path.realpath(f) if f.endswith(".py") and not f.endswith("__init__.py"): splits = f.split(os.sep) import_prefix_index = 0 for idx, split in enumerate(splits): if split == "mmf": import_prefix_index = idx + 1 file_name = splits[-1] module_name = file_name[:file_name.find(".py")] module = ".".join(["mmf"] + splits[import_prefix_index:-1] + [module_name]) importlib.import_module(module) registry.register("imports_setup", True)
def test_caption_bleu4(self): path = os.path.join( os.path.abspath(__file__), "../../../mmf/configs/datasets/coco/defaults.yaml", ) config = load_yaml(os.path.abspath(path)) captioning_config = config.dataset_config.coco caption_processor_config = captioning_config.processors.caption_processor vocab_path = os.path.join(os.path.abspath(__file__), "..", "..", "data", "vocab.txt") caption_processor_config.params.vocab.type = "random" caption_processor_config.params.vocab.vocab_file = os.path.abspath( vocab_path) caption_processor = CaptionProcessor(caption_processor_config.params) registry.register("coco_caption_processor", caption_processor) caption_bleu4 = metrics.CaptionBleu4Metric() expected = Sample() predicted = dict() # Test complete match expected.answers = torch.empty((5, 5, 10)) expected.answers.fill_(4) predicted["scores"] = torch.zeros((5, 10, 19)) predicted["scores"][:, :, 4] = 1.0 self.assertEqual( caption_bleu4.calculate(expected, predicted).item(), 1.0) # Test partial match expected.answers = torch.empty((5, 5, 10)) expected.answers.fill_(4) predicted["scores"] = torch.zeros((5, 10, 19)) predicted["scores"][:, 0:5, 4] = 1.0 predicted["scores"][:, 5:, 18] = 1.0 self.assertAlmostEqual( caption_bleu4.calculate(expected, predicted).item(), 0.3928, 4)
def setUp(self): self.trainer = argparse.Namespace() self.config = OmegaConf.create({ "model": "simple", "model_config": {}, "training": { "lr_scheduler": True, "lr_ratio": 0.1, "lr_steps": [1, 2], "use_warmup": False, }, }) # Keep original copy for testing purposes self.trainer.config = deepcopy(self.config) registry.register("config", self.trainer.config) self.trainer.model = SimpleModule() self.trainer.val_dataset = NumbersDataset() self.trainer.optimizer = torch.optim.Adam( self.trainer.model.parameters(), lr=1e-01) self.trainer.lr_scheduler_callback = LRSchedulerCallback( self.config, self.trainer)
def update_from_report(self, report, should_update_loss=True): """ this method updates the provided meter with report info. this method by default handles reducing metrics. Args: report (Report): report object which content is used to populate the current meter Usage:: >>> meter = Meter() >>> report = Report(prepared_batch, model_output) >>> meter.update_from_report(report) """ if hasattr(report, "metrics"): metrics_dict = report.metrics reduced_metrics_dict = reduce_dict(metrics_dict) if should_update_loss: loss_dict = report.losses reduced_loss_dict = reduce_dict(loss_dict) with torch.no_grad(): meter_update_dict = {} if should_update_loss: meter_update_dict = scalarize_dict_values(reduced_loss_dict) total_loss_key = report.dataset_type + "/total_loss" total_loss = sum(meter_update_dict.values()) registry.register(total_loss_key, total_loss) meter_update_dict.update({total_loss_key: total_loss}) if hasattr(report, "metrics"): metrics_dict = scalarize_dict_values(reduced_metrics_dict) meter_update_dict.update(**metrics_dict) self._update(meter_update_dict, report.batch_size)
def update_registry_for_model(self, config): if hasattr(self.dataset, "text_processor"): registry.register( self.dataset_name + "_text_vocab_size", self.dataset.text_processor.get_vocab_size(), ) registry.register( f"{self.dataset_name}_text_processor", self.dataset.text_processor ) if hasattr(self.dataset, "answer_processor"): registry.register( self.dataset_name + "_num_final_outputs", self.dataset.answer_processor.get_vocab_size(), ) registry.register( f"{self.dataset_name}_answer_processor", self.dataset.answer_processor )
def _load_counts_and_lr_scheduler(self, ckpt): ckpt_config = self.trainer.config.checkpoint if "best_update" in ckpt: if ckpt_config.resume_best: self.trainer.num_updates = ckpt.get( "best_update", self.trainer.num_updates ) self.trainer.current_iteration = ckpt.get( "best_iteration", self.trainer.current_iteration ) else: self.trainer.num_updates = ckpt.get( "num_updates", self.trainer.num_updates ) self.trainer.current_iteration = ckpt.get( "current_iteration", self.trainer.current_iteration ) self.trainer.current_epoch = ckpt.get( "current_epoch", self.trainer.current_epoch ) elif "best_iteration" in ckpt: # Preserve old behavior for old checkpoints where we always # load best iteration if ckpt_config.resume_best and "current_iteration" in ckpt: self.trainer.current_iteration = ckpt["current_iteration"] else: self.trainer.current_iteration = ckpt.get( "best_iteration", self.trainer.current_iteration ) self.trainer.num_updates = self.trainer.current_iteration lr_scheduler = self.trainer.lr_scheduler_callback if ( lr_scheduler is not None and getattr(lr_scheduler, "_scheduler", None) is not None ): lr_scheduler = lr_scheduler._scheduler if "lr_scheduler" in ckpt: lr_scheduler.load_state_dict(ckpt["lr_scheduler"]) else: warnings.warn( "'lr_scheduler' key is not present in the " "checkpoint asked to be loaded. Setting lr_scheduler's " "last_epoch to current_iteration." ) lr_scheduler.last_epoch = self.trainer.current_iteration registry.register("current_iteration", self.trainer.current_iteration) registry.register("num_updates", self.trainer.num_updates) self.trainer.current_epoch = ckpt.get("best_epoch", self.trainer.current_epoch) registry.register("current_epoch", self.trainer.current_epoch)
def __init__(self, config, num_data_size=100, **kwargs): super().__init__(config) self.config = config self.callbacks = [] # settings trainer_config = self.config.trainer.params self.trainer_config = trainer_config self.training_config = self.config.training for key, value in kwargs.items(): trainer_config[key] = value # data self.data_module = MultiDataModuleNumbersTestObject( config=config, num_data=num_data_size) self.run_type = self.config.get("run_type", "train") registry.register("config", self.config) self.train_loader = self.data_module.train_dataloader() self.val_loader = self.data_module.val_dataloader() self.test_loader = self.data_module.test_dataloader()
def __init__(self, args=None, default_only=False): self.config = {} if not args: import argparse args = argparse.Namespace(opts=[]) default_only = True self.args = args self._register_resolvers() self._default_config = self._build_default_config() if default_only: other_configs = {} else: other_configs = self._build_other_configs() self.config = OmegaConf.merge(self._default_config, other_configs) self.config = self._merge_with_dotlist(self.config, args.opts) self._update_specific(self.config) registry.register("config", self.config)
def test_batch_size_per_device(self, a): # Need to patch the mmf.utils.general's world size not mmf.utils.distributed # as the first one is what will be used with patch("mmf.utils.general.get_world_size", return_value=2): config = self._get_config(max_updates=2, max_epochs=None, batch_size=4) trainer = TrainerTrainingLoopMock(config=config) add_model(trainer, SimpleModel({"in_dim": 1})) add_optimizer(trainer, config) registry.register("config", trainer.config) batch_size = get_batch_size() trainer.config.training.batch_size = batch_size trainer.load_datasets() # Train loader has batch size per device, for global batch size 4 # with world size 2, batch size per device should 4 // 2 = 2 self.assertEqual(trainer.train_loader.current_loader.batch_size, 2) # This is per device, so should stay same config = self._get_config(max_updates=2, max_epochs=None, batch_size_per_device=4) trainer = TrainerTrainingLoopMock(config=config) add_model(trainer, SimpleModel({"in_dim": 1})) add_optimizer(trainer, config) registry.register("config", trainer.config) batch_size = get_batch_size() trainer.config.training.batch_size = batch_size trainer.load_datasets() self.assertEqual(trainer.train_loader.current_loader.batch_size, 4) max_updates = trainer._calculate_max_updates() self.assertEqual(max_updates, 2) self.check_values(trainer, 0, 0, 0) trainer.training_loop() self.check_values(trainer, 2, 1, 2)
def parallelize_model(self) -> None: registry.register("data_parallel", False) registry.register("distributed", False) if ("cuda" in str(self.device) and torch.cuda.device_count() > 1 and not self.distributed): registry.register("data_parallel", True) self.model = torch.nn.DataParallel(self.model) if "cuda" in str(self.device) and self.distributed: registry.register("distributed", True) set_torch_ddp = True try: from fairscale.nn.data_parallel import ShardedDataParallel from fairscale.optim.oss import OSS if isinstance(self.optimizer, OSS): self.model = ShardedDataParallel(self.model, self.optimizer) set_torch_ddp = False logger.info("Using FairScale ShardedDataParallel") except ImportError: logger.info("Using PyTorch DistributedDataParallel") warnings.warn( "You can enable ZeRO and Sharded DDP, by installing fairscale " + "and setting optimizer.enable_state_sharding=True.") if set_torch_ddp: self.model = torch.nn.parallel.DistributedDataParallel( self.model, device_ids=[self.local_rank], output_device=self.local_rank, find_unused_parameters=self.config.training. find_unused_parameters, ) if is_xla() and get_world_size() > 1: broadcast_xla_master_model_param(self.model)
def update_registry_for_pretrained(cls, config, checkpoint, full_output): from omegaconf import OmegaConf # Hack datasets using OmegaConf datasets = full_output["full_config"].datasets dataset = datasets.split(",")[0] config_mock = OmegaConf.create({"datasets": datasets}) registry.register("config", config_mock) registry.register( f"{dataset}_num_final_outputs", # Need to add as it is subtracted checkpoint["classifier.module.weight"].size(0) + config.classifier.ocr_max_num, ) # Fix this later, when processor pipeline is available answer_processor = OmegaConf.create({"BOS_IDX": 1}) registry.register(f"{dataset}_answer_processor", answer_processor)
def __init__(self, args=None, default_only=False, load_dataset=True): self.config = {} if not args: import argparse args = argparse.Namespace(opts=[]) default_only = True self.args = args self._register_resolvers() self._default_config = self._build_default_config() # Initially, silently add opts so that some of the overrides for the defaults # from command line required for setup can be honored self._default_config = _merge_with_dotlist(self._default_config, args.opts, skip_missing=True, log_info=False) # Register the config and configuration for setup registry.register("config", self._default_config) registry.register("configuration", self) if default_only: other_configs = {} else: other_configs = self._build_other_configs( load_dataset=load_dataset) self.config = OmegaConf.merge(self._default_config, other_configs) self.config = _merge_with_dotlist(self.config, args.opts) self._update_specific(self.config) self.upgrade(self.config) # Resolve the config here itself after full creation so that spawned workers # don't face any issues self.config = OmegaConf.create( OmegaConf.to_container(self.config, resolve=True)) # Update the registry with final config registry.register("config", self.config)
def setUp(self): torch.manual_seed(1234) registry.register("clevr_text_vocab_size", 80) registry.register("clevr_num_final_outputs", 32) config_path = os.path.join( get_mmf_root(), "..", "projects", "others", "cnn_lstm", "clevr", "defaults.yaml", ) config_path = os.path.abspath(config_path) args = dummy_args(model="cnn_lstm", dataset="clevr") args.opts.append(f"config={config_path}") configuration = Configuration(args) configuration.config.datasets = "clevr" configuration.freeze() self.config = configuration.config registry.register("config", self.config)