コード例 #1
0
 def setUpClass(cls) -> None:
     cls._tmpdir = tempfile.mkdtemp()
     args = argparse.Namespace()
     args.opts = [
         f"env.save_dir={cls._tmpdir}", f"model=cnn_lstm", f"dataset=clevr"
     ]
     args.config_override = None
     configuration = Configuration(args)
     configuration.freeze()
     cls.config = configuration.get_config()
     registry.register("config", cls.config)
     setup_output_folder.cache_clear()
     setup_logger.cache_clear()
     cls.writer = setup_logger()
コード例 #2
0
 def setUp(self):
     setup_imports()
     torch.manual_seed(1234)
     config_path = os.path.join(
         get_mmf_root(),
         "..",
         "projects",
         "butd",
         "configs",
         "coco",
         "beam_search.yaml",
     )
     config_path = os.path.abspath(config_path)
     args = dummy_args(model="butd", dataset="coco")
     args.opts.append(f"config={config_path}")
     configuration = Configuration(args)
     configuration.config.datasets = "coco"
     configuration.freeze()
     self.config = configuration.config
     registry.register("config", self.config)
コード例 #3
0
    def test_dataset_configs_for_keys(self):
        builder_name = registry.mapping["builder_name_mapping"]

        for builder_key, builder_cls in builder_name.items():
            if builder_cls.config_path() is None:
                warnings.warn(
                    ("Dataset {} has no default configuration defined. " +
                     "Skipping it. Make sure it is intentional"
                     ).format(builder_key))
                continue

            with contextlib.redirect_stdout(StringIO()):
                args = dummy_args(dataset=builder_key)
                configuration = Configuration(args)
                configuration.freeze()
                config = configuration.get_config()
                self.assertTrue(
                    builder_key in config.dataset_config,
                    "Key for dataset {} doesn't exists in its configuration".
                    format(builder_key),
                )
コード例 #4
0
    def test_model_configs_for_keys(self):
        models_mapping = registry.mapping["model_name_mapping"]

        for model_key, model_cls in models_mapping.items():
            if model_cls.config_path() is None:
                warnings.warn(
                    ("Model {} has no default configuration defined. " +
                     "Skipping it. Make sure it is intentional"
                     ).format(model_key))
                continue

            with contextlib.redirect_stdout(StringIO()):
                args = dummy_args(model=model_key)
                configuration = Configuration(args)
                configuration.freeze()
                config = configuration.get_config()
                self.assertTrue(
                    model_key in config.model_config,
                    "Key for model {} doesn't exists in its configuration".
                    format(model_key),
                )
コード例 #5
0
 def setUp(self):
     torch.manual_seed(1234)
     registry.register("clevr_text_vocab_size", 80)
     registry.register("clevr_num_final_outputs", 32)
     config_path = os.path.join(
         get_mmf_root(),
         "..",
         "projects",
         "others",
         "cnn_lstm",
         "clevr",
         "defaults.yaml",
     )
     config_path = os.path.abspath(config_path)
     args = dummy_args(model="cnn_lstm", dataset="clevr")
     args.opts.append(f"config={config_path}")
     configuration = Configuration(args)
     configuration.config.datasets = "clevr"
     configuration.freeze()
     self.config = configuration.config
     registry.register("config", self.config)
コード例 #6
0
 def setUp(self):
     setup_imports()
     torch.manual_seed(1234)
     config_path = os.path.join(
         get_mmf_root(),
         "..",
         "projects",
         "butd",
         "configs",
         "coco",
         "nucleus_sampling.yaml",
     )
     config_path = os.path.abspath(config_path)
     args = dummy_args(model="butd", dataset="coco")
     args.opts.append(f"config={config_path}")
     configuration = Configuration(args)
     configuration.config.datasets = "coco"
     configuration.config.model_config.butd.inference.params.sum_threshold = 0.5
     configuration.freeze()
     self.config = configuration.config
     registry.register("config", self.config)
コード例 #7
0
    def test_init_processors(self):
        path = os.path.join(
            os.path.abspath(__file__),
            "../../../mmf/configs/datasets/vqa2/defaults.yaml",
        )
        args = dummy_args()
        args.opts.append(f"config={path}")
        configuration = Configuration(args)
        answer_processor = (
            configuration.get_config().dataset_config.vqa2.processors.answer_processor
        )
        vocab_path = os.path.join(
            os.path.abspath(__file__), "..", "..", "data", "vocab.txt"
        )
        answer_processor.params.vocab_file = os.path.abspath(vocab_path)
        self._fix_configuration(configuration)
        configuration.freeze()

        base_dataset = BaseDataset(
            "vqa2", configuration.get_config().dataset_config.vqa2, "train"
        )
        expected_processors = [
            "answer_processor",
            "ocr_token_processor",
            "bbox_processor",
        ]

        # Check no processors are initialized before init_processors call
        self.assertFalse(any(hasattr(base_dataset, key) for key in expected_processors))

        for processor in expected_processors:
            self.assertIsNone(registry.get("{}_{}".format("vqa2", processor)))

        # Check processors are initialized after init_processors
        base_dataset.init_processors()
        self.assertTrue(all(hasattr(base_dataset, key) for key in expected_processors))
        for processor in expected_processors:
            self.assertIsNotNone(registry.get("{}_{}".format("vqa2", processor)))
コード例 #8
0
 def test_config_overrides(self):
     config_path = os.path.join(
         get_mmf_root(),
         "..",
         "projects",
         "m4c",
         "configs",
         "textvqa",
         "defaults.yaml",
     )
     config_path = os.path.abspath(config_path)
     args = dummy_args(model="m4c", dataset="textvqa")
     args.opts += [
         f"config={config_path}",
         "training.lr_steps[1]=10000",
         'dataset_config.textvqa.zoo_requirements[0]="test"',
     ]
     configuration = Configuration(args)
     configuration.freeze()
     config = configuration.get_config()
     self.assertEqual(config.training.lr_steps[1], 10000)
     self.assertEqual(config.dataset_config.textvqa.zoo_requirements[0],
                      "test")