Exemple #1
0
    def forward(self, outputs, batch):
        opt = self.opt
        hm_loss, iou_loss = 0, 0
        # hm_loss, iou_loss = [torch.autograd.Variable(torch.tensor(0.)).cuda() for _ in [1,1]]
        # hm_loss.require_grad = True
        # iou_loss.require_grad = True

        for s in range(opt.num_stacks):
            output = outputs[s]
            output['hm'] = _sigmoid(output['hm'])
            hm_loss += self.crit(output['hm'], batch['hm'])
            WH = _tranpose_and_gather_feat(output['wh'], batch['ind'])
            REG = _tranpose_and_gather_feat(output['reg'], batch['ind'])
            for hm, ct_int, mask, reg, wh, reg_, wh_ in zip(
                    output['hm'], batch['ctr'], batch['reg_mask'],
                    batch['reg'], batch['wh'], REG, WH):
                if mask.sum():
                    iou_loss += self.iou(reg[mask], wh[mask], reg_[mask],
                                         wh_[mask])
                    # ct_int = ct_int[mask]
                    # xs = ct_int[:, 0]
                    # ys = ct_int[:, 1]
                    # self.hau(hm[0], torch.stack([ys, xs], -1))
                    # hm_loss += self.hau(hm[0], torch.stack([ys, xs], -1))
            # breakpoint()
        # hm_loss /= opt.batch_size * opt.num_stacks
        hm_loss /= opt.num_stacks
        iou_loss /= opt.batch_size * opt.num_stacks
        loss = opt.hm_weight * hm_loss + opt.wh_weight * iou_loss
        loss_stats = {
            'loss': loss,
            'hm_loss': hm_loss,
            'iou_loss': iou_loss or torch.zeros_like(loss)
        }
        return loss, loss_stats
Exemple #2
0
 def forward(self, output, mask, ind, target):
     pred = _tranpose_and_gather_feat(output, ind)
     mask = mask.unsqueeze(2).expand_as(pred).float()
     loss = F.l1_loss(pred * mask,
                      target * mask,
                      reduction='elementwise_mean')
     return loss
Exemple #3
0
 def forward(self, output, mask, ind, target):
     pred = _tranpose_and_gather_feat(output, ind)
     mask = mask.float()
     # loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean')
     loss = F.l1_loss(pred * mask, target * mask, size_average=False)
     loss = loss / (mask.sum() + 1e-4)
     return loss
Exemple #4
0
 def forward(self, output, mask, ind, target):
     pred = _tranpose_and_gather_feat(output, ind)
     mask = mask.unsqueeze(2).expand_as(pred).float()
     # loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean')
     pred = pred / (target + 1e-4)
     target = target * 0 + 1
     loss = F.l1_loss(pred * mask, target * mask, size_average=False)
     loss = loss / (mask.sum() + 1e-4)
     return loss
Exemple #5
0
    def forward(self, outputs, batch, dataparallel=False):
        opt = self.opt
        hm_loss, wh_loss, off_loss, id_loss = 0, 0, 0, 0
        for s in range(opt.num_stacks):
            output = outputs[s]
            if not opt.mse_loss:
                output['hm'] = _sigmoid(output['hm'])

            hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks
            if opt.wh_weight > 0:
                if opt.dense_wh:
                    mask_weight = batch['dense_wh_mask'].sum() + 1e-4
                    wh_loss += (self.crit_wh(
                        output['wh'] * batch['dense_wh_mask'],
                        batch['dense_wh'] * batch['dense_wh_mask']) /
                                mask_weight) / opt.num_stacks
                else:
                    wh_loss += self.crit_reg(output['wh'], batch['reg_mask'],
                                             batch['ind'],
                                             batch['wh']) / opt.num_stacks

            if opt.reg_offset and opt.off_weight > 0:
                off_loss += self.crit_reg(output['reg'], batch['reg_mask'],
                                          batch['ind'],
                                          batch['reg']) / opt.num_stacks

            if opt.id_weight > 0:
                id_head = _tranpose_and_gather_feat(output['id'], batch['ind'])
                id_head = id_head[batch['reg_mask'] > 0].contiguous()
                id_head = self.emb_scale * F.normalize(id_head)
                id_target = batch['ids'][batch['reg_mask'] > 0]
                id_output = self.classifier(id_head).contiguous()
                id_loss += self.IDLoss(id_output, id_target)

        det_loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + opt.off_weight * off_loss

        loss = torch.exp(-self.s_det) * det_loss + torch.exp(
            -self.s_id) * id_loss + (self.s_det + self.s_id)
        loss *= 0.5

        if not dataparallel:
            loss_stats = {
                'loss': loss.mean().item(),
                'hm_loss': hm_loss.mean().item(),
                'wh_loss': wh_loss.mean().item(),
                'off_loss': off_loss.mean().item(),
                'id_loss': id_loss.mean().item()
            }
        else:
            loss_stats = {
                'loss': loss,
                'hm_loss': hm_loss,
                'wh_loss': wh_loss,
                'off_loss': off_loss,
                'id_loss': id_loss
            }
        return loss, loss_stats
Exemple #6
0
    def forward(self, outputs, batch):
        opt = self.opt
        hm_loss, wh_loss, off_loss, id_loss = 0, 0, 0, 0
        for s in range(opt.num_stacks):
            output = outputs[s]
            if not opt.mse_loss:
                output['hm'] = _sigmoid(output['hm'])

            hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks
            if opt.wh_weight > 0:
                wh_loss += self.crit_reg(output['wh'], batch['reg_mask'],
                                         batch['ind'],
                                         batch['wh']) / opt.num_stacks

            if opt.reg_offset and opt.off_weight > 0:
                off_loss += self.crit_reg(output['reg'], batch['reg_mask'],
                                          batch['ind'],
                                          batch['reg']) / opt.num_stacks

            if opt.id_weight > 0:
                id_head = _tranpose_and_gather_feat(output['id'], batch['ind'])
                id_head = id_head[batch['reg_mask'] > 0].contiguous()
                id_head = self.emb_scale * F.normalize(id_head)
                id_target = batch['ids'][batch['reg_mask'] > 0]

                id_output = self.classifier(id_head).contiguous()
                if self.opt.id_loss == 'focal':
                    id_target_one_hot = id_output.new_zeros(
                        (id_head.size(0),
                         self.nID)).scatter_(1,
                                             id_target.long().view(-1, 1), 1)
                    id_loss += sigmoid_focal_loss_jit(
                        id_output,
                        id_target_one_hot,
                        alpha=0.25,
                        gamma=2.0,
                        reduction="sum") / id_output.size(0)
                else:
                    id_loss += self.IDLoss(id_output, id_target)

        det_loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + opt.off_weight * off_loss
        if opt.multi_loss == 'uncertainty':
            loss = torch.exp(-self.s_det) * det_loss + torch.exp(
                -self.s_id) * id_loss + (self.s_det + self.s_id)
            loss *= 0.5
        else:
            loss = det_loss + 0.1 * id_loss

        loss_stats = {
            'loss': loss,
            'hm_loss': hm_loss,
            'wh_loss': wh_loss,
            'off_loss': off_loss,
            'id_loss': id_loss
        }
        return loss, loss_stats
Exemple #7
0
  def debug(self, detections, targets, ae_threshold):
      tl_heat = detections['tl_heatmap']
      br_heat = detections['br_heatmap']
      ct_heat = detections['ct_heatmap']

      tl_scores, tl_inds, tl_clses, tl_ys, tl_xs = _topk_original(tl_heat, K=128)
      br_scores, br_inds, br_clses, br_ys, br_xs = _topk_original(br_heat, K=128)
      ct_scores, ct_inds, ct_clses, ct_ys, ct_xs = _topk_original(ct_heat, K=128)

      tl_tag = detections['tl_tag']
      br_tag = detections['br_tag']

      # gather by gt
      # tl_tag = _tranpose_and_gather_feat(tl_tag, targets['tl_tag'].to(torch.device("cuda")))
      # br_tag = _tranpose_and_gather_feat(br_tag, targets['br_tag'].to(torch.device("cuda")))
      # gather by top k
      tl_tag = _tranpose_and_gather_feat(tl_tag, tl_inds)
      br_tag = _tranpose_and_gather_feat(br_tag, br_inds)

      dists_tl_br = torch.abs(tl_tag - br_tag)

      dist_inds = (dists_tl_br < ae_threshold)
      dist_inds = dist_inds.squeeze(2)

      # get tl, bl and br index of heatmap after grouping
      tl_inds = tl_inds[dist_inds].to(torch.device("cpu")).numpy()
      br_inds = br_inds[dist_inds].to(torch.device("cpu")).numpy()
      # tl bl br index of heatmap groundtruth
      tl_tag_gt = targets['tl_tag'].to(torch.device("cpu")).numpy()
      br_tag_gt = targets['br_tag'].to(torch.device("cpu")).numpy()

      tl_intersect, _, tl_inds = np.intersect1d(tl_inds, tl_tag_gt[0], return_indices=True)
      br_intersect, _, br_inds = np.intersect1d(br_inds, br_tag_gt[0], return_indices=True)

      tl_br_intersect = np.intersect1d(tl_inds, br_inds)

      # true_positive = (dist_inds & targets['reg_mask'].to(torch.device("cuda")))
      # true_positive = true_positive.to(torch.device("cpu")).numpy()

      # print("Recall is {} out of {}".format(true_positive.sum(), targets['reg_mask'].numpy().sum()))

      # return true_positive.sum() / targets['reg_mask'].numpy().sum()
      return len(tl_br_intersect)
Exemple #8
0
    def predict(self, img0):
        # img0 = cv2.imread(img_path)  # BGR
        img, _, _, _ = letterbox(img0, height=640, width=640)
        # Normalize RGB
        img = img[:, :, ::-1].transpose(2, 0, 1)
        img = np.ascontiguousarray(img, dtype=np.float32)
        img /= 255.0
        im_blob = torch.from_numpy(img).unsqueeze(0).to(self.device)

        width = img0.shape[1]
        height = img0.shape[0]
        inp_height = im_blob.shape[2]
        inp_width = im_blob.shape[3]
        c = np.array([width / 2., height / 2.], dtype=np.float32)
        s = max(float(inp_width) / float(inp_height) * height, width) * 1.0
        meta = {
            'c': c,
            's': s,
            'out_height': inp_height // down_ratio,
            'out_width': inp_width // down_ratio
        }
        ''' Step 1: Network forward, get detections & embeddings'''
        with torch.no_grad():
            output = self.model(im_blob)[-1]
            hm = output['hm'].sigmoid_()
            wh = output['wh']
            id_feature = output['id']
            id_feature = F.normalize(id_feature, dim=1)

            reg = output['reg']
            dets, inds = mot_decode(hm, wh, reg=reg, ltrb=True, K=500)
            id_feature = _tranpose_and_gather_feat(id_feature, inds)
            id_feature = id_feature.squeeze(0)
            id_feature = id_feature.cpu().numpy()

        dets = post_process(dets, meta)
        dets = merge_outputs([dets])[1]
        remain_inds = dets[:, 4] > self.conf_thres
        dets = dets[remain_inds]
        id_feature = id_feature[remain_inds]
        res = []
        for i in range(0, dets.shape[0]):
            bbox = dets[i][0:4]
            bbox = [
                int(min(bbox[0], bbox[2])),
                int(min(bbox[1], bbox[3])),
                int(max(bbox[0], bbox[2])),
                int(max(bbox[1], bbox[3])),
            ]
            res.append({
                "reid": id_feature[i],
                "bbox": bbox,
            })
        return res
Exemple #9
0
    def forward(self, outputs, batch):
        opt = self.opt
        hm_loss, wh_loss, off_loss, id_loss, count_loss = 0, 0, 0, 0, 0
        for s in range(opt.num_stacks):
            output = outputs[s]
            if not opt.mse_loss:
                output['hm'] = _sigmoid(output['hm'])

            hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks
            if opt.wh_weight > 0:
                wh_loss += self.crit_reg(output['wh'], batch['reg_mask'],
                                         batch['ind'],
                                         batch['wh']) / opt.num_stacks

            if opt.reg_offset and opt.off_weight > 0:
                off_loss += self.crit_reg(output['reg'], batch['reg_mask'],
                                          batch['ind'],
                                          batch['reg']) / opt.num_stacks

            if opt.id_weight > 0:
                id_head = _tranpose_and_gather_feat(output['id'], batch['ind'])
                id_head = id_head[batch['reg_mask'] > 0].contiguous()
                id_head = self.emb_scale * F.normalize(id_head)
                id_target = batch['ids'][batch['reg_mask'] > 0]

                id_output = self.classifier(id_head).contiguous()
                id_loss += self.IDLoss(id_output, id_target)

            if opt.count_weight > 0:
                output['density'] = F.sigmoid(output['density'])
                density_loss = self.crit_density_focal(
                    output['density'],
                    batch['density']) + self.crit_density_ssim(
                        output['density'], batch['density'])
                count_loss += self.crit_count(output['count'],
                                              batch['count']) + density_loss

        det_loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + opt.off_weight * off_loss + opt.count_weight * count_loss

        loss = torch.exp(-self.s_det) * det_loss + torch.exp(
            -self.s_id) * id_loss + (self.s_det + self.s_id)
        loss *= 0.5

        loss_stats = {
            'loss': loss,
            'hm_loss': hm_loss,
            'wh_loss': wh_loss,
            'off_loss': off_loss,
            'id_loss': id_loss,
            "density_loss": density_loss,
            'count_loss': count_loss
        }
        return loss, loss_stats
Exemple #10
0
    def forward(self, outputs, batch):
        opt = self.opt
        hm_loss, wh_loss, off_loss, id_loss = 0, 0, 0, 0
        for s in range(opt.num_stacks):
            output = outputs[s]
            if not opt.mse_loss:
                output['hm'] = _sigmoid(output['hm'])

            hm_loss_temp = self.crit(output['hm'],
                                     batch['hm']) / opt.num_stacks
            hm_loss += hm_loss_temp
            if opt.wh_weight > 0:
                wh_loss += self.crit_reg(output['wh'], batch['reg_mask'],
                                         batch['ind'],
                                         batch['wh']) / opt.num_stacks

            if opt.reg_offset and opt.off_weight > 0:
                off_loss += self.crit_reg(output['reg'], batch['reg_mask'],
                                          batch['ind'],
                                          batch['reg']) / opt.num_stacks

            if opt.id_weight > 0:
                id_head = _tranpose_and_gather_feat(output['id'], batch['ind'])
                id_head = id_head[batch['reg_mask'] > 0].contiguous()
                id_head = self.emb_scale * F.normalize(id_head)
                id_target = batch['ids'][batch['reg_mask'] > 0]

                id_output = self.classifier(id_head).contiguous()
                if len(id_target) > 0 and len(id_output) > 0:
                    id_loss += self.IDLoss(id_output, id_target)

            if torch.isnan(hm_loss_temp):
                # Plot the current graph to derive insight
                raise ValueError("yee: Hit NaN")

        det_loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + opt.off_weight * off_loss

        loss = torch.exp(-self.s_det) * det_loss + torch.exp(
            -self.s_id) * id_loss + (self.s_det + self.s_id)
        loss *= 0.5

        loss_stats = {
            'loss': loss,
            'hm_loss': hm_loss,
            'wh_loss': wh_loss,
            'off_loss': off_loss,
            'id_loss': id_loss
        }
        return loss, loss_stats
