Example #1
0
File: train.py Project: lewfish/mlx
def loss_batch(model:nn.Module, xb:Tensor, yb:Tensor, loss_func:OptLossFunc=None,
               opt:OptOptimizer=None,
               cb_handler:Optional[CallbackHandler]=None)->Tuple[Union[Tensor,int,float,str]]:
    "Calculate loss and metrics for a batch, call out to callbacks as necessary."
    cb_handler = ifnone(cb_handler, CallbackHandler())
    device = xb.device
    # Translate from fastai box format to torchvision.
    batch_sz = len(xb)
    images = xb
    targets = []
    for i in range(batch_sz):
        boxes = yb[0][i]
        labels = yb[1][i]
        boxes = to_box_pixel(boxes, *images[0].shape[1:3])
        targets.append(BoxList(boxes, labels=labels))

    out = None
    loss = torch.Tensor([0.0]).to(device=device)
    if model.training:
        loss_dict = model(images, targets)
        loss = loss_dict['total_loss']
        cb_handler.state_dict['loss_dict'] = loss_dict
    else:
        out = model(images)

    out = cb_handler.on_loss_begin(out)

    if opt is not None:
        loss,skip_bwd = cb_handler.on_backward_begin(loss)
        if not skip_bwd:                     loss.backward()
        if not cb_handler.on_backward_end(): opt.step()
        if not cb_handler.on_step_end():     opt.zero_grad()

    return loss.detach().cpu()
Example #2
0
    def test_encode_decode(self):
        height, width = 256, 256
        stride = 4
        num_labels = 2
        boxes = torch.tensor([
            [0., 0., 128., 128.],
            [128., 128., 192., 192.],
            [64., 64., 128., 128.],
        ])
        # shift by two since only numbers that fall in the middle of the stride
        # can be represented exactly
        boxes += 2
        labels = torch.tensor([0, 0, 1])
        bl = BoxList(boxes, labels=labels)
        boxlists = [bl]
        positions = get_positions(height, width, stride, boxes.device)

        keypoint, reg = encode(boxlists, positions, stride, num_labels)
        path = '/opt/data/pascal2007/encoded.png'
        plot_encoded(boxlists[0], stride, keypoint[0], reg[0], path)

        decoded_boxlists = decode(
            keypoint, reg, positions, stride, prob_thresh=0.05)
        path = '/opt/data/pascal2007/decoded.png'
        plot_encoded(decoded_boxlists[0], stride, keypoint[0], reg[0], path)

        del decoded_boxlists[0].extras['scores']
        self.assertTrue(boxlists[0].equal(decoded_boxlists[0]))
Example #3
0
    def forward(self, input, targets=None):
        """Forward pass

        Args:
            input: tensor<n, 3, h, w> with batch of images
            targets: None or list<BoxList> of length n with boxes and labels

        Returns:
            if targets is None, returns list<BoxList> of length n, containing
            boxes, labels, and scores for boxes with score > 0.05. Further
            filtering based on score should be done before considering the
            prediction "final".

            if targets is a list, returns the losses as dict of form {
            }
        """
        if targets:
            _targets = [bl.xyxy() for bl in targets]
            _targets = [{
                'boxes': bl.boxes,
                'labels': bl.get_field('labels')
            } for bl in _targets]
            loss_dict = self.model(input, _targets)
            loss_dict['total_loss'] = sum(list(loss_dict.values()))
            return loss_dict

        out = self.model(input)
        return [
            BoxList(_out['boxes'],
                    labels=_out['labels'],
                    scores=_out['scores']).yxyx() for _out in out
        ]
