Пример #1
0
    def test_decode(self):
        pyramid_shape = [(32, 32, 2, 2), (16, 16, 4, 4), (8, 8, 8, 8)]
        num_labels = 2
        img_height, img_width = (64, 64)

        exp_boxes = [
            torch.Tensor([[0, 0, 16, 16], [16, 16, 32, 32], [0, 0, 32, 32]]),
            torch.Tensor([
                [0, 0, 16, 16],
            ])
        ]

        exp_labels = [
            torch.Tensor([0, 0, 1]),
            torch.Tensor([1]),
        ]

        targets = []
        targets.append(
            encode_single_targets(exp_boxes[0], exp_labels[0], pyramid_shape,
                                  num_labels))
        targets.append(
            encode_single_targets(exp_boxes[1], exp_labels[1], pyramid_shape,
                                  num_labels))

        # Merge single targets into a batch.
        pyramid_sz = len(pyramid_shape)
        batch_sz = len(exp_boxes)
        batch_targets = []
        for level_ind in range(pyramid_sz):
            reg_arrs, label_arrs, center_arrs = [], [], []
            for batch_ind in range(batch_sz):
                reg_arr, label_arr, center_arr = targets[batch_ind][level_ind]

                def prob2logit(prob):
                    return torch.log(prob / (1 - prob))

                label_arr = prob2logit(label_arr)
                center_arr = prob2logit(center_arr)
                reg_arrs.append(reg_arr.unsqueeze(0))
                label_arrs.append(label_arr.unsqueeze(0))
                center_arrs.append(center_arr.unsqueeze(0))
            batch_targets.append((torch.cat(reg_arrs), torch.cat(label_arrs),
                                  torch.cat(center_arrs)))

        boxlists = decode_batch_output(batch_targets, pyramid_shape,
                                       img_height, img_width)
        for ind, boxlist in enumerate(boxlists):
            self.assertEqual(
                make_tuple_set(boxlist.boxes.int(),
                               boxlist.get_field('labels').int()),
                make_tuple_set(exp_boxes[ind], exp_labels[ind]))
            self.assertTrue(torch.all(boxlist.get_field('scores') == 1))
Пример #2
0
    def test_encode(self):
        pyramid_shape = [(32, 32, 2, 2), (16, 16, 4, 4), (8, 8, 8, 8)]
        num_labels = 2
        boxes = torch.Tensor([[0, 0, 16, 16], [16, 16, 32, 32], [0, 0, 32,
                                                                 32]])
        labels = torch.Tensor([0, 0, 1])
        targets = encode_single_targets(boxes, labels, pyramid_shape,
                                        num_labels)

        # stride 8
        reg_arr, label_arr, center_arr = targets[2]
        self.assertTrue(torch.all(reg_arr == 0))
        self.assertTrue(torch.all(label_arr == 0))
        self.assertTrue(reg_arr.shape == (4, 8, 8))
        self.assertTrue(label_arr.shape == (2, 8, 8))

        # stride 16
        reg_arr, label_arr, center_arr = targets[1]
        exp_reg_arr = torch.zeros((4, 4, 4))
        exp_reg_arr[:, 0, 0] = torch.Tensor([8, 8, 8, 8])
        exp_reg_arr[:, 1, 1] = torch.Tensor([8, 8, 8, 8])
        self.assertTrue(reg_arr.equal(exp_reg_arr))
        exp_label_arr = torch.zeros((2, 4, 4))
        exp_label_arr[0, 0, 0] = 1
        exp_label_arr[0, 1, 1] = 1
        self.assertTrue(label_arr.equal(exp_label_arr))

        # stride 32
        reg_arr, label_arr, center_arr = targets[0]
        exp_reg_arr = torch.zeros((4, 2, 2))
        exp_reg_arr[:, 0, 0] = torch.Tensor([16, 16, 16, 16])
        self.assertTrue(reg_arr.equal(exp_reg_arr))
        exp_label_arr = torch.zeros((2, 2, 2))
        exp_label_arr[1, 0, 0] = 1
        self.assertTrue(label_arr.equal(exp_label_arr))
Пример #3
0
    def encode_decode_output(self, pyramid_shape, exp_boxes, exp_labels):
        score_thresh = 0.2
        num_labels = 2

        targets = encode_single_targets(exp_boxes, exp_labels, pyramid_shape,
                                        num_labels)
        boxlist = decode_single_output(targets,
                                       pyramid_shape,
                                       score_thresh=score_thresh)

        self.assertEqual(
            make_tuple_set(boxlist.boxes.int(),
                           boxlist.get_field('labels').int()),
            make_tuple_set(exp_boxes, exp_labels))
        self.assertTrue(torch.all(boxlist.get_field('scores') == 1))
