Exemplo n.º 1
0
def boxes_to_masks(boxes, h, w, padding=0.0):
    n = boxes.shape[0]
    boxes = boxes
    x1 = boxes[:, 0]
    y1 = boxes[:, 1]
    x2 = boxes[:, 2]
    y2 = boxes[:, 3]
    b_w = x2 - x1
    b_h = y2 - y1
    x1 = jt.clamp(x1 - 1 - b_w * padding, min_v=0)
    x2 = jt.clamp(x2 + 1 + b_w * padding, max_v=w)
    y1 = jt.clamp(y1 - 1 - b_h * padding, min_v=0)
    y2 = jt.clamp(y2 + 1 + b_h * padding, max_v=h)

    rows = jt.arange(w, dtype=x1.dtype).view(1, 1, -1).expand((n, h, w))
    cols = jt.arange(h, dtype=x1.dtype).view(1, -1, 1).expand((n, h, w))

    masks_left = rows >= x1.view(-1, 1, 1)
    masks_right = rows < x2.view(-1, 1, 1)
    masks_up = cols >= y1.view(-1, 1, 1)
    masks_down = cols < y2.view(-1, 1, 1)

    masks = masks_left * masks_right * masks_up * masks_down

    return masks
 def clip_to_image(self, remove_empty=True):
     if not isinstance(self.bbox, jt.Var):
         self.to_jittor()
     #print(self.bbox)
     if self.bbox.numel() == 0:
         return self
     TO_REMOVE = 1
     self.bbox[:, 0] = jt.clamp(self.bbox[:, 0],
                                min_v=0,
                                max_v=self.size[0] - TO_REMOVE)
     self.bbox[:, 1] = jt.clamp(self.bbox[:, 1],
                                min_v=0,
                                max_v=self.size[1] - TO_REMOVE)
     self.bbox[:, 2] = jt.clamp(self.bbox[:, 2],
                                min_v=0,
                                max_v=self.size[0] - TO_REMOVE)
     self.bbox[:, 3] = jt.clamp(self.bbox[:, 3],
                                min_v=0,
                                max_v=self.size[1] - TO_REMOVE)
     if remove_empty:
         box = self.bbox
         keep = jt.logical_and((box[:, 3] > box[:, 1]),
                               (box[:, 2] > box[:, 0]))
         #print(keep)
         return self[keep]
     return self
 def _split_into_xyxy(self):
     if self.mode == "xyxy":
         if self.bbox.shape[0] == 1:
             xmin, ymin, xmax, ymax = self.bbox[:, :
                                                1], self.bbox[:, 1:
                                                              2], self.bbox[:,
                                                                            2:
                                                                            3], self.bbox[:,
                                                                                          3:]
         else:
             xmin, ymin, xmax, ymax = self.bbox.split(1, dim=-1)
         return xmin, ymin, xmax, ymax
     elif self.mode == "xywh":
         TO_REMOVE = 1
         if self.bbox.shape[0] == 1:
             xmin, ymin, w, h = self.bbox[:, :
                                          1], self.bbox[:, 1:
                                                        2], self.bbox[:, 2:
                                                                      3], self.bbox[:,
                                                                                    3:]
         else:
             xmin, ymin, w, h = self.bbox.split(1, dim=-1)
         return (
             xmin,
             ymin,
             xmin + jt.clamp(w - TO_REMOVE, min_v=0, max_v=9999999),
             ymin + jt.clamp(h - TO_REMOVE, min_v=0, max_v=9999999),
         )
     else:
         raise RuntimeError("Should not be here")
Exemplo n.º 4
0
 def clip_to_image(self, remove_empty=True):
     if self.jittor and not isinstance(self.bbox,jt.Var):
         self.to_jittor()
     if self.jittor:
         if self.bbox.numel()==0:
             return self
         TO_REMOVE = 1
         self.bbox[:, 0] = jt.clamp(self.bbox[:, 0] ,min_v=0, max_v=self.size[0] - TO_REMOVE)
         self.bbox[:, 1]= jt.clamp(self.bbox[:, 1],min_v=0, max_v=self.size[1] - TO_REMOVE)
         self.bbox[:, 2]= jt.clamp(self.bbox[:, 2],min_v=0, max_v=self.size[0] - TO_REMOVE)
         self.bbox[:, 3]= jt.clamp(self.bbox[:, 3],min_v=0, max_v=self.size[1] - TO_REMOVE)
         if remove_empty:
             box = self.bbox
             keep = jt.logical_and((box[:, 3] > box[:, 1]),(box[:, 2] > box[:, 0]))
             return self[keep]
     else:
         if self.bbox.size==0:
             return self
         TO_REMOVE = 1
         self.bbox[:, 0] = np.clip(self.bbox[:, 0] ,0, self.size[0] - TO_REMOVE)
         self.bbox[:, 1]= np.clip(self.bbox[:, 1],0, self.size[1] - TO_REMOVE)
         self.bbox[:, 2]= np.clip(self.bbox[:, 2],0, self.size[0] - TO_REMOVE)
         self.bbox[:, 3]= np.clip(self.bbox[:, 3],0, self.size[1] - TO_REMOVE)
         if remove_empty:
             box = self.bbox
             keep = np.where((box[:, 3] > box[:, 1])&(box[:, 2] > box[:, 0]))[0]
             return self[keep]
     return self
