Exemple #1
0
def pred_eval(predictor,
              test_data,
              imdb,
              cfg,
              vis=False,
              thresh=1e-3,
              logger=None,
              ignore_cache=True):
    """
    wrapper for calculating offline validation for faster data analysis
    in this example, all threshold are set by hand
    :param predictor: Predictor
    :param test_data: data iterator, must be non-shuffle
    :param imdb: image database
    :param vis: controls visualization
    :param thresh: valid detection threshold
    :return:
    """

    det_file = os.path.join(imdb.result_path, imdb.name + '_detections.pkl')
    if os.path.exists(det_file) and not ignore_cache:
        with open(det_file, 'rb') as fid:
            all_boxes = cPickle.load(fid)
        info_str = imdb.evaluate_detections(all_boxes)
        if logger:
            logger.info('evaluate detections: \n{}'.format(info_str))
        return

    assert vis or not test_data.shuffle
    data_names = [k[0] for k in test_data.provide_data[0]]

    if not isinstance(test_data, PrefetchingIter):
        test_data = PrefetchingIter(test_data)

    #if cfg.TEST.SOFTNMS:
    #    nms = py_softnms_wrapper(cfg.TEST.NMS)
    #else:
    #    nms = py_nms_wrapper(cfg.TEST.NMS)

    if cfg.TEST.SOFTNMS:
        nms = py_softnms_wrapper(cfg.TEST.NMS)
    else:
        nms = py_nms_wrapper(cfg.TEST.NMS)

    # limit detections to max_per_image over all classes
    max_per_image = cfg.TEST.max_per_image

    num_images = imdb.num_images
    # all detections are collected into:
    #    all_boxes[cls][image] = N x 5 array of detections in
    #    (x1, y1, x2, y2, score)
    all_boxes = [[[] for _ in range(num_images)]
                 for _ in range(imdb.num_classes)]
    class_lut = [[] for _ in range(imdb.num_classes)]
    valid_tally = 0
    valid_sum = 0

    idx = 0
    t = time.time()
    inference_count = 0
    all_inference_time = []
    post_processing_time = []
    for im_info, data_batch in test_data:
        t1 = time.time() - t
        t = time.time()

        scales = [iim_info[0, 2] for iim_info in im_info]
        scores_all, boxes_all, data_dict_all = im_detect(
            predictor, data_batch, data_names, scales, cfg)

        t2 = time.time() - t
        t = time.time()
        for delta, (scores, boxes, data_dict) in enumerate(
                zip(scores_all, boxes_all, data_dict_all)):
            if cfg.TEST.LEARN_NMS:
                for j in range(1, imdb.num_classes):
                    indexes = np.where(scores[:, j - 1] > thresh)[0]
                    cls_scores = scores[indexes, j - 1:j]
                    cls_boxes = boxes[indexes, j - 1, :]
                    cls_dets = np.hstack((cls_boxes, cls_scores))
                    # count the valid ground truth
                    if len(cls_scores) > 0:
                        class_lut[j].append(idx + delta)
                        valid_tally += len(cls_scores)
                        valid_sum += len(scores)
                    all_boxes[j][idx + delta] = cls_dets
            else:
                for j in range(1, imdb.num_classes):
                    indexes = np.where(scores[:, j] > thresh)[0]
                    if cfg.TEST.FIRST_N > 0:
                        # todo: check whether the order affects the result
                        sort_indices = np.argsort(
                            scores[:, j])[-cfg.TEST.FIRST_N:]
                        # sort_indices = np.argsort(-scores[:, j])[0:cfg.TEST.FIRST_N]
                        indexes = np.intersect1d(sort_indices, indexes)

                    cls_scores = scores[indexes, j, np.newaxis]
                    cls_boxes = boxes[indexes,
                                      4:8] if cfg.CLASS_AGNOSTIC else boxes[
                                          indexes, j * 4:(j + 1) * 4]
                    # count the valid ground truth
                    if len(cls_scores) > 0:
                        class_lut[j].append(idx + delta)
                        valid_tally += len(cls_scores)
                        valid_sum += len(scores)
                        # print np.min(cls_scores), valid_tally, valid_sum
                        # cls_scores = scores[:, j, np.newaxis]
                        # cls_scores[cls_scores <= thresh] = thresh
                        # cls_boxes = boxes[:, 4:8] if cfg.CLASS_AGNOSTIC else boxes[:, j * 4:(j + 1) * 4]
                    cls_dets = np.hstack((cls_boxes, cls_scores))
                    if cfg.TEST.SOFTNMS:
                        all_boxes[j][idx + delta] = nms(cls_dets)
                    else:
                        keep = nms(cls_dets)
                        all_boxes[j][idx + delta] = cls_dets[keep, :]

            if max_per_image > 0:
                image_scores = np.hstack([
                    all_boxes[j][idx + delta][:, -1]
                    for j in range(1, imdb.num_classes)
                ])
                if len(image_scores) > max_per_image:
                    image_thresh = np.sort(image_scores)[-max_per_image]
                    for j in range(1, imdb.num_classes):
                        keep = np.where(
                            all_boxes[j][idx + delta][:,
                                                      -1] >= image_thresh)[0]
                        all_boxes[j][idx +
                                     delta] = all_boxes[j][idx +
                                                           delta][keep, :]

            if vis:
                boxes_this_image = [[]] + [
                    all_boxes[j][idx + delta]
                    for j in range(1, imdb.num_classes)
                ]
                vis_all_detection(data_dict['data'].asnumpy(),
                                  boxes_this_image, imdb.classes,
                                  scales[delta], cfg)

        idx += test_data.batch_size
        t3 = time.time() - t
        t = time.time()
        post_processing_time.append(t3)
        all_inference_time.append(t1 + t2 + t3)
        inference_count += 1
        if inference_count % 200 == 0:
            valid_count = 500 if inference_count > 500 else inference_count
            print("--->> running-average inference time per batch: {}".format(
                float(sum(all_inference_time[-valid_count:])) / valid_count))
            print("--->> running-average post processing time per batch: {}".
                  format(
                      float(sum(post_processing_time[-valid_count:])) /
                      valid_count))
        print 'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(
            idx, imdb.num_images, t1, t2, t3)
        if logger:
            logger.info(
                'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(
                    idx, imdb.num_images, t1, t2, t3))

    with open(det_file, 'wb') as f:
        cPickle.dump(all_boxes, f, protocol=cPickle.HIGHEST_PROTOCOL)

    # np.save('class_lut.npy', class_lut)

    info_str = imdb.evaluate_detections(all_boxes)
    if logger:
        logger.info('evaluate detections: \n{}'.format(info_str))
        num_valid_classes = [len(x) for x in class_lut]
        logger.info('valid class ratio:{}'.format(
            np.sum(num_valid_classes) / float(num_images)))
        logger.info('valid score ratio:{}'.format(
            float(valid_tally) / float(valid_sum + 0.01)))
