Ejemplo n.º 1
0
 def test_diff_cfg_no_new_allowed(self):
     """check that if new_allowed is False, new keys cause key error"""
     # create base config
     cfg1 = CfgNode()
     cfg1.A = CfgNode()
     cfg1.A.set_new_allowed(False)
     cfg1.A.Y = 2
     # case 2: new allowed not set, new config has new keys
     cfg2 = cfg1.clone()
     cfg2.A.X = 2
     self.assertRaises(KeyError, get_diff_cfg, cfg1, cfg2)
Ejemplo n.º 2
0
 def test_diff_cfg_with_new_allowed(self):
     """diff config with new keys and new_allowed set to True"""
     # create base config
     cfg1 = CfgNode()
     cfg1.A = CfgNode()
     cfg1.A.set_new_allowed(True)
     cfg1.A.Y = 2
     # case 3: new allowed set, new config has new keys
     cfg2 = cfg1.clone()
     cfg2.A.X = 2
     gt = CfgNode()
     gt.A = CfgNode()
     gt.A.X = 2
     self.assertEqual(gt, get_diff_cfg(cfg1, cfg2))
Ejemplo n.º 3
0
 def test_get_diff_cfg(self):
     """check config that is diff from default config, no new keys"""
     # create base config
     cfg1 = CfgNode()
     cfg1.A = CfgNode()
     cfg1.A.Y = 2
     # case 1: new allowed not set, new config has only old keys
     cfg2 = cfg1.clone()
     cfg2.set_new_allowed(False)
     cfg2.A.Y = 3
     gt = CfgNode()
     gt.A = CfgNode()
     gt.A.Y = 3
     self.assertEqual(gt, get_diff_cfg(cfg1, cfg2))
Ejemplo n.º 4
0
def create_cfg_from_cli_args(args, default_cfg):
    """
    Instead of loading from defaults.py, this binary only includes necessary
    configs building from scratch, and overrides them from args. There're two
    levels of config:
        _C: the config system used by this binary, which is a sub-set of training
            config, override by configurable_cfg. It can also be override by
            args.opts for convinience.
        configurable_cfg: common configs that user should explicitly specify
            in the args.
    """

    _C = CN()
    _C.INPUT = default_cfg.INPUT
    _C.DATASETS = default_cfg.DATASETS
    _C.DATALOADER = default_cfg.DATALOADER
    _C.TEST = default_cfg.TEST
    if hasattr(default_cfg, "D2GO_DATA"):
        _C.D2GO_DATA = default_cfg.D2GO_DATA
    if hasattr(default_cfg, "TENSORBOARD"):
        _C.TENSORBOARD = default_cfg.TENSORBOARD

    # NOTE configs below might not be necessary, but must add to make code work
    _C.MODEL = CN()
    _C.MODEL.META_ARCHITECTURE = default_cfg.MODEL.META_ARCHITECTURE
    _C.MODEL.MASK_ON = default_cfg.MODEL.MASK_ON
    _C.MODEL.KEYPOINT_ON = default_cfg.MODEL.KEYPOINT_ON
    _C.MODEL.LOAD_PROPOSALS = default_cfg.MODEL.LOAD_PROPOSALS
    assert _C.MODEL.LOAD_PROPOSALS is False, "caffe2 model doesn't support"

    _C.OUTPUT_DIR = args.output_dir

    configurable_cfg = [
        "DATASETS.TEST",
        args.datasets,
        "INPUT.MIN_SIZE_TEST",
        args.min_size,
        "INPUT.MAX_SIZE_TEST",
        args.max_size,
    ]

    cfg = _C.clone()
    cfg.merge_from_list(configurable_cfg)
    cfg.merge_from_list(args.opts)

    return cfg
Ejemplo n.º 5
0
def do_train(cfg: CfgNode, trainer: pl.Trainer,
             task: GeneralizedRCNNTask) -> Dict[str, str]:
    """Runs the training loop with given trainer and task.

    Args:
        cfg: The normalized ConfigNode for this D2Go Task.
        trainer: PyTorch Lightning trainer.
        task: Lightning module instance.

    Returns:
        A map of model name to trained model config path.
    """
    with EventStorage() as storage:
        task.storage = storage
        trainer.fit(task)
        final_ckpt = os.path.join(cfg.OUTPUT_DIR, FINAL_MODEL_CKPT)
        trainer.save_checkpoint(final_ckpt)  # for validation monitor

        trained_cfg = cfg.clone()
        with temp_defrost(trained_cfg):
            trained_cfg.MODEL.WEIGHTS = final_ckpt
        model_configs = dump_trained_model_configs(
            cfg.OUTPUT_DIR, {"model_final": trained_cfg})
    return model_configs