Exemplo n.º 5
0
    def ohem_conf_loss(self, conf_data, conf_t, pos, num):
        # Compute max conf across batch for hard negative mining
        batch_conf = conf_data.view(-1, self.num_classes)
        if cfg.ohem_use_most_confident:
            # i.e. max(softmax) along classes > 0
            batch_conf = nn.softmax(batch_conf, dim=1)
            loss_c = batch_conf[:, 1:].max(dim=1)
        else:
            # i.e. -softmax(class 0 confidence)
            loss_c = log_sum_exp(batch_conf) - batch_conf[:, 0]

        # Hard Negative Mining
        loss_c = loss_c.view(num, -1)
        loss_c[pos] = 0  # filter out pos boxes
        loss_c[conf_t < 0] = 0  # filter out neutrals (conf_t = -1)
        loss_idx, _ = loss_c.argsort(1, descending=True)
        idx_rank, _ = loss_idx.argsort(1)
        num_pos = pos.int32().sum(1, keepdims=True)
        num_neg = jt.clamp(self.negpos_ratio * num_pos, max_v=pos.shape[1] - 1)
        neg = idx_rank < num_neg.expand_as(idx_rank)
        neg = neg.int()

        # Just in case there aren't enough negatives, don't start using positives as negatives
        neg[pos] = 0
        neg[conf_t < 0] = 0  # Filter out neutrals
        neg = neg.bool()
        # Confidence Loss Including Positive and Negative Examples
        pos_idx = pos.unsqueeze(2).expand_as(conf_data)
        neg_idx = neg.unsqueeze(2).expand_as(conf_data)
        conf_p = conf_data[(pos_idx.int() + neg_idx.int()) > 0].view(
            -1, self.num_classes)
        targets_weighted = conf_t[(pos.int() + neg.int()) > 0]
        loss_c = cross_entropy_loss(conf_p, targets_weighted, reduction='none')

        if cfg.use_class_balanced_conf:
            # Lazy initialization
            if self.class_instances is None:
                self.class_instances = jt.zeros(self.num_classes,
                                                device=targets_weighted.device)

            classes, counts = targets_weighted.unique(return_counts=True)

            for _cls, _cnt in zip(classes.numpy(), counts.numpy()):
                self.class_instances[_cls] += _cnt

            self.total_instances += targets_weighted.shape[0]

            weighting = 1 - (self.class_instances[targets_weighted] /
                             self.total_instances)
            weighting = jt.clamp(weighting, min_v=1 / self.num_classes)

            # If you do the math, the average weight of self.class_instances is this
            avg_weight = (self.num_classes - 1) / self.num_classes

            loss_c = (loss_c * weighting).sum() / avg_weight
        else:
            loss_c = loss_c.sum()

        return cfg.conf_alpha * loss_c
Exemplo n.º 6
0
    def execute(self, mesh, eyes=None):
        if self.Gbuffer == "albedo":
            return mesh
        if self.Gbuffer == "normal" or self.Gbuffer == "depth":
            mesh.textures = jt.ones_like(mesh.textures)
        if self.light_mode == 'surface':
            diffuseLight = jt.zeros(mesh.faces.shape)
            specularLight = jt.zeros(mesh.faces.shape)
            diffuseLight = self.ambient(diffuseLight)
            for directional in self.directionals:
                [diffuseLight, specularLight] = directional(
                    diffuseLight, specularLight, mesh.surface_normals,
                    (jt.sum(mesh.face_vertices, dim=2) / 3.0), eyes,
                    mesh.with_specular, mesh.metallic_textures,
                    mesh.roughness_textures)
            if len(mesh.textures.shape) == 4:
                mesh.textures = jt.clamp(
                    mesh.textures * diffuseLight.unsqueeze(2) +
                    jt.ones_like(mesh.textures) * specularLight.unsqueeze(2),
                    0.0, 1.0)
            elif len(mesh.textures.shape) == 6:
                mesh.textures = jt.clamp(
                    mesh.textures *
                    diffuseLight.unsqueeze(2).unsqueeze(2).unsqueeze(2) +
                    jt.ones_like(mesh.textures) *
                    specularLight.unsqueeze(2).unsqueeze(2).unsqueeze(2), 0.0,
                    1.0)

        elif self.light_mode == 'vertex':
            diffuseLight = jt.zeros(mesh.vertices.shape)
            specularLight = jt.zeros(mesh.vertices.shape)
            diffuseLight = self.ambient(diffuseLight)
            for directional in self.directionals:
                [diffuseLight, specularLight
                 ] = directional(diffuseLight, specularLight,
                                 mesh.vertex_normals, mesh.vertices, eyes,
                                 mesh.with_specular, mesh.metallic_textures,
                                 mesh.roughness_textures)
            if len(mesh.textures.shape) == 4:
                mesh.textures = jt.clamp(
                    mesh.textures * diffuseLight.unsqueeze(2) +
                    jt.ones_like(mesh.textures) * specularLight.unsqueeze(2),
                    0.0, 1.0)
            elif len(mesh.textures.shape) == 6:
                mesh.textures = jt.clamp(
                    mesh.textures *
                    diffuseLight.unsqueeze(2).unsqueeze(2).unsqueeze(2) +
                    jt.ones_like(mesh.textures) *
                    specularLight.unsqueeze(2).unsqueeze(2).unsqueeze(2), 0.0,
                    1.0)

        return mesh
Exemplo n.º 7
0
def elemwise_mask_iou(masks_a, masks_b):
    """ Does the same as above but instead of pairwise, elementwise along the outer dimension. """
    masks_a = masks_a.view(-1, masks_a.shape[-1])
    masks_b = masks_b.view(-1, masks_b.shape[-1])

    intersection = (masks_a * masks_b).sum(dim=0)
    area_a = masks_a.sum(dim=0)
    area_b = masks_b.sum(dim=0)

    # Return value is [n] for inputs [h, w, n]
    return jt.clamp(intersection /
                    jt.clamp(area_a + area_b - intersection, min_v=0.1),
                    max_v=1)
Exemplo n.º 8
0
def elemwise_box_iou(box_a, box_b):
    """ Does the same as above but instead of pairwise, elementwise along the inner dimension. """
    max_xy = jt.minimum(box_a[:, 2:], box_b[:, 2:])
    min_xy = jt.maximum(box_a[:, :2], box_b[:, :2])
    inter = jt.clamp((max_xy - min_xy), min_v=0)
    inter = inter[:, 0] * inter[:, 1]

    area_a = (box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1])
    area_b = (box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1])

    union = area_a + area_b - inter
    union = jt.clamp(union, min_v=0.1)

    # Return value is [n] for inputs [n, 4]
    return jt.clamp(inter / union, max_v=1)
Exemplo n.º 9
0
def trunc_normal(tensor, mean=0., std=1., a=-2., b=2.):
    _sqrt2 = 1.4142135623730951

    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / _sqrt2)) / 2.

    with jt.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        nn.init.uniform_(tensor, 2 * l - 1, 2 * u - 1)

        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor = erfinv(tensor)

        # Transform to proper mean, std
        tensor = tensor * std * _sqrt2 + mean

        # Clamp to ensure it's in the proper range
        tensor = jt.clamp(tensor, min_v=a, max_v=b)
        return tensor
