def filter_cls_data(yx_min, yx_max, mask): if mask.numel() > 0: _mask = torch.unsqueeze(mask, -1).repeat(1, 2) # PyTorch's bug yx_min, yx_max = (t[_mask].view(-1, 2) for t in (yx_min, yx_max)) else: # all bboxes are difficult yx_min = utils.ensure_device(torch.zeros(0, 2)) yx_max = utils.ensure_device(torch.zeros(0, 2)) return yx_min, yx_max
def filter_visible(self, yx_min, yx_max, iou, prob, cls): try: score = iou mask = score > self.config.getfloat('detect', 'threshold') except configparser.NoOptionError: score = prob mask = score > self.config.getfloat('detect', 'threshold_cls') _mask = torch.unsqueeze(mask, -1).repeat(1, 2) # PyTorch's bug yx_min, yx_max = (t[_mask].view(-1, 2) for t in (yx_min, yx_max)) cls, score = (t[mask].view(-1) for t in (cls, score)) return yx_min, yx_max, cls, score
def filter_cls_pred(yx_min, yx_max, score, mask): _mask = torch.unsqueeze(mask, -1).repeat(1, 2) # PyTorch's bug yx_min, yx_max = (t[_mask].view(-1, 2) for t in (yx_min, yx_max)) score = score[mask] return yx_min, yx_max, score
def filter_valid(yx_min, yx_max, cls, difficult): mask = torch.prod(yx_min < yx_max, -1) & (difficult < 1) _mask = torch.unsqueeze(mask, -1).repeat(1, 2) # PyTorch's bug cls, = (t[mask] for t in (cls, )) yx_min, yx_max = (t[_mask].view(-1, 2) for t in (yx_min, yx_max)) return yx_min, yx_max, cls
def filter_valid(yx_min, yx_max, cls, difficult): mask = torch.prod(yx_min < yx_max, -1) & (difficult < 1) _mask = torch.unsqueeze(mask, -1).repeat(1, 2) # PyTorch's bug cls, = (t[mask] for t in (cls,)) yx_min, yx_max = (t[_mask].view(-1, 2) for t in (yx_min, yx_max)) return yx_min, yx_max, cls