Exemple #11
0
    def forward(self, outputs, batch):
        opt = self.opt
        hm_loss, wh_loss, off_loss, id_loss = 0, 0, 0, 0
        for s in range(opt.num_stacks):
            output = outputs[s]
            if not opt.mse_loss:
                output['hm'] = _sigmoid(output['hm'])

            hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks
            if opt.wh_weight > 0:
                wh_loss += self.crit_reg(
                    output['wh'], batch['reg_mask'],
                    batch['ind'], batch['wh']) / opt.num_stacks

            if opt.reg_offset and opt.off_weight > 0:
                off_loss += self.crit_reg(output['reg'], batch['reg_mask'],
                                          batch['ind'], batch['reg']) / opt.num_stacks

            if opt.id_weight > 0:
                id_head = _tranpose_and_gather_feat(output['id'], batch['ind'])
                # id_head = id_head[batch['reg_mask'] > 0].contiguous()
                id_head = paddle.to_tensor(id_head.numpy()[batch['reg_mask'].numpy()>0])
                # id_head = paddle.masked_select(id_head, batch['reg_mask'] > 0)                
                id_head = self.emb_scale * F.normalize(id_head)
                # id_target = batch['ids'][batch['reg_mask'] > 0]
                id_target = paddle.to_tensor(batch['ids'].numpy()[batch['reg_mask'].numpy() > 0])
                # id_target = paddle.masked_select(batch['ids'], batch['reg_mask'] > 0)
                id_output = self.classifier(id_head)#.contiguous()
                id_target.stop_gradient = True
                id_loss += self.IDLoss(id_output, id_target)

        det_loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + opt.off_weight * off_loss
        if opt.multi_loss == 'uncertainty':
            # loss = torch.exp(-self.s_det) * det_loss + torch.exp(-self.s_id) * id_loss + (self.s_det + self.s_id)
            loss = paddle.exp(-self.s_det) * det_loss + paddle.exp(-self.s_id) * id_loss + (self.s_det + self.s_id)
            loss *= 0.5
            
        else:
            loss = det_loss + 0.1 * id_loss

        loss_stats = {'loss': loss, 'hm_loss': hm_loss,
                      'wh_loss': wh_loss, 'off_loss': off_loss, 'id_loss': id_loss}
        id_classifier_state_dict = self.classifier.state_dict()
        return loss, loss_stats, id_classifier_state_dict