Exemplo n.º 10
0
    def execute(self, inputs, targets, mask=None, act=False):
        losses = []
        for id in range(len(inputs)):
            if mask is not None:
                input_flatten, target_flatten = self.flatten(
                    inputs[id], targets[id], mask[id])
            else:
                input_flatten, target_flatten = self.flatten(
                    inputs[id], targets[id])
            if act:
                MIN = 1e-9
                input_flatten = jt.clamp(input_flatten,
                                         min_v=MIN,
                                         max_v=1 - MIN)
                input_flatten = jt.log(input_flatten) - jt.log(1 -
                                                               input_flatten)
            losses.append(self.lovasz_hinge_flat(input_flatten,
                                                 target_flatten))
        losses = jt.stack(losses)
        if self.reduction == "mean":
            losses = losses.mean()
        elif self.reduction == "sum":
            losses = losses.sum()

        return losses
Exemplo n.º 11
0
    def predict(self, images,score_thresh=0.7,nms_thresh = 0.3):
        N = images.shape[0]
        img_size = (images.shape[-1],images.shape[-2])
        rpn_locs, rpn_scores,roi_cls_locs, roi_scores, rois, roi_indices = self.execute(images)
        roi_cls_locs = roi_cls_locs.reshape(roi_cls_locs.shape[0],-1,4)
        probs = nn.softmax(roi_scores,dim=-1)
        rois = rois.unsqueeze(1).repeat(1,self.n_class,1)
        cls_bbox = loc2bbox(rois.reshape(-1,4),roi_cls_locs.reshape(-1,4))
        cls_bbox[:,0::2] = jt.clamp(cls_bbox[:,0::2],min_v=0,max_v=img_size[0])
        cls_bbox[:,1::2] = jt.clamp(cls_bbox[:,1::2],min_v=0,max_v=img_size[1])
        
        cls_bbox = cls_bbox.reshape(roi_cls_locs.shape)
        
        results = []
        for i in range(N):
            index = jt.where(roi_indices==i)[0]
            score = probs[index,:]
            bbox = cls_bbox[index,:,:]
            boxes = []
            scores = []
            labels = []
            for j in range(1,self.n_class):
                bbox_j = bbox[:,j,:]
                score_j = score[:,j]
                mask = jt.where(score_j>score_thresh)[0]
                bbox_j = bbox_j[mask,:]
                score_j = score_j[mask]
                dets = jt.contrib.concat([bbox_j,score_j.unsqueeze(1)],dim=1)
                keep = jt.nms(dets,nms_thresh)
                bbox_j = bbox_j[keep]
                score_j = score_j[keep]
                label_j = jt.ones_like(score_j).int32()*j
                boxes.append(bbox_j)
                scores.append(score_j)
                labels.append(label_j)
            
            boxes = jt.contrib.concat(boxes,dim=0)
            scores = jt.contrib.concat(scores,dim=0)
            labels = jt.contrib.concat(labels,dim=0)
            results.append((boxes,scores,labels))

        return results
    
    
    
    
        
Exemplo n.º 12
0
    def compute_mask_prob(self, pixel_embed, proposal_embed, proposal_margin,
                          boxes):
        dim, m_h, m_w = pixel_embed.shape
        pixel_embed = pixel_embed.view(dim, m_h * m_w).transpose(1, 0)

        boxes = boxes.int()
        boxes[:, 0] = jt.clamp(boxes[:, 0] - 2, min_v=0)
        boxes[:, 1] = jt.clamp(boxes[:, 1] - 2, min_v=0)
        boxes[:, 2] = jt.clamp(boxes[:, 2] + 2, max_v=m_w)
        boxes[:, 3] = jt.clamp(boxes[:, 3] + 2, max_v=m_h)

        box_areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
        area_sum = box_areas.sum().item()

        prob = mask_prob_cuda(pixel_embed, proposal_embed, proposal_margin,
                              boxes, box_areas, area_sum, m_w)
        prob = prob.view(-1, m_h, m_w)

        return prob
    def decode(self, rel_codes, boxes):
        """
        From a set of original boxes and encoded relative box offsets,
        get the decoded boxes.

        Arguments:
            rel_codes (Tensor): encoded boxes
            boxes (Tensor): reference boxes.
        """
        boxes = boxes.cast(rel_codes.dtype)

        TO_REMOVE = 1  # TODO remove
        widths = boxes[:, 2] - boxes[:, 0] + TO_REMOVE
        heights = boxes[:, 3] - boxes[:, 1] + TO_REMOVE
        ctr_x = boxes[:, 0] + 0.5 * widths
        ctr_y = boxes[:, 1] + 0.5 * heights

        wx, wy, ww, wh = self.weights
        dx = rel_codes[:, 0::4] / wx
        dy = rel_codes[:, 1::4] / wy
        dw = rel_codes[:, 2::4] / ww
        dh = rel_codes[:, 3::4] / wh

        # Prevent sending too large values into torch.exp()
        dw = jt.clamp(dw, max_v=self.bbox_xform_clip)
        dh = jt.clamp(dh, max_v=self.bbox_xform_clip)

        pred_ctr_x = dx * widths.unsqueeze(-1) + ctr_x.unsqueeze(-1)
        pred_ctr_y = dy * heights.unsqueeze(-1) + ctr_y.unsqueeze(-1)
        pred_w = jt.exp(dw) * widths.unsqueeze(-1)
        pred_h = jt.exp(dh) * heights.unsqueeze(-1)

        pred_boxes = jt.zeros_like(rel_codes)
        # x1
        pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w
        # y1
        pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h
        # x2 (note: "- 1" is correct; don't be fooled by the asymmetry)
        pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w - 1
        # y2 (note: "- 1" is correct; don't be fooled by the asymmetry)
        pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h - 1

        return pred_boxes
Exemplo n.º 14
0
def crop_by_box(masks, box, padding=0.0):
    n, h, w = masks.size()

    b_w = box[:, 2] - box[:, 0]
    b_h = box[:, 3] - box[:, 1]
    x1 = jt.clamp(box[:, 0] - b_w * padding - 1, min_v=0)
    x2 = jt.clamp(box[:, 2] + b_w * padding + 1, max_v=w - 1)
    y1 = jt.clamp(box[:, 1] - b_h * padding - 1, min_v=0)
    y2 = jt.clamp(box[:, 3] + b_h * padding + 1, max_v=h - 1)

    rows = jt.arange(w, dtype=x1.dtype).view(1, 1, -1).expand((n, h, w))
    cols = jt.arange(h, dtype=x1.dtype).view(1, -1, 1).expand((n, h, w))

    masks_left = rows >= x1.view(n, 1, 1)
    masks_right = rows < x2.view(n, 1, 1)
    masks_up = cols >= y1.view(n, 1, 1)
    masks_down = cols < y2.view(n, 1, 1)

    crop_mask = masks_left * masks_right * masks_up * masks_down
    return masks * crop_mask.float(), crop_mask
    def __call__(self, boxlists):
        """
        Arguments:
            boxlists (list[BoxList])
        """
        # Compute level ids
        s = jt.sqrt(cat([boxlist.area() for boxlist in boxlists]))

        # Eqn.(1) in FPN paper
        target_lvls = jt.floor(self.lvl0 + jt.log2(s / self.s0 + self.eps))
        target_lvls = jt.clamp(target_lvls, min_v=self.k_min, max_v=self.k_max)
        return target_lvls.int32() - self.k_min
    def __call__(self, boxlists):
        """
        Arguments:
            boxlists (list[BoxList])
        """
        # Compute level ids
        bbox_area = cat([boxlist.area() for boxlist in boxlists])
        img_area = cat([boxlist.image_area() for boxlist in boxlists])

        target_lvls = jt.ceil(self.k_max -
                              jt.log2(img_area / bbox_area + self.eps))
        target_lvls = jt.clamp(target_lvls, min_v=self.k_min, max_v=self.k_max)
        return target_lvls.int32() - self.k_min