Example #4
0
File: data.py Project: lewfish/mlx
    def __getitem__(self, ind):
        img_fn = self.imgs[ind]
        img_id = self.img2id[img_fn]
        img = np.array(Image.open(join(self.img_dir, img_fn)))
        boxes, labels = self.id2boxes[img_id], self.id2labels[img_id]
        if self.transforms:
            out = self.transforms(image=img, bboxes=boxes, labels=labels)
            img = out['image']
            boxes = torch.tensor(out['bboxes'])
            labels = torch.tensor(out['labels'])

        if len(boxes) > 0:
            x, y, w, h = boxes[:, 0:1], boxes[:, 1:2], boxes[:,
                                                             2:3], boxes[:,
                                                                         3:4]
            boxes = torch.cat([y, x, y + h, x + w], dim=1)
            boxlist = BoxList(boxes, labels=labels)
        else:
            boxlist = BoxList(torch.empty((0, 4)), labels=torch.empty((0, )))
        return (img, boxlist)
Example #5
0
    def test_fcos_with_targets(self):
        h, w = 64, 64
        num_labels = 3
        x = torch.empty((1, 3, h, w))
        model = FCOS('resnet18', num_labels, pretrained=False)

        boxes = torch.tensor([[0, 0, 16, 16], [8, 8, 12, 12]])
        labels = torch.tensor([0, 1])
        targets = [BoxList(boxes, labels=labels)]

        loss_dict = model(x, targets)
        self.assertTrue('label_loss' in loss_dict)
        self.assertTrue('reg_loss' in loss_dict)
        self.assertTrue('center_loss' in loss_dict)
Example #6
0
    def test_backwards(self):
        h, w = 64, 64
        num_labels = 3
        x = 2.0 * torch.rand((1, 3, h, w)) - 1.0
        model = FCOS('resnet18', num_labels, pretrained=False)

        boxes = torch.tensor([[0, 0, 16, 16], [16, 16, 32, 32]])
        labels = torch.tensor([0, 1])
        targets = [BoxList(boxes, labels=labels)]

        model.train()
        model.zero_grad()
        loss_dict = model(x, targets)
        loss = sum(list(loss_dict.values()))
        loss.backward()

        for param in model.parameters():
            self.assertTrue(len(torch.nonzero(param.grad)) > 0)
Example #7
0
def decode_batch_output(output, pyramid_shape, img_height, img_width,
                        iou_thresh=0.5):
    """Decode output for batch of images.

    Args:
        output: list of tuples where each tuple corresponds to a pyramid level
            tuple is of form (reg_arr, label_arr, center_arr) where
                - reg_arr is tensor<n, 4, h, w>,
                - label_arr is tensor<n, num_labels, h, w>
                - center_arr is tensor<n, 1, h, w>
            and label_arr and center_arr are logits
        pyramid_shape:
        img_height:
        img_width:
        iou_thresh: (float) iou threshold passed to NMS

    Returns:
        list of n BoxLists
    """
    boxlists = []
    batch_sz = output[0][0].shape[0]
    for i in range(batch_sz):
        single_head_out = []
        for level, (reg_arr, label_arr, center_arr) in enumerate(output):
            # Convert logits in label_arr and center_arr to probabilities.
            single_head_out.append((
                reg_arr[i],
                torch.sigmoid(label_arr[i]),
                torch.sigmoid(center_arr[i])))
        boxlist = decode_single_output(single_head_out, pyramid_shape)
        boxlist = BoxList(
            boxlist.boxes, labels=boxlist.get_field('labels'),
            scores=boxlist.get_field('scores') * boxlist.get_field('centerness'),
            centerness=boxlist.get_field('centerness'))
        boxlist = boxlist.clamp(img_height, img_width)
        boxlist = boxlist.nms(iou_thresh=iou_thresh)
        boxlists.append(boxlist)
    return boxlists