Exemple #12
0
    def update(self, im_blob, img0):
        self.frame_id += 1
        activated_starcks = []
        refind_stracks = []
        lost_stracks = []
        removed_stracks = []

        width = img0.shape[1]
        height = img0.shape[0]
        inp_height = im_blob.shape[2]
        inp_width = im_blob.shape[3]
        c = np.array([width / 2., height / 2.], dtype=np.float32)
        s = max(float(inp_width) / float(inp_height) * height, width) * 1.0
        meta = {
            'c': c,
            's': s,
            'out_height': inp_height // self.opt.down_ratio,
            'out_width': inp_width // self.opt.down_ratio
        }
        ''' Step 1: Network forward, get detections & embeddings'''
        with torch.no_grad():
            if hasattr(self.model, 'relation'):
                outputs, stuff = self.model(im_blob)
                det_heads = set(['wh', 'hm', 'reg'])
                trk_heads = set(['id'])
                for head in (set(self.model.backend.heads) & det_heads):
                    outputs[head] = getattr(self.model.backend,
                                            head)(outputs['raw'])
                # for head in (set(self.model.heads) & trk_heads):
                #     outputs[head] = getattr(self.model, head)(outputs['raw_trk'])
                # del outputs['raw_trk']
                del outputs['raw']
                output = outputs
                if hasattr(self.model.relation, 'loss'):
                    cur_feats = stuff[-2]
                    self.model.relation.lock.acquire()
                    self.model.relation.feature_bank.append(
                        cur_feats.detach().cpu())
                    self.model.relation.lock.release()
            else:
                output = self.model(im_blob)[-1]
            hm = output['hm'].sigmoid_()
            wh = output['wh']
            id_feature = output['id']
            id_feature = F.normalize(id_feature, dim=1)

            reg = output['reg'] if self.opt.reg_offset else None
            dets, inds = mot_decode(hm,
                                    wh,
                                    reg=reg,
                                    ltrb=self.opt.ltrb,
                                    K=self.opt.K)
            id_feature = _tranpose_and_gather_feat(id_feature, inds)
            id_feature = id_feature.squeeze(0)
            id_feature = id_feature.cpu().numpy()

        dets = self.post_process(dets, meta)
        dets = self.merge_outputs([dets])[1]

        remain_inds = dets[:, 4] > self.opt.conf_thres
        dets = dets[remain_inds]
        id_feature = id_feature[remain_inds]
        self.inputs_embs.append((dets, id_feature))

        # vis
        '''
        for i in range(0, dets.shape[0]):
            bbox = dets[i][0:4]
            cv2.rectangle(img0, (bbox[0], bbox[1]),
                          (bbox[2], bbox[3]),
                          (0, 255, 0), 2)
        cv2.imshow('dets', img0)
        cv2.waitKey(0)
        id0 = id0-1
        '''

        if len(dets) > 0:
            '''Detections'''
            detections = [
                STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f, 30)
                for (tlbrs, f) in zip(dets[:, :5], id_feature)
            ]
        else:
            detections = []
        ''' Add newly detected tracklets to tracked_stracks'''
        unconfirmed = []
        tracked_stracks = []  # type: list[STrack]
        for track in self.tracked_stracks:
            if not track.is_activated:
                unconfirmed.append(track)
            else:
                tracked_stracks.append(track)
        ''' Step 2: First association, with embedding'''
        strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
        # Predict the current location with KF
        #for strack in strack_pool:
        #strack.predict()
        STrack.multi_predict(strack_pool)
        dists = matching.embedding_distance(strack_pool, detections)
        #dists = matching.iou_distance(strack_pool, detections)
        dists = matching.fuse_motion(self.kalman_filter, dists, strack_pool,
                                     detections)
        matches, u_track, u_detection = matching.linear_assignment(dists,
                                                                   thresh=0.4)

        for itracked, idet in matches:
            track = strack_pool[itracked]
            det = detections[idet]
            if track.state == TrackState.Tracked:
                track.update(detections[idet], self.frame_id)
                activated_starcks.append(track)
            else:
                track.re_activate(det, self.frame_id, new_id=False)
                refind_stracks.append(track)
        ''' Step 3: Second association, with IOU'''
        detections = [detections[i] for i in u_detection]
        r_tracked_stracks = [
            strack_pool[i] for i in u_track
            if strack_pool[i].state == TrackState.Tracked
        ]
        dists = matching.iou_distance(r_tracked_stracks, detections)
        matches, u_track, u_detection = matching.linear_assignment(dists,
                                                                   thresh=0.5)

        for itracked, idet in matches:
            track = r_tracked_stracks[itracked]
            det = detections[idet]
            if track.state == TrackState.Tracked:
                track.update(det, self.frame_id)
                activated_starcks.append(track)
            else:
                track.re_activate(det, self.frame_id, new_id=False)
                refind_stracks.append(track)

        for it in u_track:
            track = r_tracked_stracks[it]
            if not track.state == TrackState.Lost:
                track.mark_lost()
                lost_stracks.append(track)
        '''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
        detections = [detections[i] for i in u_detection]
        dists = matching.iou_distance(unconfirmed, detections)
        matches, u_unconfirmed, u_detection = matching.linear_assignment(
            dists, thresh=0.7)
        for itracked, idet in matches:
            unconfirmed[itracked].update(detections[idet], self.frame_id)
            activated_starcks.append(unconfirmed[itracked])
        for it in u_unconfirmed:
            track = unconfirmed[it]
            track.mark_removed()
            removed_stracks.append(track)
        """ Step 4: Init new stracks"""
        for inew in u_detection:
            track = detections[inew]
            if track.score < self.det_thresh:
                continue
            track.activate(self.kalman_filter, self.frame_id)
            activated_starcks.append(track)
        """ Step 5: Update state"""
        for track in self.lost_stracks:
            if self.frame_id - track.end_frame > self.max_time_lost:
                track.mark_removed()
                removed_stracks.append(track)

        # print('Ramained match {} s'.format(t4-t3))

        self.tracked_stracks = [
            t for t in self.tracked_stracks if t.state == TrackState.Tracked
        ]
        self.tracked_stracks = joint_stracks(self.tracked_stracks,
                                             activated_starcks)
        self.tracked_stracks = joint_stracks(self.tracked_stracks,
                                             refind_stracks)
        self.lost_stracks = sub_stracks(self.lost_stracks,
                                        self.tracked_stracks)
        self.lost_stracks.extend(lost_stracks)
        self.lost_stracks = sub_stracks(self.lost_stracks,
                                        self.removed_stracks)
        self.removed_stracks.extend(removed_stracks)
        self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(
            self.tracked_stracks, self.lost_stracks)
        # get scores of lost tracks
        output_stracks = [
            track for track in self.tracked_stracks if track.is_activated
        ]

        logger.debug('===========Frame {}=========='.format(self.frame_id))
        logger.debug('Activated: {}'.format(
            [track.track_id for track in activated_starcks]))
        logger.debug('Refind: {}'.format(
            [track.track_id for track in refind_stracks]))
        logger.debug('Lost: {}'.format(
            [track.track_id for track in lost_stracks]))
        logger.debug('Removed: {}'.format(
            [track.track_id for track in removed_stracks]))

        return output_stracks
Exemple #13
0
    def update(self, im_blob, img0):        # 处理当前帧中的检测框
        self.frame_id += 1
        activated_starcks = []
        refind_stracks = []                 # 从上一帧到当前帧,新发现的track
        lost_stracks = []                   # 从上一帧到当前帧,丢失的stack
        removed_stracks = []                # 从上一帧到当前帧,需要被移除的stack

        width = img0.shape[1]
        height = img0.shape[0]
        inp_height = im_blob.shape[2]
        inp_width = im_blob.shape[3]

        c = np.array([width / 2., height / 2.], dtype=np.float32)
        s = max(float(inp_width) / float(inp_height) * height, width) * 1.0
        meta = {'c': c, 's': s,
                'out_height': inp_height // self.opt.down_ratio,
                'out_width': inp_width // self.opt.down_ratio}

        ''' Step 1: Network forward, get detections & embeddings'''
        with torch.no_grad():
            output = self.model(im_blob)[-1]
            hm = output['hm'].sigmoid_()
            wh = output['wh']
            id_feature = output['id']
            id_feature = F.normalize(id_feature, dim=1)                 # torch.Size([1, 512, 152, 272])
            reg = output['reg'] if self.opt.reg_offset else None

            dets, inds = mot_decode(hm, wh, reg=reg, cat_spec_wh=self.opt.cat_spec_wh, K=self.opt.K)    # 预测框左上角、右下角的坐标表示、得分、分类,inds是图像在一维情况下的索引
            # inds 是在图像转换成一维情况下,置信度得分最大的128个值,表示最大输出目标的数量
            id_feature = _tranpose_and_gather_feat(id_feature, inds)        # id_feature torch.Size([1, 512, 152, 272]), inds torch.Size([1, 128])
            id_feature = id_feature.squeeze(0)                              # torch.Size([1, 128, 512])
            id_feature = id_feature.cpu().numpy()

        dets = self.post_process(dets, meta)                                # 是将在feature上的预测结果,映射到原始图像中,给出在原始图像中128个检测框的坐标、及相应置信度
        dets = self.merge_outputs([dets])[1]                    # (128, 5)

        remain_inds = dets[:, 4] > self.opt.conf_thres                      # 仅保留置信度得分 大于 设置阈值的检测框
        dets = dets[remain_inds]                    # (2, 5),只剩下两个检测框作为最终的结果
        id_feature = id_feature[remain_inds]        # (2, 512),对应的feature

        # vis
        '''
        for i in range(0, dets.shape[0]):
            bbox = dets[i][0:4]
            cv2.rectangle(img0, (bbox[0], bbox[1]),
                          (bbox[2], bbox[3]),
                          (0, 255, 0), 2)
        cv2.imshow('dets', img0)
        cv2.waitKey(0)
        id0 = id0-1
        '''

        if len(dets) > 0:
            '''Detections'''
            detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f, 30) for               # 直接调用类的方法进行计算,有什么特别的么?
                          (tlbrs, f) in zip(dets[:, :5], id_feature)]                               # 创建strack,这里相当于tracklets
        else:
            detections = []

        ''' Add newly detected tracklets to tracked_stracks'''
        unconfirmed = []
        tracked_stracks = []  # type: list[STrack]
        for track in self.tracked_stracks:                      # 将当前帧之前存在的track,划分为unconfirmed、track_stracks两种类型
            if not track.is_activated:
                unconfirmed.append(track)
            else:
                tracked_stracks.append(track)

        ''' Step 2: First association, with embedding'''
        strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)     # 取并集
        # Predict the current location with KF
        #for strack in strack_pool:
            #strack.predict()
        STrack.multi_predict(strack_pool)                               # 使用卡尔曼滤波预测下一帧中目标的状态,调用每一个track的predict方法进行预测

        dists = matching.embedding_distance(strack_pool, detections)            # 使用embedding进行匹配,返回匹配矩阵,将detection与当前存在的track的smooth feat计算距离
        #dists = matching.gate_cost_matrix(self.kalman_filter, dists, strack_pool, detections)
        dists = matching.fuse_motion(self.kalman_filter, dists, strack_pool, detections)    # 对每一个track,计算其与当前帧中每一个detection的门距离
        matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.7)       # 根据门距离,使用匈牙利算法最大匹配,确定三种匹配结果

        for itracked, idet in matches:
            track = strack_pool[itracked]
            det = detections[idet]
            if track.state == TrackState.Tracked:               # 上一帧是被追踪状态
                track.update(detections[idet], self.frame_id)   # track状态更新,其中 KF 的均值向量、协方差矩阵进行更新
                activated_starcks.append(track)
            else:                                               # 上一帧是new状态
                track.re_activate(det, self.frame_id, new_id=False)
                refind_stracks.append(track)

        ''' Step 3: Second association, with IOU'''             # 第二次,尝试将未匹配到的detection和未匹配到的track匹配起来
        detections = [detections[i] for i in u_detection]
        r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
        dists = matching.iou_distance(r_tracked_stracks, detections)
        matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.5)

        for itracked, idet in matches:
            track = r_tracked_stracks[itracked]
            det = detections[idet]
            if track.state == TrackState.Tracked:
                track.update(det, self.frame_id)
                activated_starcks.append(track)
            else:
                track.re_activate(det, self.frame_id, new_id=False)
                refind_stracks.append(track)

        for it in u_track:
            track = r_tracked_stracks[it]
            if not track.state == TrackState.Lost:      # 判断track是否
                track.mark_lost()
                lost_stracks.append(track)

        '''第三次匹配, Deal with unconfirmed tracks, usually tracks with only one beginning frame 仅追踪到一帧的track为unconfirmed track'''
        detections = [detections[i] for i in u_detection]
        dists = matching.iou_distance(unconfirmed, detections)
        matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)

        for itracked, idet in matches:
            unconfirmed[itracked].update(detections[idet], self.frame_id)
            activated_starcks.append(unconfirmed[itracked])

        for it in u_unconfirmed:
            track = unconfirmed[it]
            track.mark_removed()
            removed_stracks.append(track)

        """ Step 4: Init new stracks"""
        for inew in u_detection:
            track = detections[inew]
            if track.score < self.det_thresh:                           # 与tracking的置信度阈值相比较
                continue
            track.activate(self.kalman_filter, self.frame_id)
            activated_starcks.append(track)

        """ Step 5: Update state"""
        for track in self.lost_stracks:
            if self.frame_id - track.end_frame > self.max_time_lost:
                track.mark_removed()                                    # 移除达到条件的track
                removed_stracks.append(track)

        # print('Ramained match {} s'.format(t4-t3))

        self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
        self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
        self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
        self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
        self.lost_stracks.extend(lost_stracks)
        self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
        self.removed_stracks.extend(removed_stracks)
        self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
        # get scores of lost tracks
        output_stracks = [track for track in self.tracked_stracks if track.is_activated]

        logger.debug('===========Frame {}=========='.format(self.frame_id))
        logger.debug('Activated: {}'.format([track.track_id for track in activated_starcks]))
        logger.debug('Refind: {}'.format([track.track_id for track in refind_stracks]))
        logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
        logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))

        return output_stracks
Exemple #14
0
    def forward(self, outputs, batch, global_step, tb_writer):
        opt = self.opt
        hm_loss, wh_loss, off_loss = 0, 0, 0
        batch_hm_loss, batch_wh_loss, batch_off_loss = 0, 0, 0  #per batch losses
        for s in range(opt.num_stacks):

            output = outputs[s]

            output['hm'] = _sigmoid(output['hm'])

            if opt.eval_oracle_hm:
                output['hm'] = batch['hm']
            if opt.eval_oracle_wh:
                output['wh'] = torch.from_numpy(
                    gen_oracle_map(batch['wh'].detach().cpu().numpy(),
                                   batch['ind'].detach().cpu().numpy(),
                                   output['wh'].shape[3],
                                   output['wh'].shape[2])).to(opt.device)
            if opt.eval_oracle_offset:
                output['reg'] = torch.from_numpy(
                    gen_oracle_map(batch['reg'].detach().cpu().numpy(),
                                   batch['ind'].detach().cpu().numpy(),
                                   output['reg'].shape[3],
                                   output['reg'].shape[2])).to(opt.device)

            tmp = self.crit(output['hm'], batch['hm'])
            hm_loss = hm_loss + tmp[0] / opt.num_stacks
            batch_hm_loss = batch_hm_loss + tmp[1] / opt.num_stacks
            if opt.wh_weight > 0:
                if opt.mdn:
                    BS = output['mdn_logits'].shape[0]
                    M = opt.mdn_n_comps
                    H, W = output['mdn_logits'].shape[-2:]
                    K = opt.num_classes if opt.cat_spec_wh else 1

                    mdn_logits = output['mdn_logits']
                    mdn_logits = mdn_logits.reshape((BS, M, K, H, W)).permute(
                        (2, 0, 1, 3, 4))
                    # mdn_logits.shape: torch.Size([80, 2, 3, 128, 128])

                    mdn_pi = torch.clamp(
                        torch.nn.Softmax(dim=2)(mdn_logits), 1e-4, 1. - 1e-4)
                    # mdn_pi.shape: torch.Size([80, 2, 3, 128, 128])

                    mdn_sigma = torch.clamp(
                        torch.nn.ELU()(output['mdn_sigma']) +
                        opt.mdn_min_sigma, 1e-4, 1e5)
                    mdn_sigma = mdn_sigma.reshape(
                        (BS, M * 2, K, H, W)).permute((2, 0, 1, 3, 4))
                    # mdn_sigma.shape: torch.Size([80, 2, 6, 128, 128])

                    mdn_mu = output['wh']
                    mdn_mu = mdn_mu.reshape((BS, M * 2, K, H, W)).permute(
                        (2, 0, 1, 3, 4))
                    # mdn_mu.shape: torch.Size([80, 2, 6, 128, 128])

                    gt = batch['cat_spec_wh'] if opt.cat_spec_wh else batch[
                        'wh']
                    gt = gt.reshape(
                        (BS, -1, opt.num_classes if opt.cat_spec_wh else 1,
                         2)).permute((2, 0, 1, 3))
                    # gt.shape: torch.Size([80, 2, 128, 2])

                    if opt.cat_spec_wh:
                        mask = batch['cat_spec_mask'][:, :, 0::2].unsqueeze(
                            -1).permute((2, 0, 1, 3))
                        # mask.shape: torch.Size([80, 2, 128, 1])
                    else:
                        mask = batch['reg_mask'].unsqueeze(2).unsqueeze(0)
                        # print("mask.shape:", mask.shape)
                        # mask.shape: torch.Size([1, 2, 128, 1])

                    V = torch.Tensor([opt.mdn_V]).cuda()

                    I = mask.shape[-2]
                    _gt = gt.reshape((K * BS, I, -1))
                    _mask = mask.reshape((K * BS, I, -1))
                    batch_ind = torch.repeat_interleave(batch['ind'], K, dim=0)
                    _mdn_mu = _tranpose_and_gather_feat(
                        mdn_mu.reshape((K * BS, -1, H, W)), batch_ind)
                    _mdn_sigma = _tranpose_and_gather_feat(
                        mdn_sigma.reshape((K * BS, -1, H, W)), batch_ind)
                    _mdn_pi = _tranpose_and_gather_feat(
                        mdn_pi.reshape((K * BS, -1, H, W)), batch_ind)

                    # mdn_n_comps=3
                    # batch['ind'].shape: torch.Size([2, 128])
                    # gt.shape: torch.Size([1, 2, 128, 2])
                    # mask.shape: torch.Size([1, 2, 128, 1])
                    # mdn_mu.shape:    torch.Size([1, 2, 6, 128, 128])
                    # mdn_pi.shape:    torch.Size([1, 2, 3, 128, 128])
                    # mdn_sigma.shape: torch.Size([1, 2, 6, 128, 128])

                    # batch['ind'].shape: torch.Size([2, 128])
                    # _gt.shape: torch.Size([2, 128, 2])
                    # _mask.shape: torch.Size([2, 128, 1])
                    # _mdn_mu.shape: torch.Size([2, 128, 6])
                    # _mdn_pi.shape: torch.Size([2, 128, 3])
                    # _mdn_sigma.shape: torch.Size([2, 128, 6])

                    tmp = self.crit_wh(_gt,
                                       _mdn_mu,
                                       _mdn_sigma,
                                       _mdn_pi,
                                       _mask,
                                       V,
                                       C=1)
                    wh_loss += tmp[0] / opt.num_stacks
                    batch_wh_loss += tmp[1] / opt.num_stacks

                    for _c in range(opt.num_classes if opt.cat_spec_wh else 1):
                        _mdn_pi = _tranpose_and_gather_feat(
                            mdn_pi[_c], batch['ind'])
                        _mdn_sigma = _tranpose_and_gather_feat(
                            mdn_sigma[_c], batch['ind'])
                        _, _max_pi_ind = torch.max(_mdn_pi, -1)
                        if tb_writer is not None:
                            _cat = opt.cls_id_to_cls_name(_c)
                            tb_writer.add_histogram(
                                'mdn_pi_max_comp/{}'.format(_cat),
                                _max_pi_ind + 1,
                                global_step=global_step)
                            for i in range(_mdn_pi.shape[2]):
                                tb_writer.add_histogram(
                                    'mdn_pi/{}/{}'.format(_cat, i),
                                    _mdn_pi[:, :, i],
                                    global_step=global_step)
                                tb_writer.add_histogram(
                                    'mdn_sigma/{}/{}'.format(_cat, i),
                                    _mdn_sigma[:, :, i * 2:i * 2 + 2],
                                    global_step=global_step)
                else:
                    if opt.dense_wh:
                        mask_weight = batch['dense_wh_mask'].sum() + 1e-4
                        wh_loss += (self.crit_wh(
                            output['wh'] * batch['dense_wh_mask'],
                            batch['dense_wh'] * batch['dense_wh_mask']) /
                                    mask_weight) / opt.num_stacks
                    elif opt.cat_spec_wh:
                        wh_loss += self.crit_wh(
                            output['wh'], batch['cat_spec_mask'], batch['ind'],
                            batch['cat_spec_wh']) / opt.num_stacks
                    else:
                        tmp = self.crit_reg(output['wh'], batch['reg_mask'],
                                            batch['ind'], batch['wh'])
                        wh_loss += tmp[0] / opt.num_stacks
                        batch_wh_loss += tmp[1] / opt.num_stacks
                    '''
          output['wh'].shape: torch.Size([2, 160, 128, 128])
          batch['ind'].shape: torch.Size([2, 128])
          batch['cat_spec_mask'].shape: torch.Size([2, 128, 160])
          '''

            if opt.reg_offset and opt.off_weight > 0:
                tmp = self.crit_reg(output['reg'], batch['reg_mask'],
                                    batch['ind'], batch['reg'])
                off_loss += tmp[0] / opt.num_stacks
                batch_off_loss += tmp[1] / opt.num_stacks

        loss_stats = {}
        loss, batch_loss = 0, 0

        loss += opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + \
              opt.off_weight * off_loss
        batch_loss += opt.hm_weight * batch_hm_loss + opt.wh_weight * batch_wh_loss + \
              opt.off_weight * batch_off_loss

        loss_stats.update({'loss': loss, 'hm_loss': hm_loss, 'wh_loss': wh_loss, 'off_loss': off_loss,
                      'batch_loss': batch_loss, 'batch_hm_loss': batch_hm_loss,\
                      'batch_wh_loss': batch_wh_loss, 'batch_off_loss': batch_off_loss})
        return loss, loss_stats
Exemple #15
0
    def update(self, im_blob, img0):
        self.frame_id += 1
        activated_starcks = []
        refind_stracks = []
        lost_stracks = []
        removed_stracks = []

        width = img0.shape[1]
        height = img0.shape[0]
        inp_height = im_blob.shape[2]
        inp_width = im_blob.shape[3]
        c = np.array([width / 2., height / 2.], dtype=np.float32)
        s = max(float(inp_width) / float(inp_height) * height, width) * 1.0
        meta = {
            'c': c,
            's': s,
            'out_height': inp_height // self.opt.down_ratio,
            'out_width': inp_width // self.opt.down_ratio
        }
        ''' Step 1: Network forward, get detections & embeddings'''
        with torch.no_grad():
            output = self.model(im_blob)[-1]
            hm = output['hm'].sigmoid_()
            wh = output['wh']
            id_feature = output['id']
            id_feature = F.normalize(id_feature, dim=1)

            reg = output['reg'] if self.opt.reg_offset else None
            dets, inds = mot_decode(hm,
                                    wh,
                                    reg=reg,
                                    cat_spec_wh=self.opt.cat_spec_wh,
                                    K=self.opt.K)
            id_feature = _tranpose_and_gather_feat(id_feature, inds)
            id_feature = id_feature.squeeze(0)
            id_feature = id_feature.cpu().numpy()

        dets = self.post_process(dets, meta)
        dets = self.merge_outputs([dets])[1]

        remain_inds = dets[:, 4] > self.opt.conf_thres
        dets = dets[remain_inds]
        id_feature = id_feature[remain_inds]

        # vis
        '''
        for i in range(0, dets.shape[0]):
            bbox = dets[i][0:4]
            cv2.rectangle(img0, (bbox[0], bbox[1]),
                          (bbox[2], bbox[3]),
                          (0, 255, 0), 2)
        cv2.imshow('dets', img0)
        cv2.waitKey(0)
        id0 = id0-1
        '''

        if len(dets) > 0:
            '''Detections'''
            detections = [
                STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f, 30)
                for (tlbrs, f) in zip(dets[:, :5], id_feature)
            ]
        else:
            detections = []
        ''' Add newly detected tracklets to tracked_stracks'''
        unconfirmed = []
        tracked_stracks = []  # type: list[STrack]
        for track in self.tracked_stracks:
            if not track.is_activated:
                unconfirmed.append(track)
            else:
                tracked_stracks.append(track)
        ''' Step 2: First association, with embedding'''
        strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
        dists = matching.embedding_distance(
            strack_pool, detections)  # 计算新检测出来的目标和tracked_tracker之间的cosine距离
        STrack.multi_predict(strack_pool)  # 卡尔曼预测
        dists = matching.fuse_motion(
            self.kalman_filter, dists, strack_pool,
            detections)  # 利用卡尔曼计算detection和pool_stacker直接的距离代价
        matches, u_track, u_detection = matching.linear_assignment(
            dists,
            thresh=0.7)  # 匈牙利匹配 // 将跟踪框和检测框进行匹配 // u_track是未匹配的tracker的索引,

        for itracked, idet in matches:  # matches:63*2 , 63:detections的维度,2:第一列为tracked_tracker索引,第二列为detection的索引
            track = strack_pool[itracked]
            det = detections[idet]
            if track.state == TrackState.Tracked:
                track.update(
                    det, self.frame_id)  # 匹配的pool_tracker和detection,更新特征和卡尔曼状态
                activated_starcks.append(track)
            else:
                track.re_activate(det, self.frame_id,
                                  new_id=False)  # 如果是在lost中的,就重新激活
                refind_stracks.append(track)
        ''' Step 3: Second association, with IOU''' """ 在余弦距离未匹配的detection和tracker重新用iou进行匹配 """
        detections = [detections[i]
                      for i in u_detection]  # u_detection是未匹配的detection的索引
        r_tracked_stracks = [
            strack_pool[i] for i in u_track
            if strack_pool[i].state == TrackState.Tracked
        ]
        dists = matching.iou_distance(r_tracked_stracks, detections)
        matches, u_track, u_detection = matching.linear_assignment(dists,
                                                                   thresh=0.5)

        for itracked, idet in matches:
            track = r_tracked_stracks[itracked]
            det = detections[idet]
            if track.state == TrackState.Tracked:
                track.update(det, self.frame_id)
                activated_starcks.append(track)
            else:
                track.re_activate(
                    det, self.frame_id,
                    new_id=False)  # 前面已经限定了是TrackState.Tracked,这里是不用运行到的。
                refind_stracks.append(track)

        for it in u_track:
            track = r_tracked_stracks[it]
            if not track.state == TrackState.Lost:
                track.mark_lost()
                lost_stracks.append(
                    track)  # 将和tracked_tracker iou未匹配的tracker的状态改为lost

        temp = 1
        '''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
        detections = [detections[i] for i in u_detection
                      ]  # 将cosine/iou未匹配的detection和unconfirmed_tracker进行匹配
        dists = matching.iou_distance(unconfirmed, detections)
        matches, u_unconfirmed, u_detection = matching.linear_assignment(
            dists, thresh=0.7)
        for itracked, idet in matches:
            unconfirmed[itracked].update(detections[idet], self.frame_id)
            activated_starcks.append(unconfirmed[itracked])
        for it in u_unconfirmed:
            track = unconfirmed[it]
            track.mark_removed()
            removed_stracks.append(track)
        """ Step 4: Init new stracks"""
        for inew in u_detection:  # 对cosine/iou/uncofirmed_tracker都未匹配的detection重新初始化一个unconfimed_tracker
            track = detections[inew]
            if track.score < self.det_thresh:
                continue
            track.activate(self.kalman_filter,
                           self.frame_id)  # 激活track,第一帧的activated=T,其他为False
            activated_starcks.append(track)
        """ Step 5: Update state"""
        for track in self.lost_stracks:
            if self.frame_id - track.end_frame > self.max_time_lost:  # 消失15帧之后
                track.mark_removed()
                removed_stracks.append(track)

        # print('Ramained match {} s'.format(t4-t3))

        self.tracked_stracks = [
            t for t in self.tracked_stracks if t.state == TrackState.Tracked
        ]  # 筛出tracked状态的tracker
        self.tracked_stracks = joint_stracks(
            self.tracked_stracks,
            activated_starcks)  # 向self.tracked_stacks中添加新的detection
        self.tracked_stracks = joint_stracks(self.tracked_stracks,
                                             refind_stracks)  # 重新匹配出的trackers
        self.lost_stracks = sub_stracks(self.lost_stracks,
                                        self.tracked_stracks)
        self.lost_stracks.extend(lost_stracks)
        self.lost_stracks = sub_stracks(self.lost_stracks,
                                        self.removed_stracks)
        self.removed_stracks.extend(removed_stracks)
        self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(
            self.tracked_stracks, self.lost_stracks)
        # get scores of lost tracks
        output_stracks = [
            track for track in self.tracked_stracks if track.is_activated
        ]

        logger.debug('===========Frame {}=========='.format(self.frame_id))
        logger.debug('Activated: {}'.format(
            [track.track_id for track in activated_starcks]))
        logger.debug('Refind: {}'.format(
            [track.track_id for track in refind_stracks]))
        logger.debug('Lost: {}'.format(
            [track.track_id for track in lost_stracks]))
        logger.debug('Removed: {}'.format(
            [track.track_id for track in removed_stracks]))

        return output_stracks
