def main(): # pylint: disable=import-outside-toplevel,too-many-branches,too-many-statements from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval parser = make_parser() args = parser.parse_args() current_network = import_from_file(args.file) cfg = current_network.Cfg() if args.weight_file: args.start_epoch = args.end_epoch = -1 else: if args.start_epoch == -1: args.start_epoch = cfg.max_epoch - 1 if args.end_epoch == -1: args.end_epoch = args.start_epoch assert 0 <= args.start_epoch <= args.end_epoch < cfg.max_epoch for epoch_num in range(args.start_epoch, args.end_epoch + 1): if args.weight_file: weight_file = args.weight_file else: weight_file = "log-of-{}/epoch_{}.pkl".format( os.path.basename(args.file).split(".")[0], epoch_num) result_list = [] if args.devices > 1: result_queue = Queue(2000) master_ip = "localhost" server = dist.Server() port = server.py_server_port procs = [] for i in range(args.devices): proc = Process( target=worker, args=( current_network, weight_file, args.dataset_dir, result_queue, master_ip, port, args.devices, i, ), ) proc.start() procs.append(proc) num_imgs = dict(coco=5000, objects365=30000) for _ in tqdm(range(num_imgs[cfg.test_dataset["name"]])): result_list.append(result_queue.get()) for p in procs: p.join() else: worker(current_network, weight_file, args.dataset_dir, result_list) all_results = DetEvaluator.format(result_list, cfg) json_path = "log-of-{}/epoch_{}.json".format( os.path.basename(args.file).split(".")[0], epoch_num) all_results = json.dumps(all_results) with open(json_path, "w") as fo: fo.write(all_results) logger.info("Save to %s finished, start evaluation!", json_path) eval_gt = COCO( os.path.join(args.dataset_dir, cfg.test_dataset["name"], cfg.test_dataset["ann_file"])) eval_dt = eval_gt.loadRes(json_path) cocoEval = COCOeval(eval_gt, eval_dt, iouType="bbox") cocoEval.evaluate() cocoEval.accumulate() cocoEval.summarize() metrics = [ "AP", "[email protected]", "[email protected]", "APs", "APm", "APl", "AR@1", "AR@10", "AR@100", "ARs", "ARm", "ARl", ] logger.info("mmAP".center(32, "-")) for i, m in enumerate(metrics): logger.info("|\t%s\t|\t%.03f\t|", m, cocoEval.stats[i]) logger.info("-" * 32)
def main(): # pylint: disable=import-outside-toplevel,too-many-branches,too-many-statements from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval parser = make_parser() args = parser.parse_args() current_network = import_from_file(args.file) cfg = current_network.Cfg() if args.weight_file: args.start_epoch = args.end_epoch = -1 else: if args.start_epoch == -1: args.start_epoch = cfg.max_epoch - 1 if args.end_epoch == -1: args.end_epoch = args.start_epoch assert 0 <= args.start_epoch <= args.end_epoch < cfg.max_epoch for epoch_num in range(args.start_epoch, args.end_epoch + 1): if args.weight_file: weight_file = args.weight_file else: weight_file = "log-of-{}/epoch_{}.pkl".format( os.path.basename(args.file).split(".")[0], epoch_num) if args.devices > 1: dist_worker = dist.launcher(n_gpus=args.devices)(worker) result_list = dist_worker(current_network, weight_file, args.dataset_dir) result_list = sum(result_list, []) else: result_list = worker(current_network, weight_file, args.dataset_dir) all_results = DetEvaluator.format(result_list, cfg) if args.weight_file: json_path = "{}_{}.json".format( os.path.basename(args.file).split(".")[0], os.path.basename(args.weight_file).split(".")[0], ) else: json_path = "log-of-{}/epoch_{}.json".format( os.path.basename(args.file).split(".")[0], epoch_num) all_results = json.dumps(all_results) with open(json_path, "w") as fo: fo.write(all_results) logger.info("Save results to %s, start evaluation!", json_path) eval_gt = COCO( os.path.join(args.dataset_dir, cfg.test_dataset["name"], cfg.test_dataset["ann_file"])) eval_dt = eval_gt.loadRes(json_path) cocoEval = COCOeval(eval_gt, eval_dt, iouType="bbox") cocoEval.evaluate() cocoEval.accumulate() cocoEval.summarize() metrics = [ "AP", "[email protected]", "[email protected]", "APs", "APm", "APl", "AR@1", "AR@10", "AR@100", "ARs", "ARm", "ARl", ] logger.info("mmAP".center(32, "-")) for i, m in enumerate(metrics): logger.info("|\t%s\t|\t%.03f\t|", m, cocoEval.stats[i]) logger.info("-" * 32)