def _dump_log(self, log_dict, trainer): json_log = OrderedDict() for k, v in log_dict.items(): json_log[k] = self._round_float(v) if trainer.rank == 0: with open(self.json_log_path, "a+") as f: torchie.dump(json_log, f, file_format="json") f.write("\n")
def main(): args = parse_args() assert args.out or args.show or args.json_out, ( "Please specify at least one operation (save or show the results) " 'with the argument "--out" or "--show" or "--json_out"' ) if args.out is not None and not args.out.endswith((".pkl", ".pickle")): raise ValueError("The output file must be a pkl file.") if args.json_out is not None and args.json_out.endswith(".json"): args.json_out = args.json_out[:-5] cfg = torchie.Config.fromfile(args.config) # set cudnn_benchmark if cfg.get("cudnn_benchmark", False): torch.backends.cudnn.benchmark = True # cfg.model.pretrained = None cfg.data.test.test_mode = True # cfg.data.val.test_mode = True # init distributed env first, since logger depends on the dist info. if args.launcher == "none": distributed = False else: distributed = True init_dist(args.launcher, **cfg.dist_params) # build the dataloader # TODO: support multiple images per gpu (only minor changes are needed) dataset = build_dataset(cfg.data.test) # dataset = build_dataset(cfg.data.val) data_loader = build_dataloader( dataset, batch_size=cfg.data.samples_per_gpu, workers_per_gpu=cfg.data.workers_per_gpu, dist=distributed, shuffle=False, ) # build the model and load checkpoint model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) checkpoint = load_checkpoint(model, args.checkpoint, map_location="cpu") # old versions did not save class info in checkpoints, this walkaround is # for backward compatibility if "CLASSES" in checkpoint["meta"]: model.CLASSES = checkpoint["meta"]["CLASSES"] else: model.CLASSES = dataset.CLASSES model = MegDataParallel(model, device_ids=[0]) result_dict, detections = test( data_loader, model, save_dir=None, distributed=distributed ) for k, v in result_dict["results"].items(): print(f"Evaluation {k}: {v}") rank, _ = get_dist_info() if args.out and rank == 0: print("\nwriting results to {}".format(args.out)) torchie.dump(detections, args.out) if args.txt_result: res_dir = os.path.join(os.getcwd(), "predictions") for dt in detections: with open( os.path.join(res_dir, "%06d.txt" % int(dt["metadata"]["token"])), "w" ) as fout: lines = kitti.annos_to_kitti_label(dt) for line in lines: fout.write(line + "\n") ap_result_str, ap_dict = kitti_evaluate( "/data/Datasets/KITTI/Kitti/object/training/label_2", res_dir, label_split_file="/data/Datasets/KITTI/Kitti/ImageSets/val.txt", current_class=0, ) print(ap_result_str)