예제 #1
0
def inference(args):
    @jit.trace(symbolic=False)
    def val_func():
        pred_boxes = net(net.inputs)
        return pred_boxes

    # model path
    model_file = args.resume_weights
    assert os.path.exists(model_file)
    # load model
    net = network.Network()
    net.eval()
    check_point = mge.load(model_file)
    net.load_state_dict(check_point['state_dict'])
    ori_image, image, im_info = get_data(args.img_path)
    net.inputs["image"].set_value(image.astype(np.float32))
    net.inputs["im_info"].set_value(im_info)
    pred_boxes = val_func().numpy()
    num_tag = config.num_classes - 1
    target_shape = (pred_boxes.shape[0] // num_tag // top_k, top_k)
    pred_tags = (np.arange(num_tag) + 1).reshape(-1, 1)
    pred_tags = np.tile(pred_tags, target_shape).reshape(-1, 1)
    # nms
    if if_set_nms:
        from set_nms_utils import set_cpu_nms
        n = pred_boxes.shape[0] // top_k
        idents = np.tile(np.arange(n)[:, None], (1, top_k)).reshape(-1, 1)
        pred_boxes = np.hstack((pred_boxes, idents))
        keep = pred_boxes[:, -2] > args.thresh
        pred_boxes = pred_boxes[keep]
        pred_tags = pred_tags[keep]
        keep = set_cpu_nms(pred_boxes, 0.5)
        pred_boxes = pred_boxes[keep][:, :-1]
        pred_tags = pred_tags[keep]
    else:
        from set_nms_utils import cpu_nms
        keep = pred_boxes[:, -1] > args.thresh
        pred_boxes = pred_boxes[keep]
        pred_tags = pred_tags[keep]
        keep = cpu_nms(pred_boxes, 0.5)
        pred_boxes = pred_boxes[keep]
        pred_tags = pred_tags[keep]
    pred_tags = pred_tags.astype(np.int32).flatten()
    pred_tags_name = np.array(config.class_names)[pred_tags]
    visual_utils.draw_boxes(ori_image, pred_boxes[:, :-1], pred_boxes[:, -1],
                            pred_tags_name)
    name = args.img_path.split('/')[-1].split('.')[-2]
    fpath = 'result.jpg'
    cv2.imwrite(fpath, ori_image)
예제 #2
0
def inference(model_file, device, records, result_queue):
    @jit.trace(symbolic=False)
    def val_func():
        pred_boxes = net(net.inputs)
        return pred_boxes

    net = network.Network()
    net.eval()
    check_point = mge.load(model_file)
    net.load_state_dict(check_point['state_dict'])
    for record in records:
        np.set_printoptions(precision=2, suppress=True)
        net.eval()
        image, gt_boxes, im_info, ID = get_data(record, device)
        net.inputs["image"].set_value(image.astype(np.float32))
        net.inputs["im_info"].set_value(im_info)
        pred_boxes = val_func().numpy()
        num_tag = config.num_classes - 1
        target_shape = (pred_boxes.shape[0] // num_tag // top_k, top_k)
        pred_tags = (np.arange(num_tag) + 1).reshape(-1, 1)
        pred_tags = np.tile(pred_tags, target_shape).reshape(-1, 1)
        # nms
        if if_set_nms:
            from set_nms_utils import set_cpu_nms
            n = pred_boxes.shape[0] // top_k
            idents = np.tile(np.arange(n)[:, None], (1, top_k)).reshape(-1, 1)
            pred_boxes = np.hstack((pred_boxes, idents))
            keep = pred_boxes[:, -2] > 0.05
            pred_boxes = pred_boxes[keep]
            pred_tags = pred_tags[keep]
            keep = set_cpu_nms(pred_boxes, 0.5)
            pred_boxes = pred_boxes[keep][:, :-1]
            pred_tags = pred_tags[keep].flatten()
        else:
            from set_nms_utils import cpu_nms
            keep = pred_boxes[:, -1] > 0.05
            pred_boxes = pred_boxes[keep]
            pred_tags = pred_tags[keep]
            keep = cpu_nms(pred_boxes, 0.5)
            pred_boxes = pred_boxes[keep]
            pred_tags = pred_tags[keep].flatten()
        result_dict = dict(ID=ID,
                           height=int(im_info[0, -2]),
                           width=int(im_info[0, -1]),
                           dtboxes=boxes_dump(pred_boxes, pred_tags, False),
                           gtboxes=boxes_dump(gt_boxes, None, True))
        result_queue.put_nowait(result_dict)
예제 #3
0
def inference(args):
    @jit.trace(symbolic=False)
    def val_func():
        pred_boxes = net(net.inputs)
        return pred_boxes

    # model path
    saveDir = config.model_dir
    evalDir = config.eval_dir
    misc_utils.ensure_dir(evalDir)
    model_file = os.path.join(saveDir,
                              'epoch_{}.pkl'.format(args.resume_weights))
    assert os.path.exists(model_file)
    # load model
    net = network.Network()
    net.eval()
    check_point = mge.load(model_file)
    net.load_state_dict(check_point['state_dict'])
    ori_image, image, im_info = get_data(args.img_path)
    net.inputs["image"].set_value(image.astype(np.float32))
    net.inputs["im_info"].set_value(im_info)
    pred_boxes = val_func().numpy()
    num_tag = config.num_classes - 1
    target_shape = (pred_boxes.shape[0] // num_tag, 1)
    pred_tags = (np.arange(num_tag) + 1).reshape(-1, 1)
    pred_tags = np.tile(pred_tags, target_shape).reshape(-1, 1)
    # nms
    from set_nms_utils import cpu_nms
    keep = pred_boxes[:, -1] > args.thresh
    pred_boxes = pred_boxes[keep]
    pred_tags = pred_tags[keep]
    keep = cpu_nms(pred_boxes, 0.5)
    pred_boxes = pred_boxes[keep]
    pred_tags = pred_tags[keep]

    pred_tags = pred_tags.astype(np.int32).flatten()
    pred_tags_name = np.array(config.class_names)[pred_tags]
    visual_utils.draw_boxes(ori_image, pred_boxes[:, :-1], pred_boxes[:, -1],
                            pred_tags_name)
    name = args.img_path.split('/')[-1].split('.')[-2]
    fpath = '/data/jupyter/{}.png'.format(name)
    cv2.imwrite(fpath, ori_image)
예제 #4
0
def inference(args):
    @jit.trace(symbolic=False)
    def val_func():
        pred_boxes = net(net.inputs)
        return pred_boxes
    # model path
    saveDir = config.model_dir
    evalDir = config.eval_dir
    misc_utils.ensure_dir(evalDir)
    model_file = os.path.join(saveDir,
            'epoch_{}.pkl'.format(args.resume_weights))
    assert os.path.exists(model_file)
    # load model
    net = network.Network()
    net.eval()
    check_point = mge.load(model_file)
    net.load_state_dict(check_point['state_dict'])
    image, im_info = get_data(args.img_path)
    net.inputs["image"].set_value(image.astype(np.float32))
    net.inputs["im_info"].set_value(im_info)
    pred_boxes = val_func().numpy()
    num_tag = config.num_classes - 1
    target_shape = (pred_boxes.shape[0]//num_tag, 1)
    pred_tags = (np.arange(num_tag) + 1).reshape(-1,1)
    pred_tags = np.tile(pred_tags, target_shape).reshape(-1,1)
    # nms
    from set_nms_utils import cpu_nms
    keep = pred_boxes[:, -1] > 0.05
    pred_boxes = pred_boxes[keep]
    pred_tags = pred_tags[keep]
    keep = cpu_nms(pred_boxes, 0.5)
    pred_boxes = pred_boxes[keep]
    pred_tags = pred_tags[keep].flatten()
    result_dict = dict(height=int(im_info[0, -2]), width=int(im_info[0, -1]),
        dtboxes=boxes_dump(pred_boxes, pred_tags))
    name = args.img_path.split('/')[-1].split('.')[-2]
    misc_utils.save_json_lines([result_dict], '{}.json'.format(name))