def encode_ground_truth(self, ground_truth, anchors): """ Args: ground_truth: list(:len Batch) of torch.tensor(:shape [Boxes_i, 5]) anchors: torch.tensor(:shape [AnchorBoxes, 4]) Returns: target_classes: torch.tensor(:shape [Batch, AnchorBoxes]) target_locs: torch.tensor(:shape [Batch, AnchorBoxes, 4]) """ batch_size = len(ground_truth) num_anchors = anchors.size(0) device = anchors.device ground_truth = [x.to(device, non_blocking=True) for x in ground_truth] corner_anchors = box_utils.to_corners(anchors) target = torch.zeros((batch_size, num_anchors, TARGET_SIZE), dtype=torch.float32, device=device) target[..., CLASS_INDEX] = torch.full_like(target[..., SCORE_INDEX], NEGATIVE_CLASS) target[..., SCORE_INDEX] = torch.ones_like(target[..., SCORE_INDEX]) for i, gt in enumerate(ground_truth): if not len(gt): continue gt_boxes = gt[:, det_ds.LOC_INDEX_START:det_ds.LOC_INDEX_END] weights = box_utils.iou(gt_boxes, corner_anchors) box_idx = matcher.match_per_prediction(weights, self.matched_threshold, self.unmatched_threshold) matched = box_idx.ne(matcher.NOT_MATCHED) & box_idx.ne( matcher.IGNORE) target[i, matched, LOC_INDEX_START:LOC_INDEX_END] = gt[ box_idx[matched], det_ds.LOC_INDEX_START:det_ds.LOC_INDEX_END] target[i, matched, CLASS_INDEX] = gt[box_idx[matched], det_ds.CLASS_INDEX] target[i, matched, SCORE_INDEX] = gt[box_idx[matched], det_ds.SCORE_INDEX] ingored = box_idx.eq(IGNORE_CLASS) target[i, ingored, CLASS_INDEX] = IGNORE_CLASS target[i, ingored, SCORE_INDEX] = IGNORE_CLASS positive = target[..., CLASS_INDEX].ne(NEGATIVE_CLASS) & target[ ..., CLASS_INDEX].ne(IGNORE_CLASS) assert not torch.isnan( target[..., LOC_INDEX_START:LOC_INDEX_END][positive]).any().item() return target
def encode_ground_truth(self, ground_truth, priors): """ Args: ground_truth: list(:len Batch) of torch.tensor(:shape [Boxes_i, 5]) priors: torch.tensor(:shape [AnchorBoxes, 4]) Returns: target_classes: torch.tensor(:shape [Batch, AnchorBoxes]) target_locs: torch.tensor(:shape [Batch, AnchorBoxes, 4]) """ batch_size = len(ground_truth) num_priors = priors.size(0) device = priors.device ground_truth = [x.to(device, non_blocking=True) for x in ground_truth] corner_priors = box_utils.to_corners(priors) target_classes = torch.zeros((batch_size, num_priors), dtype=torch.long, device=device) target_locs = torch.zeros((batch_size, num_priors, 4), dtype=torch.float32, device=device) for i, gt in enumerate(ground_truth): if len(gt) == 0: continue weights = box_utils.jaccard(gt[:, :4], corner_priors) box_idx = match_per_prediction( weights, matched_threshold=self.matched_threshold, unmatched_threshold=self.unmatched_threshold) matched = box_idx >= 0 target_classes[i, matched] = gt[box_idx[matched], 4].long() target_locs[i, matched] = gt[box_idx[matched], :4] ingored = box_idx == -1 target_classes[i, ingored] = -1 target_locs = box_utils.to_centroids(target_locs) target_locs = self.box_coder.encode_box(target_locs, priors, inplace=False) assert not torch.isnan(target_locs[target_classes.gt(0)]).any().item() return target_classes, target_locs
def postprocess(self, prediction, priors): """ Args: prediction: tuple of torch.tensor(:shape [Batch, AnchorBoxes * Classes]) torch.tensor(:shape [Batch, AnchorBoxes * 4]) priors: torch.tensor(:shape [AnchorBoxes, 4] Returns: processed: list(:len Batch) of torch.tensor(:shape [Boxes_i, 6] ~ {[0-3] - box, [4] - class, [5] - score}) """ b_scores, b_boxes = prediction batch_size = b_scores.size(0) num_priors = priors.size(0) b_scores = b_scores.view(batch_size, num_priors, -1) b_scores = self.score_converter_fn(b_scores) num_classes = b_scores.size(-1) if self.score_converter != 'SIGMOID': num_classes = num_classes - 1 b_scores = b_scores[..., 1:] b_scores = b_scores.cpu() b_boxes = b_boxes.view(batch_size, num_priors, 4) b_boxes = b_boxes.to(priors.device) b_boxes = self.box_coder.decode_box(b_boxes, priors, inplace=False) b_boxes = box_utils.to_corners(b_boxes) b_boxes = b_boxes.cpu() processed = [] for scores, boxes in zip(b_scores, b_boxes): picked = [] for class_index in range(0, num_classes): class_scores = scores[:, class_index] mask = class_scores > self.score_threshold (boxes_picked, scores_picked), _ = self.nms(boxes[mask], class_scores[mask]) classes_picked = torch.full_like(scores_picked.unsqueeze_(1), class_index + 1, dtype=torch.float) picked.append( torch.cat([boxes_picked, classes_picked, scores_picked], dim=-1)) picked = torch.cat(picked, dim=0) if self.max_total is not None and self.max_total < picked.size(0): _, indexes = torch.topk(picked[:, 5], self.max_total, sorted=True, largest=True) picked = picked[indexes] processed.append(picked) return processed
def forward(self, pred, anchors, target): """ Args: pred: tuple of torch.tensor(:shape [Batch, AnchorBoxes * Classes]) torch.tensor(:shape [Batch, AnchorBoxes * 4]) target: torch.tensor(:shape [Batch, AnchorBoxes, 6]) Returns: losses: tuple(float, float) """ scores, locs = pred target_locs = target[..., LOC_INDEX_START:LOC_INDEX_END] target_classes = target[..., CLASS_INDEX].long() target_scores = target[..., SCORE_INDEX] batch_size = target.size(0) num_priors = target.size(1) scores = scores.view(batch_size, num_priors, -1) locs = locs.view(batch_size, num_priors, 4) positive_mask = target_classes.ne(NEGATIVE_CLASS) & target_classes.ne( IGNORE_CLASS) sampled_mask = self.sampler(scores, target_classes) scores = scores[sampled_mask] target_classes = target_classes[sampled_mask] target_scores = target_scores[sampled_mask] if self.multiclass: class_target = torch.zeros_like(scores) mask = target_classes.ne(NEGATIVE_CLASS) & target_classes.ne( IGNORE_CLASS) class_target[mask, target_classes[mask] - 1] = target_scores[mask] elif self.soft_target: class_target = torch.zeros_like(scores) mask = target_classes.ne(IGNORE_CLASS) class_target[mask, target_classes[mask]] = target_scores[mask] else: class_target = target_classes.view(-1) class_loss = self.classification_loss(scores, class_target) if self.iou_loss: locs = self.box_coder.decode_box(locs, anchors) locs = box_utils.to_corners(locs) else: box_utils.to_centroids(target_locs, inplace=True) self.box_coder.encode_box(target_locs, anchors, inplace=True) positive_locs = locs[positive_mask].view(-1, 4) positive_target_locs = target_locs[positive_mask].view(-1, 4) loc_loss = self.localization_loss(positive_locs, positive_target_locs) divider = positive_mask.sum().clamp_(min=1).float() loc_loss.mul_(self.localization_weight).div_(divider) class_loss.mul_(self.classification_weight).div_(divider) loss = class_loss + loc_loss return loss, class_loss, loc_loss