Example #8
0
def decode_single_output(output, pyramid_shape, score_thresh=0.05):
    """Decode output of heads for all levels of pyramid for one image.

    Args:
        output: list of tuples where each tuple corresponds to a pyramid level
            tuple is of form (reg_arr, label_arr, center_arr) where
                - reg_arr is tensor<4, h, w>,
                - label_arr is tensor<num_labels, h, w>
                - center_arr is tensor<1, h, w>
            and label_arr and center_arr are probabilities
        score_thresh: (float) probability score threshold used to determine
            if a box is present at a cell

    Returns:
        BoxList
    """
    boxlists = []
    for level, level_out in enumerate(output):
        stride = pyramid_shape[level][0]
        boxlist = decode_level_output(
            *level_out, stride, score_thresh=score_thresh)
        boxlists.append(boxlist)
    return BoxList.cat(boxlists)
Example #9
0
def decode_level_output(reg_arr, label_arr, center_arr, stride, score_thresh=0.05):
    """Decode output of head for one level of the pyramid for one image.

    Args:
        reg_arr: (tensor) with shape (4, h, w). The first dimension ranges over
            t, l, b, r (ie. top, left, bottom, right).
        label_arr: (tensor) with shape (num_labels, h, w) containing
            probabilities
        center_arr: (tensor) with shape (1, h, w) containing values between
            0 and 1
        stride: (int) the stride of the level of the pyramid
        score_thresh: (float) probability score threshold used to determine
            if a box is present at a cell

    Returns:
        BoxList
    """
    device = reg_arr.device
    h, w = reg_arr.shape[1:]
    pos_arr = torch.empty((2, h, w), device=device)
    pos_arr[0, :, :] = torch.arange(
        stride//2, stride * h, stride, device=device)[:, None]
    pos_arr[1, :, :] = torch.arange(
        stride//2, stride * w, stride, device=device)[None, :]

    boxes = torch.empty((4, h, w), device=device)
    boxes[0:2, :, :] = pos_arr - reg_arr[0:2, :, :]
    boxes[2:, :, :] = pos_arr + reg_arr[2:, :, :]

    scores, labels = torch.max(label_arr, dim=0)

    boxes = boxes.reshape(4, -1).transpose(1, 0)
    labels = labels.reshape(-1)
    scores = scores.reshape(-1)
    centerness = center_arr.reshape(-1)
    return BoxList(boxes, labels=labels, scores=scores, centerness=centerness).score_filter(score_thresh)
Example #10
0
def decode(keypoint, reg, positions, stride, cfg, prob_thresh=0.05):
    N = keypoint.shape[0]
    boxlists = []
    flat_positions = positions.permute((1, 2, 0)).reshape((-1, 2))
    img_height = positions.shape[1] * stride
    img_width = positions.shape[2] * stride

    for n in range(N):
        per_keypoint = keypoint[n]
        per_reg = reg[n]
        num_labels = per_keypoint.shape[0]

        is_over_thresh = per_keypoint > prob_thresh
        is_local_max = torch.ones_like(is_over_thresh)
        if cfg.model.centernet.max_pool_nms:
            is_local_max = per_keypoint == F.max_pool2d(
                per_keypoint, kernel_size=3, stride=1, padding=1)
        is_pos = is_local_max * is_over_thresh
        num_pos = is_pos.sum()

        if num_pos == 0:
            bl = BoxList(
                torch.empty((0, 4)), labels=torch.empty((0,)),
                scores=torch.empty((0,)))
        else:
            flat_is_pos, _ = is_pos.permute((1, 2, 0)).reshape((-1, num_labels)).max(1)
            flat_per_reg = per_reg.permute((1, 2, 0)).reshape((-1, 2))
            flat_per_reg = flat_per_reg[flat_is_pos, :]
            sizes = flat_per_reg
            centers = flat_positions[flat_is_pos]
            boxes = torch.cat([centers - sizes / 2, centers + sizes / 2], dim=1)

            flat_per_keypoint = per_keypoint.permute((1, 2, 0)).reshape((-1, num_labels))
            flat_per_keypoint = flat_per_keypoint[flat_is_pos, :]
            scores, labels = flat_per_keypoint.max(1)

            bl = BoxList(boxes, labels=labels, scores=scores)
            bl.clamp(img_height, img_width)
        boxlists.append(bl)
    return boxlists