Exemple #16
0
    def forward(self, outputs, batch):
        opt = self.opt
        hm_loss, wh_loss, off_loss, id_loss = 0, 0, 0, 0
        for s in range(opt.num_stacks):
            output = outputs[s]
            if not opt.mse_loss:
                output['hm'] = _sigmoid(output['hm'])

            hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks
            if opt.wh_weight > 0:
                if opt.dense_wh:
                    mask_weight = batch['dense_wh_mask'].sum() + 1e-4
                    wh_loss += (
                                   self.crit_wh(output['wh'] * batch['dense_wh_mask'],
                                                batch['dense_wh'] * batch['dense_wh_mask']) /
                                   mask_weight) / opt.num_stacks
                else:
                    wh_loss += self.crit_reg(
                        output['wh'], batch['reg_mask'],
                        batch['ind'], batch['wh']) / opt.num_stacks

            if opt.reg_offset and opt.off_weight > 0:
                off_loss += self.crit_reg(output['reg'], batch['reg_mask'],
                                          batch['ind'], batch['reg']) / opt.num_stacks

            if opt.id_weight > 0:
                '''
                ids_label_9 = batch['ids'].repeat(1, 9)     # batch, 9K
                mask_label_9 = batch['reg_mask'].repeat(1, 9)     # batch, 9K
                ind_labels_9 = batch['ind'].repeat(1, 9)     # batch, 9K
                K = opt.K
                hm_w = output['hm'].shape[3]
                hm_h = output['hm'].shape[2]
                ind_labels_9[:, K:2 * K] -= 5
                ind_labels_9[:, 2 * K:3 * K] += 5
                ind_labels_9[:, 3 * K:4 * K] -= 5 * hm_w
                ind_labels_9[:, 4 * K:5 * K] -= 5 * (hm_w - 1)
                ind_labels_9[:, 5 * K:6 * K] -= 5 * (hm_w + 1)
                ind_labels_9[:, 6 * K:7 * K] += 5 * hm_w
                ind_labels_9[:, 7 * K:8 * K] += 5 * (hm_w - 1)
                ind_labels_9[:, 8 * K:9 * K] += 5 * (hm_w + 1)
                ind_labels_9 = torch.clamp(ind_labels_9, 0, hm_w * hm_h - 1)
                id_head = _tranpose_and_gather_feat(output['id'], ind_labels_9)
                id_head = id_head[mask_label_9 > 0].contiguous()
                id_head = self.emb_scale * F.normalize(id_head)
                id_target = ids_label_9[mask_label_9 > 0]
                '''
                id_head = _tranpose_and_gather_feat(output['id'], batch['ind'])
                id_head = id_head[batch['reg_mask'] > 0].contiguous()
                id_head = self.emb_scale * F.normalize(id_head)
                id_target = batch['ids'][batch['reg_mask'] > 0]

                id_output = self.classifier(id_head).contiguous()
                id_loss += self.IDLoss(id_output, id_target)

        #loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + opt.off_weight * off_loss + opt.id_weight * id_loss

        det_loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + opt.off_weight * off_loss

        loss = torch.exp(-self.s_det) * det_loss + torch.exp(-self.s_id) * id_loss + (self.s_det + self.s_id)
        loss *= 0.5
        #loss = det_loss

        #print(loss, hm_loss, wh_loss, off_loss, id_loss)

        loss_stats = {'loss': loss, 'hm_loss': hm_loss,
                      'wh_loss': wh_loss, 'off_loss': off_loss, 'id_loss': id_loss}
        return loss, loss_stats
Exemple #17
0
        'c': c,
        's': s,
        'out_height': inp_height // down_ratio,
        'out_width': inp_width // down_ratio
    }
    ''' Step 1: Network forward, get detections & embeddings'''
    with torch.no_grad():
        output = model(im_blob)[-1]
        hm = output['hm'].sigmoid_()
        wh = output['wh']
        id_feature = output['id']
        id_feature = F.normalize(id_feature, dim=1)

        reg = output['reg'] if reg_offset else None
        dets, inds = mot_decode(hm, wh, reg=reg, ltrb=ltrb, K=Kt)
        id_feature = _tranpose_and_gather_feat(id_feature, inds)
        id_feature = id_feature.squeeze(0)
        id_feature = id_feature.cpu().numpy()

    dets = post_process(dets, meta)
    dets = merge_outputs([dets])[1]
    remain_inds = dets[:, 4] > conf_thres
    dets = dets[remain_inds]
    id_feature = id_feature[remain_inds]

    # vis
    person_count += len(dets)
    for i in range(0, dets.shape[0]):
        bbox = dets[i][0:4]
        cv2.rectangle(img0, (bbox[0], bbox[1]), (bbox[2], bbox[3]),
                      (0, 255, 0), 5)
Exemple #18
0
 def forward(self, output, mask, ind, target):
     pred = _tranpose_and_gather_feat(output, ind)
     loss = _reg_loss(pred, target, mask)
     return loss
Exemple #19
0
    def forward(self, outputs, batch):
        opt = self.opt
        hm_loss, wh_loss, off_loss, id_loss = 0, 0, 0, 0
        for s in range(opt.num_stacks):
            output = outputs[s]
            if not opt.mse_loss:
                output["hm"] = _sigmoid(output["hm"])

            hm_loss += self.crit(output["hm"], batch["hm"]) / opt.num_stacks
            if opt.wh_weight > 0:
                if opt.dense_wh:
                    mask_weight = batch["dense_wh_mask"].sum() + 1e-4
                    wh_loss += (
                        self.crit_wh(
                            output["wh"] * batch["dense_wh_mask"],
                            batch["dense_wh"] * batch["dense_wh_mask"],
                        )
                        / mask_weight
                    ) / opt.num_stacks
                else:
                    wh_loss += (
                        self.crit_reg(
                            output["wh"], batch["reg_mask"], batch["ind"], batch["wh"]
                        )
                        / opt.num_stacks
                    )

            if opt.reg_offset and opt.off_weight > 0:
                off_loss += (
                    self.crit_reg(
                        output["reg"], batch["reg_mask"], batch["ind"], batch["reg"]
                    )
                    / opt.num_stacks
                )

            if opt.id_weight > 0:
                id_head = _tranpose_and_gather_feat(output["id"], batch["ind"])
                id_head = id_head[batch["reg_mask"] > 0].contiguous()
                id_head = self.emb_scale * F.normalize(id_head)
                id_target = batch["ids"][batch["reg_mask"] > 0]
                id_output = self.classifier(id_head).contiguous()
                id_loss += self.IDLoss(id_output, id_target)
                # id_loss += self.IDLoss(id_output, id_target) + self.TriLoss(id_head, id_target)

        # loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + opt.off_weight * off_loss + opt.id_weight * id_loss

        det_loss = (
            opt.hm_weight * hm_loss
            + opt.wh_weight * wh_loss
            + opt.off_weight * off_loss
        )

        loss = (
            torch.exp(-self.s_det) * det_loss
            + torch.exp(-self.s_id) * id_loss
            + (self.s_det + self.s_id)
        )
        loss *= 0.5

        # print(loss, hm_loss, wh_loss, off_loss, id_loss)

        loss_stats = {
            "loss": loss,
            "hm_loss": hm_loss,
            "wh_loss": wh_loss,
            "off_loss": off_loss,
            "id_loss": id_loss,
        }
        return loss, loss_stats
Exemple #20
0
 def forward(self, output, mask, ind, rotbin, rotres):
     pred = _tranpose_and_gather_feat(output, ind)
     loss = compute_rot_loss(pred, rotbin, rotres, mask)
     return loss
