def setUpClass(cls) -> None: cls._tmpdir = tempfile.mkdtemp() args = argparse.Namespace() args.opts = [ f"env.save_dir={cls._tmpdir}", f"model=cnn_lstm", f"dataset=clevr" ] args.config_override = None configuration = Configuration(args) configuration.freeze() cls.config = configuration.get_config() registry.register("config", cls.config) setup_output_folder.cache_clear() setup_logger.cache_clear() cls.writer = setup_logger()
def setUp(self): self.tmpdir = tempfile.mkdtemp() self.trainer = argparse.Namespace() self.config = OmegaConf.create({ "model": "simple", "model_config": {}, "training": { "checkpoint_interval": 1, "evaluation_interval": 10, "early_stop": { "criteria": "val/total_loss" }, "batch_size": 16, "log_interval": 10, "logger_level": "info", }, "env": { "save_dir": self.tmpdir }, }) # Keep original copy for testing purposes self.trainer.config = deepcopy(self.config) registry.register("config", self.trainer.config) setup_logger.cache_clear() setup_logger() self.report = Mock(spec=Report) self.report.dataset_name = "abcd" self.report.dataset_type = "test" self.trainer.model = SimpleModule() self.trainer.val_loader = torch.utils.data.DataLoader( NumbersDataset(), batch_size=self.config.training.batch_size) self.trainer.optimizer = torch.optim.Adam( self.trainer.model.parameters(), lr=1e-01) self.trainer.device = "cpu" self.trainer.num_updates = 0 self.trainer.current_iteration = 0 self.trainer.current_epoch = 0 self.trainer.max_updates = 0 self.trainer.meter = Meter() self.cb = LogisticsCallback(self.config, self.trainer)