def _create_label(self, anchor, bbox): # label: 1 is positive, 0 is negative, -1 is dont care label = -jt.ones((anchor.shape[0], ), dtype="int32") argmax_ious, max_ious, gt_argmax_ious = self._calc_ious(anchor, bbox) # assign negative labels first so that positive labels can clobber them label[max_ious < self.neg_iou_thresh] = 0 # positive label: for each gt, anchor with highest iou label[gt_argmax_ious] = 1 # positive label: above threshold IOU label[max_ious >= self.pos_iou_thresh] = 1 # subsample positive labels if we have too many n_pos = int(self.pos_ratio * self.n_sample) pos_index = jt.where(label == 1)[0] if len(pos_index) > n_pos: tmp_index = np.arange(0, pos_index.shape[0]) np.random.shuffle(tmp_index) disable_index = tmp_index[:pos_index.shape[0] - n_pos] disable_index = pos_index[disable_index] label[disable_index] = -1 # subsample negative labels if we have too many n_neg = self.n_sample - jt.sum(label == 1).item() neg_index = jt.where(label == 0)[0] if len(neg_index) > n_neg: tmp_index = np.arange(0, neg_index.shape[0]) np.random.shuffle(tmp_index) disable_index = tmp_index[:neg_index.shape[0] - n_neg] disable_index = neg_index[disable_index] label[disable_index] = -1 return argmax_ious, label
def test(self): assert (jt.where([0, 1, 0, 1])[0].data == [1, 3]).all() a, = jt.where([0, 1, 0, 1]) assert a.uncertain_shape == [-4] a.data assert a.uncertain_shape == [2] a, b = jt.where([[0, 0, 1], [1, 0, 0]]) assert (a.data == [0, 1]).all() and (b.data == [2, 0]).all()
def detect(self, batch_idx, conf_preds, decoded_boxes, mask_data, inst_data): """ Perform nms for only the max scoring class that isn't background (class 0) """ with timer.env('Slices'): cur_scores = conf_preds[batch_idx, 1:] conf_scores = jt.max(cur_scores, dim=0) keep = (conf_scores > self.conf_thresh) keep = jt.where(keep)[0] scores = cur_scores[:, keep] boxes = decoded_boxes[keep] masks = mask_data[batch_idx, keep] if inst_data is not None: inst = inst_data[batch_idx, keep] if scores.shape[1] == 0: return None if self.use_fast_nms: if self.use_cross_class_nms: boxes, masks, classes, scores = self.cc_fast_nms( boxes, masks, scores, self.nms_thresh, self.top_k) else: boxes, masks, classes, scores = self.fast_nms( boxes, masks, scores, self.nms_thresh, self.top_k) else: boxes, masks, classes, scores = self.traditional_nms( boxes, masks, scores, self.nms_thresh, self.conf_thresh) if self.use_cross_class_nms: print( 'Warning: Cross Class Traditional NMS is not implemented.') return {'box': boxes, 'mask': masks, 'class': classes, 'score': scores}
def predict(self, images,score_thresh=0.7,nms_thresh = 0.3): N = images.shape[0] img_size = (images.shape[-1],images.shape[-2]) rpn_locs, rpn_scores,roi_cls_locs, roi_scores, rois, roi_indices = self.execute(images) roi_cls_locs = roi_cls_locs.reshape(roi_cls_locs.shape[0],-1,4) probs = nn.softmax(roi_scores,dim=-1) rois = rois.unsqueeze(1).repeat(1,self.n_class,1) cls_bbox = loc2bbox(rois.reshape(-1,4),roi_cls_locs.reshape(-1,4)) cls_bbox[:,0::2] = jt.clamp(cls_bbox[:,0::2],min_v=0,max_v=img_size[0]) cls_bbox[:,1::2] = jt.clamp(cls_bbox[:,1::2],min_v=0,max_v=img_size[1]) cls_bbox = cls_bbox.reshape(roi_cls_locs.shape) results = [] for i in range(N): index = jt.where(roi_indices==i)[0] score = probs[index,:] bbox = cls_bbox[index,:,:] boxes = [] scores = [] labels = [] for j in range(1,self.n_class): bbox_j = bbox[:,j,:] score_j = score[:,j] mask = jt.where(score_j>score_thresh)[0] bbox_j = bbox_j[mask,:] score_j = score_j[mask] dets = jt.contrib.concat([bbox_j,score_j.unsqueeze(1)],dim=1) keep = jt.nms(dets,nms_thresh) bbox_j = bbox_j[keep] score_j = score_j[keep] label_j = jt.ones_like(score_j).int32()*j boxes.append(bbox_j) scores.append(score_j) labels.append(label_j) boxes = jt.contrib.concat(boxes,dim=0) scores = jt.contrib.concat(scores,dim=0) labels = jt.contrib.concat(labels,dim=0) results.append((boxes,scores,labels)) return results
def _calc_ious(self, anchor, bbox): # ious between the anchors and the gt boxes ious = bbox_iou(anchor, bbox) argmax_ious, max_ious = ious.argmax(dim=1) gt_argmax_ious, gt_max_ious = ious.argmax(dim=0) gt_argmax_ious = jt.where(ious == gt_max_ious)[0] return argmax_ious, max_ious, gt_argmax_ious
def nonzero(x): r''' Return the index of the elements of input tensor which are not equal to zero. ''' x = jt.where(x) x = [xx.unsqueeze(1) for xx in x] if len(x)<2: return x[0] x = jt.contrib.concat(x,dim=1) return x
def test_vary_shape_dep2(self): a = jt.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) index0, = jt.where(a.sum(1) > 7) # [1,2] index0 = index0.broadcast([1, 3], dims=[1]) # [[1,1,1],[2,2,2]] index1 = index0.index_var(1) # [[0,1,2],[0,1,2]] b = a.reindex_var([index0, index1]) assert b.uncertain_shape == [-3, 3] assert (b.data == [[4, 5, 6], [7, 8, 9]]).all() assert (index0.data == [[1, 1, 1], [2, 2, 2]]).all() assert (index1.data == [[0, 1, 2], [0, 1, 2]]).all()
def remove_small_boxes(boxlist, min_size): """ Only keep boxes with both sides >= min_size Arguments: boxlist (Boxlist) min_size (int) """ # TODO maybe add an API for querying the ws / hs xywh_boxes = boxlist.convert("xywh").bbox _, _, ws, hs = xywh_boxes.unbind(dim=1) keep = jt.where(jt.logical_and((ws >= min_size), (hs >= min_size)))[0] return boxlist[keep]
def _forward_train(self,features,img_size,boxes,labels): N = features.shape[0] rpn_locs, rpn_scores, rois, roi_indices, anchor = self.rpn(features, img_size) sample_rois = [] gt_roi_locs = [] gt_roi_labels = [] sample_roi_indexs = [] gt_rpn_locs = [] gt_rpn_labels = [] for i in range(N): index = jt.where(roi_indices == i)[0] roi = rois[index,:] box = boxes[i] label = labels[i] sample_roi, gt_roi_loc, gt_roi_label = self.proposal_target_creator(roi,box,label) sample_roi_index = i*jt.ones((sample_roi.shape[0],)) sample_rois.append(sample_roi) gt_roi_labels.append(gt_roi_label) gt_roi_locs.append(gt_roi_loc) sample_roi_indexs.append(sample_roi_index) gt_rpn_loc, gt_rpn_label = self.anchor_target_creator(box,anchor,img_size) gt_rpn_locs.append(gt_rpn_loc) gt_rpn_labels.append(gt_rpn_label) sample_roi_indexs = jt.contrib.concat(sample_roi_indexs,dim=0) sample_rois = jt.contrib.concat(sample_rois,dim=0) roi_cls_loc, roi_score = self.head(features,sample_rois,sample_roi_indexs) # ------------------ RPN losses -------------------# rpn_locs = rpn_locs.reshape(-1,4) rpn_scores = rpn_scores.reshape(-1,2) gt_rpn_labels = jt.contrib.concat(gt_rpn_labels,dim=0) gt_rpn_locs = jt.contrib.concat(gt_rpn_locs,dim=0) rpn_loc_loss = _fast_rcnn_loc_loss(rpn_locs,gt_rpn_locs,gt_rpn_labels,self.rpn_sigma) rpn_cls_loss = nn.cross_entropy_loss(rpn_scores[gt_rpn_labels>=0,:],gt_rpn_labels[gt_rpn_labels>=0]) # ------------------ ROI losses (fast rcnn loss) -------------------# gt_roi_locs = jt.contrib.concat(gt_roi_locs,dim=0) gt_roi_labels = jt.contrib.concat(gt_roi_labels,dim=0) n_sample = roi_cls_loc.shape[0] roi_cls_loc = roi_cls_loc.view(n_sample, np.prod(roi_cls_loc.shape[1:]).item()//4, 4) roi_loc = roi_cls_loc[jt.arange(0, n_sample).int32(), gt_roi_labels] roi_loc_loss = _fast_rcnn_loc_loss(roi_loc,gt_roi_locs,gt_roi_labels,self.roi_sigma) roi_cls_loss = nn.cross_entropy_loss(roi_score, gt_roi_labels) losses = [rpn_loc_loss, rpn_cls_loss, roi_loc_loss, roi_cls_loss] losses = losses + [sum(losses)] return losses
def execute(self, bbox, anchor, img_size): """Assign ground truth supervision to sampled subset of anchors. Types of input arrays and output arrays are same. Here are notations. * :math:`S` is the number of anchors. * :math:`R` is the number of bounding boxes. Args: bbox (array): Coordinates of bounding boxes. Its shape is :math:`(R, 4)`. anchor (array): Coordinates of anchors. Its shape is :math:`(S, 4)`. img_size (tuple of ints): A tuple :obj:`H, W`, which is a tuple of height and width of an image. Returns: (array, array): #NOTE: it's scale not only offset * **loc**: Offsets and scales to match the anchors to \ the ground truth bounding boxes. Its shape is :math:`(S, 4)`. * **label**: Labels of anchors with values \ :obj:`(1=positive, 0=negative, -1=ignore)`. Its shape \ is :math:`(S,)`. """ img_W, img_H = img_size n_anchor = len(anchor) inside_index = jt.where((anchor[:, 0] >= 0) & (anchor[:, 1] >= 0) & (anchor[:, 2] <= img_W) & (anchor[:, 3] <= img_H))[0] anchor = anchor[inside_index] argmax_ious, label = self._create_label(anchor, bbox) # compute bounding box regression targets loc = bbox2loc(anchor, bbox[argmax_ious]) # map up to original set of anchors label = _unmap(label, n_anchor, inside_index, fill=-1) loc = _unmap(loc, n_anchor, inside_index, fill=0) return loc, label
def process_batch(self, detections, labels): """ Return intersection-over-union (Jaccard index) of boxes. Both sets of boxes are expected to be in (x1, y1, x2, y2) format. Arguments: detections (Array[N, 6]), x1, y1, x2, y2, conf, class labels (Array[M, 5]), class, x1, y1, x2, y2 Returns: None, updates confusion matrix accordingly """ detections = detections[detections[:, 4] > self.conf] gt_classes = labels[:, 0].int() detection_classes = detections[:, 5].int() iou = general.box_iou(labels[:, 1:], detections[:, :4]) x = jt.where(iou > self.iou_thres) if x[0].shape[0]: matches = jt.contrib.concat( (jt.stack(x, 1), iou[x[0], x[1]][:, None]), 1).numpy() if x[0].shape[0] > 1: matches = matches[matches[:, 2].argsort()[::-1]] matches = matches[np.unique(matches[:, 1], return_index=True)[1]] matches = matches[matches[:, 2].argsort()[::-1]] matches = matches[np.unique(matches[:, 0], return_index=True)[1]] else: matches = np.zeros((0, 3)) n = matches.shape[0] > 0 m0, m1, _ = matches.transpose().astype(np.int16) for i, gc in enumerate(gt_classes): j = m0 == i if n and sum(j) == 1: self.matrix[gc, detection_classes[m1[j]]] += 1 # correct else: self.matrix[self.nc, gc] += 1 # background FP if n: for i, dc in enumerate(detection_classes): if not any(m1 == i): self.matrix[dc, self.nc] += 1 # background FN
def execute(self, roi, bbox, label): """Assigns ground truth to sampled proposals. This function samples total of :obj:`self.n_sample` RoIs from the combination of :obj:`roi` and :obj:`bbox`. The RoIs are assigned with the ground truth class labels as well as bounding box offsets and scales to match the ground truth bounding boxes. As many as :obj:`pos_ratio * self.n_sample` RoIs are sampled as foregrounds. Offsets and scales of bounding boxes are calculated using :func:`model.utils.bbox_tools.bbox2loc`. Also, types of input arrays and output arrays are same. Here are notations. * :math:`S` is the total number of sampled RoIs, which equals \ :obj:`self.n_sample`. * :math:`L` is number of object classes possibly including the \ background. Args: roi (array): Region of Interests (RoIs) from which we sample. Its shape is :math:`(R, 4)` bbox (array): The coordinates of ground truth bounding boxes. Its shape is :math:`(R', 4)`. label (array): Ground truth bounding box labels. Its shape is :math:`(R',)`. Its range is :math:`[0, L - 1]`, where :math:`L` is the number of foreground classes. Returns: (array, array, array): * **sample_roi**: Regions of interests that are sampled. \ Its shape is :math:`(S, 4)`. * **gt_roi_loc**: Offsets and scales to match \ the sampled RoIs to the ground truth bounding boxes. \ Its shape is :math:`(S, 4)`. * **gt_roi_label**: Labels assigned to sampled RoIs. Its shape is \ :math:`(S,)`. Its range is :math:`[0, L]`. The label with \ value 0 is the background. """ pos_roi_per_image = np.round(self.n_sample * self.pos_ratio) iou = bbox_iou(roi, bbox) gt_assignment, max_iou = iou.argmax(dim=1) # Offset range of classes from [0, n_fg_class - 1] to [1, n_fg_class]. # The label with value 0 is the background. gt_roi_label = label[gt_assignment] # Select foreground RoIs as those with >= pos_iou_thresh IoU. pos_index = jt.where(max_iou >= self.pos_iou_thresh)[0] pos_roi_per_this_image = int(min(pos_roi_per_image, pos_index.shape[0])) if pos_index.shape[0] > 0: tmp_indexes = np.arange(0, pos_index.shape[0]) np.random.shuffle(tmp_indexes) tmp_indexes = tmp_indexes[:pos_roi_per_this_image] pos_index = pos_index[tmp_indexes] # Select background RoIs as those within # [neg_iou_thresh_lo, neg_iou_thresh_hi). neg_index = jt.where((max_iou < self.neg_iou_thresh_hi) & (max_iou >= self.neg_iou_thresh_lo))[0] neg_roi_per_this_image = self.n_sample - pos_roi_per_this_image neg_roi_per_this_image = int( min(neg_roi_per_this_image, neg_index.shape[0])) if neg_index.shape[0] > 0: tmp_indexes = np.arange(0, neg_index.shape[0]) np.random.shuffle(tmp_indexes) tmp_indexes = tmp_indexes[:neg_roi_per_this_image] neg_index = neg_index[tmp_indexes] # The indices that we're selecting (both positive and negative). keep_index = jt.contrib.concat((pos_index, neg_index), dim=0) gt_roi_label = gt_roi_label[keep_index] gt_roi_label[pos_roi_per_this_image:] = 0 # negative labels --> 0 sample_roi = roi[keep_index] # Compute offsets and scales to match sampled RoIs to the GTs. gt_roi_loc = bbox2loc(sample_roi, bbox[gt_assignment[keep_index]]) return sample_roi, gt_roi_loc, gt_roi_label
def execute(self, loc, score, anchor, img_size, scale=1.): """input should be ndarray Propose RoIs. Inputs :obj:`loc, score, anchor` refer to the same anchor when indexed by the same index. On notations, :math:`R` is the total number of anchors. This is equal to product of the height and the width of an image and the number of anchor bases per pixel. Type of the output is same as the inputs. Args: loc (array): Predicted offsets and scaling to anchors. Its shape is :math:`(R, 4)`. score (array): Predicted foreground probability for anchors. Its shape is :math:`(R,)`. anchor (array): Coordinates of anchors. Its shape is :math:`(R, 4)`. img_size (tuple of ints): A tuple :obj:`height, width`, which contains image size after scaling. scale (float): The scaling factor used to scale an image after reading it from a file. Returns: array: An array of coordinates of proposal boxes. Its shape is :math:`(S, 4)`. :math:`S` is less than :obj:`self.n_test_post_nms` in test time and less than :obj:`self.n_train_post_nms` in train time. :math:`S` depends on the size of the predicted bounding boxes and the number of bounding boxes discarded by NMS. """ # NOTE: when test, remember if self.is_training(): n_pre_nms = self.n_train_pre_nms n_post_nms = self.n_train_post_nms else: n_pre_nms = self.n_test_pre_nms n_post_nms = self.n_test_post_nms # Convert anchors into proposal via bbox transformations. roi = loc2bbox(anchor, loc) # Clip predicted boxes to image. roi[:, 0] = jt.clamp(roi[:, 0], min_v=0, max_v=img_size[0]) roi[:, 2] = jt.clamp(roi[:, 2], min_v=0, max_v=img_size[0]) roi[:, 1] = jt.clamp(roi[:, 1], min_v=0, max_v=img_size[1]) roi[:, 3] = jt.clamp(roi[:, 3], min_v=0, max_v=img_size[1]) # Remove predicted boxes with either height or width < threshold. min_size = self.min_size * scale hs = roi[:, 2] - roi[:, 0] ws = roi[:, 3] - roi[:, 1] keep = jt.where((hs >= min_size) & (ws >= min_size))[0] roi = roi[keep, :] score = score[keep] # Sort all (proposal, score) pairs by score from highest to lowest. # Take top pre_nms_topN (e.g. 6000). order, _ = jt.argsort(score, descending=True) if n_pre_nms > 0: order = order[:n_pre_nms] roi = roi[order, :] score = score[order] # Apply nms (e.g. threshold = 0.7). # Take after_nms_topN (e.g. 300). dets = jt.contrib.concat([roi, score.unsqueeze(1)], dim=1) keep = jt.nms(dets, self.nms_thresh) if n_post_nms > 0: keep = keep[:n_post_nms] roi = roi[keep] return roi
def test_vary_shape_setitem(self): a = jt.array([1, 2, 3, 4, 5]) b = jt.array([1, 2, 3, 4, 5]) c = jt.where(b > 3) a[c] = 0 assert (a.data == [1, 2, 3, 0, 0]).all()
def test_vary_shape_dep(self): a, = jt.where([1, 0, 1]) b, = a.index_var() assert a.uncertain_shape == [-3] and b.uncertain_shape == [-3] assert (b.data == [0, 1]).all()
def lincomb_mask_loss(self, pos, idx_t, loc_data, mask_data, priors, proto_data, masks, gt_box_t, score_data, inst_data, labels, interpolation_mode='bilinear'): mask_h = proto_data.shape[1] mask_w = proto_data.shape[2] process_gt_bboxes = cfg.mask_proto_normalize_emulate_roi_pooling or cfg.mask_proto_crop if cfg.mask_proto_remove_empty_masks: # Make sure to store a copy of this because we edit it to get rid of all-zero masks pos = pos.clone() loss_m = 0 loss_d = 0 # Coefficient diversity loss maskiou_t_list = [] maskiou_net_input_list = [] label_t_list = [] for idx in range(mask_data.shape[0]): with jt.no_grad(): downsampled_masks = nn.interpolate(masks[idx].unsqueeze(0), (mask_h, mask_w), mode=interpolation_mode, align_corners=False).squeeze(0) downsampled_masks = downsampled_masks.permute(1, 2, 0) if cfg.mask_proto_binarize_downsampled_gt: downsampled_masks = (downsampled_masks>0.5).float() if cfg.mask_proto_remove_empty_masks: # Get rid of gt masks that are so small they get downsampled away very_small_masks = (downsampled_masks.sum(0).sum(0) <= 0.0001) for i in range(very_small_masks.shape[0]): if very_small_masks[i]: pos[idx, idx_t[idx] == i] = 0 if cfg.mask_proto_reweight_mask_loss: # Ensure that the gt is binary if not cfg.mask_proto_binarize_downsampled_gt: bin_gt = (downsampled_masks>0.5).float() else: bin_gt = downsampled_masks gt_foreground_norm = bin_gt / (jt.sum(bin_gt, dim=(0,1), keepdim=True) + 0.0001) gt_background_norm = (1-bin_gt) / (jt.sum(1-bin_gt, dim=(0,1), keepdim=True) + 0.0001) mask_reweighting = gt_foreground_norm * cfg.mask_proto_reweight_coeff + gt_background_norm mask_reweighting *= mask_h * mask_w cur_pos = pos[idx] cur_pos = jt.where(cur_pos)[0] pos_idx_t = idx_t[idx, cur_pos] if process_gt_bboxes: # Note: this is in point-form if cfg.mask_proto_crop_with_pred_box: pos_gt_box_t = decode(loc_data[idx, :, :], priors.data, cfg.use_yolo_regressors)[cur_pos] else: pos_gt_box_t = gt_box_t[idx, cur_pos] if pos_idx_t.shape[0] == 0: continue proto_masks = proto_data[idx] proto_coef = mask_data[idx, cur_pos, :] if cfg.use_mask_scoring: mask_scores = score_data[idx, cur_pos, :] if cfg.mask_proto_coeff_diversity_loss: if inst_data is not None: div_coeffs = inst_data[idx, cur_pos, :] else: div_coeffs = proto_coef loss_d += self.coeff_diversity_loss(div_coeffs, pos_idx_t) # If we have over the allowed number of masks, select a random sample old_num_pos = proto_coef.shape[0] if old_num_pos > cfg.masks_to_train: perm = jt.randperm(proto_coef.shape[0]) select = perm[:cfg.masks_to_train] proto_coef = proto_coef[select, :] pos_idx_t = pos_idx_t[select] if process_gt_bboxes: pos_gt_box_t = pos_gt_box_t[select, :] if cfg.use_mask_scoring: mask_scores = mask_scores[select, :] num_pos = proto_coef.shape[0] mask_t = downsampled_masks[:, :, pos_idx_t] label_t = labels[idx][pos_idx_t] # Size: [mask_h, mask_w, num_pos] pred_masks = proto_masks @ proto_coef.transpose(1,0) pred_masks = cfg.mask_proto_mask_activation(pred_masks) if cfg.mask_proto_double_loss: if cfg.mask_proto_mask_activation == activation_func.sigmoid: pre_loss = nn.bce_loss(jt.clamp(pred_masks, 0, 1), mask_t, size_average=False) else: pre_loss = nn.smooth_l1_loss(pred_masks, mask_t, reduction='sum') loss_m += cfg.mask_proto_double_loss_alpha * pre_loss if cfg.mask_proto_crop: pred_masks = crop(pred_masks, pos_gt_box_t) if cfg.mask_proto_mask_activation == activation_func.sigmoid: pre_loss = binary_cross_entropy(jt.clamp(pred_masks, 0, 1), mask_t) else: pre_loss = nn.smooth_l1_loss(pred_masks, mask_t, reduction='none') if cfg.mask_proto_normalize_mask_loss_by_sqrt_area: gt_area = jt.sum(mask_t, dim=(0, 1), keepdims=True) pre_loss = pre_loss / (jt.sqrt(gt_area) + 0.0001) if cfg.mask_proto_reweight_mask_loss: pre_loss = pre_loss * mask_reweighting[:, :, pos_idx_t] if cfg.mask_proto_normalize_emulate_roi_pooling: weight = mask_h * mask_w if cfg.mask_proto_crop else 1 pos_gt_csize = center_size(pos_gt_box_t) gt_box_width = pos_gt_csize[:, 2] * mask_w gt_box_height = pos_gt_csize[:, 3] * mask_h pre_loss = pre_loss.sum(0).sum(0) / gt_box_width / gt_box_height * weight # If the number of masks were limited scale the loss accordingly if old_num_pos > num_pos: pre_loss *= old_num_pos / num_pos loss_m += jt.sum(pre_loss) if cfg.use_maskiou: if cfg.discard_mask_area > 0: gt_mask_area = jt.sum(mask_t, dim=(0, 1)) select = gt_mask_area > cfg.discard_mask_area if jt.sum(select).item() < 1: continue pos_gt_box_t = pos_gt_box_t[select, :] pred_masks = pred_masks[:, :, select] mask_t = mask_t[:, :, select] label_t = label_t[select] maskiou_net_input = pred_masks.permute(2, 0, 1).unsqueeze(1) pred_masks = (pred_masks>0.5).float() maskiou_t = self._mask_iou(pred_masks, mask_t) maskiou_net_input_list.append(maskiou_net_input) maskiou_t_list.append(maskiou_t) label_t_list.append(label_t) losses = {'M': loss_m * cfg.mask_alpha / mask_h / mask_w} if cfg.mask_proto_coeff_diversity_loss: losses['D'] = loss_d if cfg.use_maskiou: # discard_mask_area discarded every mask in the batch, so nothing to do here if len(maskiou_t_list) == 0: return losses, None maskiou_t = jt.contrib.concat(maskiou_t_list) label_t = jt.contrib.concat(label_t_list) maskiou_net_input = jt.contrib.concat(maskiou_net_input_list) num_samples = maskiou_t.shape[0] if cfg.maskious_to_train > 0 and num_samples > cfg.maskious_to_train: perm = jt.randperm(num_samples) select = perm[:cfg.masks_to_train] maskiou_t = maskiou_t[select] label_t = label_t[select] maskiou_net_input = maskiou_net_input[select] return losses, [maskiou_net_input, maskiou_t, label_t] return losses