def inference(args, config, network): # model_path misc_utils.ensure_dir('outputs') saveDir = os.path.join('../model', args.model_dir, config.model_dir) model_file = os.path.join(saveDir, 'dump-{}.pth'.format(args.resume_weights)) assert os.path.exists(model_file) # build network net = network() net.eval() check_point = torch.load(model_file, map_location=torch.device('cpu')) net.load_state_dict(check_point['state_dict']) # get data image, resized_img, im_info = get_data(args.img_path, config.eval_image_short_size, config.eval_image_max_size) pred_boxes = net(resized_img, im_info).numpy() pred_boxes = post_process(pred_boxes, config, im_info[0, 2]) pred_tags = pred_boxes[:, 5].astype(np.int32).flatten() pred_tags_name = np.array(config.class_names)[pred_tags] # inplace draw image = visual_utils.draw_boxes(image, pred_boxes[:, :4], scores=pred_boxes[:, 4], tags=pred_tags_name, line_thick=1, line_color='white') name = args.img_path.split('/')[-1].split('.')[-2] fpath = 'outputs/{}.png'.format(name) cv2.imwrite(fpath, image)
def inference(config, network): # model_path saveDir = os.path.join('../model', 'rcnn_emd_simple') model_file = os.path.join(saveDir, 'outputs', 'rcnn_emd_simple_mge.pth') assert os.path.exists(model_file) # build network net = network() net.eval() check_point = torch.load(model_file, map_location=torch.device('cpu')) net.load_state_dict(check_point['state_dict']) # get data imglist = os.listdir('../test_imgs') for index, imgname in enumerate(imglist): image, resized_img, im_info = get_data( os.path.join('../test_imgs', imgname), config.eval_image_short_size, config.eval_image_max_size) pred_boxes = net(resized_img, im_info).numpy() pred_boxes = post_process(pred_boxes, config, im_info[0, 2]) persons = visual_utils.draw_boxes( image, pred_boxes[:, :4],) # 保存一张从一张图中检测出来的行人 for person_id, person in enumerate(persons): fpath = 'outputs/{}.png'.format('person' + '_' + str(person_id + 1)) print(fpath) cv2.imwrite(fpath, person)
def eval_all(args): # json file assert os.path.exists(args.json_file), "Wrong json path!" misc_utils.ensure_dir('outputs') records = misc_utils.load_json_lines(args.json_file)[:args.number] for record in records: dtboxes = misc_utils.load_bboxes( record, key_name='dtboxes', key_box='box', key_score='score', key_tag='tag') gtboxes = misc_utils.load_bboxes(record, 'gtboxes', 'box') dtboxes = misc_utils.xywh_to_xyxy(dtboxes) gtboxes = misc_utils.xywh_to_xyxy(gtboxes) keep = dtboxes[:, -2] > args.visual_thresh dtboxes = dtboxes[keep] len_dt = len(dtboxes) len_gt = len(gtboxes) line = "{}: dt:{}, gt:{}.".format(record['ID'], len_dt, len_gt) print(line) img_path = img_root + record['ID'] + '.png' img = misc_utils.load_img(img_path) visual_utils.draw_boxes(img, dtboxes, line_thick=1, line_color='blue') visual_utils.draw_boxes(img, gtboxes, line_thick=1, line_color='white') fpath = 'outputs/{}.png'.format(record['ID']) cv2.imwrite(fpath, img)
def inference(config, network): # model_path saveDir = os.path.join('../model', 'rcnn_emd_simple') model_file = os.path.join(saveDir, 'outputs', 'rcnn_emd_simple_mge.pth') assert os.path.exists(model_file) # build network net = network() net.eval() check_point = torch.load(model_file, map_location=torch.device('cpu')) net.load_state_dict(check_point['state_dict']) # get data video = "../test_video/Crash_walk_f_cm_np1_ri_med_0.avi" cap = cv2.VideoCapture(video) retaining = True # 视频帧计数间隔频率 timeF = 4 # 记录积累的帧数 c = 0 while retaining: retaining, frame = cap.read() frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) if not retaining and frame is None: continue if c % timeF == 0: image, resized_img, im_info = get_numpy_data( frame, config.eval_image_short_size, config.eval_image_max_size) pred_boxes = net(resized_img, im_info).numpy() pred_boxes = post_process(pred_boxes, config, im_info[0, 2]) persons = visual_utils.draw_boxes( image, pred_boxes[:, :4], ) # 保存一张从一张图中检测出来的行人 for person_id, person in enumerate(persons): fpath = 'outputs/{}_{}.png'.format( str(c), 'person' + '_' + str(person_id + 1)) print(fpath) cv2.imwrite(fpath, person) c += 1