def loss_masks(self, outputs, targets, indices, num_boxes): """Compute the losses related to the masks: the focal loss and the dice loss. targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] """ assert "pred_masks" in outputs src_idx = self._get_src_permutation_idx(indices) tgt_idx = self._get_tgt_permutation_idx(indices) src_masks = outputs["pred_masks"] src_masks = src_masks[src_idx] masks = [t["masks"] for t in targets] target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() target_masks = target_masks.to(src_masks) target_masks = target_masks[tgt_idx] # upsample predictions to the target size src_masks = interpolate( src_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False, ) src_masks = src_masks[:, 0].flatten(1) target_masks = target_masks.flatten(1) target_masks = target_masks.view(src_masks.shape) losses = { "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes), "loss_dice": dice_loss(src_masks, target_masks, num_boxes), } return losses
def resize(image, target, size, max_size=None): # size can be min_size (scalar) or (w, h) tuple def get_size_with_aspect_ratio(image_size, size, max_size=None): w, h = image_size if max_size is not None: min_original_size = float(min((w, h))) max_original_size = float(max((w, h))) if max_original_size / min_original_size * size > max_size: size = int(round(max_size * min_original_size / max_original_size)) if (w <= h and w == size) or (h <= w and h == size): return (h, w) if w < h: ow = size oh = int(size * h / w) else: oh = size ow = int(size * w / h) return (oh, ow) def get_size(image_size, size, max_size=None): if isinstance(size, (list, tuple)): return size[::-1] else: return get_size_with_aspect_ratio(image_size, size, max_size) size = get_size(image.size, size, max_size) rescaled_image = F.resize(image, size) if target is None: return rescaled_image, None ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) ratio_width, ratio_height = ratios target = target.copy() if "boxes" in target: boxes = target["boxes"] scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) target["boxes"] = scaled_boxes if "area" in target: area = target["area"] scaled_area = area * (ratio_width * ratio_height) target["area"] = scaled_area h, w = size target["size"] = torch.tensor([h, w]) if "masks" in target: target['masks'] = interpolate( target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5 return rescaled_image, target
def forward(self, outputs, processed_sizes, target_sizes=None): # noqa: C901 """This function computes the panoptic prediction from the model's predictions. Parameters: outputs: This is a dict coming directly from the model. See the model doc for the content. processed_sizes: This is a list of tuples (or torch tensors) of sizes of the images that were passed to the model, ie the size after data augmentation but before batching. target_sizes: This is a list of tuples (or torch tensors) corresponding to the requested final size of each prediction. If left to None, it will default to the processed_sizes """ if target_sizes is None: target_sizes = processed_sizes assert len(processed_sizes) == len(target_sizes) out_logits, raw_masks, raw_boxes = ( outputs["pred_logits"], outputs["pred_masks"], outputs["pred_boxes"], ) assert len(out_logits) == len(raw_masks) == len(target_sizes) preds = [] def to_tuple(tup): if isinstance(tup, tuple): return tup return tuple(tup.cpu().tolist()) for cur_logits, cur_masks, cur_boxes, size, target_size in zip( out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes ): # we filter empty queries and detection below threshold scores, labels = cur_logits.softmax(-1).max(-1) keep = labels.ne(outputs["pred_logits"].shape[-1] - 1) & ( scores > self.threshold ) cur_scores, cur_classes = cur_logits.softmax(-1).max(-1) cur_scores = cur_scores[keep] cur_classes = cur_classes[keep] cur_masks = cur_masks[keep] cur_masks = interpolate( cur_masks[:, None], to_tuple(size), mode="bilinear" ).squeeze(1) cur_boxes = box_ops.box_cxcywh_to_xyxy(cur_boxes[keep]) h, w = cur_masks.shape[-2:] assert len(cur_boxes) == len(cur_classes) # It may be that we have several predicted masks for the same stuff class. # In the following, we track the list of masks ids for each stuff class (they are merged later on) cur_masks = cur_masks.flatten(1) stuff_equiv_classes = defaultdict(lambda: []) for k, label in enumerate(cur_classes): if not self.is_thing_map[label.item()]: stuff_equiv_classes[label.item()].append(k) def get_ids_area(masks, scores, dedup=False): # This helper function creates the final panoptic segmentation image # It also returns the area of the masks that appears on the image m_id = masks.transpose(0, 1).softmax(-1) if m_id.shape[-1] == 0: # We didn't detect any mask :( m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device) else: m_id = m_id.argmax(-1).view(h, w) if dedup: # Merge the masks corresponding to the same stuff class for equiv in stuff_equiv_classes.values(): if len(equiv) > 1: for eq_id in equiv: m_id.masked_fill_(m_id.eq(eq_id), equiv[0]) final_h, final_w = to_tuple(target_size) seg_img = Image.fromarray(id2rgb(m_id.view(h, w).cpu().numpy())) seg_img = seg_img.resize( size=(final_w, final_h), resample=Image.NEAREST ) np_seg_img = ( torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes())) .view(final_h, final_w, 3) .numpy() ) m_id = torch.from_numpy(rgb2id(np_seg_img)) area = [] for i in range(len(scores)): area.append(m_id.eq(i).sum().item()) return area, seg_img area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True) if cur_classes.numel() > 0: # We know filter empty masks as long as we find some while True: filtered_small = torch.as_tensor( [area[i] <= 4 for i, c in enumerate(cur_classes)], dtype=torch.bool, device=keep.device, ) if filtered_small.any().item(): cur_scores = cur_scores[~filtered_small] cur_classes = cur_classes[~filtered_small] cur_masks = cur_masks[~filtered_small] area, seg_img = get_ids_area(cur_masks, cur_scores) else: break else: cur_classes = torch.ones(1, dtype=torch.long, device=cur_classes.device) segments_info = [] for i, a in enumerate(area): cat = cur_classes[i].item() segments_info.append( { "id": i, "isthing": self.is_thing_map[cat], "category_id": cat, "area": a, } ) del cur_classes with io.BytesIO() as out: seg_img.save(out, format="PNG") predictions = { "png_string": out.getvalue(), "segments_info": segments_info, } preds.append(predictions) return preds