Пример #4
0
def fcos_batch_loss(out, targets, pyramid_shape, num_labels):
    """Compute loss for a single image.

    Note: the label_arr and center_arr for output is assumed to contain
    logits, and is assumed to contain probabilities for targets.

    Args:
        out: the output of the heads for the whole pyramid
        targets: list<BoxList> of length n

        the format of out is a list of tuples where each tuple corresponds to a
        pyramid level. tuple is of form (reg_arr, label_arr, center_arr) where
            - reg_arr is tensor<n, 4, h, w>,
            - label_arr is tensor<n, num_labels, h, w>
            - center_arr is tensor<n, 1, h, w>

        and label_arr and center_arr values are logits.

    Returns:
        dict of form {
            'reg_loss': tensor<1>,
            'label_loss': tensor<1>,
            'center_loss': tensor<1>
        }
    """
    iou_loss = IOULoss()
    batch_sz = len(targets)
    reg_arrs, label_arrs, center_arrs = [], [], []
    for single_targets in targets:
        single_targets = encode_single_targets(
            single_targets.boxes, single_targets.get_field('labels'),
            pyramid_shape, num_labels)
        for reg_arr, label_arr, center_arr in single_targets:
            # (4, H, W) -> (H, W, 4) -> (H*W, 4)
            reg_arrs.append(reg_arr.permute((1, 2, 0)).reshape((-1, 4)))
            # (C, H, W) -> (H, W, C) -> (H*W, C)
            label_arrs.append(
                label_arr.permute((1, 2, 0)).reshape((-1, num_labels)))
            # (1, H, W) -> (H, W, 1) -> (H*W,)
            center_arrs.append(center_arr.permute((1, 2, 0)).reshape((-1, )))

    targets_reg_arr, targets_label_arr, targets_center_arr = (
        torch.cat(reg_arrs), torch.cat(label_arrs), torch.cat(center_arrs))
    out_reg_arr, out_label_arr, out_center_arr = flatten_output(out)

    pos_indicator = targets_label_arr.sum(1) > 0.0
    out_reg_arr = out_reg_arr[pos_indicator, :]
    targets_reg_arr = targets_reg_arr[pos_indicator, :]
    out_center_arr = out_center_arr[pos_indicator]
    targets_center_arr = targets_center_arr[pos_indicator]

    npos = targets_reg_arr.shape[0] + 1
    label_loss = focal_loss(out_label_arr, targets_label_arr) / npos
    reg_loss = torch.tensor(0.0, device=label_loss.device)
    center_loss = torch.tensor(0.0, device=label_loss.device)
    if npos > 1:
        reg_loss = iou_loss(out_reg_arr, targets_reg_arr, targets_center_arr)
        center_loss = nn.functional.binary_cross_entropy_with_logits(
            out_center_arr, targets_center_arr, reduction='mean')

    total_loss = label_loss + reg_loss + center_loss
    loss_dict = {
        'total_loss': total_loss,
        'label_loss': label_loss,
        'reg_loss': reg_loss,
        'center_loss': center_loss
    }
    return loss_dict
Пример #5
0
    def make_debug_plots(self,
                         dataset,
                         model,
                         classes,
                         output_dir,
                         max_plots=25,
                         score_thresh=0.25):
        preds_dir = join(output_dir, 'preds')
        zip_path = join(output_dir, 'preds.zip')
        make_dir(preds_dir, force_empty=True)

        model.eval()
        for img_id, (x, y) in enumerate(dataset):
            if img_id == max_plots:
                break

            # Get predictions
            boxlist, head_out = self.get_pred(x, model, score_thresh)

            # Plot image, ground truth, and predictions
            fig = self.plot_image_preds(x, y, boxlist, classes)
            plt.savefig(join(preds_dir, '{}.png'.format(img_id)),
                        dpi=200,
                        bbox_inches='tight')
            plt.close(fig)

            # Plot raw output of network at each level.
            for level, level_out in enumerate(head_out):
                stride = model.fpn.strides[level]
                reg_arr, label_arr, center_arr = level_out

                # Plot label_arr
                label_arr = label_arr[0].detach().cpu()
                label_probs = torch.sigmoid(label_arr)
                fig = self.plot_label_arr(label_probs, classes, stride)
                plt.savefig(join(preds_dir,
                                 '{}-{}-label-arr.png'.format(img_id, stride)),
                            dpi=100,
                            bbox_inches='tight')
                plt.close(fig)

                # Plot top, left, bottom, right from reg_arr and center_arr.
                reg_arr = reg_arr[0].detach().cpu()
                center_arr = center_arr[0][0].detach().cpu()
                center_probs = torch.sigmoid(center_arr)
                fig = plot_reg_center_arr(reg_arr, center_probs, stride)
                plt.savefig(join(
                    preds_dir,
                    '{}-{}-reg-center-arr.png'.format(img_id, stride)),
                            dpi=100,
                            bbox_inches='tight')
                plt.close(fig)

            # Get encoding of ground truth targets.
            h, w = x.shape[1:]
            targets = encode_single_targets(y.boxes, y.get_field('labels'),
                                            model.pyramid_shape,
                                            model.num_labels)

            # Plot encoding of ground truth at each level.
            for level, level_targets in enumerate(targets):
                stride = model.fpn.strides[level]
                reg_arr, label_arr, center_arr = level_targets

                # Plot label_arr
                label_probs = label_arr.detach().cpu()
                fig = self.plot_label_arr(label_probs, classes, stride)
                plt.savefig(join(
                    preds_dir, '{}-{}-label-arr-gt.png'.format(img_id,
                                                               stride)),
                            dpi=100,
                            bbox_inches='tight')
                plt.close(fig)

                # Plot top, left, bottom, right from reg_arr and center_arr.
                reg_arr = reg_arr.detach().cpu()
                center_arr = center_arr[0].detach().cpu()
                center_probs = center_arr
                fig = plot_reg_center_arr(reg_arr, center_probs, stride)
                plt.savefig(join(
                    preds_dir,
                    '{}-{}-reg-center-arr-gt.png'.format(img_id, stride)),
                            dpi=100,
                            bbox_inches='tight')
                plt.close(fig)

        zipdir(preds_dir, zip_path)
        shutil.rmtree(preds_dir)