Exemplo n.º 17
0
    def execute(self, xyz1, xyz2, points1, points2):
        """
        Input:
            xyz1: input points position data, [B, C, N]
            xyz2: sampled input points position data, [B, C, S]
            points1: input points data, [B, N, D]
            points2: input points data, [B, S, D]
        Return:
            new_points: upsampled points data, [B, N, D']
        """
        # xyz1 = xyz1.permute(0, 2, 1)
        # xyz2 = xyz2.permute(0, 2, 1)

        # points2 = points2.permute(0, 2, 1)
        B, N, C = xyz1.shape
        _, S, _ = xyz2.shape

        points2 = points2.transpose(0, 2, 1)  # b, n, c

        if S == 1:
            interpolated_points = points2.repeat(1, N, 1)
        else:
            # dists = square_distance(xyz1, xyz2)
            # idx, dists = jt.argsort(dists, dim=-1)
            # dists, idx = dists[:, :, :3], idx[:, :, :3]  # [B, N, 3]
            _, idx = three_nn(xyz1, xyz2)
            dists = ((xyz1.unsqueeze(-2) - xyz2['i0', idx]) ** 2).sum(-1)
            # dists = jt.ones([B, N, 3])
            # idx = jt.index([B, N, 3], 2)

            dists = jt.clamp(dists, min_v=1e-10)
            dist_recip = 1.0 / dists
            norm = jt.sum(dist_recip, dim=2, keepdims=True)
            weight = dist_recip / norm
            interpolated_points = jt.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)

        interpolated_points = interpolated_points.transpose(0, 2, 1)  # b, c, n
        if points1 is not None:
            # points1 = points1.permute(0, 2, 1)
            new_points = concat([interpolated_points, points1], dim=1)
        else:
            new_points = interpolated_points

        # new_points = new_points.permute(0, 2, 1)  # b, c, n
        # l = len(self.mlp_convs)
        # for i, conv in self.mlp_convs.layers.items():
        # conv = self.mlp_convs[i]
        # bn = self.mlp_bns[i]
        # new_points = self.relu(bn(conv(new_points)))
        new_points = self.mlp(new_points)
        return new_points  # .permute(0, 2, 1)  # b, n, c
Exemplo n.º 18
0
    def direct_mask_loss(self, pos_idx, idx_t, loc_data, mask_data, priors,
                         masks):
        """ Crops the gt masks using the predicted bboxes, scales them down, and outputs the BCE loss. """
        loss_m = 0
        for idx in range(mask_data.shape[0]):
            with jt.no_grad():
                cur_pos_idx = pos_idx[idx]
                cur_pos_idx_squeezed = cur_pos_idx[:, 1]

                # Shape: [num_priors, 4], decoded predicted bboxes
                pos_bboxes = decode(loc_data[idx], priors.data,
                                    cfg.use_yolo_regressors)
                pos_bboxes = pos_bboxes[cur_pos_idx].view(-1, 4).clamp(0, 1)
                pos_lookup = idx_t[idx, cur_pos_idx_squeezed]

                cur_masks = masks[idx]
                pos_masks = cur_masks[pos_lookup]

                # Convert bboxes to absolute coordinates
                num_pos, img_height, img_width = pos_masks.shape

                # Take care of all the bad behavior that can be caused by out of bounds coordinates
                x1, x2 = sanitize_coordinates(pos_bboxes[:, 0],
                                              pos_bboxes[:, 2], img_width)
                y1, y2 = sanitize_coordinates(pos_bboxes[:, 1],
                                              pos_bboxes[:, 3], img_height)

                # Crop each gt mask with the predicted bbox and rescale to the predicted mask size
                # Note that each bounding box crop is a different size so I don't think we can vectorize this
                scaled_masks = []
                for jdx in range(num_pos):
                    tmp_mask = pos_masks[jdx, y1[jdx]:y2[jdx], x1[jdx]:x2[jdx]]

                    # Restore any dimensions we've left out because our bbox was 1px wide
                    while tmp_mask.ndim < 2:
                        tmp_mask = tmp_mask.unsqueeze(0)

                    new_mask = nn.AdaptiveAvgPool2d(cfg.mask_size)(
                        tmp_mask.unsqueeze(0))
                    scaled_masks.append(new_mask.view(1, -1))

                mask_t = (jt.contrib.concat(scaled_masks, 0) >
                          0.5).float()  # Threshold downsampled mask

            pos_mask_data = mask_data[idx, cur_pos_idx_squeezed, :]
            loss_m += nn.bce_loss(jt.clamp(pos_mask_data, 0, 1),
                                  mask_t,
                                  size_average=False) * cfg.mask_alpha

        return loss_m
Exemplo n.º 19
0
def sanitize_coordinates(_x1,
                         _x2,
                         img_size: int,
                         padding: int = 0,
                         cast: bool = True):
    """
    Sanitizes the input coordinates so that x1 < x2, x1 != x2, x1 >= 0, and x2 <= image_size.
    Also converts from relative to absolute coordinates and casts the results to long tensors.

    If cast is false, the result won't be cast to longs.
    Warning: this does things in-place behind the scenes so copy if necessary.
    """
    _x1 = _x1 * img_size
    _x2 = _x2 * img_size
    if cast:
        _x1 = _x1.int32()
        _x2 = _x2.int32()
    x1 = jt.minimum(_x1, _x2)
    x2 = jt.maximum(_x1, _x2)
    x1 = jt.clamp(x1 - padding, min_v=0)
    x2 = jt.clamp(x2 + padding, max_v=img_size)

    return x1, x2
