def test_diff_cfg_no_new_allowed(self): """check that if new_allowed is False, new keys cause key error""" # create base config cfg1 = CfgNode() cfg1.A = CfgNode() cfg1.A.set_new_allowed(False) cfg1.A.Y = 2 # case 2: new allowed not set, new config has new keys cfg2 = cfg1.clone() cfg2.A.X = 2 self.assertRaises(KeyError, get_diff_cfg, cfg1, cfg2)
def test_diff_cfg_with_new_allowed(self): """diff config with new keys and new_allowed set to True""" # create base config cfg1 = CfgNode() cfg1.A = CfgNode() cfg1.A.set_new_allowed(True) cfg1.A.Y = 2 # case 3: new allowed set, new config has new keys cfg2 = cfg1.clone() cfg2.A.X = 2 gt = CfgNode() gt.A = CfgNode() gt.A.X = 2 self.assertEqual(gt, get_diff_cfg(cfg1, cfg2))
def test_get_diff_cfg(self): """check config that is diff from default config, no new keys""" # create base config cfg1 = CfgNode() cfg1.A = CfgNode() cfg1.A.Y = 2 # case 1: new allowed not set, new config has only old keys cfg2 = cfg1.clone() cfg2.set_new_allowed(False) cfg2.A.Y = 3 gt = CfgNode() gt.A = CfgNode() gt.A.Y = 3 self.assertEqual(gt, get_diff_cfg(cfg1, cfg2))
def create_cfg_from_cli_args(args, default_cfg): """ Instead of loading from defaults.py, this binary only includes necessary configs building from scratch, and overrides them from args. There're two levels of config: _C: the config system used by this binary, which is a sub-set of training config, override by configurable_cfg. It can also be override by args.opts for convinience. configurable_cfg: common configs that user should explicitly specify in the args. """ _C = CN() _C.INPUT = default_cfg.INPUT _C.DATASETS = default_cfg.DATASETS _C.DATALOADER = default_cfg.DATALOADER _C.TEST = default_cfg.TEST if hasattr(default_cfg, "D2GO_DATA"): _C.D2GO_DATA = default_cfg.D2GO_DATA if hasattr(default_cfg, "TENSORBOARD"): _C.TENSORBOARD = default_cfg.TENSORBOARD # NOTE configs below might not be necessary, but must add to make code work _C.MODEL = CN() _C.MODEL.META_ARCHITECTURE = default_cfg.MODEL.META_ARCHITECTURE _C.MODEL.MASK_ON = default_cfg.MODEL.MASK_ON _C.MODEL.KEYPOINT_ON = default_cfg.MODEL.KEYPOINT_ON _C.MODEL.LOAD_PROPOSALS = default_cfg.MODEL.LOAD_PROPOSALS assert _C.MODEL.LOAD_PROPOSALS is False, "caffe2 model doesn't support" _C.OUTPUT_DIR = args.output_dir configurable_cfg = [ "DATASETS.TEST", args.datasets, "INPUT.MIN_SIZE_TEST", args.min_size, "INPUT.MAX_SIZE_TEST", args.max_size, ] cfg = _C.clone() cfg.merge_from_list(configurable_cfg) cfg.merge_from_list(args.opts) return cfg
def do_train(cfg: CfgNode, trainer: pl.Trainer, task: GeneralizedRCNNTask) -> Dict[str, str]: """Runs the training loop with given trainer and task. Args: cfg: The normalized ConfigNode for this D2Go Task. trainer: PyTorch Lightning trainer. task: Lightning module instance. Returns: A map of model name to trained model config path. """ with EventStorage() as storage: task.storage = storage trainer.fit(task) final_ckpt = os.path.join(cfg.OUTPUT_DIR, FINAL_MODEL_CKPT) trainer.save_checkpoint(final_ckpt) # for validation monitor trained_cfg = cfg.clone() with temp_defrost(trained_cfg): trained_cfg.MODEL.WEIGHTS = final_ckpt model_configs = dump_trained_model_configs( cfg.OUTPUT_DIR, {"model_final": trained_cfg}) return model_configs