Exemple #2
0
def pred_eval(predictor,
              test_data,
              imdb,
              cfg,
              vis=False,
              thresh=1e-3,
              logger=None,
              ignore_cache=True):
    """
    wrapper for calculating offline validation for faster data analysis
    in this example, all threshold are set by hand
    :param predictor: Predictor
    :param test_data: data iterator, must be non-shuffle
    :param imdb: image database
    :param vis: controls visualization
    :param thresh: valid detection threshold
    :return:
    """

    det_file = os.path.join(imdb.result_path, imdb.name + '_detections.pkl')
    if os.path.exists(det_file) and not ignore_cache:
        with open(det_file, 'rb') as fid:
            all_boxes = cPickle.load(fid)
        info_str = imdb.evaluate_detections(all_boxes)
        if logger:
            logger.info('evaluate detections: \n{}'.format(info_str))
        return

    assert vis or not test_data.shuffle
    data_names = [k[0] for k in test_data.provide_data[0]]

    if not isinstance(test_data, PrefetchingIter):
        test_data = PrefetchingIter(test_data)

    # limit detections to max_per_image over all classes
    max_per_image = cfg.TEST.max_per_image
    num_images = imdb.num_images

    for test_scale_index, test_scale in enumerate(cfg.TEST_SCALES):
        det_file_single_scale = os.path.join(
            imdb.result_path,
            imdb.name + '_detections_' + str(test_scale_index) + '.pkl')
        # if os.path.exists(det_file_single_scale):
        #    continue
        cfg.SCALES = [test_scale]
        test_data.reset()

        # all detections are collected into:
        #    all_boxes[cls][image] = N x 5 array of detections in
        #    (x1, y1, x2, y2, score)
        all_boxes_single_scale = [[[] for _ in range(num_images)]
                                  for _ in range(imdb.num_classes)]

        detect_at_single_scale(predictor, data_names, imdb, test_data, cfg,
                               thresh, vis, all_boxes_single_scale, logger)

        with open(det_file_single_scale, 'wb') as f:
            cPickle.dump(all_boxes_single_scale,
                         f,
                         protocol=cPickle.HIGHEST_PROTOCOL)

    # all detections are collected into:
    #    all_boxes[cls][image] = N x 5 array of detections in
    #    (x1, y1, x2, y2, score)
    all_boxes = [[[] for _ in range(num_images)]
                 for _ in range(imdb.num_classes)]

    for test_scale_index, test_scale in enumerate(cfg.TEST_SCALES):
        det_file_single_scale = os.path.join(
            imdb.result_path,
            imdb.name + '_detections_' + str(test_scale_index) + '.pkl')
        if os.path.exists(det_file_single_scale):
            with open(det_file_single_scale, 'rb') as fid:
                all_boxes_single_scale = cPickle.load(fid)
            for idx_class in range(1, imdb.num_classes):
                for idx_im in range(0, num_images):
                    if len(all_boxes[idx_class][idx_im]) == 0:
                        all_boxes[idx_class][idx_im] = all_boxes_single_scale[
                            idx_class][idx_im]
                    else:
                        all_boxes[idx_class][idx_im] = np.vstack(
                            (all_boxes[idx_class][idx_im],
                             all_boxes_single_scale[idx_class][idx_im]))

    for idx_class in range(1, imdb.num_classes):
        for idx_im in range(0, num_images):
            if cfg.TEST.USE_SOFTNMS:
                soft_nms = py_softnms_wrapper(cfg.TEST.SOFTNMS_THRESH,
                                              max_dets=max_per_image)
                all_boxes[idx_class][idx_im] = soft_nms(
                    all_boxes[idx_class][idx_im])
            else:
                nms = py_nms_wrapper(cfg.TEST.NMS)
                keep = nms(all_boxes[idx_class][idx_im])
                all_boxes[idx_class][idx_im] = all_boxes[idx_class][idx_im][
                    keep, :]

    if max_per_image > 0:
        for idx_im in range(0, num_images):
            image_scores = np.hstack([
                all_boxes[j][idx_im][:, -1]
                for j in range(1, imdb.num_classes)
            ])
            if len(image_scores) > max_per_image:
                image_thresh = np.sort(image_scores)[-max_per_image]
                for j in range(1, imdb.num_classes):
                    keep = np.where(
                        all_boxes[j][idx_im][:, -1] >= image_thresh)[0]
                    all_boxes[j][idx_im] = all_boxes[j][idx_im][keep, :]

    with open(det_file, 'wb') as f:
        cPickle.dump(all_boxes, f, protocol=cPickle.HIGHEST_PROTOCOL)

    info_str = imdb.evaluate_detections(all_boxes)
    if logger:
        logger.info('evaluate detections: \n{}'.format(info_str))
