def create_regression_target(bbox: Tensor, stride: int, size_target: Tuple[int, int]) -> Tensor: r"""Given a set of anchor boxes, creates regression targets each anchor box. Each location in the resultant target gives the distance from that location to the left, top, right, and bottom of the ground truth anchor box (in that order). Args: bbox (:class:`torch.Tensor`): Ground truth anchor boxes in form :math:`x_1, y_1, x_2, y_2`. stride (int): Stride at the FPN level for which the target is being created Shapes: * ``bbox`` - :math:`(*, N, 4)` * Output - :math:`(*, N, 2)`, :math:`(*, N, 4)` """ check_is_tensor(bbox, "bbox") check_dimension(bbox, -1, 4, "bbox") # create starting grid num_boxes = bbox.shape[-2] height, width = size_target[0], size_target[1] grid = FCOSLoss.coordinate_grid(height, width, stride, indexing="xy", device=bbox.device) grid = grid.unsqueeze_(0).repeat(num_boxes, 2, 1, 1) # compute distance to box edges relative to each grid location grid.sub_(bbox[..., None, None]).abs_() return grid
def update(self, pred: torch.Tensor, target: torch.Tensor): """ Update state with predictions and targets. Args: preds: Predictions from model target: Ground truth values """ # preds, target = _input_format(self.num_classes, preds, target, self.threshold, self.multilabel) check_is_tensor(pred, "pred") check_is_tensor(target, "target") check_ndim_match(pred, target, "pred", "target") check_dimension(pred, -1, 6, "pred") check_dimension(target, -1, 5, "pred") # restrict the number of predicted boxes to the top K highest confidence boxes if self.pred_box_limit is not None and pred.shape[ -2] > self.pred_box_limit: indices = pred[..., -2].argsort() pred = pred[indices, ...] assert pred.shape[-2] <= self.pred_box_limit # restrict pred and target to class of interest if self.pos_label is not None: pred_keep = pred[..., -1] == self.pos_label pred = pred[pred_keep] target_keep = target[..., -1] == self.pos_label target = target[target_keep] pred_score, target_class, binary_target = self.get_pred_target_pairs( pred, target) self.pred_score = torch.cat([self.pred_score, pred_score]) self.target_class = torch.cat([self.target_class, target_class]) self.binary_target = torch.cat([self.binary_target, binary_target])
def combine_box_target(bbox: Tensor, label: Tensor, *extra_labels) -> Tensor: r"""Combine a bounding box coordinates and labels into a single label. Args: bbox (:class:`torch.Tensor`): Coordinates of the bounding box. label (:class:`torch.Tensor`): Label associated with each bounding box. Shape: * ``bbox`` - :math:`(*, N, 4)` * ``label`` - :math:`(*, N, 1)` * Output - :math:`(*, N, 4 + 1)` """ # input validation check_is_tensor(bbox, "bbox") check_is_tensor(label, "label") if bbox.shape[-1] != 4: raise ValueError( f"Expected bbox.shape[-1] == 4, found shape {bbox.shape}") if bbox.shape[:-1] != label.shape[:-1]: raise ValueError( f"Expected bbox.shape[:-1] == label.shape[:-1], found shapes {bbox.shape}, {label.shape}" ) return torch.cat([bbox, label, *extra_labels], dim=-1)
def append_bbox_label(old_label: Tensor, new_label: Tensor) -> Tensor: r"""Adds a new label element to an existing bounding box target. The new label will be concatenated to the end of the last dimension in ``old_label``. Args: old_label (:class:`torch.Tensor`): The existing bounding box label new_label (:class:`torch.Tensor`): The label entry to add to ``old_label`` Shape: * ``old_label`` - :math:`(*, N, C_0)` * ``new_label`` - :math:`(B, N, C_1`)` * Output - :math:`(B, N, C_0 + C_1)` """ check_is_tensor(old_label, "old_label") check_is_tensor(new_label, "new_label") check_ndim_match(old_label, new_label, "old_label", "new_label") if old_label.shape[:-1] != new_label.shape[:-1]: raise ValueError( "expected old_label.shape[:-1] == new_label.shape[:-1], " "found {old_label.shape}, {new_label.shape}" ) return torch.cat([old_label, new_label], dim=-1)
def get_global_pred_target_pairs(pred: Tensor, target: Tensor, pad_value: float = -1) -> Tensor: r"""Given predicted CenterNet heatmap and target bounding box label, create a paring of per-class global heatmap maxima to binary labels indicating whether or not the class was present in the true label. Args: pred (:class:`torch.Tensor`): Predicted heatmap. target (:class:`torch.Tensor`): Target bounding boxes in format ``x1, y1, x2, y2, class``. pad_value (float): Value used for padding a batched ``target`` input. Returns: Tensor paring a predicted probability with a binary indicator Shape: * ``pred`` - :math:`(*, C+4, H, W)` * ``target`` - :math:`(*, N_i, 5)` * Output - :math:`(*, C, 2)` """ check_is_tensor(pred, "pred") check_is_tensor(target, "target") is_batch = pred.ndim > 3 assert pred.shape[-3] > 4 if is_batch: assert target.ndim > 2, "pred batched but target not batched" batch_result = [] for pred_i, target_i in zip(pred, target): result = CenterNetMixin.get_global_pred_target_pairs( pred_i, target_i, pad_value) batch_result.append(result) return torch.stack(batch_result, dim=0) # we might be operating on a batched example that was padded, so remove these padded locations pad_locations = (target == -1).all(dim=-1) target = target[~pad_locations, :] # get the global max probability for each class in the heatmap num_classes = pred.shape[-3] - 4 max_pred_scores = CenterNetMixin.heatmap_max_score(pred) # get boolean mask of which classes were present in the target target_class_present = torch.zeros(num_classes, device=target.device).bool() target_class_present[target[..., -1].unique().long()] = True assert max_pred_scores.shape == target_class_present.shape return torch.stack( [max_pred_scores, target_class_present.type_as(max_pred_scores)], dim=-1)
def complete_iou_loss(inputs: Tensor, targets: Tensor, reduction: str = "mean") -> Tensor: # validation check_is_tensor(inputs, "inputs") check_is_tensor(targets, "targets") check_dimension(inputs, -1, 4, "inputs") check_dimension(targets, -1, 4, "targets") check_shapes_match(inputs, targets, "inputs", "targets") inputs = inputs.float() targets = targets.float() # compute euclidean distance between pred and true box centers pred_size = inputs[..., 2:] - inputs[..., :2] target_size = targets[..., 2:] - targets[..., :2] pred_center = pred_size.div(2).add(inputs[..., :2]) target_center = target_size.div(2).add(targets[..., :2]) euclidean_dist_squared = (pred_center - target_center).pow(2).sum(dim=-1) # compute c, the diagonal length of smallest box enclosing pred and true min_coords = torch.min(inputs[..., :2], targets[..., :2]) max_coords = torch.max(inputs[..., 2:], targets[..., 2:]) c_squared = (max_coords - min_coords).pow(2).sum(dim=-1) # compute diou diou = euclidean_dist_squared / c_squared # compute vanilla IoU pred_area = pred_size[..., 0] * pred_size[..., 1] target_area = target_size[..., 0] * target_size[..., 1] lt = torch.max(inputs[..., :2], targets[..., :2]) rb = torch.min(inputs[..., 2:], targets[..., 2:]) wh = (rb - lt).clamp(min=0) inter = wh[..., 0] * wh[..., 1] iou = inter / (pred_area + target_area - inter).clamp_min(1e-9) # compute v, which measure aspect ratio consistency pred_w, pred_h = pred_size[..., 0], pred_size[..., 1] target_w, target_h = target_size[..., 0], target_size[..., 1] _ = torch.atan(target_w / target_h) - torch.atan(pred_w / pred_h) v = 4 / pi ** 2 * _.pow(2) # compute alpha, the tradeoff parameter alpha = v / ((1 - iou) + v).clamp_min(1e-5) # compute the final ciou loss loss = 1 - iou + diou + alpha * v if reduction == "mean": return loss.mean() elif reduction == "sum": return loss.sum() elif reduction == "none": return loss else: raise ValueError(f"Unknown reduction {reduction}")
def filter_heatmap_classes(heatmap: Tensor, keep_classes: Iterable[int], return_inverse: bool = False, with_regression: bool = False) -> Tensor: r"""Filters a CenterNet heatmap based on class, dropping class channels that do not meet the critera. Args: heatmap (:class:`torch.Tensor`): Heatmap to filter keep_classes (iterable of ints): Integer id of the classes to keep return_inverse (:class:`torch.Tensor`): If ``True``, remove channels with classes not in ``keep_classes`` with_regression (bool): If ``True``, expect :math:`C+4` channels in ``heatmap`` Shape: * ``target`` - :math:`(*, C, H, W)` or :math:`(*, C+4, H, W)`, where :math:`C` is the number of classes * Output - :math:`(*, C', H, W)` or :math:`(*, C'+4, H, W)` """ check_is_tensor(heatmap, "heatmap") if not isinstance(keep_classes, Iterable): raise TypeError( f"Expected iterable for keep_classes, found {type(keep_classes)}" ) if not keep_classes: raise ValueError( f"Expected non-empty iterable for keep classes, found {keep_classes}" ) if with_regression: num_classes = heatmap.shape[-3] - 4 else: num_classes = heatmap.shape[-3] assert num_classes > 0 possible_classes = set(range(num_classes)) keep_classes = set( keep_classes ) if not return_inverse else possible_classes - set(keep_classes) keep_heatmap = heatmap[..., tuple(keep_classes), :, :] if with_regression: return torch.cat([keep_heatmap, heatmap[..., -4:, :, :]], dim=-3) else: return keep_heatmap.clone()
def split_bbox_scores_class( target: Tensor, split_scores: Union[bool, Iterable[int]] = False) -> Tuple[Tensor, ...]: r"""Split a predicted bounding box into box coordinates, probability score, and predicted class. This implementation supports multiple score assignments for each box. It is expected that ``target`` be ordered along the last dimension as ``bbox``, ``scores``, ``class``. .. note:: This operation returns views of the original tensor. Args: target (:class:`torch.Tensor`): The target to split. split_scores (bool or iterable of ints): Whether to further decompose the scores tensor. If ``split_scores`` is ``True``, split the scores tensor along the last dimension. If an interable of ints is given, treat each int as a split size arugment to :func:`torch.split` along the last dimension. Shape: * ``target`` - :math:`(*, N, 4 + S + 1)` where :math:`N` is the number of boxes and :math:`S` is the number of scores associated with each box. * Output - :math:`(*, N, 4)`, :math:`(*, N, S)`, and :math:`(*, N, 1)` """ check_is_tensor(target, "target") bbox = target[..., :4] scores = target[..., 4:-1] cls = target[..., -1:] if isinstance(split_scores, bool) and not split_scores: return bbox, scores, cls num_scores = scores.shape[-1] # setup split size of 1 if bool given if isinstance(split_scores, bool): split_scores = [ 1, ] * num_scores lower_bound = 0 upper_bound = 0 final_scores = [] for delta in split_scores: upper_bound = lower_bound + delta final_scores.append(scores[..., lower_bound:upper_bound]) lower_bound = upper_bound assert len(final_scores) == len(split_scores) return tuple([bbox] + final_scores + [cls])
def filter_bbox_classes(target: Tensor, keep_classes: Iterable[int], pad_value: float = -1, return_inverse: bool = False) -> Tensor: r"""Filters bounding boxes based on class, replacing bounding boxes that do not meet the criteria with padding. Integer class ids should be the last column in ``target``. Args: target (:class:`torch.Tensor`): Bounding boxes to filter keep_classes (iterable of ints): Integer id of the classes to keep pad_value (float): Value used to indicate padding in both input and output tensors return_inverse (:class:`torch.Tensor`): If ``True``, remove boxes with classes not in ``keep_classes`` Shape: * ``target`` - :math:`(*, N, C)` * Output - same as ``target`` """ check_is_tensor(target, "target") if not isinstance(keep_classes, Iterable): raise TypeError( f"Expected iterable for keep_classes, found {type(keep_classes)}") if not keep_classes: raise ValueError( f"Expected non-empty iterable for keep classes, found {keep_classes}" ) locations_to_keep = torch.zeros_like(target[..., -1]).bool() for keep_cls in keep_classes: if not isinstance(keep_cls, (float, int)): raise TypeError( f"Expected int or float for keep_classes elements, found {type(keep_cls)}" ) locations_for_cls = torch.as_tensor(target[..., -1] == keep_cls) locations_to_keep.logical_or_(locations_for_cls) if return_inverse: locations_to_keep.logical_not_() target = target.clone() target[~locations_to_keep] = -1 return target
def heatmap_max_score(heatmap: Tensor) -> Tensor: r"""Computes global maximum scores over a heatmap on a per-class basis. Args: heatmap (:class:`torch.Tensor`): CenterNet heatmap Shape: * ``heatmap`` - :math:`(*, C+4, H, W)` * Output - :math:`(*, C)` """ check_is_tensor(heatmap, "heatmap") heatmap = heatmap[..., :-4, :, :] non_spatial_shape = heatmap.shape[:-2] output = heatmap.view(*non_spatial_shape, -1).max(dim=-1).values return output
def compute_centerness_targets(reg_targets: Tensor) -> Tensor: r"""Computes centerness targets given regression targets. Under FCOS, a target regression map is created for each FPN level. Any map location that lies within a ground truth bounding box is assigned a regression target based on the left, right, top, and bottom distance from that location to the edges of the ground truth box. .. image:: ./fcos_target.png :width: 200px :align: center :height: 600px :alt: FCOS Centerness Target For each of these locations with regression targets :math:`l^*, r^*, t^*, b^*`, a "centerness" target is created as follows: .. math:: centerness = \sqrt{\frac{\min(l^*, r*^}{\max(l^*, r*^} \times \frac{\min(t^*, b*^}{\max(t^*, b*^}} Args: reg_targets (:class:`torch.Tensor`): Ground truth regression featuremap in form :math:`x_1, y_1, x_2, y_2`. Shapes: * ``reg_targets`` - :math:`(..., 4)` * Output - :math:`(..., 1)` """ check_is_tensor(reg_targets, "reg_targets") check_dimension(reg_targets, -1, 4, "reg_targets") left_right = reg_targets[..., (0, 2)].float() top_bottom = reg_targets[..., (1, 3)].float() lr_min = left_right.amin(dim=-1).clamp_min_(0) lr_max = left_right.amax(dim=-1).clamp_min_(1) tb_min = top_bottom.amin(dim=-1).clamp_min_(0) tb_max = top_bottom.amax(dim=-1).clamp_min_(1) centerness_lr = lr_min.true_divide_(lr_max) centerness_tb = tb_min.true_divide_(tb_max) centerness = centerness_lr.mul_(centerness_tb).sqrt_().unsqueeze_(-1) assert centerness.shape[:-1] == reg_targets.shape[:-1] assert centerness.shape[-1] == 1 assert centerness.ndim == reg_targets.ndim return centerness
def split_box_target( target: Tensor, split_label: Union[bool, Iterable[int]] = False) -> Tuple[Tensor, ...]: r"""Split a bounding box label set into box coordinates and label tensors. .. note:: This operation returns views of the original tensor. Args: target (:class:`torch.Tensor`): The target to split. split_label (bool or iterable of ints): Whether to further decompose the label tensor. If ``split_label`` is ``True``, split the label tensor along the last dimension. If an interable of ints is given, treat each int as a split size arugment to :func:`torch.split` along the last dimension. Shape: * ``target`` - :math:`(*, N, 4 + C)` where :math:`N` is the number of boxes and :math:`C` is the number of labels associated with each box. * Output - :math:`(*, N, 4)` and :math:`(*, N, C)` """ check_is_tensor(target, "target") bbox = target[..., :4] label = target[..., 4:] if isinstance(split_label, bool) and not split_label: return bbox, label num_labels = label.shape[-1] # setup split size of 1 if bool given if isinstance(split_label, bool): split_label = [ 1, ] * num_labels lower_bound = 0 upper_bound = 0 final_label = [] for delta in split_label: upper_bound = lower_bound + delta final_label.append(label[..., lower_bound:upper_bound]) lower_bound = upper_bound assert len(final_label) == len(split_label) return tuple([bbox] + final_label)
def split_point_target(target: Tensor) -> Tuple[Tensor, Tensor]: r"""Split a CenterNet target into heatmap and regression components. .. note:: This operation returns views of the original tensor. Args: target (:class:`torch.Tensor`): The target to split. Shape: * ``target`` - :math:`(*, C + 4, H, W)` where :math:`C` is the number of classes. * Output - :math:`(*, C, H, W)` and :math:`(*, 4, H, W)` """ check_is_tensor(target, "target") heatmap = target[..., :-4, :, :] bbox = target[..., -4:, :, :] return heatmap, bbox
def mask_to_polygon(mask: Tensor, num_classes: int) -> Tuple[Tuple[Tensor, ...], ...]: check_is_tensor(mask, "mask") assert mask.ndim == 4 assert num_classes > 0 batch_size, _, height, width = mask.shape mask = mask.view(batch_size, height, width) result = [] for elem in mask: for cls in range(num_classes): cls_mask = (elem == cls).byte().numpy() contours, _ = cv2.findContours( cls_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) # type: ignore result.append(tuple([torch.from_numpy(x).long() for x in contours])) return tuple(result)
def flatten_box_target(target: Tensor, pad_value: float = -1) -> Tensor: r"""Flattens a batch of bounding box target tensors, removing padded locations Args: target (:class:`torch.Tensor`): Batch of bounding box targets to split pad_value (float): Value used for padding when creating the batch Shape: * ``target`` - :math:`(B, N, 4 + C)` where :math:`N` is the number of boxes and :math:`C` is the number of labels associated with each box. * Output - :math:`(N_{tot}, 4 + C)` """ check_is_tensor(target, "target") padding_indices = (target == pad_value).all(dim=-1) non_padded_coords = (~padding_indices).nonzero(as_tuple=True) return target[non_padded_coords]
def append_heatmap_label(old_label: Tensor, new_label: Tensor) -> Tensor: r"""Adds a new label element to an existing CenterNet target. The new label will be concatenated to along the heatmap channel dimension immediately preceeding the regression component of the heatmap. Args: old_label (:class:`torch.Tensor`): The existing heatmap label new_label (:class:`torch.Tensor`): The heatmap channel to add to ``old_label`` Shape: * ``old_label`` - :math:`(*, C_0 + 4, H, W)` * ``new_label`` - :math:`(*, C_1, H, W)` * Output - :math:`(*, C_0 + C_1 + 4, H, W)` """ check_is_tensor(old_label, "old_label") check_is_tensor(new_label, "new_label") check_ndim_match(old_label, new_label, "old_label", "new_label") return torch.cat([old_label[..., :-4, :, :], new_label, old_label[..., -4:, :, :]], dim=-3)
def combine_regression(offset: Tensor, size: Tensor) -> Tensor: r"""Combines CenterNet offset and size predictions into a single tensor. Args: offset (:class:`torch.Tensor`): Offset component of the heatmap size (:class:`torch.Tensor`): Size component of the heatmap Returns: Tuple of offset and size tensors Shape: * ``offset`` - :math:`(*, 2, H, W)` * ``size`` - :math:`(*, 2, H, W)` * Output - :math:`(*, 4, H, W)` """ check_is_tensor(offset, "offset") check_is_tensor(size, "size") return torch.cat([offset, size], dim=-3)
def combine_bbox_scores_class(bbox: Tensor, cls: Tensor, scores: Tensor, *extra_scores) -> Tensor: r"""Combine a bounding box coordinates and labels into a single label. Combined tensor will be ordered along the last dimension as ``bbox``, ``scores``, ``cls``. Args: bbox (:class:`torch.Tensor`): Coordinates of the bounding box. cls (:class:`torch.Tensor`): Class associated with each bounding box scores (:class:`torch.Tensor`): Probability associated with each bounding box. *extra_scores (:class:`torch.Tensor`): Additional scores to combine Shape: * ``bbox`` - :math:`(*, N, 4)` * ``scores`` - :math:`(*, N, S)` * ``cls`` - :math:`(*, N, 1)` * Output - :math:`(*, N, 4 + S + 1)` """ # input validation check_is_tensor(bbox, "bbox") check_is_tensor(scores, "scores") check_is_tensor(cls, "cls") if bbox.shape[-1] != 4: raise ValueError(f"Expected bbox.shape[-1] == 4, found shape {bbox.shape}") return torch.cat([bbox, scores, *extra_scores, cls], dim=-1)
def combine_point_target(heatmap: Tensor, regression: Tensor) -> Tensor: r"""Combine a CenterNet heatmap and regression components into a single label. Args: heatmap (:class:`torch.Tensor`): The CenterNet heatmap. regression (:class:`torch.Tensor`): The CenterNet regression map. Shape: * ``heatmap`` - :math:`(*, C, H, W)` * ``regression`` - :math:`(*, 4, H, W)` * Output - :math:`(*, C+4, H, W)` """ # input validation check_is_tensor(heatmap, "heatmap") check_is_tensor(regression, "regression") if regression.shape[-3] != 4: raise ValueError(f"Expected regression.shape[-3] == 4, found shape {regression.shape}") return torch.cat([heatmap, regression], dim=-3)
def split_regression(regression: Tensor) -> Tuple[Tensor, Tensor]: r"""Split a CenterNet regression prediction into offset and sizecomponents. .. note:: This operation returns views of the original tensor. Args: regression (:class:`torch.Tensor`): The target to split. Returns: Tuple of offset and size tensors Shape: * ``target`` - :math:`(*, 4, H, W)` * Output - :math:`(*, 2, H, W)` and :math:`(*, 2, H, W)` """ check_is_tensor(regression, "regression") offset = regression[..., :2, :, :] size = regression[..., 2:, :, :] assert offset.shape[-3] == 2 assert size.shape[-3] == 2 return offset, size
def unbatch_box_target(target: Tensor, pad_value: float = -1) -> List[Tensor]: r"""Splits a padded batch of bounding boxtarget tensors into a list of unpadded target tensors Args: target (:class:`torch.Tensor`): Batch of bounding box targets to split pad_value (float): Value used for padding when creating the batch Shape: * ``target`` - :math:`(B, N, 4 + C)` where :math:`N` is the number of boxes and :math:`C` is the number of labels associated with each box. * Output - :math:`(N, 4 + C)` """ check_is_tensor(target, "target") padding_indices = (target == pad_value).all(dim=-1) non_padded_coords = (~padding_indices).nonzero(as_tuple=True) flat_result = target[non_padded_coords] split_size = non_padded_coords[0].unique(return_counts=True)[1] return torch.split(flat_result, split_size.tolist(), dim=0)
def create_classification_target( bbox: Tensor, cls: Tensor, mask: Tensor, num_classes: int, size_target: Tuple[int, int], ) -> Tensor: check_is_tensor(bbox, "bbox") check_is_tensor(cls, "cls") check_is_tensor(mask, "mask") check_dimension_match(bbox, cls, -2, "bbox", "cls") check_dimension_match(bbox, mask, 0, "bbox", "mask") check_dimension(bbox, -1, 4, "bbox") check_dimension(cls, -1, 1, "cls") target = torch.zeros(num_classes, *mask.shape[-2:], device=mask.device, dtype=torch.float) box_id, h, w = mask.nonzero(as_tuple=True) class_id = cls[box_id, 0] target[class_id, h, w] = 1.0 return target
def bbox_to_mask(bbox: Tensor, stride: int, size_target: Tuple[int, int], center_radius: Optional[float] = None) -> Tensor: r"""Creates a mask for each input anchor box indicating which heatmap locations for that box should be positive examples. Under FCOS, a target maps are created for each FPN level. Any map location that lies within ``center_radius * stride`` units from the center of the ground truth bounding box is considered a positive example for regression and classification. This method creates a mask for FPN level with stride ``stride``. The mask will have shape :math:`(N, H, W)` where :math:`(H, W)` are given in ``size_target``. Mask locations that lie within ``center_radius * stride`` units of the box center will be ``True``. If ``center_radius=None``, all locations within a box will be considered positive. Args: bbox (:class:`torch.Tensor`): Ground truth anchor boxes in form :math:`x_1, y_1, x_2, y_2`. stride (int): Stride at the FPN level for which the target is being created size_target (tuple of int, int): Height and width of the mask. Should match the height and width of the FPN level for which a target is being created. center_radius (float, optional): Radius (in units of ``stride``) about the center of each box for which examples should be considered positive. If ``center_radius=None``, all locations within a box will be considered positive. Shapes: * ``reg_targets`` - :math:`(..., 4, H, W)` * Output - :math:`(..., 1, H, W)` """ check_is_tensor(bbox, "bbox") check_dimension(bbox, -1, 4, "bbox") # create mesh grid of size `size_target` # locations in grid give h/w at center of that location # # we will compare bbox coords against this grid to find locations that lie within # the center_radius of bbox num_boxes = bbox.shape[-2] h = torch.arange(size_target[0], dtype=torch.float, device=bbox.device) w = torch.arange(size_target[1], dtype=torch.float, device=bbox.device) mask = (torch.stack(torch.meshgrid(h, w), 0).mul_(stride).add_( stride / 2).unsqueeze_(0).expand(num_boxes, -1, -1, -1)) # get edge coordinates of each box based on whole box or center sampled lower_bound = bbox[..., :2] upper_bound = bbox[..., 2:] if center_radius is not None: assert center_radius >= 1 # update bounds according to radius from center center = (bbox[..., :2] + bbox[..., 2:]).true_divide(2) offset = center.new_tensor([stride, stride]).mul_(center_radius) lower_bound = torch.max(lower_bound, center - offset[None]) upper_bound = torch.min(upper_bound, center + offset[None]) # x1y1 to h1w1, add h/w dimensions, convert to strided coords lower_bound = lower_bound[..., (1, 0), None, None] upper_bound = upper_bound[..., (1, 0), None, None] # use edge coordinates to create a binary mask mask = (mask >= lower_bound).logical_and_(mask <= upper_bound).all( dim=-3) return mask
def visualize_heatmap( heatmap: Tensor, background: Optional[Tensor] = None, cmap: str = "gnuplot", same_on_batch: bool = True, heatmap_alpha: float = 0.5, background_alpha: float = 0.5, ) -> List[ByteTensor]: r"""Generates visualizations of a CenterNet heatmap. Can optionally overlay the heatmap on top of a background image. Args: heatmap (:class:`torch.Tensor`): The heatmap to visualize background (:class:`torch.Tensor`): An optional background image for the heatmap visualization cmap (str): Matplotlib colormap same_on_batch (bool): See :func:`combustion.vision.to_8bit` heatmap_alpha (float): See :func:`combustion.util.alpha_blend` background_alpha (float): See :func:`combustion.util.alpha_blend` Returns: List of tensors, where each tensor is a heatmap visualization for one class in the heatmap Shape: * ``heatmap`` - :math:`(N, C, H, W)` where :math:`C` is the number of classes in the heatmap. * Output - :math:`(N, 3, H, W)` """ check_is_tensor(heatmap, "heatmap") if background is not None: check_is_tensor(background, "heatmap") # need background to be float [0, 1] for alpha blend w/ heatmap background = to_8bit(background, same_on_batch=same_on_batch).float().div_(255).cpu() if background.shape[-3] == 1: repetitions = [ 1, ] * background.ndim repetitions[-3] = 3 background = background.repeat(*repetitions) num_channels = heatmap.shape[-3] result = [] for channel_idx in range(num_channels): _ = heatmap[..., channel_idx : channel_idx + 1, :, :] _ = to_8bit(_, same_on_batch=same_on_batch) # output is float from [0, 1] heatmap_channel = apply_colormap(_.cpu(), cmap=cmap) # drop alpha channel heatmap_channel = heatmap_channel[..., :3, :, :] # alpha blend w/ background if background is not None: heatmap_channel = F.interpolate( heatmap_channel, size=background.shape[-2:], mode="bilinear", align_corners=True ) heatmap_channel = alpha_blend(heatmap_channel, background, heatmap_alpha, background_alpha)[0] heatmap_channel = heatmap_channel.mul_(255).byte() result.append(heatmap_channel) return result