def crop_mask(self, mask, boxes, h, w): mask = interpolate(mask, size=(h, w), mode='bilinear', align_corners=False) prob = torch.zeros((len(boxes), 1, 28, 28), device=boxes.device, dtype=mask.dtype) for i, box in enumerate(boxes): x1, y1, x2, y2 = box.int() prob[i, 0] = interpolate(mask[:, :, x1:x2, y1:y2], size=(28, 28), mode='bilinear', align_corners=False)[0, 0] return prob
def compute_mask(self, mask, boxes, h, w): mask = interpolate(mask, size=(h, w), mode='bilinear', align_corners=False) mask = mask.permute(0, 2, 3, 1) M = int(math.sqrt(mask.shape[-1])) prob = torch.zeros((len(boxes), 1, M, M), device=boxes.device, dtype=mask.dtype) for i, box in enumerate(boxes): x1, y1, x2, y2 = box.int() x1 = int(max(0, x1)) x2 = int(min(x2, h - 1)) y1 = int(max(0, y1)) y2 = int(min(y2, w - 1)) if x1 >= x2 or y1 >= y2: continue #print(mask.shape) #print(x1, x2, y1, y2) #loc = self.compute_location(x1, x2, y1, y2, mask.device) #print(loc[:,0].min(), loc[:,0].max(), loc[:,1].min(), loc[:,1].max()) #print() #prob[i, 0] = mask[:, loc[:,0], loc[:,1], :].mean(dim=1).reshape(M, M) prob[i, 0] = mask[:, (x1 + x2) // 2, (y1 + y2) // 2, :].reshape(M, M) return prob
def forward(self, locations, box_cls, box_regression, centerness, proposal_embed, proposal_margin, pixel_embed, image_sizes, targets): """ Arguments: anchors: list[list[BoxList]] box_cls: list[tensor] box_regression: list[tensor] image_sizes: list[(h, w)] Returns: boxlists (list[BoxList]): the post-processed anchors, after applying box decoding and NMS """ sampled_boxes = [] for i, (l, o, b, c) in enumerate(zip(locations, box_cls, box_regression, centerness)): em = proposal_embed[i] mar = proposal_margin[i] if self.fix_margin: mar = torch.ones_like(mar) * self.init_margin sampled_boxes.append( self.forward_for_single_feature_map( l, o, b, c, em, mar, image_sizes, i ) ) boxlists = list(zip(*sampled_boxes)) boxlists = [cat_boxlist(boxlist) for boxlist in boxlists] boxlists = self.select_over_all_levels(boxlists) # resize pixel embedding for higher resolution N, dim, m_h, m_w = pixel_embed.shape o_h = m_h * self.mask_scale_factor o_w = m_w * self.mask_scale_factor pixel_embed = interpolate(pixel_embed, size=(o_h, o_w), mode='bilinear', align_corners=False) boxlists = self.forward_for_mask(boxlists, pixel_embed) return boxlists
def prepare_masks(self, o_h, o_w, r_h, r_w, targets_masks): masks = [] for im_i in range(len(targets_masks)): mask_t = targets_masks[im_i] if len(mask_t) == 0: masks.append(mask_t.new_tensor([])) continue n, h, w = mask_t.shape mask = mask_t.new_zeros((n, r_h, r_w)) mask[:, :h, :w] = mask_t resized_mask = interpolate( input=mask.float().unsqueeze(0), size=(o_h, o_w), mode="bilinear", align_corners=False, )[0].gt(0) masks.append(resized_mask) return masks
def compute_single_instance_mask(self, masks): instances = torch.split(masks, [1] * len(masks), dim=0) instances = sorted(instances, key=lambda x: x.sum(), reverse=True) re = instances[0] for i, item in enumerate(instances[1:]): re = re * (1 - item) + item * (i + 2) size = int(math.sqrt(self.box_mask_pw_channels)) obj = [] #obj.append(torch.zeros((1, self.box_mask_pw_channels), device=instances[0].device, dtype=instances[0].dtype)) for item in instances: loc = torch.nonzero(item) xmin, xmax, ymin, ymax = loc[:, 1].min( ), loc[:, 1].max() + 1, loc[:, 2].min(), loc[:, 2].max() + 1 tmp = interpolate(item[:, xmin:xmax, ymin:ymax].unsqueeze(0).float(), size=(size, size), mode='bilinear', align_corners=False) > 0 obj.append(tmp.squeeze(0).reshape(1, -1)) obj.insert(0, torch.zeros_like(obj[0])) return re, torch.cat(obj, dim=0)
def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1): padded_mask, scale = expand_masks(mask[None], padding=padding) mask = padded_mask[0, 0] box = expand_boxes(box[None], scale)[0] box = box.to(dtype=torch.int32) TO_REMOVE = 1 w = int(box[2] - box[0] + TO_REMOVE) h = int(box[3] - box[1] + TO_REMOVE) w = max(w, 1) h = max(h, 1) # Set shape to [batchxCxHxW] mask = mask.expand((1, 1, -1, -1)) # Resize mask mask = mask.to(torch.float32) mask = interpolate(mask, size=(h, w), mode='bilinear', align_corners=False) mask = mask[0][0] if thresh >= 0: mask = mask > thresh else: # for visualization and debugging, we also # allow it to return an unmodified mask mask = (mask * 255).to(torch.uint8) im_mask = torch.zeros((im_h, im_w), dtype=torch.uint8) x_0 = max(box[0], 0) x_1 = min(box[2] + 1, im_w) y_0 = max(box[1], 0) y_1 = min(box[3] + 1, im_h) im_mask[y_0:y_1, x_0:x_1] = mask[ (y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0]) ] return im_mask
def compute_targets_for_locations(self, locations, feature_size, targets, object_sizes_of_interest): labels = [] reg_targets = [] mask_targets = [] xs, ys = locations[:, 0], locations[:, 1] device = targets[0].bbox.device for im_i in range(len(targets)): targets_per_im = targets[im_i] assert targets_per_im.mode == "xyxy" bboxes = targets_per_im.bbox labels_per_im = targets_per_im.get_field("labels") masks_per_im = targets_per_im.get_field( 'masks').get_mask_tensor().to(device) area = targets_per_im.area() if len(masks_per_im.size()) < 3: masks_per_im = masks_per_im.unsqueeze(0) instance_mask, instances = self.compute_single_instance_mask( masks_per_im) masks = [] for size in feature_size: with torch.no_grad(): resized_masks_per_im = interpolate( instance_mask.unsqueeze(0).float(), size=size, mode='bilinear', align_corners=False ) #F.adaptive_avg_pool2d(Variable(instance_mask.float()), size).data masks.append(instances[ resized_masks_per_im.squeeze().long()].float().reshape( size[0] * size[1], -1)) masks = torch.cat(masks, dim=0) l = xs[:, None] - bboxes[:, 0][None] t = ys[:, None] - bboxes[:, 1][None] r = bboxes[:, 2][None] - xs[:, None] b = bboxes[:, 3][None] - ys[:, None] reg_targets_per_im = torch.stack([l, t, r, b], dim=2) is_in_boxes = reg_targets_per_im.min(dim=2)[0] > 0 max_reg_targets_per_im = reg_targets_per_im.max(dim=2)[0] # limit the regression range for each location is_cared_in_the_level = \ (max_reg_targets_per_im >= object_sizes_of_interest[:, [0]]) & \ (max_reg_targets_per_im <= object_sizes_of_interest[:, [1]]) locations_to_gt_area = area[None].repeat(len(locations), 1) locations_to_gt_area[is_in_boxes == 0] = INF locations_to_gt_area[is_cared_in_the_level == 0] = INF # if there are still more than one objects for a location, # we choose the one with minimal area locations_to_min_area, locations_to_gt_inds = locations_to_gt_area.min( dim=1) reg_targets_per_im = reg_targets_per_im[range(len(locations)), locations_to_gt_inds] labels_per_im = labels_per_im[locations_to_gt_inds] labels_per_im[locations_to_min_area == INF] = 0 #masks = masks[range(len(locations)), locations_to_gt_inds] masks[locations_to_min_area == INF] = 0 #masks = (masks.sum(dim=1) > 0).float() labels.append(labels_per_im) reg_targets.append(reg_targets_per_im) mask_targets.append(masks) return labels, reg_targets, mask_targets
def __call__(self, locations, box_cls, box_regression, centerness, proposal_embed, proposal_margin, pixel_embed, targets): """ Arguments: locations (list[BoxList]) box_cls (list[Tensor]) box_regression (list[Tensor]) centerness (list[Tensor]) targets (list[BoxList]) Returns: cls_loss (Tensor) reg_loss (Tensor) centerness_loss (Tensor) """ num_classes = box_cls[0].size(1) im_h = box_cls[4].shape[2] * self.fpn_strides[4] im_w = box_cls[4].shape[3] * self.fpn_strides[4] labels_per_level, reg_targets_per_level, labels, reg_targets, matched_idxes = self.prepare_targets( locations, targets, im_w, im_h) box_cls_flatten = [] box_regression_flatten = [] centerness_flatten = [] labels_flatten = [] reg_targets_flatten = [] for l in range(len(labels_per_level)): box_cls_flatten.append(box_cls[l].permute(0, 2, 3, 1).reshape( -1, num_classes)) box_regression_flatten.append(box_regression[l].permute( 0, 2, 3, 1).reshape(-1, 4)) labels_flatten.append(labels_per_level[l].reshape(-1)) reg_targets_flatten.append(reg_targets_per_level[l].reshape(-1, 4)) centerness_flatten.append(centerness[l].reshape(-1)) box_cls_flatten = torch.cat(box_cls_flatten, dim=0) box_regression_flatten = torch.cat(box_regression_flatten, dim=0) centerness_flatten = torch.cat(centerness_flatten, dim=0) labels_flatten = torch.cat(labels_flatten, dim=0) reg_targets_flatten = torch.cat(reg_targets_flatten, dim=0) pos_inds = torch.nonzero(labels_flatten > 0).squeeze(1) box_regression_flatten = box_regression_flatten[pos_inds] reg_targets_flatten = reg_targets_flatten[pos_inds] centerness_flatten = centerness_flatten[pos_inds] num_gpus = get_num_gpus() # sync num_pos from all gpus total_num_pos = reduce_sum(pos_inds.new_tensor([pos_inds.numel() ])).item() num_pos_avg_per_gpu = max(total_num_pos / float(num_gpus), 1.0) cls_loss = self.cls_loss_func( box_cls_flatten, labels_flatten.int()) / num_pos_avg_per_gpu if pos_inds.numel() > 0: centerness_targets = self.compute_centerness_targets( reg_targets_flatten) # average sum_centerness_targets from all gpus, # which is used to normalize centerness-weighed reg loss sum_centerness_targets_avg_per_gpu = \ reduce_sum(centerness_targets.sum()).item() / float(num_gpus) reg_loss = self.box_reg_loss_func( box_regression_flatten, reg_targets_flatten, centerness_targets) / sum_centerness_targets_avg_per_gpu centerness_loss = self.centerness_loss_func( centerness_flatten, centerness_targets) / num_pos_avg_per_gpu else: reg_loss = box_regression_flatten.sum() reduce_sum(centerness_flatten.new_tensor([0.0])) centerness_loss = centerness_flatten.sum() #################################### Mask Related Losses ###################################### # get positive proposal labels for each gt instance pos_proposal_labels_for_targets = self.get_pos_proposal_indexes( locations, box_regression, matched_idxes, targets) # get positive samples of embeddings & margins for each gt instance proposal_embed_for_targets, valids_for_targets = self.get_proposal_element( proposal_embed, pos_proposal_labels_for_targets) proposal_margin_for_targets, _ = self.get_proposal_element( proposal_margin, pos_proposal_labels_for_targets) ######## MEANINGLESS_LOSS ####### mask_loss = box_cls[0].new_tensor(0.0) for i in range(len(proposal_embed)): mask_loss += 0 * proposal_embed[i].sum() mask_loss += 0 * proposal_margin[i].sum() mask_loss += 0 * pixel_embed.sum() ############ Mask Losses ############## # get target masks in prefer size N, _, m_h, m_w = pixel_embed.shape o_h = m_h * self.mask_scale_factor o_w = m_w * self.mask_scale_factor r_h = int(m_h * self.fpn_strides[0]) r_w = int(m_w * self.fpn_strides[0]) stride = self.fpn_strides[0] / self.mask_scale_factor targets_masks = [ target_im.get_field('masks').convert('mask').instances.masks.to( device=pixel_embed.device) for target_im in targets ] masks_t = self.prepare_masks(o_h, o_w, r_h, r_w, targets_masks) pixel_embed = interpolate(input=pixel_embed, size=(o_h, o_w), mode="bilinear", align_corners=False) if self.loss_mask_alpha > 0: for im in range(N): valid = valids_for_targets[im] if valid.sum() == 0: continue proposal_embed_im = proposal_embed_for_targets[im][valid] proposal_margin_im = proposal_margin_for_targets[im][valid] masks_t_im = masks_t[im][valid] boxes_t_im = targets[im].bbox[valid] / stride masks_prob = self.compute_mask_prob(proposal_embed_im, proposal_margin_im, pixel_embed[im]) if self.box_padding >= 0: masks_prob_crop, crop_mask = crop_by_box( masks_prob, boxes_t_im, self.box_padding) mask_loss_per_target = self.mask_loss_func(masks_prob_crop, masks_t_im, mask=crop_mask, act=True) else: mask_loss_per_target = self.mask_loss_func(masks_prob, masks_t_im, act=True) mask_loss += mask_loss_per_target.mean() mask_loss = mask_loss / N * self.loss_mask_alpha return cls_loss, reg_loss, centerness_loss, mask_loss