Exemple #21
0
    def update(self, im_blob, img0):
        self.frame_id += 1
        activated_starcks = []
        refind_stracks = []
        lost_stracks = []
        removed_stracks = []

        width = img0.shape[1]
        height = img0.shape[0]
        inp_height = im_blob.shape[2]
        inp_width = im_blob.shape[3]
        c = np.array([width / 2., height / 2.], dtype=np.float32)
        s = max(float(inp_width) / float(inp_height) * height, width) * 1.0
        meta = {'c': c, 's': s,
                'out_height': inp_height // self.opt.down_ratio,
                'out_width': inp_width // self.opt.down_ratio}

        ''' Step 1: Network forward, get detections & embeddings'''
        with torch.no_grad():
            output = self.model(im_blob)[-1]
            hm = output['hm'].sigmoid_()
            wh = output['wh']
            id_feature = output['id']
            id_feature = F.normalize(id_feature, dim=1)

            reg = output['reg'] if self.opt.reg_offset else None
            dets, inds = mot_decode(hm, wh, reg=reg, ltrb=self.opt.ltrb, K=self.opt.K)
            id_feature = _tranpose_and_gather_feat(id_feature, inds)
            id_feature = id_feature.squeeze(0)
            id_feature = id_feature.cpu().numpy()

        dets = self.post_process(dets, meta)
        dets = self.merge_outputs([dets])[1]

        remain_inds = dets[:, 4] > self.opt.conf_thres
        dets = dets[remain_inds]
        id_feature = id_feature[remain_inds]

        # vis
        '''
        for i in range(0, dets.shape[0]):
            bbox = dets[i][0:4]
            cv2.rectangle(img0, (bbox[0], bbox[1]),
                          (bbox[2], bbox[3]),
                          (0, 255, 0), 2)
        cv2.imshow('dets', img0)
        cv2.waitKey(0)
        id0 = id0-1
        '''

        if len(dets) > 0:
            '''Detections'''
            detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f, 30) for
                          (tlbrs, f) in zip(dets[:, :5], id_feature)]
        else:
            detections = []

        ''' Add newly detected tracklets to tracked_stracks'''
        unconfirmed = []
        tracked_stracks = []  # type: list[STrack]
        for track in self.tracked_stracks:
            if not track.is_activated:
                unconfirmed.append(track)
            else:
                tracked_stracks.append(track)

        ''' Step 2: First association, with embedding'''
        ##Join track ids into one
        strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
        # Predict the current location with KF
        #for strack in strack_pool:
            #strack.predict()

        ##Calculate joint average mean, dev for kalman tracker
        STrack.multi_predict(strack_pool)

        #Gets cost matrix between tracks and dets
        dists = matching.embedding_distance(strack_pool, detections)
        #dists = matching.iou_distance(strack_pool, detections)

        #If tracks with their assignment are too far away from the kalman filter prediction then assign infinite cose
        #Update cost matrix with kalman filter
        dists = matching.fuse_motion(self.kalman_filter, dists, strack_pool, detections)
        #Find optimum assignment using cost matrix
        #u_track and u_detecion are the unmatched tracks and detections respectively
        matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.4)
        #Update currently tracked tracks with matches found
        for itracked, idet in matches:
            track = strack_pool[itracked]
            det = detections[idet]
            if track.state == TrackState.Tracked:
                track.update(detections[idet], self.frame_id)
                activated_starcks.append(track)
            else:
                track.re_activate(det, self.frame_id, new_id=False)
                refind_stracks.append(track)

        ''' Step 3: Second association, with IOU'''
        detections = [detections[i] for i in u_detection]
        #Get tracked tracks which were not matched before which were tracked
        r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
        #Get cost matrix
        dists = matching.iou_distance(r_tracked_stracks, detections)
        matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.6) #Default 0.5
        for itracked, idet in matches:
            track = r_tracked_stracks[itracked]
            det = detections[idet]
            if track.state == TrackState.Tracked:
                track.update(det, self.frame_id)
                activated_starcks.append(track)
            else:
                track.re_activate(det, self.frame_id, new_id=False)
                refind_stracks.append(track)
        #For all of the unmatched tracks, mark them as lost
        for it in u_track:
            track = r_tracked_stracks[it]
            if not track.state == TrackState.Lost:
                track.mark_lost()
                lost_stracks.append(track)

        '''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
        #For the unconfirmed tracks, tracks with only one beginning frame, use the remaining detection to try to pair them
        detections = [detections[i] for i in u_detection]
        dists = matching.iou_distance(unconfirmed, detections)
        matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
        #Add the matched ones
        for itracked, idet in matches:
            unconfirmed[itracked].update(detections[idet], self.frame_id)
            activated_starcks.append(unconfirmed[itracked])
        #For the ones that couldn't be matched, remove them
        for it in u_unconfirmed:
            track = unconfirmed[it]
            track.mark_removed()
            removed_stracks.append(track)

        """ Step 4: Init new stracks"""
        for inew in u_detection:
            track = detections[inew]
            if track.score < self.det_thresh:
                continue
            track.activate(self.kalman_filter, self.frame_id)
            activated_starcks.append(track)
        """ Step 5: Update state"""
        for track in self.lost_stracks:
            # If dissappeared for max_time_lost then remove
            if self.frame_id - track.end_frame > self.max_time_lost:
                track.mark_removed()
                removed_stracks.append(track)

        # print('Ramained match {} s'.format(t4-t3))

        self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
        self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
        self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
        self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
        self.lost_stracks.extend(lost_stracks)
        self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
        self.removed_stracks.extend(removed_stracks)
        self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
        # get scores of lost tracks
        output_stracks = [track for track in self.tracked_stracks if track.is_activated]

        logger.debug('===========Frame {}=========='.format(self.frame_id))
        logger.debug('Activated: {}'.format([track.track_id for track in activated_starcks]))
        logger.debug('Refind: {}'.format([track.track_id for track in refind_stracks]))
        logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
        logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))

        return output_stracks
Exemple #22
0
  def forward(self, outputs, batch,global_step,tb_writer):
    opt = self.opt
    hm_loss, wh_loss, off_loss = 0, 0, 0
    hp_loss, off_loss, hm_hp_loss, hp_offset_loss = 0, 0, 0, 0

    loss_stats = {}
    for s in range(opt.num_stacks):
      output = outputs[s]
      output['hm'] = _sigmoid(output['hm'])
      if opt.hm_hp:
        output['hm_hp'] = _sigmoid(output['hm_hp'])
      
      if opt.eval_oracle_hmhp:
        output['hm_hp'] = batch['hm_hp']
      if opt.eval_oracle_hm:
        output['hm'] = batch['hm']
      if opt.eval_oracle_kps:
        if opt.dense_hp:
          output['hps'] = batch['dense_hps']
        else:
          output['hps'] = torch.from_numpy(gen_oracle_map(
            batch['hps'].detach().cpu().numpy(), 
            batch['ind'].detach().cpu().numpy(), 
            opt.output_res, opt.output_res)).to(opt.device)
      if opt.eval_oracle_hp_offset:
        output['hp_offset'] = torch.from_numpy(gen_oracle_map(
          batch['hp_offset'].detach().cpu().numpy(), 
          batch['hp_ind'].detach().cpu().numpy(), 
          opt.output_res, opt.output_res)).to(opt.device)

      if opt.mdn:
        V=torch.Tensor((np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62,.62, 1.07, 1.07, .87, .87, .89, .89])/10.0).astype(np.float32)).float().cuda()

      hm_loss += self.crit(output['hm'], batch['hm'])[0] / opt.num_stacks

      if opt.mdn:
        mdn_logits = output['mdn_logits']
        #mdn_logits.shape: torch.Size([2, 3, 128, 128])
        if opt.mdn_dropout > 0 and opt.epoch<opt.mdn_dropout_stop:
          M=opt.mdn_n_comps

          ridx= torch.randperm(M)[:torch.randint(1,1+opt.mdn_dropout,(1,))]
          drop_mask = torch.ones((M,))
          drop_mask[ridx]=0
          drop_mask = torch.reshape(drop_mask,(1,-1,1,1)).float().cuda()
          mdn_logits = mdn_logits*drop_mask
          if tb_writer is not None:
            tb_writer.add_histogram('drop_out_idx',ridx+1,global_step=global_step)
        else:
          mdn_pi = torch.clamp(torch.nn.Softmax(dim=1)(mdn_logits), 1e-4, 1.-1e-4)  
        
        mdn_sigma= torch.clamp(torch.nn.ELU()(output['mdn_sigma'])+opt.mdn_min_sigma,1e-4,1e5)
        mdn_mu = output['hps']

        if tb_writer  is not None:
          for i in range(mdn_pi.shape[1]):
            tb_writer.add_histogram('mdn_pi/{}'.format(i),mdn_pi[:,i],global_step=global_step)
            tb_writer.add_histogram('mdn_sigma/{}'.format(i),mdn_sigma[:,i*2:i*2+2],global_step=global_step)

        if opt.dense_hp:
          gt = batch['dense_hps']
          mask = batch['dense_hps_mask'][:,0::2,:,:]
          _,max_pi_ind = torch.max(mdn_pi,1)
        else:
          gt = batch['hps']
          mask = batch['hps_mask'][:,:,0::2]
          mdn_mu= _tranpose_and_gather_feat(mdn_mu, batch['ind'])
          mdn_pi= _tranpose_and_gather_feat(mdn_pi, batch['ind'])
          mdn_sigma= _tranpose_and_gather_feat(mdn_sigma, batch['ind'])
          _,max_pi_ind = torch.max(mdn_pi,-1)

        if tb_writer is not None:
          tb_writer.add_histogram('mdn_pi_max_comp',max_pi_ind+1,global_step=global_step)
        '''
          mdn_n_comps=3
          batch['hps'].shape: torch.Size([2, 32, 34])
          batch['hps_mask'].shape: torch.Size([2, 32, 34])
          batch['ind'].shape: torch.Size([2, 32])
          gt.shape: torch.Size([2, 32, 34])
          mask.shape: torch.Size([2, 32, 17])
          before gather, after gather
          mdn_mu.shape: torch.Size([2, 102, 128, 128]), torch.Size([2, 32, 102])
          mdn_pi.shape: torch.Size([2, 3, 128, 128]), torch.Size([2, 32, 3])
          mdn_sigma.shape: torch.Size([2, 6, 128, 128]), torch.Size([2, 32, 6])
        '''
        if opt.mdn_inter:
          hp_loss += self.crit_kp(gt,mdn_mu,mdn_sigma,mdn_pi,mask,V,debug=opt.debug==6)[0] / opt.num_stacks
        else:
          hp_loss = self.crit_kp(gt,mdn_mu,mdn_sigma,mdn_pi,mask,V,debug=opt.debug==6)[0] / opt.num_stacks
      else:
        if opt.dense_hp:
          mask_weight = batch['dense_hps_mask'].sum() + 1e-4
          hp_loss += (self.crit_kp(output['hps'] * batch['dense_hps_mask'], 
                                  batch['dense_hps'] * batch['dense_hps_mask']) / 
                                  mask_weight) / opt.num_stacks
        else:
          hp_loss += self.crit_kp(output['hps'], batch['hps_mask'], 
                                batch['ind'], batch['hps']) / opt.num_stacks
      
      if opt.wh_weight > 0:
        wh_loss += self.crit_reg(output['wh'], batch['reg_mask'],
                                 batch['ind'], batch['wh'])[0] / opt.num_stacks
      if opt.reg_offset and opt.off_weight > 0:
        off_loss += self.crit_reg(output['reg'], batch['reg_mask'],
                                  batch['ind'], batch['reg'])[0] / opt.num_stacks
      if opt.reg_hp_offset and opt.off_weight > 0:
        hp_offset_loss += self.crit_reg(
          output['hp_offset'], batch['hp_mask'],
          batch['hp_ind'], batch['hp_offset'])[0] / opt.num_stacks
      if opt.hm_hp and opt.hm_hp_weight > 0:
        hm_hp_loss += self.crit_hm_hp(
          output['hm_hp'], batch['hm_hp'])[0] / opt.num_stacks

    loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + \
           opt.off_weight * off_loss + opt.hp_weight * hp_loss + \
           opt.hm_hp_weight * hm_hp_loss + opt.off_weight * hp_offset_loss
    
    loss_stats.update({'loss': loss, 'hm_loss': hm_loss, 'hp_loss': hp_loss, 
                  'hm_hp_loss': hm_hp_loss, 'hp_offset_loss': hp_offset_loss,
                  'wh_loss': wh_loss, 'off_loss': off_loss})
    return loss, loss_stats
Exemple #23
0
def test_emb(
        opt,
        batch_size=16,
        img_size=(1088, 608),
        print_interval=40,
):
    data_cfg = opt.data_cfg
    f = open(data_cfg)
    data_cfg_dict = json.load(f)
    f.close()
    nC = 1
    test_paths = data_cfg_dict['test_emb']
    dataset_root = data_cfg_dict['root']
    if opt.gpus[0] >= 0:
        opt.device = torch.device('cuda')
    else:
        opt.device = torch.device('cpu')
    print('Creating model...')
    model = create_model(opt.arch, opt.heads, opt.head_conv)
    model = load_model(model, opt.load_model)
    # model = torch.nn.DataParallel(model)
    model = model.to(opt.device)
    model.eval()

    # Get dataloader
    transforms = T.Compose([T.ToTensor()])
    dataset = JointDataset(opt, dataset_root, test_paths, img_size, augment=False, transforms=transforms)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False,
                                             num_workers=8, drop_last=False)
    embedding, id_labels = [], []
    print('Extracting pedestrain features...')
    for batch_i, batch in enumerate(dataloader):
        t = time.time()
        output = model(batch['input'].cuda())[-1]
        id_head = _tranpose_and_gather_feat(output['id'], batch['ind'].cuda())
        id_head = id_head[batch['reg_mask'].cuda() > 0].contiguous()
        emb_scale = math.sqrt(2) * math.log(opt.nID - 1)
        id_head = emb_scale * F.normalize(id_head)
        id_target = batch['ids'].cuda()[batch['reg_mask'].cuda() > 0]

        for i in range(0, id_head.shape[0]):
            if len(id_head.shape) == 0:
                continue
            else:
                feat, label = id_head[i], id_target[i].long()
            if label != -1:
                embedding.append(feat)
                id_labels.append(label)

        if batch_i % print_interval == 0:
            print(
                'Extracting {}/{}, # of instances {}, time {:.2f} sec.'.format(batch_i, len(dataloader), len(id_labels),
                                                                               time.time() - t))

    print('Computing pairwise similairity...')
    if len(embedding) < 1:
        return None

    embedding = torch.stack(embedding, dim=0).cuda()
    id_labels = torch.LongTensor(id_labels)
    n = len(id_labels)
    print(n, len(embedding))
    assert len(embedding) == n

    embedding = F.normalize(embedding, dim=1)
    pdist = torch.mm(embedding, embedding.t()).cpu().numpy()
    gt = id_labels.expand(n, n).eq(id_labels.expand(n, n).t()).numpy()

    up_triangle = np.where(np.triu(pdist) - np.eye(n) * pdist != 0)
    pdist = pdist[up_triangle]
    gt = gt[up_triangle]

    far_levels = [1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1]
    far, tar, threshold = metrics.roc_curve(gt, pdist)
    interp = interpolate.interp1d(far, tar)
    tar_at_far = [interp(x) for x in far_levels]
    for f, fa in enumerate(far_levels):
        print('TPR@FAR={:.7f}: {:.4f}'.format(fa, tar_at_far[f]))
    return tar_at_far
Exemple #24
0
    def process(self, images, kernel=1, ae_threshold=1, K=100, num_dets=100):
        with torch.no_grad():
            output = self.model(images)[-1]

            tl_heat = output['tl'].sigmoid_()
            bl_heat = output['bl'].sigmoid_()
            br_heat = output['br'].sigmoid_()
            ct_heat = output['ct'].sigmoid_()

            tl_tag = output['tl_tag']
            bl_tag = output['bl_tag']
            br_tag = output['br_tag']

            tl_reg = output['tl_reg']
            bl_reg = output['bl_reg']
            br_reg = output['br_reg']
            ct_reg = output['ct_reg']

            batch, cat, height, width = tl_heat.size()

            tl_heat = _nms(tl_heat, kernel=3)
            bl_heat = _nms(bl_heat, kernel=3)
            br_heat = _nms(br_heat, kernel=3)
            ct_heat = _nms(ct_heat, kernel=3)

            tl_scores, tl_inds, tl_clses, tl_ys, tl_xs = _topk(tl_heat, K=K)
            bl_scores, bl_inds, bl_clses, bl_ys, bl_xs = _topk(bl_heat, K=K)
            br_scores, br_inds, br_clses, br_ys, br_xs = _topk(br_heat, K=K)
            ct_scores, ct_inds, ct_clses, ct_ys, ct_xs = _topk(ct_heat, K=K)

            tl_ys = tl_ys.view(batch, K, 1, 1).expand(batch, K, K, K)
            tl_xs = tl_xs.view(batch, K, 1, 1).expand(batch, K, K, K)
            bl_ys = bl_ys.view(batch, 1, K, 1).expand(batch, K, K, K)
            bl_xs = bl_xs.view(batch, 1, K, 1).expand(batch, K, K, K)
            br_ys = br_ys.view(batch, 1, 1, K).expand(batch, K, K, K)
            br_xs = br_xs.view(batch, 1, 1, K).expand(batch, K, K, K)
            ct_ys = ct_ys.view(batch, 1, K).expand(batch, K, K)
            ct_xs = ct_xs.view(batch, 1, K).expand(batch, K, K)

            if tl_reg is not None and bl_reg is not None and br_reg is not None:
                tl_reg = _tranpose_and_gather_feat(tl_reg, tl_inds)
                tl_reg = tl_reg.view(batch, K, 1, 1, 2)
                bl_reg = _tranpose_and_gather_feat(bl_reg, bl_inds)
                bl_reg = bl_reg.view(batch, 1, K, 1, 2)
                br_reg = _tranpose_and_gather_feat(br_reg, br_inds)
                br_reg = br_reg.view(batch, 1, 1, K, 2)
                ct_reg = _tranpose_and_gather_feat(ct_reg, ct_inds)
                ct_reg = ct_reg.view(batch, 1, K, 2)

                tl_xs = tl_xs + tl_reg[..., 0]
                tl_ys = tl_ys + tl_reg[..., 1]
                bl_xs = bl_xs + bl_reg[..., 0]
                bl_ys = bl_ys + bl_reg[..., 1]
                br_xs = br_xs + br_reg[..., 0]
                br_ys = br_ys + br_reg[..., 1]
                ct_xs = ct_xs + ct_reg[..., 0]
                ct_ys = ct_ys + ct_reg[..., 1]

            # all possible boxes based on top k corners (ignoring class)
            bboxes = torch.stack((tl_xs, tl_ys, bl_xs, bl_ys, br_xs, br_ys),
                                 dim=4)

            tl_tag = _tranpose_and_gather_feat(tl_tag, tl_inds)
            tl_tag = tl_tag.view(batch, K, 1, 1)
            bl_tag = _tranpose_and_gather_feat(bl_tag, bl_inds)
            bl_tag = bl_tag.view(batch, 1, K, 1)
            br_tag = _tranpose_and_gather_feat(br_tag, br_inds)
            br_tag = br_tag.view(batch, 1, 1, K)
            avg_tag = (tl_tag + bl_tag + br_tag) / 3
            dists = (torch.abs(tl_tag - avg_tag) + torch.abs(bl_tag - avg_tag)
                     + torch.abs(br_tag - avg_tag)) / 3

            tl_scores = tl_scores.view(batch, K, 1, 1).expand(batch, K, K, K)
            bl_scores = bl_scores.view(batch, 1, K, 1).expand(batch, K, K, K)
            br_scores = br_scores.view(batch, 1, 1, K).expand(batch, K, K, K)
            # reject boxes based on corner scores
            # sc_inds = (tl_scores < scores_thresh) | (bl_scores < scores_thresh) | (br_scores < scores_thresh)
            scores = (tl_scores + bl_scores + br_scores) / 3

            # reject boxes based on classes
            tl_clses = tl_clses.view(batch, K, 1, 1).expand(batch, K, K, K)
            bl_clses = bl_clses.view(batch, 1, K, 1).expand(batch, K, K, K)
            br_clses = br_clses.view(batch, 1, 1, K).expand(batch, K, K, K)
            cls_inds = (tl_clses != bl_clses) | (bl_clses != br_clses) | (
                tl_clses != br_clses)

            # reject boxes based on distances
            dist_inds = (dists > ae_threshold)

            scores[cls_inds] = -1
            scores[dist_inds] = -1
            # scores[sc_inds] = -1

            scores = scores.view(batch, -1)
            scores, inds = torch.topk(scores, num_dets)
            scores = scores.unsqueeze(2)

            bboxes = bboxes.view(batch, -1, 6)
            bboxes = _gather_feat(bboxes, inds)

            clses = bl_clses.contiguous().view(batch, -1, 1)
            clses = _gather_feat(clses, inds).float()
            tl_scores = tl_scores.contiguous().view(batch, -1, 1)
            tl_scores = _gather_feat(tl_scores, inds).float()
            bl_scores = bl_scores.contiguous().view(batch, -1, 1)
            bl_scores = _gather_feat(bl_scores, inds).float()
            br_scores = br_scores.contiguous().view(batch, -1, 1)
            br_scores = _gather_feat(br_scores, inds).float()

            ct_xs = ct_xs[:, 0, :]
            ct_ys = ct_ys[:, 0, :]

            centers = torch.cat([
                ct_xs.unsqueeze(2),
                ct_ys.unsqueeze(2),
                ct_clses.float().unsqueeze(2),
                ct_scores.unsqueeze(2)
            ],
                                dim=2)
            detections = torch.cat(
                [bboxes, scores, tl_scores, bl_scores, br_scores, clses],
                dim=2)

            # tl_heat = output['tl'].sigmoid_()
            # bl_heat = output['bl'].sigmoid_()
            # br_heat = output['br'].sigmoid_()
            # ct_heat = output['ct'].sigmoid_()
            #
            # tl_tag = output['tl_tag']
            # bl_tag = output['bl_tag']
            # br_tag = output['br_tag']
            #
            # tl_reg = output['tl_reg']
            # bl_reg = output['bl_reg']
            # br_reg = output['br_reg']
            # ct_reg = output['ct_reg']
            #
            # kernel = self.opt.nms_kernel
            # ae_threshold = self.opt.ae_threshold
            # K = self.opt.K
            #
            # batch, cat, height, width = tl_heat.size()
            #
            # # perform nms on heatmaps
            # tl_heat = _nms(tl_heat, kernel=kernel)
            # bl_heat = _nms(bl_heat, kernel=kernel)
            # br_heat = _nms(br_heat, kernel=kernel)
            # ct_heat = _nms(ct_heat, kernel=kernel)
            #
            # tl_scores, tl_inds, tl_clses, tl_ys, tl_xs = _topk(tl_heat, K=K)
            # bl_scores, bl_inds, bl_clses, bl_ys, bl_xs = _topk(bl_heat, K=K)
            # br_scores, br_inds, br_clses, br_ys, br_xs = _topk(br_heat, K=K)
            # ct_scores, ct_inds, ct_clses, ct_ys, ct_xs = _topk(ct_heat, K=K)
            #
            # tl_ys = tl_ys.view(batch, K, 1, 1).expand(batch, K, K, K)
            # tl_xs = tl_xs.view(batch, K, 1, 1).expand(batch, K, K, K)
            # bl_ys = bl_ys.view(batch, 1, K, 1).expand(batch, K, K, K)
            # bl_xs = bl_xs.view(batch, 1, K, 1).expand(batch, K, K, K)
            # br_ys = br_ys.view(batch, 1, 1, K).expand(batch, K, K, K)
            # br_xs = br_xs.view(batch, 1, 1, K).expand(batch, K, K, K)
            # ct_ys = ct_ys.view(batch, 1, K).expand(batch, K, K)
            # ct_xs = ct_xs.view(batch, 1, K).expand(batch, K, K)
            #
            # if tl_reg is not None and bl_reg is not None and br_reg is not None:
            #     tl_reg = _tranpose_and_gather_feat(tl_reg, tl_inds)
            #     tl_reg = tl_reg.view(batch, K, 1, 1, 2)
            #     bl_reg = _tranpose_and_gather_feat(bl_reg, bl_inds)
            #     bl_reg = bl_reg.view(batch, 1, K, 1, 2)
            #     br_reg = _tranpose_and_gather_feat(br_reg, br_inds)
            #     br_reg = br_reg.view(batch, 1, 1, K, 2)
            #     ct_reg = _tranpose_and_gather_feat(ct_reg, ct_inds)
            #     ct_reg = ct_reg.view(batch, 1, K, 2)
            #
            #     tl_xs = tl_xs + tl_reg[..., 0]
            #     tl_ys = tl_ys + tl_reg[..., 1]
            #     bl_xs = bl_xs + bl_reg[..., 0]
            #     bl_ys = bl_ys + bl_reg[..., 1]
            #     br_xs = br_xs + br_reg[..., 0]
            #     br_ys = br_ys + br_reg[..., 1]
            #     ct_xs = ct_xs + ct_reg[..., 0]
            #     ct_ys = ct_ys + ct_reg[..., 1]
            #
            # # all possible boxes based on top k corners (ignoring class)
            # bboxes = torch.stack((tl_xs, tl_ys, bl_xs, bl_ys, br_xs, br_ys), dim=4)
            #
            # tl_tag = _tranpose_and_gather_feat(tl_tag, tl_inds)
            # tl_tag = tl_tag.view(batch, K, 1, 1).expand(batch, K, K, K)
            # bl_tag = _tranpose_and_gather_feat(bl_tag, bl_inds)
            # bl_tag = bl_tag.view(batch, 1, K, 1).expand(batch, K, K, K)
            # br_tag = _tranpose_and_gather_feat(br_tag, br_inds)
            # br_tag = br_tag.view(batch, 1, 1, K).expand(batch, K, K, K)
            # avg_tag = (tl_tag + bl_tag + br_tag) / 3
            # dists = (torch.abs(tl_tag - avg_tag) + torch.abs(bl_tag - avg_tag) + torch.abs(br_tag - avg_tag)) / 3
            #
            # tl_scores = tl_scores.view(batch, K, 1, 1).expand(batch, K, K, K)
            # bl_scores = bl_scores.view(batch, 1, K, 1).expand(batch, K, K, K)
            # br_scores = br_scores.view(batch, 1, 1, K).expand(batch, K, K, K)
            # scores = (tl_scores + bl_scores + br_scores) / 3
            #
            # # reject boxes based on classes
            # tl_clses = tl_clses.view(batch, K, 1, 1).expand(batch, K, K, K)
            # bl_clses = bl_clses.view(batch, 1, K, 1).expand(batch, K, K, K)
            # br_clses = br_clses.view(batch, 1, 1, K).expand(batch, K, K, K)
            # cls_inds = (tl_clses != bl_clses) | (bl_clses != br_clses) | (tl_clses != br_clses)
            #
            # # reject boxes based on distances
            # dist_inds = (dists > ae_threshold)
            #
            # # instead of filtering prediction according to the out-of-bound rotation, do data augmentation to mirror groundtruth
            #
            # scores[cls_inds] = -1
            # scores[dist_inds] = -1
            #
            # scores = scores.view(batch, -1)
            # scores, inds = torch.topk(scores, 100)
            # scores = scores.unsqueeze(2)
            #
            # bboxes = bboxes.view(batch, -1, 6)
            # bboxes = _gather_feat(bboxes, inds)
            #
            # tl_tag = tl_tag.contiguous().view(batch, -1, 1)
            # tl_tag = _gather_feat(tl_tag, inds)
            # bl_tag = bl_tag.contiguous().view(batch, -1, 1)
            # bl_tag = _gather_feat(bl_tag, inds)
            # br_tag = br_tag.contiguous().view(batch, -1, 1)
            # br_tag = _gather_feat(br_tag, inds)
            # avg_tag = avg_tag.contiguous().view(batch, -1, 1)
            # avg_tag = _gather_feat(avg_tag, inds)
            #
            # clses = bl_clses.contiguous().view(batch, -1, 1)
            # clses = _gather_feat(clses, inds).float()
            #
            # tl_scores = tl_scores.contiguous().view(batch, -1, 1)
            # tl_scores = _gather_feat(tl_scores, inds).float()
            # bl_scores = bl_scores.contiguous().view(batch, -1, 1)
            # bl_scores = _gather_feat(bl_scores, inds).float()
            # br_scores = br_scores.contiguous().view(batch, -1, 1)
            # br_scores = _gather_feat(br_scores, inds).float()
            #
            # ct_xs = ct_xs[:, 0, :]
            # ct_ys = ct_ys[:, 0, :]
            #
            # centers = torch.cat(
            #     [ct_xs.unsqueeze(2), ct_ys.unsqueeze(2), ct_clses.float().unsqueeze(2), ct_scores.unsqueeze(2)], dim=2)
            # detections = torch.cat([bboxes, scores, tl_scores, bl_scores, br_scores, clses, tl_tag, bl_tag, br_tag, avg_tag], dim=2)

        return detections, centers
Exemple #25
0
    def forward(self, outputs, batch, dataparallel=False):
        opt = self.opt
        hm_loss, wh_loss, off_loss, id_loss, edge_loss = 0, 0, 0, 0, 0
        for s in range(opt.num_stacks):
            output = outputs[s]
            if not opt.mse_loss:
                output['hm'] = _sigmoid(output['hm'])

            hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks
            if opt.wh_weight > 0:
                if opt.dense_wh:
                    mask_weight = batch['dense_wh_mask'].sum() + 1e-4
                    wh_loss += (self.crit_wh(
                        output['wh'] * batch['dense_wh_mask'],
                        batch['dense_wh'] * batch['dense_wh_mask']) /
                                mask_weight) / opt.num_stacks
                else:
                    wh_loss += self.crit_reg(output['wh'], batch['reg_mask'],
                                             batch['ind'],
                                             batch['wh']) / opt.num_stacks

            if opt.reg_offset and opt.off_weight > 0:
                off_loss += self.crit_reg(output['reg'], batch['reg_mask'],
                                          batch['ind'],
                                          batch['reg']) / opt.num_stacks

            if opt.id_weight > 0:
                id_head = _tranpose_and_gather_feat(output['id'], batch['ind'])
                id_head = id_head[batch['reg_mask'] > 0].contiguous()
                id_head = self.emb_scale * F.normalize(id_head)
                id_target = batch['ids'][batch['reg_mask'] > 0]
                id_output = self.classifier(id_head).contiguous()
                id_loss += self.IDLoss(id_output, id_target)
                #id_loss += self.IDLoss(id_output, id_target) + self.TriLoss(id_head, id_target)

            if opt.edge_reg_weight > 0:
                # TODO: compute the edge regression with BCELoss here
                if output['edge_preds'] is None or output[
                        'edge_labels'] is None:
                    edge_loss += torch.tensor(0.).to(hm_loss.device)
                elif len(output['edge_preds']) == 0 or len(
                        output['edge_labels']) == 0:
                    edge_loss += torch.tensor(0.).to(hm_loss.device)
                else:
                    edge_loss += self.crit_edge(
                        output['edge_preds'],
                        output['edge_labels']) * opt.edge_reg_weight
                if torch.isnan(edge_loss):
                    print(output['edge_preds'])
                    print(output['edge_labels'])
                    import ipdb
                    ipdb.set_trace()
        #loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + opt.off_weight * off_loss + opt.id_weight * id_loss

        det_loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + opt.off_weight * off_loss

        # loss = torch.exp(-self.s_det) * det_loss + torch.exp(-self.s_id) * id_loss + (self.s_det + self.s_id)
        loss = torch.exp(-self.s_det) * det_loss + torch.exp(-self.s_id) * (
            id_loss + edge_loss) + (self.s_det + self.s_id)
        loss *= 0.5

        #print(loss, hm_loss, wh_loss, off_loss, id_loss)
        if not dataparallel:
            loss_stats = {
                'loss': loss.mean().item(),
                'hm_loss': hm_loss.mean().item(),
                'wh_loss': wh_loss.mean().item(),
                'off_loss': off_loss.mean().item(),
                'id_loss': id_loss.mean().item(),
                'edge_loss': edge_loss.mean().item()
            }
        else:
            loss_stats = {
                'loss': loss,
                'hm_loss': hm_loss,
                'wh_loss': wh_loss,
                'off_loss': off_loss,
                'id_loss': id_loss,
                'edge_loss': edge_loss
            }
        return loss, loss_stats
    def multi_pose_decode(self,
                          heat,
                          wh,
                          kps,
                          reg=None,
                          hm_hp=None,
                          hp_offset=None,
                          K=100):
        batch, cat, height, width = heat.size()
        num_joints = kps.shape[1] // 2
        # perform nms on heatmaps
        heat = self._nms(heat)
        scores, inds, clses, ys, xs = self._topk(heat, K=K)

        kps = _tranpose_and_gather_feat(kps, inds)
        kps = kps.view(batch, K, num_joints * 2)
        kps[..., ::2] += xs.view(batch, K, 1).expand(batch, K, num_joints)
        kps[..., 1::2] += ys.view(batch, K, 1).expand(batch, K, num_joints)
        if reg is not None:
            reg = _tranpose_and_gather_feat(reg, inds)
            reg = reg.view(batch, K, 2)
            xs = xs.view(batch, K, 1) + reg[:, :, 0:1]
            ys = ys.view(batch, K, 1) + reg[:, :, 1:2]
        else:
            xs = xs.view(batch, K, 1) + 0.5
            ys = ys.view(batch, K, 1) + 0.5
        wh = _tranpose_and_gather_feat(wh, inds)
        wh = wh.view(batch, K, 2)
        clses = clses.view(batch, K, 1).float()
        scores = scores.view(batch, K, 1)

        bboxes = torch.cat([
            xs - wh[..., 0:1] / 2, ys - wh[..., 1:2] / 2,
            xs + wh[..., 0:1] / 2, ys + wh[..., 1:2] / 2
        ],
                           dim=2)
        if hm_hp is not None:
            hm_hp = self._nms(hm_hp)
            thresh = 0.1
            kps = kps.view(batch, K, num_joints,
                           2).permute(0, 2, 1, 3).contiguous()  # b x J x K x 2
            reg_kps = kps.unsqueeze(3).expand(batch, num_joints, K, K, 2)
            hm_score, hm_inds, hm_ys, hm_xs = self._topk_channel(
                hm_hp, K=K)  # b x J x K
            if hp_offset is not None:
                hp_offset = _tranpose_and_gather_feat(hp_offset,
                                                      hm_inds.view(batch, -1))
                hp_offset = hp_offset.view(batch, num_joints, K, 2)
                hm_xs = hm_xs + hp_offset[:, :, :, 0]
                hm_ys = hm_ys + hp_offset[:, :, :, 1]
            else:
                hm_xs = hm_xs + 0.5
                hm_ys = hm_ys + 0.5

            mask = (hm_score > thresh).float()
            hm_score = (1 - mask) * -1 + mask * hm_score
            hm_ys = (1 - mask) * (-10000) + mask * hm_ys
            hm_xs = (1 - mask) * (-10000) + mask * hm_xs
            hm_kps = torch.stack([hm_xs, hm_ys], dim=-1).unsqueeze(2).expand(
                batch, num_joints, K, K, 2)
            dist = (((reg_kps - hm_kps)**2).sum(dim=4)**0.5)
            min_dist, min_ind = dist.min(dim=3)  # b x J x K
            hm_score = hm_score.gather(2,
                                       min_ind).unsqueeze(-1)  # b x J x K x 1
            min_dist = min_dist.unsqueeze(-1)
            min_ind = min_ind.view(batch, num_joints, K, 1,
                                   1).expand(batch, num_joints, K, 1, 2)
            hm_kps = hm_kps.gather(3, min_ind)
            hm_kps = hm_kps.view(batch, num_joints, K, 2)
            l = bboxes[:, :, 0].view(batch, 1, K,
                                     1).expand(batch, num_joints, K, 1)
            t = bboxes[:, :, 1].view(batch, 1, K,
                                     1).expand(batch, num_joints, K, 1)
            r = bboxes[:, :, 2].view(batch, 1, K,
                                     1).expand(batch, num_joints, K, 1)
            b = bboxes[:, :, 3].view(batch, 1, K,
                                     1).expand(batch, num_joints, K, 1)
            mask = (hm_kps[..., 0:1] < l) + (hm_kps[..., 0:1] > r) + \
                 (hm_kps[..., 1:2] < t) + (hm_kps[..., 1:2] > b) + \
                 (hm_score < thresh) + (min_dist > (torch.max(b - t, r - l) * 0.3))
            mask = (mask > 0).float().expand(batch, num_joints, K, 2)
            kps = (1 - mask) * hm_kps + mask * kps
            kps = kps.permute(0, 2, 1,
                              3).contiguous().view(batch, K, num_joints * 2)
        detections = torch.cat([
            bboxes, scores, kps,
            torch.transpose(hm_score.squeeze(dim=3), 1, 2)
        ],
                               dim=2)
        return detections
Exemple #27
0
    def forward(self, output_dict, batch):
        opt = self.opt
        loss_results = {loss: 0 for loss in self.loss_states}

        outputs = output_dict['orig']
        flipped_outputs = output_dict['flipped'] if 'flipped' in output_dict else None

        # Take loss at each scale
        for s in range(opt.num_stacks):
            output = outputs[s]

            # Supervised loss on predicted heatmap
            if not opt.mse_loss:
                output['hm'] = _sigmoid(output['hm'])

            loss_results['hm'] += self.crit(output['hm'], batch['hm']) / opt.num_stacks

            # Supervised loss on object sizes
            if opt.wh_weight > 0:
                if opt.dense_wh:
                    mask_weight = batch['dense_wh_mask'].sum() + 1e-4
                    loss_results['wh'] += (self.crit_wh(output['wh'] * batch['dense_wh_mask'],
                                                        batch['dense_wh'] * batch['dense_wh_mask']) /
                                           mask_weight) / opt.num_stacks
                else:
                    loss_results['wh'] += self.crit_reg(
                        output['wh'], batch['reg_mask'],
                        batch['ind'], batch['wh']) / opt.num_stacks

            # Supervised loss on offsets
            if opt.reg_offset and opt.off_weight > 0:
                loss_results['off'] += self.crit_reg(output['reg'], batch['reg_mask'],
                                                     batch['ind'], batch['reg']) / opt.num_stacks

            id_head = _tranpose_and_gather_feat(output['id'], batch['ind'])
            id_head = id_head[batch['reg_mask'] > 0].contiguous()
            id_head = self.emb_scale * F.normalize(id_head)

            # Supervised loss on object ID predictions
            if opt.id_weight > 0 and not opt.unsup:
                id_target = batch['ids'][batch['reg_mask'] > 0]
                id_output = self.classifier(id_head).contiguous()
                loss_results['id'] += self.IDLoss(id_output, id_target)

            # Take self-supervised loss using negative sample (flipped img)
            if opt.unsup and flipped_outputs is not None:
                flipped_output = flipped_outputs[s]

                flipped_id_head = _tranpose_and_gather_feat(flipped_output['id'], batch['flipped_ind'])
                flipped_id_head = flipped_id_head[batch['reg_mask'] > 0].contiguous()
                # flipped_id_head = self.emb_scale * F.normalize(flipped_id_head)
                flipped_id_head = F.normalize(flipped_id_head)

                # Compute loss between the positive and negative set of reid features
                loss_results[opt.unsup_loss] = self.SelfSupLoss(id_head, flipped_id_head, batch['num_objs'])

        # Total supervised
        det_loss = opt.hm_weight * loss_results['hm'] + \
                   opt.wh_weight * loss_results['wh'] + \
                   opt.off_weight * loss_results['off']

        id_loss = torch.exp(-self.s_id) * loss_results['id'] if not opt.unsup else \
            torch.exp(-self.s_id) * (loss_results[opt.unsup_loss])

        # Total of supervised and self-supervised losses on object embeddings
        total_loss = torch.exp(-self.s_det) * det_loss + \
                     torch.exp(-self.s_id) * id_loss + \
                     self.s_det + self.s_id

        total_loss *= 0.5
        loss_results['loss'] = total_loss

        return total_loss, loss_results
Exemple #28
0
    def debug(self, detections, targets, ae_threshold):
        tl_heat = detections['tl_heatmap']
        bl_heat = detections['bl_heatmap']
        br_heat = detections['br_heatmap']
        ct_heat = detections['ct_heatmap']

        targets['tl_tag'] = targets['tl_tag'][targets['reg_mask']].unsqueeze(0)
        targets['bl_tag'] = targets['bl_tag'][targets['reg_mask']].unsqueeze(0)
        targets['br_tag'] = targets['br_tag'][targets['reg_mask']].unsqueeze(0)
        targets['ct_tag'] = targets['ct_tag'][targets['reg_mask']].unsqueeze(0)
        targets['tl_reg'] = targets['tl_reg'][targets['reg_mask']].unsqueeze(0)
        targets['bl_reg'] = targets['bl_reg'][targets['reg_mask']].unsqueeze(0)
        targets['br_reg'] = targets['br_reg'][targets['reg_mask']].unsqueeze(0)
        targets['ct_reg'] = targets['ct_reg'][targets['reg_mask']].unsqueeze(0)

        batch, cat, height, width = tl_heat.size()

        # tl_scores, tl_inds, tl_clses, tl_ys, tl_xs = _topk(tl_heat, K=256)
        # bl_scores, bl_inds, bl_clses, bl_ys, bl_xs = _topk(bl_heat, K=256)
        # br_scores, br_inds, br_clses, br_ys, br_xs = _topk(br_heat, K=256)
        # ct_scores, ct_inds, ct_clses, ct_ys, ct_xs = _topk(ct_heat, K=256)

        tl_tag = detections['tl_tag']
        bl_tag = detections['bl_tag']
        br_tag = detections['br_tag']
        tl_reg = detections['tl_reg']
        bl_reg = detections['bl_reg']
        br_reg = detections['br_reg']
        ct_reg = detections['ct_reg']

        # gather by gt
        tl_tag = _tranpose_and_gather_feat(
            tl_tag, targets['tl_tag'].to(torch.device("cuda")))
        bl_tag = _tranpose_and_gather_feat(
            bl_tag, targets['bl_tag'].to(torch.device("cuda")))
        br_tag = _tranpose_and_gather_feat(
            br_tag, targets['br_tag'].to(torch.device("cuda")))
        # gather by top k
        # tl_tag = _tranpose_and_gather_feat(tl_tag, tl_inds)
        # bl_tag = _tranpose_and_gather_feat(bl_tag, bl_inds)
        # br_tag = _tranpose_and_gather_feat(br_tag, br_inds)

        avg_tag = (tl_tag + bl_tag + br_tag) / 3

        dists_tl = torch.abs(avg_tag - tl_tag).to(torch.device("cpu")).numpy()
        dists_bl = torch.abs(bl_tag - avg_tag).to(torch.device("cpu")).numpy()
        dists_br = torch.abs(avg_tag - br_tag).to(torch.device("cpu")).numpy()
        dists_avg = (dists_tl.sum() + dists_bl.sum() +
                     dists_br.sum()) / dists_tl.shape[1] / 3
        min_tl = dists_tl.min()
        max_tl = dists_tl.max()
        min_bl = dists_bl.min()
        max_bl = dists_bl.max()
        min_br = dists_br.min()
        max_br = dists_br.max()

        # gather by gt
        tl_reg = _tranpose_and_gather_feat(
            tl_reg, targets['tl_tag'].to(torch.device("cuda")))
        bl_reg = _tranpose_and_gather_feat(
            bl_reg, targets['bl_tag'].to(torch.device("cuda")))
        br_reg = _tranpose_and_gather_feat(
            br_reg, targets['br_tag'].to(torch.device("cuda")))
        ct_reg = _tranpose_and_gather_feat(
            ct_reg, targets['ct_tag'].to(torch.device("cuda")))

        # reg_diff_tl = tl_reg - targets['tl_reg'].to(torch.device("cuda"))
        # reg_diff_tl = torch.sqrt(reg_diff_tl[..., 0]*reg_diff_tl[..., 0] + reg_diff_tl[..., 1]*reg_diff_tl[..., 1])
        # reg_diff_bl = bl_reg - targets['bl_reg'].to(torch.device("cuda"))
        # reg_diff_bl = torch.sqrt(reg_diff_bl[..., 0] * reg_diff_bl[..., 0] + reg_diff_bl[..., 1] * reg_diff_bl[..., 1])
        # reg_diff_br = br_reg - targets['br_reg'].to(torch.device("cuda"))
        # reg_diff_br = torch.sqrt(reg_diff_br[..., 0] * reg_diff_br[..., 0] + reg_diff_br[..., 1] * reg_diff_br[..., 1])
        # reg_diff_ct = ct_reg - targets['ct_reg'].to(torch.device("cuda"))
        # reg_diff_ct = torch.sqrt(reg_diff_ct[..., 0] * reg_diff_ct[..., 0] + reg_diff_ct[..., 1] * reg_diff_ct[..., 1])

        tl_xs = ((targets['tl_tag'] % (width * height)) %
                 width).int().float().to(torch.device("cuda"))
        tl_ys = ((targets['tl_tag'] % (width * height)) /
                 width).int().float().to(torch.device("cuda"))
        bl_xs = ((targets['bl_tag'] % (width * height)) %
                 width).int().float().to(torch.device("cuda"))
        bl_ys = ((targets['bl_tag'] % (width * height)) /
                 width).int().float().to(torch.device("cuda"))
        br_xs = ((targets['br_tag'] % (width * height)) %
                 width).int().float().to(torch.device("cuda"))
        br_ys = ((targets['br_tag'] % (width * height)) /
                 width).int().float().to(torch.device("cuda"))
        ct_xs = ((targets['ct_tag'] % (width * height)) %
                 width).int().float().to(torch.device("cuda"))
        ct_ys = ((targets['ct_tag'] % (width * height)) /
                 width).int().float().to(torch.device("cuda"))

        tl_xs_pr = (tl_xs + tl_reg[..., 0]).squeeze(0).to(
            torch.device("cpu")).numpy()
        tl_ys_pr = (tl_ys + tl_reg[..., 1]).squeeze(0).to(
            torch.device("cpu")).numpy()
        bl_xs_pr = (bl_xs + bl_reg[..., 0]).squeeze(0).to(
            torch.device("cpu")).numpy()
        bl_ys_pr = (bl_ys + bl_reg[..., 1]).squeeze(0).to(
            torch.device("cpu")).numpy()
        br_xs_pr = (br_xs + br_reg[..., 0]).squeeze(0).to(
            torch.device("cpu")).numpy()
        br_ys_pr = (br_ys + br_reg[..., 1]).squeeze(0).to(
            torch.device("cpu")).numpy()
        ct_xs_pr = (ct_xs + ct_reg[..., 0]).squeeze(0).to(
            torch.device("cpu")).numpy()
        ct_ys_pr = (ct_ys + ct_reg[..., 1]).squeeze(0).to(
            torch.device("cpu")).numpy()

        tl_xs_gt = (tl_xs + targets['tl_reg'][..., 0].to(
            torch.device("cuda"))).squeeze(0).to(torch.device("cpu")).numpy()
        tl_ys_gt = (tl_ys + targets['tl_reg'][..., 1].to(
            torch.device("cuda"))).squeeze(0).to(torch.device("cpu")).numpy()
        bl_xs_gt = (bl_xs + targets['bl_reg'][..., 0].to(
            torch.device("cuda"))).squeeze(0).to(torch.device("cpu")).numpy()
        bl_ys_gt = (bl_ys + targets['bl_reg'][..., 1].to(
            torch.device("cuda"))).squeeze(0).to(torch.device("cpu")).numpy()
        br_xs_gt = (br_xs + targets['br_reg'][..., 0].to(
            torch.device("cuda"))).squeeze(0).to(torch.device("cpu")).numpy()
        br_ys_gt = (br_ys + targets['br_reg'][..., 1].to(
            torch.device("cuda"))).squeeze(0).to(torch.device("cpu")).numpy()
        ct_xs_gt = (ct_xs + targets['ct_reg'][..., 0].to(
            torch.device("cuda"))).squeeze(0).to(torch.device("cpu")).numpy()
        ct_ys_gt = (ct_ys + targets['ct_reg'][..., 1].to(
            torch.device("cuda"))).squeeze(0).to(torch.device("cpu")).numpy()

        bboxes_gt = targets['bbox'][targets['reg_mask']]

        nm_instances = tl_xs_pr.shape[0]

        for i in range(nm_instances):
            bbox_gt = bboxes_gt[i, :]
            # prediction
            bbox_coord_pr = []
            tl_x_pr = tl_xs_pr[i]
            tl_y_pr = tl_ys_pr[i]
            bl_x_pr = bl_xs_pr[i]
            bl_y_pr = bl_ys_pr[i]
            br_x_pr = br_xs_pr[i]
            br_y_pr = br_ys_pr[i]

            # center
            x_c = (tl_x_pr + br_x_pr) / 2.
            y_c = (tl_y_pr + br_y_pr) / 2.

            if bl_x_pr == br_x_pr:
                p_y = tl_y_pr
                p_x = br_x_pr
                if br_y_pr > bl_y_pr:
                    angle = np.pi / 2.
                else:
                    angle = -np.pi / 2.
            elif bl_y_pr == br_y_pr:
                p_x = tl_x_pr
                p_y = br_y_pr
                angle = 0.
            else:
                # angle
                angle = math.atan2(-(br_y_pr - bl_y_pr), br_x_pr - bl_x_pr)
                # find intersected point
                a = (br_x_pr - bl_x_pr) / (br_y_pr - bl_y_pr)
                b = br_y_pr - a * br_x_pr
                delta_x = br_x_pr - bl_x_pr
                delta_y = br_y_pr - bl_y_pr
                p_x = (delta_x * tl_x_pr + delta_y * tl_y_pr -
                       delta_x * b) / (delta_x + delta_x * a)
                p_y = a * p_x + b
                # w, h
            w = np.sqrt((br_x_pr - p_x) * (br_x_pr - p_x) + (br_y_pr - p_y) *
                        (br_y_pr - p_y))
            h = np.sqrt((tl_x_pr - p_x) * (tl_x_pr - p_x) + (tl_y_pr - p_y) *
                        (tl_y_pr - p_y))

            bbox_coord_pr.append(
                [x_c - w / 2, y_c - h / 2, x_c + w / 2, y_c + h / 2, angle])
            bbox_coord_pr = np.array(bbox_coord_pr)

            # groundtruth
            boxes_coord_gt = []
            tl_x_gt = tl_xs_gt[i]
            tl_y_gt = tl_ys_gt[i]
            bl_x_gt = bl_xs_gt[i]
            bl_y_gt = bl_ys_gt[i]
            br_x_gt = br_xs_gt[i]
            br_y_gt = br_ys_gt[i]
            if bl_x_gt == br_x_gt:
                p_y = tl_y_gt
                p_x = bl_x_gt
                if br_y_gt > bl_y_gt:
                    angle = np.pi / 4
                else:
                    angle = -np.pi / 4
            else:
                # center
                x_c = (tl_x_gt + br_x_gt) / 2.
                y_c = (tl_y_gt + br_y_gt) / 2.
                # angle
                angle = math.atan(-(br_y_gt - bl_y_gt) / (br_x_gt - bl_x_gt))
                # find intersected point
                a = (br_y_gt - bl_y_gt) / (br_x_gt - bl_x_gt)
                b = br_y_gt - a * br_x_gt
                delta_x = br_x_gt - bl_x_gt
                delta_y = br_y_gt - bl_y_gt
                p_x = (delta_x * tl_x_gt + delta_y * tl_y_gt -
                       delta_y * b) / (delta_x + delta_y * a)
                p_y = a * p_x + b
                # w, h
            w = np.sqrt((br_x_gt - p_x) * (br_x_gt - p_x) + (br_y_gt - p_y) *
                        (br_y_gt - p_y))
            h = np.sqrt((tl_x_gt - p_x) * (tl_x_gt - p_x) + (tl_y_gt - p_y) *
                        (tl_y_gt - p_y))
            boxes_coord_gt.append(
                [x_c - w / 2, y_c - h / 2, x_c + w / 2, y_c + h / 2, angle])
            boxes_coord_gt = np.array(boxes_coord_gt)
            # print(np.array_equal(bbox_gt, boxes_coord_gt))

            overlaps = _bbox_overlaps(
                np.ascontiguousarray(bbox_coord_pr[:, :4], dtype=np.float32),
                np.ascontiguousarray(boxes_coord_gt[:, :4], dtype=np.float32),
                bbox_coord_pr[:, -1], boxes_coord_gt[:, -1], 128, 128)

            flag_suc = False
            flag_exit = 0
            for i in range(overlaps.shape[0]):
                for j in range(overlaps.shape[1]):
                    value_overlap = overlaps[i, j]
                    angle_diff = math.fabs(bbox_coord_pr[i, -1] -
                                           boxes_coord_gt[j, -1])

                    if value_overlap > 0.25 and angle_diff < np.pi / 6:
                        flag_suc = True
                        flag_exit = 1
                        break
                if flag_exit:
                    break
            if flag_exit:
                break

        return min_tl, max_tl, min_bl, max_bl, min_br, max_br, dists_avg, flag_suc
Exemple #29
0
    def update(self, im_blob, img0, p_crops, p_crops_lengths, edge_index, gnn_output_layer=-1, p_imgs=None,
               conf_thres=0.3):
        self.frame_id += 1
        activated_starcks = []
        refind_stracks = []
        lost_stracks = []
        removed_stracks = []

        width = img0.shape[1]
        height = img0.shape[0]
        inp_height = im_blob.shape[2]
        inp_width = im_blob.shape[3]
        c = np.array([width / 2., height / 2.], dtype=np.float32)
        s = max(float(inp_width) / float(inp_height) * height, width) * 1.0
        meta = {'c': c, 's': s,
                'out_height': inp_height // self.opt.down_ratio,
                'out_width': inp_width // self.opt.down_ratio}

        ''' Step 1: Network forward, get detections & embeddings'''
        with torch.no_grad():
            output = self.model(im_blob, p_crops, p_crops_lengths, edge_index, p_imgs=p_imgs)[gnn_output_layer]
            if type(output) is list:
                output = output[-1]
            hm = output['hm'].sigmoid_()
            wh = output['wh']
            id_feature = output['id']
            id_feature = F.normalize(id_feature, dim=1)

            reg = output['reg'] if self.opt.reg_offset else None
            dets, inds = mot_decode(hm, wh, reg=reg, cat_spec_wh=self.opt.cat_spec_wh, K=self.opt.K)
            id_feature = _tranpose_and_gather_feat(id_feature, inds)
            id_feature = id_feature.squeeze(0)
            id_feature = id_feature.cpu().numpy()
        if self.viz_attention and self.frame_id == self.opt.vis_attn_frame:
            # vis attention
            attn = output['p']
            node0_neighbor_idx = output['node0_neighbor_idx']
            keep = torch.where(attn > self.opt.vis_attn_thres)[0]
            self.visualize_centers(im_blob, keep, node0_neighbor_idx, attn, output, p_imgs)

        dets = self.post_process(dets, meta)
        dets = self.merge_outputs([dets])[1]

        # remain_inds = dets[:, 4] > self.opt.conf_thres
        remain_inds = dets[:, 4] > conf_thres
        dets = dets[remain_inds]
        id_feature = id_feature[remain_inds]


        # vis
        '''
        for i in range(0, dets.shape[0]):
            bbox = dets[i][0:4]
            cv2.rectangle(img0, (bbox[0], bbox[1]),
                          (bbox[2], bbox[3]),
                          (0, 255, 0), 2)
        cv2.imshow('dets', img0)
        cv2.waitKey(0)
        id0 = id0-1
        '''

        if len(dets) > 0:
            '''Detections'''
            detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f, 30) for
                          (tlbrs, f) in zip(dets[:, :5], id_feature)]
        else:
            detections = []

        ''' Add newly detected tracklets to tracked_stracks'''
        unconfirmed = []
        tracked_stracks = []  # type: list[STrack]
        for track in self.tracked_stracks:
            if not track.is_activated:
                unconfirmed.append(track)
            else:
                tracked_stracks.append(track)

        ''' Step 2: First association, with embedding'''
        strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
        # Predict the current location with KF
        #for strack in strack_pool:
            #strack.predict()
        STrack.multi_predict(strack_pool)
        dists = matching.embedding_distance(strack_pool, detections)
        #dists = matching.gate_cost_matrix(self.kalman_filter, dists, strack_pool, detections)
        dists = matching.fuse_motion(self.kalman_filter, dists, strack_pool, detections)
        matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.7)

        for itracked, idet in matches:
            track = strack_pool[itracked]
            det = detections[idet]
            if track.state == TrackState.Tracked:
                track.update(detections[idet], self.frame_id)
                activated_starcks.append(track)
            else:
                track.re_activate(det, self.frame_id, new_id=False)
                refind_stracks.append(track)

        ''' Step 3: Second association, with IOU'''
        detections = [detections[i] for i in u_detection]
        r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
        dists = matching.iou_distance(r_tracked_stracks, detections)
        matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.5)

        for itracked, idet in matches:
            track = r_tracked_stracks[itracked]
            det = detections[idet]
            if track.state == TrackState.Tracked:
                track.update(det, self.frame_id)
                activated_starcks.append(track)
            else:
                track.re_activate(det, self.frame_id, new_id=False)
                refind_stracks.append(track)

        for it in u_track:
            track = r_tracked_stracks[it]
            if not track.state == TrackState.Lost:
                track.mark_lost()
                lost_stracks.append(track)

        '''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
        detections = [detections[i] for i in u_detection]
        dists = matching.iou_distance(unconfirmed, detections)
        matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
        for itracked, idet in matches:
            unconfirmed[itracked].update(detections[idet], self.frame_id)
            activated_starcks.append(unconfirmed[itracked])
        for it in u_unconfirmed:
            track = unconfirmed[it]
            track.mark_removed()
            removed_stracks.append(track)

        """ Step 4: Init new stracks"""
        for inew in u_detection:
            track = detections[inew]
            # if track.score < self.det_thresh:
            if track.score < conf_thres:
                continue
            track.activate(self.kalman_filter, self.frame_id)
            activated_starcks.append(track)
        """ Step 5: Update state"""
        for track in self.lost_stracks:
            if self.frame_id - track.end_frame > self.max_time_lost:
                track.mark_removed()
                removed_stracks.append(track)

        # print('Ramained match {} s'.format(t4-t3))

        self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
        self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
        self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
        self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
        self.lost_stracks.extend(lost_stracks)
        self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
        self.removed_stracks.extend(removed_stracks)
        self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
        # get scores of lost tracks
        output_stracks = [track for track in self.tracked_stracks if track.is_activated]

        logger.debug('===========Frame {}=========='.format(self.frame_id))
        logger.debug('Activated: {}'.format([track.track_id for track in activated_starcks]))
        logger.debug('Refind: {}'.format([track.track_id for track in refind_stracks]))
        logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
        logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))

        return output_stracks
