def decode(keypoint, reg, positions, stride, cfg, prob_thresh=0.05): N = keypoint.shape[0] boxlists = [] flat_positions = positions.permute((1, 2, 0)).reshape((-1, 2)) img_height = positions.shape[1] * stride img_width = positions.shape[2] * stride for n in range(N): per_keypoint = keypoint[n] per_reg = reg[n] num_labels = per_keypoint.shape[0] is_over_thresh = per_keypoint > prob_thresh is_local_max = torch.ones_like(is_over_thresh) if cfg.model.centernet.max_pool_nms: is_local_max = per_keypoint == F.max_pool2d( per_keypoint, kernel_size=3, stride=1, padding=1) is_pos = is_local_max * is_over_thresh num_pos = is_pos.sum() if num_pos == 0: bl = BoxList( torch.empty((0, 4)), labels=torch.empty((0,)), scores=torch.empty((0,))) else: flat_is_pos, _ = is_pos.permute((1, 2, 0)).reshape((-1, num_labels)).max(1) flat_per_reg = per_reg.permute((1, 2, 0)).reshape((-1, 2)) flat_per_reg = flat_per_reg[flat_is_pos, :] sizes = flat_per_reg centers = flat_positions[flat_is_pos] boxes = torch.cat([centers - sizes / 2, centers + sizes / 2], dim=1) flat_per_keypoint = per_keypoint.permute((1, 2, 0)).reshape((-1, num_labels)) flat_per_keypoint = flat_per_keypoint[flat_is_pos, :] scores, labels = flat_per_keypoint.max(1) bl = BoxList(boxes, labels=labels, scores=scores) bl.clamp(img_height, img_width) boxlists.append(bl) return boxlists
def decode_batch_output(output, pyramid_shape, img_height, img_width, iou_thresh=0.5): """Decode output for batch of images. Args: output: list of tuples where each tuple corresponds to a pyramid level tuple is of form (reg_arr, label_arr, center_arr) where - reg_arr is tensor<n, 4, h, w>, - label_arr is tensor<n, num_labels, h, w> - center_arr is tensor<n, 1, h, w> and label_arr and center_arr are logits pyramid_shape: img_height: img_width: iou_thresh: (float) iou threshold passed to NMS Returns: list of n BoxLists """ boxlists = [] batch_sz = output[0][0].shape[0] for i in range(batch_sz): single_head_out = [] for level, (reg_arr, label_arr, center_arr) in enumerate(output): # Convert logits in label_arr and center_arr to probabilities. single_head_out.append(( reg_arr[i], torch.sigmoid(label_arr[i]), torch.sigmoid(center_arr[i]))) boxlist = decode_single_output(single_head_out, pyramid_shape) boxlist = BoxList( boxlist.boxes, labels=boxlist.get_field('labels'), scores=boxlist.get_field('scores') * boxlist.get_field('centerness'), centerness=boxlist.get_field('centerness')) boxlist = boxlist.clamp(img_height, img_width) boxlist = boxlist.nms(iou_thresh=iou_thresh) boxlists.append(boxlist) return boxlists