Exemplo n.º 1
0
def count_dots_lbl(lbl):
    n = 0
    label010 = label_tools.int_to_label010(lbl)
    for c01 in label010:
        if c01 == '1':
            n += 1
        else:
            assert c01 == '0'
    return n
Exemplo n.º 2
0
    def detection_collate(batch):
        '''
        :param batch: list of (tb image(CHW float), [(left, top, right, bottom, class),...]) сcoords in [0,1], extra_params
        :return: batch: ( images (BCNHW), ( encoded_rects, encoded_labels ) )
        copied from RetinaNet, but a) accepts rects as input, b) returns (x,y) where y = (encoded_rects, encoded_labels)
        '''

        # t = [b for b in batch if b[1].shape[0]==0]
        # if len(t):
        #     pass

        #device = torch.device('cpu')  # commented to use settings.device

        boxes = [torch.tensor(b[1][:, :4], dtype = torch.float32, device=device)
                 *torch.tensor(params.data.net_hw[::-1]*2, dtype = torch.float32, device=device) for b in batch]
        labels = [torch.tensor(b[1][:, 4], dtype = torch.long, device=device) for b in batch]
        if params.data.get_points:
            labels = [torch.tensor([0]*len(lb), dtype = torch.long, device=device) for lb in labels]
        elif use_multiple_class_groups:
            # классы нумеруются с 0, отсутствие класса = -1, далее в encode cls_targets=1+labels
            labels = [torch.tensor([[int(ch)-1 for ch in label_tools.int_to_label010(int_lbl.item())] for int_lbl in lb],
                                   dtype=torch.long, device=device) for lb in labels]

        original_images = [b[3] for b in batch if len(b)>3] # batch contains augmented image if not in train mode

        imgs = [x[0] for x in batch]
        calc_cls_mask = torch.tensor([b[2].get('calc_cls', True) for b in batch],
                                dtype=torch.bool,
                                device=device)

        h, w = tuple(params.data.net_hw)
        num_imgs = len(batch)
        inputs = torch.zeros(num_imgs, 3, h, w).to(imgs[0])

        loc_targets = []
        cls_targets = []
        for i in range(num_imgs):
            inputs[i] = imgs[i]
            labels_i = labels[i]
            if use_multiple_class_groups and len(labels_i.shape) != 2:  # it can happen if no labels are on image
                labels_i = labels_i.reshape((0, len(num_classes)))
            loc_target, cls_target, max_ious = encoder.encode(boxes[i], labels_i, input_size=(w,h))
            loc_targets.append(loc_target)
            cls_targets.append(cls_target)
        if original_images: # inference mode
            return inputs, ( torch.stack(loc_targets), torch.stack(cls_targets), calc_cls_mask), original_images
        else:
            return inputs, (torch.stack(loc_targets), torch.stack(cls_targets), calc_cls_mask)
Exemplo n.º 3
0
def dot_metrics_rects(boxes, labels, gt_rects, image_wh, img,
                      do_filter_lonely_rects):
    if do_filter_lonely_rects:
        boxes, labels = filter_lonely_rects(boxes, labels, img)
    gt_labels = [r[4] for r in gt_rects]
    gt_rec_labels = [-1] * len(
        gt_rects)  # recognized label for gt, -1 - missed
    rec_is_false = [1] * len(labels)  # recognized is false

    if len(gt_rects) and len(labels):
        boxes = torch.tensor(boxes)
        gt_boxes = torch.tensor(
            [r[:4] for r in gt_rects], dtype=torch.float32) * torch.tensor(
                [image_wh[0], image_wh[1], image_wh[0], image_wh[1]])

        # Для отладки
        # labels = torch.tensor(labels)
        # gt_labels = torch.tensor(gt_labels)
        #
        # _, rec_order = torch.sort(boxes[:, 1], dim=0)
        # boxes = boxes[rec_order][:15]
        # labels = labels[rec_order][:15]
        # _, gt_order = torch.sort(gt_boxes[:, 1], dim=0)
        # gt_boxes = gt_boxes[gt_order][:15]
        # gt_labels = gt_labels[gt_order][:15]
        #
        # _, rec_order = torch.sort(labels, dim=0)
        # boxes = boxes[rec_order]
        # labels = labels[rec_order]
        # _, gt_order = torch.sort(-gt_labels, dim=0)
        # gt_boxes = gt_boxes[gt_order]
        # gt_labels = gt_labels[gt_order]
        #
        # labels = torch.tensor(labels)
        # gt_labels = torch.tensor(gt_labels)

        areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
        gt_areas = (gt_boxes[:, 2] - gt_boxes[:, 0]) * (gt_boxes[:, 3] -
                                                        gt_boxes[:, 1])
        x1 = torch.max(gt_boxes[:, 0].unsqueeze(1), boxes[:, 0].unsqueeze(0))
        y1 = torch.max(gt_boxes[:, 1].unsqueeze(1), boxes[:, 1].unsqueeze(0))
        x2 = torch.min(gt_boxes[:, 2].unsqueeze(1), boxes[:, 2].unsqueeze(0))
        y2 = torch.min(gt_boxes[:, 3].unsqueeze(1), boxes[:, 3].unsqueeze(0))
        intersect_area = (x2 - x1).clamp(min=0) * (y2 - y1).clamp(min=0)
        iou = intersect_area / (gt_areas.unsqueeze(1) + areas.unsqueeze(0) -
                                intersect_area)
        for gt_i in range(len(gt_labels)):
            rec_i = iou[gt_i, :].argmax()
            if iou[gt_i, rec_i] > 0:
                gt_i2 = iou[:, rec_i].argmax()
                if gt_i2 == gt_i:
                    gt_rec_labels[gt_i] = labels[rec_i]
                    rec_is_false[rec_i] = 0

    tp = 0
    fp = 0
    fn = 0
    for gt_label, rec_label in zip(gt_labels, gt_rec_labels):
        if rec_label == -1:
            fn += count_dots_lbl(gt_label)
        else:
            res010 = label_tools.int_to_label010(rec_label)
            gt010 = label_tools.int_to_label010(gt_label)
            for p in range(6):
                if res010[p] == '1' and gt010[p] == '0':
                    fp += 1
                elif res010[p] == '0' and gt010[p] == '1':
                    fn += 1
                elif res010[p] == '1' and gt010[p] == '1':
                    tp += 1
    for label, is_false in zip(labels, rec_is_false):
        if is_false:
            fp += count_dots_lbl(label)
    return tp, fp, fn
Exemplo n.º 4
0
def pseudo_char_to_label010(ch):
    lbl = ord(ch) - ord('0')
    label_tools.validate_int(lbl)
    label010 = label_tools.int_to_label010(lbl)
    return label010