Exemple #30
0
    def update(self, im_blob, img0):
        self.frame_id += 1
        activated_starcks = []
        refind_stracks = []
        lost_stracks = []
        removed_stracks = []

        width = img0.shape[1]
        height = img0.shape[0]
        inp_height = im_blob.shape[2]
        inp_width = im_blob.shape[3]
        c = np.array([width / 2.0, height / 2.0], dtype=np.float32)
        s = max(float(inp_width) / float(inp_height) * height, width) * 1.0
        meta = {
            "c": c,
            "s": s,
            "out_height": inp_height // self.opt.down_ratio,
            "out_width": inp_width // self.opt.down_ratio,
        }
        """ Step 1: Network forward, get detections & embeddings"""
        with torch.no_grad():
            output = self.model(im_blob)[-1]
            hm = output["hm"].sigmoid_()
            wh = output["wh"]
            id_feature = output["id"]
            id_feature = F.normalize(id_feature, dim=1)

            reg = output["reg"] if self.opt.reg_offset else None
            dets, inds = mot_decode(hm,
                                    wh,
                                    reg=reg,
                                    cat_spec_wh=self.opt.cat_spec_wh,
                                    K=self.opt.K)
            id_feature = _tranpose_and_gather_feat(id_feature, inds)
            id_feature = id_feature.squeeze(0)
            id_feature = id_feature.cpu().numpy()

        dets = self.post_process(dets, meta)
        dets = self.merge_outputs([dets])[1]

        remain_inds = dets[:, 4] > self.opt.conf_thres
        dets = dets[remain_inds]
        id_feature = id_feature[remain_inds]

        # vis
        """
        for i in range(0, dets.shape[0]):
            bbox = dets[i][0:4]
            cv2.rectangle(img0, (bbox[0], bbox[1]),
                          (bbox[2], bbox[3]),
                          (0, 255, 0), 2)
        cv2.imshow('dets', img0)
        cv2.waitKey(0)
        id0 = id0-1
        """

        if len(dets) > 0:
            """Detections"""
            detections = [
                STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f, 30)
                for (tlbrs, f) in zip(dets[:, :5], id_feature)
            ]
        else:
            detections = []
        """ Add newly detected tracklets to tracked_stracks"""
        unconfirmed = []
        tracked_stracks = []  # type: list[STrack]
        for track in self.tracked_stracks:
            if not track.is_activated:
                unconfirmed.append(track)
            else:
                tracked_stracks.append(track)
        """ Step 2: First association, with embedding"""
        strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
        # Predict the current location with KF
        # for strack in strack_pool:
        # strack.predict()
        STrack.multi_predict(strack_pool)
        dists = matching.embedding_distance(strack_pool, detections)
        # dists = matching.gate_cost_matrix(self.kalman_filter, dists, strack_pool, detections)
        dists = matching.fuse_motion(self.kalman_filter, dists, strack_pool,
                                     detections)
        matches, u_track, u_detection = matching.linear_assignment(dists,
                                                                   thresh=0.7)

        for itracked, idet in matches:
            track = strack_pool[itracked]
            det = detections[idet]
            if track.state == TrackState.Tracked:
                track.update(detections[idet], self.frame_id)
                activated_starcks.append(track)
            else:
                track.re_activate(det, self.frame_id, new_id=False)
                refind_stracks.append(track)
        """ Step 3: Second association, with IOU"""
        detections = [detections[i] for i in u_detection]
        r_tracked_stracks = [
            strack_pool[i] for i in u_track
            if strack_pool[i].state == TrackState.Tracked
        ]
        dists = matching.iou_distance(r_tracked_stracks, detections)
        matches, u_track, u_detection = matching.linear_assignment(dists,
                                                                   thresh=0.5)

        for itracked, idet in matches:
            track = r_tracked_stracks[itracked]
            det = detections[idet]
            if track.state == TrackState.Tracked:
                track.update(det, self.frame_id)
                activated_starcks.append(track)
            else:
                track.re_activate(det, self.frame_id, new_id=False)
                refind_stracks.append(track)

        for it in u_track:
            track = r_tracked_stracks[it]
            if not track.state == TrackState.Lost:
                track.mark_lost()
                lost_stracks.append(track)
        """Deal with unconfirmed tracks, usually tracks with only one beginning frame"""
        detections = [detections[i] for i in u_detection]
        dists = matching.iou_distance(unconfirmed, detections)
        matches, u_unconfirmed, u_detection = matching.linear_assignment(
            dists, thresh=0.7)
        for itracked, idet in matches:
            unconfirmed[itracked].update(detections[idet], self.frame_id)
            activated_starcks.append(unconfirmed[itracked])
        for it in u_unconfirmed:
            track = unconfirmed[it]
            track.mark_removed()
            removed_stracks.append(track)
        """ Step 4: Init new stracks"""
        for inew in u_detection:
            track = detections[inew]
            if track.score < self.det_thresh:
                continue
            track.activate(self.kalman_filter, self.frame_id)
            activated_starcks.append(track)
        """ Step 5: Update state"""
        for track in self.lost_stracks:
            if self.frame_id - track.end_frame > self.max_time_lost:
                track.mark_removed()
                removed_stracks.append(track)

        # print('Ramained match {} s'.format(t4-t3))

        self.tracked_stracks = [
            t for t in self.tracked_stracks if t.state == TrackState.Tracked
        ]
        self.tracked_stracks = joint_stracks(self.tracked_stracks,
                                             activated_starcks)
        self.tracked_stracks = joint_stracks(self.tracked_stracks,
                                             refind_stracks)
        self.lost_stracks = sub_stracks(self.lost_stracks,
                                        self.tracked_stracks)
        self.lost_stracks.extend(lost_stracks)
        self.lost_stracks = sub_stracks(self.lost_stracks,
                                        self.removed_stracks)
        self.removed_stracks.extend(removed_stracks)
        self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(
            self.tracked_stracks, self.lost_stracks)
        # get scores of lost tracks
        output_stracks = [
            track for track in self.tracked_stracks if track.is_activated
        ]

        logger.debug("===========Frame {}==========".format(self.frame_id))
        logger.debug("Activated: {}".format(
            [track.track_id for track in activated_starcks]))
        logger.debug("Refind: {}".format(
            [track.track_id for track in refind_stracks]))
        logger.debug("Lost: {}".format(
            [track.track_id for track in lost_stracks]))
        logger.debug("Removed: {}".format(
            [track.track_id for track in removed_stracks]))

        return output_stracks