コード例 #1
0
 def setUpClass(cls) -> None:
     cls._tmpdir = tempfile.mkdtemp()
     args = argparse.Namespace()
     args.opts = [
         f"env.save_dir={cls._tmpdir}", f"model=cnn_lstm", f"dataset=clevr"
     ]
     args.config_override = None
     configuration = Configuration(args)
     configuration.freeze()
     cls.config = configuration.get_config()
     registry.register("config", cls.config)
     setup_output_folder.cache_clear()
     setup_logger.cache_clear()
     cls.writer = setup_logger()
コード例 #2
0
ファイル: test_logistics.py プロジェクト: xinyuliu828/mmf
    def setUp(self):
        self.tmpdir = tempfile.mkdtemp()
        self.trainer = argparse.Namespace()
        self.config = OmegaConf.create({
            "model": "simple",
            "model_config": {},
            "training": {
                "checkpoint_interval": 1,
                "evaluation_interval": 10,
                "early_stop": {
                    "criteria": "val/total_loss"
                },
                "batch_size": 16,
                "log_interval": 10,
                "logger_level": "info",
            },
            "env": {
                "save_dir": self.tmpdir
            },
        })
        # Keep original copy for testing purposes
        self.trainer.config = deepcopy(self.config)
        registry.register("config", self.trainer.config)
        setup_logger.cache_clear()
        setup_logger()
        self.report = Mock(spec=Report)
        self.report.dataset_name = "abcd"
        self.report.dataset_type = "test"

        self.trainer.model = SimpleModule()
        self.trainer.val_loader = torch.utils.data.DataLoader(
            NumbersDataset(), batch_size=self.config.training.batch_size)

        self.trainer.optimizer = torch.optim.Adam(
            self.trainer.model.parameters(), lr=1e-01)
        self.trainer.device = "cpu"
        self.trainer.num_updates = 0
        self.trainer.current_iteration = 0
        self.trainer.current_epoch = 0
        self.trainer.max_updates = 0
        self.trainer.meter = Meter()
        self.cb = LogisticsCallback(self.config, self.trainer)