示例#1
0
文件: config.py 项目: zcl912/TraDeS
def config_factory(configuration_name: str) -> Union[DetectionConfig, TrackingConfig]:
    """
    Creates a *Config instance that can be used to initialize a *Eval instance, where * stands for Detection/Tracking.
    Note that this only works if the config file is located in the nuscenes/eval/common/configs folder.
    :param configuration_name: Name of desired configuration in eval_detection_configs.
    :return: *Config instance.
    """
    # Check if config exists.
    tokens = configuration_name.split('_')
    assert len(tokens) > 1, 'Error: Configuration name must be have prefix "detection_" or "tracking_"!'
    task = tokens[0]
    this_dir = os.path.dirname(os.path.abspath(__file__))
    cfg_path = os.path.join(this_dir, '..', task, 'configs', '%s.json' % configuration_name)
    assert os.path.exists(cfg_path), 'Requested unknown configuration {}'.format(configuration_name)

    # Load config file and deserialize it.
    with open(cfg_path, 'r') as f:
        data = json.load(f)
    if task == 'detection':
        cfg = DetectionConfig.deserialize(data)
    elif task == 'tracking':
        cfg = TrackingConfig.deserialize(data)
    else:
        raise Exception('Error: Invalid config file name: %s' % configuration_name)

    return cfg
示例#2
0
                        default=1,
                        help='Whether to render statistic curves to disk.')
    parser.add_argument('--verbose',
                        type=int,
                        default=1,
                        help='Whether to print to stdout.')
    args = parser.parse_args()

    result_path_ = os.path.expanduser(args.result_path)
    output_dir_ = os.path.expanduser(args.output_dir)
    eval_set_ = args.eval_set
    dataroot_ = args.dataroot
    version_ = args.version
    config_path = args.config_path
    render_curves_ = bool(args.render_curves)
    verbose_ = bool(args.verbose)

    if config_path == '':
        cfg_ = config_factory('tracking_nips_2019')
    else:
        with open(config_path, 'r') as _f:
            cfg_ = TrackingConfig.deserialize(json.load(_f))

    nusc_eval = TrackingEval(config=cfg_,
                             result_path=result_path_,
                             eval_set=eval_set_,
                             output_dir=output_dir_,
                             nusc_version=version_,
                             nusc_dataroot=dataroot_,
                             verbose=verbose_)
    nusc_eval.main(render_curves=render_curves_)