예제 #1
0
def load_config():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, required=True)
    parser.add_argument('options', default=None, nargs=argparse.REMAINDER)
    args = parser.parse_args()

    config = get_default_config()
    config.merge_from_file(args.config)
    config.merge_from_list(args.options)
    update_config(config)
    config.freeze()
    return config
예제 #2
0
def find_config_diff(
        config: yacs.config.CfgNode) -> Optional[yacs.config.CfgNode]:
    def _find_diff(node: yacs.config.CfgNode,
                   default_node: yacs.config.CfgNode):
        root_node = ConfigNode()
        for key in node:
            val = node[key]
            if isinstance(val, yacs.config.CfgNode):
                new_node = _find_diff(node[key], default_node[key])
                if new_node is not None:
                    root_node[key] = new_node
            else:
                if node[key] != default_node[key]:
                    root_node[key] = node[key]
        return root_node if len(root_node) > 0 else None

    default_config = get_default_config()
    new_config = _find_diff(config, default_config)
    return new_config
예제 #3
0
def load_config():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str)
    parser.add_argument('--resume', type=str, default='')
    parser.add_argument('--local_rank', type=int, default=0)
    parser.add_argument('options', default=None, nargs=argparse.REMAINDER)
    args = parser.parse_args()

    config = get_default_config()
    if args.config is not None:
        config.merge_from_file(args.config)
    config.merge_from_list(args.options)
    if not torch.cuda.is_available():
        config.device = 'cpu'
        config.train.dataloader.pin_memory = False
    if args.resume != '':
        config_path = pathlib.Path(args.resume) / 'config.yaml'
        config.merge_from_file(config_path.as_posix())
        config.merge_from_list(['train.resume', True])
    config.merge_from_list(['train.dist.local_rank', args.local_rank])
    config = update_config(config)
    config.freeze()
    return config