Exemple #3
0
def pred_eval(predictor, test_data, imdb, cfg, vis=False, thresh=1e-3, logger=None, ignore_cache=True):
#def pred_eval(predictor, test_data, imdb, cfg, vis=False, thresh=0.7, logger=None, ignore_cache=True):
    """
    wrapper for calculating offline validation for faster data analysis
    in this example, all threshold are set by hand
    :param predictor: Predictor
    :param test_data: data iterator, must be non-shuffle
    :param imdb: image database
    :param vis: controls visualization
    :param thresh: valid detection threshold
    :return:
    """
    co_occur_matrix = np.load('/home/user/Deformable-ConvNets2/tmp/co_occur_matrix.npy')
    nor_co_occur_matrix = np.zeros((90,90))
    row_max = np.zeros(90)
    co_occur_matrix = co_occur_matrix.astype(int)
    for ind, val in enumerate(co_occur_matrix):        
        row_sum = np.sum(co_occur_matrix[:,ind])        
        if not row_sum == 0:
            nor_co_occur_matrix[:,ind] = co_occur_matrix[:,ind]/row_sum
        row_max[ind] = np.amax(nor_co_occur_matrix[:,ind])
        

    assert vis or not test_data.shuffle
    data_names = [k[0] for k in test_data.provide_data[0]]    

    roidb = test_data.roidb

    if not isinstance(test_data, PrefetchingIter):
        test_data = PrefetchingIter(test_data)    
    
    soft_nms = py_softnms_wrapper(cfg.TEST.NMS)

    # limit detections to max_per_image over all classes
    max_per_image = cfg.TEST.max_per_image

    num_images = imdb.num_images

    # all detections are collected into:    
    all_boxes = [[[] for _ in range(num_images)]
                 for _ in range(imdb.num_classes)]

    idx = 0
    data_time, net_time, post_time = 0.0, 0.0, 0.0
    t = time.time()
    #pl = Pool(8)

    annotation_file = '/home/user/Deformable-ConvNets-test/data/coco/annotations/kinstances_unlabeled2017.json'
    dataset = json.load(open(annotation_file, 'r'))    
    annotations = []    
    id_count = 1
    img_count = 1

    for im_info, data_batch in test_data:
        t1 = time.time() - t
        t = time.time()

        scales = [iim_info[0, 2] for iim_info in im_info]
        scores_all, boxes_all, data_dict_all = im_detect(predictor, data_batch, data_names, scales, cfg)
        
        t2 = time.time() - t
        t = time.time()
        for delta, (scores, boxes, data_dict) in enumerate(zip(scores_all, boxes_all, data_dict_all)):            
            for j in range(1, imdb.num_classes):
                indexes = np.where(scores[:, j] > thresh)[0]
                cls_scores = scores[indexes, j, np.newaxis]
                cls_boxes = boxes[indexes, 4:8] if cfg.CLASS_AGNOSTIC else boxes[indexes, j * 4:(j + 1) * 4]
                cls_dets = np.hstack((cls_boxes, cls_scores))                
                keep = soft_nms(cls_dets)
                keep = keep.tolist()                
                all_boxes[j][idx+delta] = cls_dets[keep, :]                
            
            if max_per_image > 0:
                image_scores = np.hstack([all_boxes[j][idx+delta][:, -1]
                                          for j in range(1, imdb.num_classes)])
                if len(image_scores) > max_per_image:
                    image_thresh = np.sort(image_scores)[-max_per_image]
                    for j in range(1, imdb.num_classes):
                        keep = np.where(all_boxes[j][idx+delta][:, -1] >= image_thresh)[0]
                        all_boxes[j][idx+delta] = all_boxes[j][idx+delta][keep, :]

            if vis:                
                boxes_this_image = [[]] + [all_boxes[j][idx+delta] for j in range(1, imdb.num_classes)]
                im_name = roidb[idx]['image']
                im_name = im_name.rsplit("/", 1)
                im_name = im_name[-1]                                
                result = draw_all_detection(data_dict['data'].asnumpy(), boxes_this_image, imdb.classes, 
                                            scales[delta], cfg, im_name, annotations, id_count, 
                                            nor_co_occur_matrix, row_max)
                annotations = result['ann']
                id_count = result['id_count']                
        
        idx += test_data.batch_size
        t3 = time.time() - t
        t = time.time()
        data_time += t1
        net_time += t2
        post_time += t3
        print 'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(idx, imdb.num_images, data_time / idx * test_data.batch_size, net_time / idx * test_data.batch_size, post_time / idx * test_data.batch_size)
        if logger:
            logger.info('testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(idx, imdb.num_images, data_time / idx * test_data.batch_size, net_time / idx * test_data.batch_size, post_time / idx * test_data.batch_size))
        
    dataset.update({'annotations':annotations})
    save_annotation_file = '/home/user/Deformable-ConvNets-test/data/coco/annotations/instances_unlabeled2017_ssl.json'
    with open(save_annotation_file, 'w') as f:
        json.dump(dataset, f)

    print "Finish generate pseudo ground truth!"
Exemple #4
0
def pred_eval(predictor,
              test_data,
              imdb,
              cfg,
              vis=False,
              thresh=1e-3,
              logger=None,
              ignore_cache=True):
    """
    wrapper for calculating offline validation for faster data analysis
    in this example, all threshold are set by hand
    :param predictor: Predictor
    :param test_data: data iterator, must be non-shuffle
    :param imdb: image database
    :param vis: controls visualization
    :param thresh: valid detection threshold
    :return:
    """

    det_file = os.path.join(imdb.result_path, imdb.name + '_detections.pkl')
    if os.path.exists(det_file) and not ignore_cache:
        with open(det_file, 'rb') as fid:
            all_boxes = cPickle.load(fid)
        info_str = imdb.evaluate_detections(all_boxes)
        if logger:
            logger.info('evaluate detections: \n{}'.format(info_str))
        return

    assert vis or not test_data.shuffle
    data_names = [k[0] for k in test_data.provide_data[0]]

    if not isinstance(test_data, PrefetchingIter):
        test_data = PrefetchingIter(test_data)

    #nms = py_nms_wrapper(cfg.TEST.NMS)

    # limit detections to max_per_image over all classes
    max_per_image = cfg.TEST.max_per_image
    # fyk add
    if cfg.TEST.USE_SOFTNMS:
        soft_nms = py_softnms_wrapper(cfg.TEST.SOFTNMS_THRESH,
                                      max_dets=max_per_image)

    num_images = imdb.num_images
    # all detections are collected into:
    #    all_boxes[cls][image] = N x 5 array of detections in
    #    (x1, y1, x2, y2, score)
    all_boxes = [[[] for _ in range(num_images)]
                 for _ in range(imdb.num_classes)]

    idx = 0
    data_time, net_time, post_time = 0.0, 0.0, 0.0
    t = time.time()
    for im_info, data_batch in test_data:
        t1 = time.time() - t
        t = time.time()

        scales = [iim_info[0, 2] for iim_info in im_info]
        scores_all, boxes_all, data_dict_all = im_detect(
            predictor, data_batch, data_names, scales, cfg)

        t2 = time.time() - t
        t = time.time()
        for delta, (scores, boxes, data_dict) in enumerate(
                zip(scores_all, boxes_all, data_dict_all)):
            for j in range(1, imdb.num_classes):
                indexes = np.where(scores[:, j] > thresh)[0]
                cls_scores = scores[indexes, j, np.newaxis]
                cls_boxes = boxes[indexes,
                                  4:8] if cfg.CLASS_AGNOSTIC else boxes[
                                      indexes, j * 4:(j + 1) * 4]
                cls_dets = np.hstack((cls_boxes, cls_scores))
                # fyk add
                if cfg.TEST.USE_SOFTNMS:
                    all_boxes[j][idx + delta] = soft_nms(cls_dets)
                else:
                    keep = nms(cls_dets)
                    all_boxes[j][idx + delta] = cls_dets[keep, :]

            if max_per_image > 0:
                image_scores = np.hstack([
                    all_boxes[j][idx + delta][:, -1]
                    for j in range(1, imdb.num_classes)
                ])
                if len(image_scores) > max_per_image:
                    image_thresh = np.sort(image_scores)[-max_per_image]
                    for j in range(1, imdb.num_classes):
                        keep = np.where(
                            all_boxes[j][idx + delta][:,
                                                      -1] >= image_thresh)[0]
                        all_boxes[j][idx +
                                     delta] = all_boxes[j][idx +
                                                           delta][keep, :]

            if vis:
                boxes_this_image = [[]] + [
                    all_boxes[j][idx + delta]
                    for j in range(1, imdb.num_classes)
                ]
                vis_all_detection(data_dict['data'].asnumpy(),
                                  boxes_this_image, imdb.classes,
                                  scales[delta], cfg)

        idx += test_data.batch_size
        t3 = time.time() - t
        t = time.time()
        data_time += t1
        net_time += t2
        post_time += t3
        print 'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(
            idx, imdb.num_images, data_time / idx * test_data.batch_size,
            net_time / idx * test_data.batch_size,
            post_time / idx * test_data.batch_size)
        if logger:
            logger.info(
                'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(
                    idx, imdb.num_images,
                    data_time / idx * test_data.batch_size,
                    net_time / idx * test_data.batch_size,
                    post_time / idx * test_data.batch_size))

    with open(det_file, 'wb') as f:
        cPickle.dump(all_boxes, f, protocol=cPickle.HIGHEST_PROTOCOL)

    info_str = imdb.evaluate_detections(all_boxes)
    if logger:
        logger.info('evaluate detections: \n{}'.format(info_str))