def adjust_dynamic_range(data, drange_in=(-1, 1), drange_out=(0, 1)):
    """
    adjust the dynamic colour range of the given input data
    :param data: input image data
    :param drange_in: original range of input
    :param drange_out: required range of output
    :return: img => colour range adjusted images
    """
    if drange_in != drange_out:
        scale = (np.float32(drange_out[1]) - np.float32(drange_out[0])) / (
            np.float32(drange_in[1]) - np.float32(drange_in[0]))
        bias = (np.float32(drange_out[0]) - np.float32(drange_in[0]) * scale)
        data = data * jt.array(scale) + jt.array(bias)
    # return torch.clamp(data, min=0, max=1)
    return jt.clamp(data, min_v=0, max_v=1)
Exemplo n.º 21
0
def intersect(box_a, box_b):
    """ We resize both tensors to [A,B,2] without new malloc:
    [A,2] -> [A,1,2] -> [A,B,2]
    [B,2] -> [1,B,2] -> [A,B,2]
    Then we compute the area of intersect between box_a and box_b.
    Args:
      box_a: (tensor) bounding boxes, Shape: [n,A,4].
      box_b: (tensor) bounding boxes, Shape: [n,B,4].
    Return:
      (tensor) intersection area, Shape: [n,A,B].
    """
    n = box_a.shape[0]
    A = box_a.shape[1]
    B = box_b.shape[1]
    max_xy = jt.minimum(box_a[:, :, 2:].unsqueeze(2).expand((n, A, B, 2)),
                        box_b[:, :, 2:].unsqueeze(1).expand((n, A, B, 2)))
    min_xy = jt.maximum(box_a[:, :, :2].unsqueeze(2).expand((n, A, B, 2)),
                        box_b[:, :, :2].unsqueeze(1).expand((n, A, B, 2)))
    return jt.clamp(max_xy - min_xy, min_v=0).prod(3)  # inter
