示例#1
0
def main():
    args = parse_args()

    cfg = mmcv.Config.fromfile(args.config)
    cfg.model.pretrained = None

    dataset = utils.get_dataset(cfg.data.val)
    class_names = cfg.data.val.class_names

    if args.gpus == 1:
        model = build_detector(
            cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
        #load_checkpoint(model, args.checkpoint)
        model = MMDataParallel(model, device_ids=[0])
        load_params_from_file(model, args.checkpoint)
        data_loader = build_dataloader(
            dataset,
            1,
            cfg.data.workers_per_gpu,
            num_gpus=1,
            shuffle=False,
            dist=False)
        outputs = single_test(model, data_loader, args.out, class_names)
    else:
        NotImplementedError
    # kitti evaluation
    gt_annos = kitti.get_label_annos(dataset.label_prefix, dataset.sample_ids)
    result = get_official_eval_result(gt_annos, outputs, current_classes=class_names)
    print(result)
示例#2
0
    def evaluate(self, runner, results):

        gt_annos = kitti.get_label_annos(self.dataset.label_prefix, self.dataset.sample_ids)

        result = get_official_eval_result(gt_annos, results, current_classes=0)

        runner.logger.info(result)

        runner.log_buffer.ready = True
示例#3
0
文件: test.py 项目: zxduan90/SA-SSD
def main():
    args = parse_args()

    # if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
    #     raise ValueError('The output file must be a pkl file.')

    cfg = mmcv.Config.fromfile(args.config)
    cfg.model.pretrained = None

    dataset = utils.get_dataset(cfg.data.val)
    class_names = cfg.data.val.class_names
    if args.gpus == 1:
        model = build_detector(cfg.model,
                               train_cfg=None,
                               test_cfg=cfg.test_cfg)
        load_checkpoint(model, args.checkpoint)
        model = MMDataParallel(model, device_ids=[0])

        data_loader = build_dataloader(
            dataset,
            1,
            cfg.data.workers_per_gpu,
            num_gpus=1,
            #collate_fn= cfg.data.collate_fn,
            shuffle=False,
            dist=False)
        outputs = single_test(model, data_loader, args.out, class_names)
    else:
        NotImplementedError
    # kitti evaluation
    gt_annos = kitti.get_label_annos(dataset.label_prefix, dataset.sample_ids)
    result = get_official_eval_result(gt_annos,
                                      outputs,
                                      current_classes=class_names)
    print(result)
    if args.out:
        print('writing results to {}'.format(args.out))
        mmcv.dump(outputs, args.out)
        eval_types = args.eval
        if eval_types:
            print('Starting evaluate {}'.format(' and '.join(eval_types)))
            if eval_types == ['proposal_fast']:
                result_file = args.out
                coco_eval(result_file, eval_types, dataset.coco)
            else:
                if not isinstance(outputs[0], dict):
                    result_file = args.out + '.json'
                    results2json(dataset, outputs, result_file)
                    coco_eval(result_file, eval_types, dataset.coco)
                else:
                    for name in outputs[0]:
                        print('\nEvaluating {}'.format(name))
                        outputs_ = [out[name] for out in outputs]
                        result_file = args.out + '.{}.json'.format(name)
                        results2json(dataset, outputs_, result_file)
                        coco_eval(result_file, eval_types, dataset.coco)