Exemple #5
0
def pred_eval(predictor, test_data, imdb, cfg, vis=False, thresh=1e-3, logger=None, ignore_cache=True):
    """
    wrapper for calculating offline validation for faster data analysis
    in this example, all threshold are set by hand
    :param predictor: Predictor
    :param test_data: data iterator, must be non-shuffle
    :param imdb: image database
    :param vis: controls visualization
    :param thresh: valid detection threshold
    :return:
    """
    assert vis or not test_data.shuffle
    data_names = [k[0] for k in test_data.provide_data[0]]

    if not isinstance(test_data, PrefetchingIter):
        test_data = PrefetchingIter(test_data)
    
    soft_nms = py_softnms_wrapper(cfg.TEST.NMS)

    # limit detections to max_per_image over all classes
    max_per_image = cfg.TEST.max_per_image

    num_images = imdb.num_images    
    all_boxes = [[[] for _ in range(num_images)]
                 for _ in range(imdb.num_classes)]

    idx = 0
    data_time, net_time, post_time = 0.0, 0.0, 0.0
    t = time.time()
    for im_info, data_batch in test_data:
        t1 = time.time() - t
        t = time.time()

        scales = [iim_info[0, 2] for iim_info in im_info]
        scores_all, boxes_all, data_dict_all = im_detect(predictor, data_batch, data_names, scales, cfg)

        t2 = time.time() - t
        t = time.time()
        for delta, (scores, boxes, data_dict) in enumerate(zip(scores_all, boxes_all, data_dict_all)):
            for j in range(1, imdb.num_classes):
                indexes = np.where(scores[:, j] > thresh)[0]
                cls_scores = scores[indexes, j, np.newaxis]
                cls_boxes = boxes[indexes, 4:8] if cfg.CLASS_AGNOSTIC else boxes[indexes, j * 4:(j + 1) * 4]
                cls_dets = np.hstack((cls_boxes, cls_scores))                
                keep = soft_nms(cls_dets)
                keep = keep.tolist()
                all_boxes[j][idx+delta] = cls_dets[keep, :]

            if max_per_image > 0:
                image_scores = np.hstack([all_boxes[j][idx+delta][:, -1]
                                          for j in range(1, imdb.num_classes)])
                if len(image_scores) > max_per_image:
                    image_thresh = np.sort(image_scores)[-max_per_image]
                    for j in range(1, imdb.num_classes):
                        keep = np.where(all_boxes[j][idx+delta][:, -1] >= image_thresh)[0]
                        all_boxes[j][idx+delta] = all_boxes[j][idx+delta][keep, :]

            if vis:
                boxes_this_image = [[]] + [all_boxes[j][idx+delta] for j in range(1, imdb.num_classes)]
                vis_all_detection(data_dict['data'].asnumpy(), boxes_this_image, imdb.classes, scales[delta], cfg)

        idx += test_data.batch_size
        t3 = time.time() - t
        t = time.time()
        data_time += t1
        net_time += t2
        post_time += t3
        print 'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(idx, imdb.num_images, data_time / idx * test_data.batch_size, net_time / idx * test_data.batch_size, post_time / idx * test_data.batch_size)
        if logger:
            logger.info('testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(idx, imdb.num_images, data_time / idx * test_data.batch_size, net_time / idx * test_data.batch_size, post_time / idx * test_data.batch_size))    

    info_str = imdb.evaluate_detections(all_boxes)
    if logger:
        logger.info('evaluate detections: \n{}'.format(info_str))
def main():
    # get symbol

    pprint.pprint(config)
    config.symbol = 'resnet_v1_101_fpn_dcn_rcnn'
    sym_instance = eval(config.symbol + '.' + config.symbol)()
    sym = sym_instance.get_symbol(config, is_train=False)
    max_per_image = config.TEST.max_per_image

    # Print the test scales
    print("Train scales: %s" % str(config.SCALES))
    print("Test scales: %s" % str(config.TEST_SCALES))

    # load demo data
    #dataBaseDir = '/b_test/pkhan/datasets/Receipts/data/'
    dataBaseDir = '/netscratch/queling/data/'
    outputBaseDir = '/netscratch/queling/Deformable/output/fpn/deep_receipt/results/' + EXPERIMENT_NAME
    #outputBaseDir = '/b_test/pkhan/Code/Deformable/output/' + EXPERIMENT_NAME

    if os.path.exists(outputBaseDir):
        shutil.rmtree(outputBaseDir)
    os.mkdir(outputBaseDir)

    outputFile = open(os.path.join(outputBaseDir, 'output.txt'), 'w')
    outputFile.write('<?xml version="1.0" encoding="UTF-8"?>\n')
    errorStatsFile = open(
        os.path.join(outputBaseDir, 'incorrect-detections.txt'), 'w')

    incorrectDetectionResultsPath = os.path.join(outputBaseDir,
                                                 'IncorrectDetections')
    if not os.path.exists(incorrectDetectionResultsPath):
        os.mkdir(incorrectDetectionResultsPath)

    detectionResultsPath = os.path.join(outputBaseDir, 'Detections')
    if not os.path.exists(detectionResultsPath):
        os.mkdir(detectionResultsPath)

    annotationResultsPath = os.path.join(outputBaseDir, 'Annotations')
    if not os.path.exists(annotationResultsPath):
        os.mkdir(annotationResultsPath)

    statistics = {}
    for cls_ind, cls in enumerate(CLASSES):
        statistics[cls] = {}
        for thresh in IoU_THRESHOLDS:
            statistics[cls][thresh] = {}
            statistics[cls][thresh]["truePositives"] = 0
            statistics[cls][thresh]["falsePositives"] = 0
            statistics[cls][thresh]["falseNegatives"] = 0
            statistics[cls][thresh]["precision"] = 0
            statistics[cls][thresh]["recall"] = 0
            statistics[cls][thresh]["fMeasure"] = 0

    im_names_file = open(os.path.join(dataBaseDir, 'ImageSets/image.txt'),
                         'r')  #test.txt for whole dataset, image.txt for one

    for im_name in im_names_file:
        im_name = im_name.strip()
        # print ("Processing file: %s" % (im_name))

        found = False
        for ext in IMAGE_EXTENSIONS:

            im_name_with_ext = im_name + ext
            im_path = os.path.join(
                dataBaseDir, 'Test',
                im_name_with_ext)  #Images for whole dataset, Test for one

            if os.path.exists(im_path):
                found = True
                break
        if not found:
            print("Error: Unable to locate file %s" % (im_name))
            exit(-1)

        # Load GT annotations

        xml_path = os.path.join(dataBaseDir, 'Annotations', im_name + '.xml')

        #gtBBoxes = loadGTAnnotationsFromXML(xml_path)

        tic()

        dets_nms = [[] for j in range(len(TOTAL_CLASSES) - 1)]

        for testScale in config.SCALES:
            data = []
            im = cv2.imread(im_path,
                            cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
            target_size = testScale[0]
            max_size = testScale[1]
            im, im_scale = resize(im,
                                  target_size,
                                  max_size,
                                  stride=config.network.IMAGE_STRIDE)
            im_tensor = transform(im, config.network.PIXEL_MEANS)
            im_info = np.array(
                [[im_tensor.shape[2], im_tensor.shape[3], im_scale]],
                dtype=np.float32)
            data.append({'data': im_tensor, 'im_info': im_info})

            # get predictor
            data_names = ['data', 'im_info']
            label_names = []
            data = [[mx.nd.array(data[i][name]) for name in data_names]
                    for i in xrange(len(data))]
            max_data_shape = [[('data', (1, 3, testScale[0], testScale[1]))]]
            provide_data = [[(k, v.shape) for k, v in zip(data_names, data[i])]
                            for i in xrange(len(data))]
            provide_label = [None for i in xrange(len(data))]
            # arg_params, aux_params = load_param(cur_path + '/../model/' + ('rfcn_dcn_coco' if not args.rfcn_only else 'rfcn_coco'), 0, process=True)
            arg_params, aux_params = load_param(MODEL_PATH,
                                                MODEL_EPOCH,
                                                process=True)
            predictor = Predictor(sym,
                                  data_names,
                                  label_names,
                                  context=[mx.gpu(0)],
                                  max_data_shapes=max_data_shape,
                                  provide_data=provide_data,
                                  provide_label=provide_label,
                                  arg_params=arg_params,
                                  aux_params=aux_params)

            # # warm up
            for j in xrange(2):
                data_batch = mx.io.DataBatch(
                    data=[data[0]],
                    label=[],
                    pad=0,
                    index=0,
                    provide_data=[[(k, v.shape)
                                   for k, v in zip(data_names, data[0])]],
                    provide_label=[None])
                scales = [
                    data_batch.data[i][1].asnumpy()[0, 2]
                    for i in xrange(len(data_batch.data))
                ]
                scores, boxes, data_dict = im_detect(predictor, data_batch,
                                                     data_names, scales,
                                                     config)

            # test
            image_names = [im_name]  # Way around
            for idx, im_name in enumerate(image_names):
                data_batch = mx.io.DataBatch(
                    data=[data[idx]],
                    label=[],
                    pad=0,
                    index=idx,
                    provide_data=[[(k, v.shape)
                                   for k, v in zip(data_names, data[idx])]],
                    provide_label=[None])
                scales = [
                    data_batch.data[i][1].asnumpy()[0, 2]
                    for i in xrange(len(data_batch.data))
                ]

                scores, boxes, data_dict = im_detect(predictor, data_batch,
                                                     data_names, scales,
                                                     config)
                boxes = boxes[0].astype('f')
                scores = scores[0].astype('f')

                # TODO: Multi-scale testing
                for j in range(1, scores.shape[1]):
                    cls_scores = scores[:, j, np.newaxis]
                    cls_boxes = boxes[:, 4:
                                      8] if config.CLASS_AGNOSTIC else boxes[:,
                                                                             j *
                                                                             4:
                                                                             (j
                                                                              +
                                                                              1
                                                                              )
                                                                             *
                                                                             4]
                    cls_dets = np.hstack((cls_boxes, cls_scores))
                    # if config.TEST.USE_SOFTNMS:
                    #     soft_nms = py_softnms_wrapper(config.TEST.SOFTNMS_THRESH, max_dets=max_per_image)
                    #     cls_dets = soft_nms(cls_dets)
                    # else:
                    #     nms = py_nms_wrapper(config.TEST.NMS)
                    #     keep = nms(cls_dets)
                    #     cls_dets = cls_dets[keep, :]
                    # cls_dets = cls_dets[cls_dets[:, -1] > confidenceThreshold, :]
                    # dets_nms.append(cls_dets)
                    if len(dets_nms[j - 1]) == 0:
                        dets_nms[j - 1] = cls_dets
                    else:
                        dets_nms[j - 1] += cls_dets

        finalDetections = []
        for clsIter in range(len(dets_nms)):
            # print ("Performing NMS on cls %d with %d boxes" % (clsIter, len(dets_nms[clsIter])))
            if config.TEST.USE_SOFTNMS:
                soft_nms = py_softnms_wrapper(config.TEST.SOFTNMS_THRESH,
                                              max_dets=max_per_image)
                # cls_dets = soft_nms(dets_nms[clsIter])
                dets_nms[clsIter] = soft_nms(dets_nms[clsIter])
            else:
                nms = py_nms_wrapper(config.TEST.NMS)
                keep = nms(dets_nms[clsIter])
                # cls_dets = dets_nms[clsIter][keep, :]
                dets_nms[clsIter] = dets_nms[clsIter][keep, :]
            dets_nms[clsIter] = dets_nms[clsIter][
                dets_nms[clsIter][:, -1] > CONFIDENCE_THRESHOLD, :]

        # if max_per_image > 0:
        #     for idx_im in range(0, num_images):
        #         image_scores = np.hstack([all_boxes[j][idx_im][:, -1]
        #                                   for j in range(1, imdb.num_classes)])
        #         if len(image_scores) > max_per_image:
        #             image_thresh = np.sort(image_scores)[-max_per_image]
        #             for j in range(1, imdb.num_classes):
        #                 keep = np.where(all_boxes[j][idx_im][:, -1] >= image_thresh)[0]
        #                 all_boxes[j][idx_im] = all_boxes[j][idx_im][keep, :]

        print 'Processing image: {} {:.4f}s'.format(im_name, toc())

        # Add detections on the image
        im = cv2.imread(
            im_path)  # Reload the image since the previous one was scaled

        item = 0
        price = 0
        asd = 0
        row = 0

        for cls_idx, cls_name in enumerate(CONCERNED_ERRORS):
            cls_dets = dets_nms[cls_idx]
            for det in cls_dets:
                predictedBBox = det[:4]
                cv2.rectangle(im,
                              (int(predictedBBox[0]), int(predictedBBox[1])),
                              (int(predictedBBox[2]), int(predictedBBox[3])),
                              (0, 0, 255), 1)
                w = predictedBBox[2] - predictedBBox[0]
                cv2.putText(im, cls_name,
                            (int(predictedBBox[0] +
                                 (w / 2.0) - 100), int(predictedBBox[1] - 5)),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.3, (0, 255, 0), 1)

                crop_im = im[int(predictedBBox[1]):int(predictedBBox[3]),
                             int(predictedBBox[0]):int(predictedBBox[2])]
                gray = cv2.cvtColor(crop_im, cv2.COLOR_BGR2GRAY)

                if cls_name == "price":
                    asd = price + 1
                    price = price + 1
                    new_path = outputBaseDir + "/price/"
                    if not os.path.exists(new_path):
                        os.makedirs(new_path)
                    outputImagePath = os.path.join(
                        new_path, cls_name + str(asd) + ".jpg")
                    # print ("Writing image: %s" % (outputImagePath))
                    cv2.imwrite(outputImagePath, crop_im)

                elif cls_name == "item_name":
                    item = item + 1
                    asd = item
                    new_path = outputBaseDir + "/item/"
                    if not os.path.exists(new_path):
                        os.makedirs(new_path)

                    outputImagePath = os.path.join(
                        new_path, cls_name + str(asd) + ".jpg")
                    # print ("Writing image: %s" % (outputImagePath))
                    gray = cv2.medianBlur(gray, 3)

                    cv2.imwrite(outputImagePath, gray)

                elif cls_name == "row":
                    row = row + 1
                    asd = row
                    new_path = outputBaseDir + "/row/"
                    if not os.path.exists(new_path):
                        os.makedirs(new_path)

                    outputImagePath = os.path.join(
                        new_path, cls_name + str(asd) + ".jpg")
                    # print ("Writing image: %s" % (outputImagePath))
                    gray = cv2.medianBlur(gray, 3)
                    cv2.imwrite(outputImagePath, gray)

                elif cls_name == 'total_price':
                    print("Found Total")
                    new_path = outputBaseDir + "/total/"
                    if not os.path.exists(new_path):
                        os.makedirs(new_path)

                    outputImagePath = os.path.join(new_path, cls_name + ".jpg")
                    # print ("Writing image: %s" % (outputImagePath))
                    gray = cv2.medianBlur(gray, 3)
                    cv2.imwrite(outputImagePath, gray)

                elif cls_name == 'header':
                    new_path = outputBaseDir + "/header/"
                    if not os.path.exists(new_path):
                        os.makedirs(new_path)

                    outputImagePath = os.path.join(new_path, cls_name + ".jpg")
                    # print ("Writing image: %s" % (outputImagePath))
                    gray = cv2.medianBlur(gray, 3)
                    cv2.imwrite(outputImagePath, gray)

                outputImagePath = os.path.join(outputBaseDir,
                                               cls_name + str(asd) + ".jpg")
                # print ("Writing image: %s" % (outputImagePath))
                cv2.imwrite(outputImagePath, crop_im)
                text = pytesseract.image_to_string(Image.open(outputImagePath))
                #if text != "":
                # print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
                #print("")
                #print(cls_name+": "+text)
                #print(" ")
        items = []
        for k in range(1, item):
            path_item = outputBaseDir + "/item/item_name" + str(k) + ".jpg"
            text_item = pytesseract.image_to_string(Image.open(path_item))
            #text_item = spellCheck.main(text_item, "product")
            print(str(k) + ": " + text_item)

            if text_item == "":
                print("empty and not relevant")
                #print(type(text_item))
            else:
                import unicodedata
                #print(unicodedata.normalize('NFKD', text_item).encode('ascii','ignore'))
                #print(type(unicodedata.normalize('NFKD', text_item).encode('ascii','ignore')))
                items = items + [text_item]

        print("-------------------------------------------------------------")
        prices = []
        for k in range(1, price):
            path_item = outputBaseDir + "/price/price" + str(k) + ".jpg"
            text_item = pytesseract.image_to_string(Image.open(path_item),
                                                    config="--psm 13")
            print(str(k) + ": " + text_item)
            if text_item == "":
                print("empty and not relevant")
                #print(type(text_item))
            else:
                import unicodedata
                #print(unicodedata.normalize('NFKD', text_item).encode('ascii','ignore'))
                #print(type(unicodedata.normalize('NFKD', text_item).encode('ascii','ignore')))
                prices = prices + [text_item]

            print(
                "-------------------------------------------------------------"
            )

        rows = []
        for k in range(1, row):
            path_item = outputBaseDir + "/row/row" + str(k) + ".jpg"
            text_item = pytesseract.image_to_string(Image.open(path_item))
            if text_item == "":
                print("empty and not relevant")
                #print(type(text_item))
            else:
                import unicodedata
                #print(unicodedata.normalize('NFKD', text_item).encode('ascii','ignore'))
                #print(type(unicodedata.normalize('NFKD', text_item).encode('ascii','ignore')))
                rows = rows + [text_item]

            print(str(k) + ": " + text_item)

        # write total in result.txt
        path_item = outputBaseDir + "/total/total_price.jpg"
        text_item = pytesseract.image_to_string(Image.open(path_item))

        f = open("/netscratch/queling/Deformable/fpn/results.txt", "a")
        f.write(text_item + "\n")
        f.close()

        #path_item = outputBaseDir+"/header/header.jpg"
        #text_item = pytesseract.image_to_string(Image.open(path_item))
        #print("Header: "+text_item)

        found = False

        for k in range(0, len(items)):
            for l in range(0, len(rows)):
                #print(type(items[k]))
                #print(type(rows[l]))
                if items[k].encode('ascii', 'ignore') in rows[l].encode(
                        'ascii', 'ignore'):
                    for m in range(0, len(prices)):
                        #print(type(prices[m].encode('ascii' ,'ignore')))
                        if prices[k].encode('ascii',
                                            'ignore') in rows[l].encode(
                                                'ascii', 'ignore'):
                            #items[k] = spellCheck.main(items[k], "product")
                            f = open(
                                "/netscratch/queling/Deformable/fpn/results.txt",
                                "a")
                            f.write(items[k] + "\n")
                            f.write(str(prices[m]) + "\n")
                            f.close()
                            found = True

            # Product not found in row
            if (found == False):
                #items[k] = spellCheck.main(items[k], "product")
                f = open("/netscratch/queling/Deformable/fpn/results.txt", "a")
                f.write(items[k] + "\n")
                f.write(" " + "\n")
                f.close()

            found = False

        # Add gt annotations
        #for bbox in gtBBoxes:
        #    if bbox[5] in CONCERNED_ERRORS:
        #        cv2.rectangle(im, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (0, 255, 0), 1)

        # Computate the statistics for the current image
        #statistics, classificationErrorMessage = computeStatistics(dets_nms, gtBBoxes, statistics, IoU_THRESHOLDS)
        #if classificationErrorMessage is not None:
        #    print ("Writing incorrect image: %s" % (im_name))
        #    errorStatsFile.write("%s: %s\n" % (im_name, classificationErrorMessage))
        #    cv2.imwrite(os.path.join(incorrectDetectionResultsPath, im_name + '.jpg'), im)

        # Write the output in ICDAR Format
        outputFile.write(convertToXML(im_name_with_ext, dets_nms))

        if WRITE_DETECTION_RESULTS:
            # visualize
            # im = cv2.imread(im_path)
            # im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)

            # # Get also the plot for saving on server
            # _, plt = show_boxes(im, dets_nms, CLASSES, 1, returnPlt=True)
            # plt.savefig(os.path.join(outputBaseDir, 'Detections', im_name[:im_name.rfind('.')] + ".png"))

            outputImagePath = os.path.join(detectionResultsPath,
                                           im_name + ".jpg")
            print("Writing image: %s" % (outputImagePath))
            cv2.imwrite(outputImagePath, im)

        if WRITE_ANNOTATION_RESULTS:
            exportToPascalVOCFormat(im_name, im_path, dets_nms,
                                    annotationResultsPath)

    outputFile.close()
    errorStatsFile.close()

    total_classes = 0
    total_F_Meausere = 0
    average_F_Meausere = 0
    # Compute final precision and recall
    outputFile = open(
        os.path.join(outputBaseDir,
                     'output-stats-' + EXPERIMENT_NAME + '.txt'), 'w')
def pred_eval(predictor, test_data, imdb, cfg, vis=False, thresh=1e-3, logger=None, ignore_cache=True):
    """
    wrapper for calculating offline validation for faster data analysis
    in this example, all threshold are set by hand
    :param predictor: Predictor
    :param test_data: data iterator, must be non-shuffle
    :param imdb: image database
    :param vis: controls visualization
    :param thresh: valid detection threshold
    :return:
    """

    det_file = os.path.join(imdb.result_path, imdb.name + '_detections.pkl')
    if os.path.exists(det_file) and not ignore_cache:
        with open(det_file, 'rb') as fid:
            all_boxes = cPickle.load(fid)
        info_str = imdb.evaluate_detections(all_boxes)
        if logger:
            logger.info('evaluate detections: \n{}'.format(info_str))
        return

    assert vis or not test_data.shuffle
    data_names = [k[0] for k in test_data.provide_data[0]]

    if not isinstance(test_data, PrefetchingIter):
        test_data = PrefetchingIter(test_data)

    # limit detections to max_per_image over all classes
    max_per_image = cfg.TEST.max_per_image
    num_images = imdb.num_images

    for test_scale_index, test_scale in enumerate(cfg.TEST_SCALES):
        det_file_single_scale = os.path.join(imdb.result_path, imdb.name + '_detections_' + str(test_scale_index) + '.pkl')
        # if os.path.exists(det_file_single_scale):
        #    continue
        cfg.SCALES = [test_scale]
        test_data.reset()

        # all detections are collected into:
        #    all_boxes[cls][image] = N x 5 array of detections in
        #    (x1, y1, x2, y2, score)
        all_boxes_single_scale = [[[] for _ in range(num_images)]
                                  for _ in range(imdb.num_classes)]

        detect_at_single_scale(predictor, data_names, imdb, test_data, cfg, thresh, vis, all_boxes_single_scale, logger)

        with open(det_file_single_scale, 'wb') as f:
            cPickle.dump(all_boxes_single_scale, f, protocol=cPickle.HIGHEST_PROTOCOL)

    # all detections are collected into:
    #    all_boxes[cls][image] = N x 5 array of detections in
    #    (x1, y1, x2, y2, score)
    all_boxes = [[[] for _ in range(num_images)] for _ in range(imdb.num_classes)]

    for test_scale_index, test_scale in enumerate(cfg.TEST_SCALES):
        det_file_single_scale = os.path.join(imdb.result_path, imdb.name + '_detections_' + str(test_scale_index) + '.pkl')
        if os.path.exists(det_file_single_scale):
            with open(det_file_single_scale, 'rb') as fid:
                all_boxes_single_scale = cPickle.load(fid)
            for idx_class in range(1, imdb.num_classes):
                for idx_im in range(0, num_images):
                    if len(all_boxes[idx_class][idx_im]) == 0:
                        all_boxes[idx_class][idx_im] = all_boxes_single_scale[idx_class][idx_im]
                    else:
                        all_boxes[idx_class][idx_im] = np.vstack((all_boxes[idx_class][idx_im], all_boxes_single_scale[idx_class][idx_im]))

    for idx_class in range(1, imdb.num_classes):
        for idx_im in range(0, num_images):
            if cfg.TEST.USE_SOFTNMS:
                soft_nms = py_softnms_wrapper(cfg.TEST.SOFTNMS_THRESH, max_dets=max_per_image)
                all_boxes[idx_class][idx_im] = soft_nms(all_boxes[idx_class][idx_im])
            else:
                nms = py_nms_wrapper(cfg.TEST.NMS)
                keep = nms(all_boxes[idx_class][idx_im])
                all_boxes[idx_class][idx_im] = all_boxes[idx_class][idx_im][keep, :]

    if max_per_image > 0:
        for idx_im in range(0, num_images):
            image_scores = np.hstack([all_boxes[j][idx_im][:, -1]
                                      for j in range(1, imdb.num_classes)])
            if len(image_scores) > max_per_image:
                image_thresh = np.sort(image_scores)[-max_per_image]
                for j in range(1, imdb.num_classes):
                    keep = np.where(all_boxes[j][idx_im][:, -1] >= image_thresh)[0]
                    all_boxes[j][idx_im] = all_boxes[j][idx_im][keep, :]

    with open(det_file, 'wb') as f:
        cPickle.dump(all_boxes, f, protocol=cPickle.HIGHEST_PROTOCOL)

    info_str = imdb.evaluate_detections(all_boxes)
    if logger:
        logger.info('evaluate detections: \n{}'.format(info_str))