Exemplo n.º 22
0
def intersect(box_a, box_b):
    """ We resize both tensors to [A,B,2] without new malloc:
    [A,2] -> [A,1,2] -> [A,B,2]
    [B,2] -> [1,B,2] -> [A,B,2]
    Then we compute the area of intersect between box_a and box_b.
    Args:
      box_a: (tensor) bounding boxes, Shape: [A,4].
      box_b: (tensor) bounding boxes, Shape: [B,4].
    Return:
      (tensor) intersection area, Shape: [A,B].
    """
    A = box_a.size(0)
    B = box_b.size(0)
    max_xy = jt.minimum(box_a[:, 2:].unsqueeze(1).expand(A, B, 2),
                       box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
    min_xy = jt.maximum(box_a[:, :2].unsqueeze(1).expand(A, B, 2),
                       box_b[:, :2].unsqueeze(0).expand(A, B, 2))
    inter = jt.clamp((max_xy - min_xy), min_v=0)
    return inter[:, :, 0] * inter[:, :, 1]
Exemplo n.º 23
0
def directional_lighting(diffuseLight,
                         specularLight,
                         normals,
                         light_intensity=0.5,
                         light_color=(1, 1, 1),
                         light_direction=(0, 1, 0),
                         positions=None,
                         eye=None,
                         with_specular=False,
                         metallic_textures=None,
                         roughness_textures=None,
                         Gbuffer="None",
                         transform=None):
    eye = jt.array(eye, "float32")
    light_color = jt.array(light_color, "float32")
    light_direction = jt.normalize(jt.array(light_direction, "float32"), dim=0)

    if len(light_color.shape) == 1:
        light_color = light_color.unsqueeze(0)
    if len(light_direction.shape) == 1:
        light_direction = light_direction.unsqueeze(0)

    cosine = nn.relu(jt.sum(normals * light_direction, dim=2))
    if with_specular:
        if len(metallic_textures.shape) == 4:
            total = metallic_textures.shape[2] * 1.0
            metallic_textures = jt.sum(metallic_textures, dim=2) / total
            roughness_textures = jt.sum(roughness_textures, dim=2) / total
        elif len(metallic_textures.shape) == 6:
            total = metallic_textures.shape[2] * metallic_textures.shape[
                3] * metallic_textures.shape[4] * 1.0
            metallic_textures = jt.sum(metallic_textures, dim=2)
            metallic_textures = jt.sum(metallic_textures, dim=2)
            metallic_textures = jt.sum(metallic_textures, dim=2)
            metallic_textures = metallic_textures / total
            roughness_textures = jt.sum(roughness_textures, dim=2)
            roughness_textures = jt.sum(roughness_textures, dim=2)
            roughness_textures = jt.sum(roughness_textures, dim=2)
            roughness_textures = roughness_textures / total

    #Microfacet model
    if with_specular and (eye is not None) and (positions is not None) and (
            metallic_textures is not None) and (roughness_textures
                                                is not None):
        N = normals
        if len(eye.shape) == 2:
            eye = eye.unsqueeze(1)
        V = jt.normalize(eye - positions, dim=2)
        L = light_direction
        H = jt.normalize(V + L, dim=2)

        #Default Setting
        metallic = metallic_textures
        roughness = roughness_textures
        F0 = jt.array((0.04, 0.04, 0.04), "float32")
        albedo = jt.array((1.0, 1.0, 1.0), "float32")

        F0 = F0.unsqueeze(0).unsqueeze(1) * (
            1 - metallic) + albedo.unsqueeze(0).unsqueeze(1) * metallic
        radiance = light_intensity * (light_color.unsqueeze(1) *
                                      cosine.unsqueeze(2))

        #Cook-Torrance BRDF
        NDF = GGX(N, H, roughness)
        G = GeometrySmith(N, V, L, roughness)
        F = fresnelSchlick(nn.relu(jt.sum(H * V, dim=2)), F0)
        KS = F
        KD = 1.0 - KS
        KD *= (1.0 - metallic)

        diffuseLight += KD * radiance
        numerator = NDF * G * F
        denominator = (4.0 * nn.relu(jt.sum(N * V, dim=2)) *
                       nn.relu(jt.sum(N * L, dim=2))).unsqueeze(2)
        specular = numerator / jt.clamp(denominator, 0.01)
        specularLight += specular * radiance
    else:
        diffuseLight += light_intensity * (light_color.unsqueeze(1) *
                                           cosine.unsqueeze(2))
    if Gbuffer == "normal":
        specularLight *= 0.0
        diffuseLight = normals * 0.5 + 0.5
    elif Gbuffer == "depth":
        specularLight *= 0.0
        viewpos = transform.tranpos(positions)
        diffuseLight = viewpos / jt.max(viewpos[..., 2])
        diffuseLight[..., 0] = viewpos[..., 2] / jt.max(viewpos[..., 2])
        diffuseLight[..., 1] = viewpos[..., 2] / jt.max(viewpos[..., 2])
    return [diffuseLight, specularLight]
Exemplo n.º 24
0
def clip_coordinates(x, clip_limit):
    return jt.clamp(x, min_v=0, max_v=clip_limit - 1)
Exemplo n.º 25
0
def hardtanh(x, min_val=-1, max_val=1):
    return jt.clamp(x, min_v=min_val, max_v=max_val)
Exemplo n.º 26
0
    def execute(self, loc, score, anchor, img_size, scale=1.):
        """input should  be ndarray
        Propose RoIs.

        Inputs :obj:`loc, score, anchor` refer to the same anchor when indexed
        by the same index.

        On notations, :math:`R` is the total number of anchors. This is equal
        to product of the height and the width of an image and the number of
        anchor bases per pixel.

        Type of the output is same as the inputs.

        Args:
            loc (array): Predicted offsets and scaling to anchors.
                Its shape is :math:`(R, 4)`.
            score (array): Predicted foreground probability for anchors.
                Its shape is :math:`(R,)`.
            anchor (array): Coordinates of anchors. Its shape is
                :math:`(R, 4)`.
            img_size (tuple of ints): A tuple :obj:`height, width`,
                which contains image size after scaling.
            scale (float): The scaling factor used to scale an image after
                reading it from a file.

        Returns:
            array:
            An array of coordinates of proposal boxes.
            Its shape is :math:`(S, 4)`. :math:`S` is less than
            :obj:`self.n_test_post_nms` in test time and less than
            :obj:`self.n_train_post_nms` in train time. :math:`S` depends on
            the size of the predicted bounding boxes and the number of
            bounding boxes discarded by NMS.

        """
        # NOTE: when test, remember
        if self.is_training():
            n_pre_nms = self.n_train_pre_nms
            n_post_nms = self.n_train_post_nms
        else:
            n_pre_nms = self.n_test_pre_nms
            n_post_nms = self.n_test_post_nms

        # Convert anchors into proposal via bbox transformations.
        roi = loc2bbox(anchor, loc)

        # Clip predicted boxes to image.
        roi[:, 0] = jt.clamp(roi[:, 0], min_v=0, max_v=img_size[0])
        roi[:, 2] = jt.clamp(roi[:, 2], min_v=0, max_v=img_size[0])

        roi[:, 1] = jt.clamp(roi[:, 1], min_v=0, max_v=img_size[1])
        roi[:, 3] = jt.clamp(roi[:, 3], min_v=0, max_v=img_size[1])

        # Remove predicted boxes with either height or width < threshold.
        min_size = self.min_size * scale
        hs = roi[:, 2] - roi[:, 0]
        ws = roi[:, 3] - roi[:, 1]
        keep = jt.where((hs >= min_size) & (ws >= min_size))[0]
        roi = roi[keep, :]
        score = score[keep]

        # Sort all (proposal, score) pairs by score from highest to lowest.
        # Take top pre_nms_topN (e.g. 6000).
        order, _ = jt.argsort(score, descending=True)
        if n_pre_nms > 0:
            order = order[:n_pre_nms]
        roi = roi[order, :]
        score = score[order]

        # Apply nms (e.g. threshold = 0.7).
        # Take after_nms_topN (e.g. 300).

        dets = jt.contrib.concat([roi, score.unsqueeze(1)], dim=1)
        keep = jt.nms(dets, self.nms_thresh)
        if n_post_nms > 0:
            keep = keep[:n_post_nms]
        roi = roi[keep]
        return roi
Exemplo n.º 27
0
 def logits(self):
     if self._logits is None:
         return jt.log(jt.clamp(self.probs, min_v=eps, max_v=1-eps))
     else:
         return self._logits
Exemplo n.º 28
0
def denormalize(var):
    """ Denormalizes image tensors using mean and std """
    return jt.clamp(
        var * jt.array(std).broadcast(var, [0, 2, 3]) +
        jt.array(mean).broadcast(var, [0, 2, 3]), 0, 255)
Exemplo n.º 29
0
    def lincomb_mask_loss(self, pos, idx_t, loc_data, mask_data, priors, proto_data, masks, gt_box_t, score_data, inst_data, labels, interpolation_mode='bilinear'):
        mask_h = proto_data.shape[1]
        mask_w = proto_data.shape[2]


        process_gt_bboxes = cfg.mask_proto_normalize_emulate_roi_pooling or cfg.mask_proto_crop

        if cfg.mask_proto_remove_empty_masks:
            # Make sure to store a copy of this because we edit it to get rid of all-zero masks
            pos = pos.clone()

        loss_m = 0
        loss_d = 0 # Coefficient diversity loss

        maskiou_t_list = []
        maskiou_net_input_list = []
        label_t_list = []

        for idx in range(mask_data.shape[0]):
            with jt.no_grad():
                downsampled_masks = nn.interpolate(masks[idx].unsqueeze(0), (mask_h, mask_w),
                                                  mode=interpolation_mode, align_corners=False).squeeze(0)
                downsampled_masks = downsampled_masks.permute(1, 2, 0)

                if cfg.mask_proto_binarize_downsampled_gt:
                    downsampled_masks = (downsampled_masks>0.5).float()

                if cfg.mask_proto_remove_empty_masks:
                    # Get rid of gt masks that are so small they get downsampled away
                    very_small_masks = (downsampled_masks.sum(0).sum(0) <= 0.0001)
                    for i in range(very_small_masks.shape[0]):
                        if very_small_masks[i]:
                            pos[idx, idx_t[idx] == i] = 0

                if cfg.mask_proto_reweight_mask_loss:
                    # Ensure that the gt is binary
                    if not cfg.mask_proto_binarize_downsampled_gt:
                        bin_gt = (downsampled_masks>0.5).float()
                    else:
                        bin_gt = downsampled_masks

                    gt_foreground_norm = bin_gt     / (jt.sum(bin_gt,   dim=(0,1), keepdim=True) + 0.0001)
                    gt_background_norm = (1-bin_gt) / (jt.sum(1-bin_gt, dim=(0,1), keepdim=True) + 0.0001)

                    mask_reweighting   = gt_foreground_norm * cfg.mask_proto_reweight_coeff + gt_background_norm
                    mask_reweighting  *= mask_h * mask_w

            cur_pos = pos[idx]
            cur_pos = jt.where(cur_pos)[0]
            pos_idx_t = idx_t[idx, cur_pos]
            
            if process_gt_bboxes:
                # Note: this is in point-form
                if cfg.mask_proto_crop_with_pred_box:
                    pos_gt_box_t = decode(loc_data[idx, :, :], priors.data, cfg.use_yolo_regressors)[cur_pos]
                else:
                    pos_gt_box_t = gt_box_t[idx, cur_pos]

            if pos_idx_t.shape[0] == 0:
                continue

            proto_masks = proto_data[idx]
            proto_coef  = mask_data[idx, cur_pos, :]
            if cfg.use_mask_scoring:
                mask_scores = score_data[idx, cur_pos, :]

            if cfg.mask_proto_coeff_diversity_loss:
                if inst_data is not None:
                    div_coeffs = inst_data[idx, cur_pos, :]
                else:
                    div_coeffs = proto_coef

                loss_d += self.coeff_diversity_loss(div_coeffs, pos_idx_t)
            
            # If we have over the allowed number of masks, select a random sample
            old_num_pos = proto_coef.shape[0]
            if old_num_pos > cfg.masks_to_train:
                perm = jt.randperm(proto_coef.shape[0])
                select = perm[:cfg.masks_to_train]

                proto_coef = proto_coef[select, :]
                pos_idx_t  = pos_idx_t[select]
                
                if process_gt_bboxes:
                    pos_gt_box_t = pos_gt_box_t[select, :]
                if cfg.use_mask_scoring:
                    mask_scores = mask_scores[select, :]

            num_pos = proto_coef.shape[0]
            mask_t = downsampled_masks[:, :, pos_idx_t]     
            label_t = labels[idx][pos_idx_t]     

            # Size: [mask_h, mask_w, num_pos]
            pred_masks = proto_masks @ proto_coef.transpose(1,0)

            pred_masks = cfg.mask_proto_mask_activation(pred_masks)

            if cfg.mask_proto_double_loss:
                if cfg.mask_proto_mask_activation == activation_func.sigmoid:
                    pre_loss = nn.bce_loss(jt.clamp(pred_masks, 0, 1), mask_t, size_average=False)
                else:
                    pre_loss = nn.smooth_l1_loss(pred_masks, mask_t, reduction='sum')
                
                loss_m += cfg.mask_proto_double_loss_alpha * pre_loss

            if cfg.mask_proto_crop:
                pred_masks = crop(pred_masks, pos_gt_box_t)
            
            if cfg.mask_proto_mask_activation == activation_func.sigmoid:
                pre_loss = binary_cross_entropy(jt.clamp(pred_masks, 0, 1), mask_t)
            else:
                pre_loss = nn.smooth_l1_loss(pred_masks, mask_t, reduction='none')

            if cfg.mask_proto_normalize_mask_loss_by_sqrt_area:
                gt_area  = jt.sum(mask_t, dim=(0, 1), keepdims=True)
                pre_loss = pre_loss / (jt.sqrt(gt_area) + 0.0001)
            
            if cfg.mask_proto_reweight_mask_loss:
                pre_loss = pre_loss * mask_reweighting[:, :, pos_idx_t]
            
                
            if cfg.mask_proto_normalize_emulate_roi_pooling:
                weight = mask_h * mask_w if cfg.mask_proto_crop else 1
                pos_gt_csize = center_size(pos_gt_box_t)
                gt_box_width  = pos_gt_csize[:, 2] * mask_w
                gt_box_height = pos_gt_csize[:, 3] * mask_h
                pre_loss = pre_loss.sum(0).sum(0) / gt_box_width / gt_box_height * weight
            

            # If the number of masks were limited scale the loss accordingly
            if old_num_pos > num_pos:
                pre_loss *= old_num_pos / num_pos

            loss_m += jt.sum(pre_loss)

            if cfg.use_maskiou:
                if cfg.discard_mask_area > 0:
                    gt_mask_area = jt.sum(mask_t, dim=(0, 1))
                    select = gt_mask_area > cfg.discard_mask_area

                    if jt.sum(select).item() < 1:
                        continue

                    pos_gt_box_t = pos_gt_box_t[select, :]
                    pred_masks = pred_masks[:, :, select]
                    mask_t = mask_t[:, :, select]
                    label_t = label_t[select]

                maskiou_net_input = pred_masks.permute(2, 0, 1).unsqueeze(1)
                pred_masks = (pred_masks>0.5).float()                
                maskiou_t = self._mask_iou(pred_masks, mask_t)
                
                maskiou_net_input_list.append(maskiou_net_input)
                maskiou_t_list.append(maskiou_t)
                label_t_list.append(label_t)
        
        losses = {'M': loss_m * cfg.mask_alpha / mask_h / mask_w}

        if cfg.mask_proto_coeff_diversity_loss:
            losses['D'] = loss_d

        if cfg.use_maskiou:
            # discard_mask_area discarded every mask in the batch, so nothing to do here
            if len(maskiou_t_list) == 0:
                return losses, None

            maskiou_t = jt.contrib.concat(maskiou_t_list)
            label_t = jt.contrib.concat(label_t_list)
            maskiou_net_input = jt.contrib.concat(maskiou_net_input_list)

            num_samples = maskiou_t.shape[0]
            if cfg.maskious_to_train > 0 and num_samples > cfg.maskious_to_train:
                perm = jt.randperm(num_samples)
                select = perm[:cfg.masks_to_train]
                maskiou_t = maskiou_t[select]
                label_t = label_t[select]
                maskiou_net_input = maskiou_net_input[select]

            return losses, [maskiou_net_input, maskiou_t, label_t]

        return losses
Exemplo n.º 30
0
    def execute(self, net, predictions, targets, masks, num_crowds):
        """Multibox Loss
        Args:
            predictions (tuple): A tuple containing loc preds, conf preds,
            mask preds, and prior boxes from SSD net.
                loc shape: jt.size(batch_size,num_priors,4)
                conf shape: jt.size(batch_size,num_priors,num_classes)
                masks shape: jt.size(batch_size,num_priors,mask_dim)
                priors shape: jt.size(num_priors,4)
                proto* shape: jt.size(batch_size,mask_h,mask_w,mask_dim)

            targets (list<tensor>): Ground truth boxes and labels for a batch,
                shape: [batch_size][num_objs,5] (last idx is the label).

            masks (list<tensor>): Ground truth masks for each object in each image,
                shape: [batch_size][num_objs,im_height,im_width]

            num_crowds (list<int>): Number of crowd annotations per batch. The crowd
                annotations should be the last num_crowds elements of targets and masks.
            
            * Only if mask_type == lincomb
        """

        loc_data  = predictions['loc']
        conf_data = predictions['conf']
        mask_data = predictions['mask']
        priors    = predictions['priors']

        if cfg.mask_type == mask_type.lincomb:
            proto_data = predictions['proto']

        score_data = predictions['score'] if cfg.use_mask_scoring   else None   
        inst_data  = predictions['inst']  if cfg.use_instance_coeff else None
        
        labels = [None] * len(targets) # Used in sem segm loss

        batch_size = loc_data.shape[0]
        num_priors = priors.shape[0]
        num_classes = self.num_classes

        # Match priors (default boxes) and ground truth boxes
        # These tensors will be created with the same device as loc_data
        loc_t = jt.empty((batch_size, num_priors, 4),dtype=loc_data.dtype)
        gt_box_t = jt.empty((batch_size, num_priors, 4),dtype=loc_data.dtype)
        conf_t = jt.empty((batch_size, num_priors)).int32()
        idx_t = jt.empty((batch_size, num_priors)).int32()

        if cfg.use_class_existence_loss:
            class_existence_t = jt.empty((batch_size, num_classes-1),dtype=loc_data.dtype)

        # jt.sync(list(predictions.values()))

        for idx in range(batch_size):
            truths      = targets[idx][:, :-1]
            labels[idx] = targets[idx][:, -1].int32()

            if cfg.use_class_existence_loss:
                # Construct a one-hot vector for each object and collapse it into an existence vector with max
                # Also it's fine to include the crowd annotations here
                class_existence_t[idx,:] = jt.eye(num_classes-1)[labels[idx]].max(dim=0)[0]

            # Split the crowd annotations because they come bundled in
            cur_crowds = num_crowds[idx]
            if cur_crowds > 0:
                split = lambda x: (x[-cur_crowds:], x[:-cur_crowds])
                crowd_boxes, truths = split(truths)

                # We don't use the crowd labels or masks
                _, labels[idx] = split(labels[idx])
                _, masks[idx]  = split(masks[idx])
            else:
                crowd_boxes = None

            
            match(self.pos_threshold, self.neg_threshold,
                  truths, priors, labels[idx], crowd_boxes,
                  loc_t, conf_t, idx_t, idx, loc_data[idx])
                  
            gt_box_t[idx,:,:] = truths[idx_t[idx]]

        # wrap targets
        loc_t.stop_grad()
        conf_t.stop_grad()
        idx_t.stop_grad()

        pos = conf_t > 0
        num_pos = pos.sum(dim=1, keepdims=True)
        
        # Shape: [batch,num_priors,4]
        pos_idx = pos.unsqueeze(pos.ndim).expand_as(loc_data)
        
        losses = {}

        # Localization Loss (Smooth L1)
        if cfg.train_boxes:
            loc_p = loc_data[pos_idx].view(-1, 4)
            loc_t = loc_t[pos_idx].view(-1, 4)
            # print(loc_t)
            losses['B'] = nn.smooth_l1_loss(loc_p, loc_t, reduction='sum') * cfg.bbox_alpha

        if cfg.train_masks:
            if cfg.mask_type == mask_type.direct:
                if cfg.use_gt_bboxes:
                    pos_masks = []
                    for idx in range(batch_size):
                        pos_masks.append(masks[idx][idx_t[idx, pos[idx]]])
                    masks_t = jt.contrib.concat(pos_masks, 0)
                    masks_p = mask_data[pos, :].view(-1, cfg.mask_dim)
                    losses['M'] = nn.bce_loss(jt.clamp(masks_p, 0, 1), masks_t, size_average=False) * cfg.mask_alpha
                else:
                    losses['M'] = self.direct_mask_loss(pos_idx, idx_t, loc_data, mask_data, priors, masks)
            elif cfg.mask_type == mask_type.lincomb:
                ret = self.lincomb_mask_loss(pos, idx_t, loc_data, mask_data, priors, proto_data, masks, gt_box_t, score_data, inst_data, labels)
                if cfg.use_maskiou:
                    loss, maskiou_targets = ret
                else:
                    loss = ret
                losses.update(loss)

                if cfg.mask_proto_loss is not None:
                    if cfg.mask_proto_loss == 'l1':
                        losses['P'] = jt.mean(jt.abs(proto_data)) / self.l1_expected_area * self.l1_alpha
                    elif cfg.mask_proto_loss == 'disj':
                        losses['P'] = -jt.mean(jt.max(nn.log_softmax(proto_data, dim=-1), dim=-1)[0])

        # Confidence loss
        if cfg.use_focal_loss:
            if cfg.use_sigmoid_focal_loss:
                losses['C'] = self.focal_conf_sigmoid_loss(conf_data, conf_t)
            elif cfg.use_objectness_score:
                losses['C'] = self.focal_conf_objectness_loss(conf_data, conf_t)
            else:
                losses['C'] = self.focal_conf_loss(conf_data, conf_t)
        else:
            if cfg.use_objectness_score:
                losses['C'] = self.conf_objectness_loss(conf_data, conf_t, batch_size, loc_p, loc_t, priors)
            else:
                losses['C'] = self.ohem_conf_loss(conf_data, conf_t, pos, batch_size)

        # Mask IoU Loss
        if cfg.use_maskiou and maskiou_targets is not None:
            losses['I'] = self.mask_iou_loss(net, maskiou_targets)

        # These losses also don't depend on anchors
        if cfg.use_class_existence_loss:
            losses['E'] = self.class_existence_loss(predictions['classes'], class_existence_t)
        if cfg.use_semantic_segmentation_loss:
            losses['S'] = self.semantic_segmentation_loss(predictions['segm'], masks, labels)

        # Divide all losses by the number of positives.
        # Don't do it for loss[P] because that doesn't depend on the anchors.
        total_num_pos = num_pos.sum().float()
        for k in losses:
            if k not in ('P', 'E', 'S'):
                losses[k] /= total_num_pos
            else:
                losses[k] /= batch_size

        # Loss Key:
        #  - B: Box Localization Loss
        #  - C: Class Confidence Loss
        #  - M: Mask Loss
        #  - P: Prototype Loss
        #  - D: Coefficient Diversity Loss
        #  - E: Class Existence Loss
        #  - S: Semantic Segmentation Loss
        return losses