예제 #1
0
    def forward(self, outputs, batch):
        """
        :param outputs:
        :param batch:
        :return:
        """
        opt = self.opt
        hm_loss, wh_loss, off_loss, id_loss = 0.0, 0.0, 0.0, 0.0  # 初始化4个loss为0
        for s in range(opt.num_stacks):
            output = outputs[s]
            if not opt.mse_loss:
                output['hm'] = _sigmoid(output['hm'])

            # 计算heatmap loss
            hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks
            if opt.wh_weight > 0:
                if opt.dense_wh:
                    mask_weight = batch['dense_wh_mask'].sum() + 1e-4
                    wh_loss += (self.crit_wh(
                        output['wh'] * batch['dense_wh_mask'],
                        batch['dense_wh'] * batch['dense_wh_mask']) /
                                mask_weight) / opt.num_stacks
                else:  # 计算box尺寸的L1/Smooth L1 loss
                    wh_loss += self.crit_reg(output['wh'], batch['reg_mask'],
                                             batch['ind'],
                                             batch['wh']) / opt.num_stacks

            if opt.reg_offset and opt.off_weight > 0:  # 计算box中心坐标偏移的L1 loss
                off_loss += self.crit_reg(output['reg'], batch['reg_mask'],
                                          batch['ind'],
                                          batch['reg']) / opt.num_stacks

            # 检测目标id分类的交叉熵损失
            if opt.id_weight > 0:
                id_head = _tranpose_and_gather_feat(output['id'], batch['ind'])
                id_head = id_head[
                    batch['reg_mask'] > 0].contiguous()  # 只有有目标的像素才计算id loss
                id_head = self.emb_scale * F.normalize(id_head)
                id_target = batch['ids'][batch['reg_mask'] > 0]  # 有目标的track id
                id_output = self.classifier.forward(
                    id_head).contiguous()  # 用于检测目标分类的最后一层是FC?
                id_loss += self.IDLoss(id_output, id_target)
                # id_loss += self.IDLoss(id_output, id_target) + self.TriLoss(id_head, id_target)

        # loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + opt.off_weight * off_loss + opt.id_weight * id_loss

        det_loss = opt.hm_weight * hm_loss \
                   + opt.wh_weight * wh_loss \
                   + opt.off_weight * off_loss

        loss = torch.exp(-self.s_det) * det_loss \
               + torch.exp(-self.s_id) * id_loss \
               + (self.s_det + self.s_id)
        loss *= 0.5
        # print(loss, hm_loss, wh_loss, off_loss, id_loss)

        loss_stats = {
            'loss': loss,
            'hm_loss': hm_loss,
            'wh_loss': wh_loss,
            'off_loss': off_loss,
            'id_loss': id_loss
        }
        return loss, loss_stats
예제 #2
0
    def update_tracking(self, im_blob, img_0):
        """
        :param im_blob:
        :param img_0:
        :return:
        """
        # update frame id
        self.frame_id += 1

        # record tracking results, key: class_id
        activated_starcks_dict = defaultdict(list)
        refind_stracks_dict = defaultdict(list)
        lost_stracks_dict = defaultdict(list)
        removed_stracks_dict = defaultdict(list)
        output_stracks_dict = defaultdict(list)

        height, width = img_0.shape[0], img_0.shape[
            1]  # H, W of original input image
        net_height, net_width = im_blob.shape[2], im_blob.shape[
            3]  # H, W of net input

        c = np.array([width * 0.5, height * 0.5], dtype=np.float32)
        s = max(float(net_width) / float(net_height) * height, width) * 1.0
        h_out = net_height // self.opt.down_ratio
        w_out = net_width // self.opt.down_ratio
        ''' Step 1: Network forward, get detections & embeddings'''
        with torch.no_grad():
            output = self.model.forward(im_blob)[-1]

            hm = output['hm'].sigmoid_()
            wh = output['wh']
            reg = output['reg'] if self.opt.reg_offset else None
            id_feature = output['id']
            id_feature = F.normalize(
                id_feature, dim=1)  # L2 normalize the reid feature vector

            #  detection decoding
            dets, inds, cls_inds_mask = mot_decode(
                heatmap=hm,
                wh=wh,
                reg=reg,
                num_classes=self.opt.num_classes,
                cat_spec_wh=self.opt.cat_spec_wh,
                K=self.opt.K)

            # ----- get ReID feature vector by object class
            cls_id_feats = []  # topK feature vectors of each object class
            for cls_id in range(self.opt.num_classes):  # cls_id starts from 0
                # get inds of each object class
                cls_inds = inds[:, cls_inds_mask[cls_id]]

                # gather feats for each object class
                cls_id_feature = _tranpose_and_gather_feat(
                    id_feature, cls_inds)  # inds: 1×128
                cls_id_feature = cls_id_feature.squeeze(0)  # n × FeatDim
                cls_id_feature = cls_id_feature.cpu().numpy()
                cls_id_feats.append(cls_id_feature)

        # 检测结果后处理
        # meta = {'c': c,
        #         's': s,
        #         'out_height': h_out,
        #         'out_width': w_out}
        # dets = self.post_process(dets, meta)  # using affine matrix
        # dets = self.merge_outputs([dets])

        # translate and scale
        dets = map2orig(dets, h_out, w_out, height, width,
                        self.opt.num_classes)

        # ----- parse each object class
        for cls_id in range(self.opt.num_classes):  # cls_id从0开始
            cls_dets = dets[cls_id]
            '''
            # visualize each class
            for i in range(0, cls_dets.shape[0]):
                bbox = cls_dets[i][0:4]
                cv2.rectangle(img0,
                              (bbox[0], bbox[1]),  # left-top point
                              (bbox[2], bbox[3]),  # right-down point
                              [0, 255, 255],  # yellow
                              2)
                cv2.putText(img0,
                            id2cls[cls_id],
                            (bbox[0], bbox[1]),
                            cv2.FONT_HERSHEY_PLAIN,
                            1.3,
                            [0, 0, 255],  # red
                            2)
            cv2.imshow('{}'.format(id2cls[cls_id]), img0)
            cv2.waitKey(0)
            '''

            # filter out low confidence detections
            remain_inds = cls_dets[:, 4] > self.opt.conf_thres
            cls_dets = cls_dets[remain_inds]
            cls_id_feature = cls_id_feats[cls_id][remain_inds]

            if len(cls_dets) > 0:
                '''Detections, tlbrs: top left bottom right score'''
                cls_detections = [
                    STrack(STrack.tlbr_to_tlwh(tlbrs[:4]),
                           tlbrs[4],
                           feat,
                           buff_size=30)
                    for (tlbrs, feat) in zip(cls_dets[:, :5], cls_id_feature)
                ]
            else:
                cls_detections = []

            # reset the track ids for each different object class
            if self.frame_id == 0:
                for track in cls_detections:
                    track.reset_track_id()
            ''' Add newly detected tracklets to tracked_stracks'''
            unconfirmed_dict = defaultdict(list)
            tracked_stracks_dict = defaultdict(list)
            for track in self.tracked_stracks_dict[cls_id]:
                if not track.is_activated:
                    unconfirmed_dict[cls_id].append(track)
                else:
                    tracked_stracks_dict[cls_id].append(track)
            ''' Step 2: First association, with embedding'''
            strack_pool_dict = defaultdict(list)
            strack_pool_dict[cls_id] = joint_stracks(
                tracked_stracks_dict[cls_id], self.lost_stracks_dict[cls_id])

            # Predict the current location with KF
            # for strack in strack_pool:
            STrack.multi_predict(strack_pool_dict[cls_id])
            dists = matching.embedding_distance(strack_pool_dict[cls_id],
                                                cls_detections)
            dists = matching.fuse_motion(self.kalman_filter, dists,
                                         strack_pool_dict[cls_id],
                                         cls_detections)
            matches, u_track, u_detection = matching.linear_assignment(
                dists, thresh=0.7)  # thresh=0.7

            for i_tracked, i_det in matches:
                track = strack_pool_dict[cls_id][i_tracked]
                det = cls_detections[i_det]
                if track.state == TrackState.Tracked:
                    track.update(cls_detections[i_det], self.frame_id)
                    activated_starcks_dict[cls_id].append(
                        track)  # for multi-class
                else:
                    track.re_activate(det, self.frame_id, new_id=False)
                    refind_stracks_dict[cls_id].append(track)
            ''' Step 3: Second association, with IOU'''
            cls_detections = [cls_detections[i] for i in u_detection]
            r_tracked_stracks = [
                strack_pool_dict[cls_id][i] for i in u_track
                if strack_pool_dict[cls_id][i].state == TrackState.Tracked
            ]
            dists = matching.iou_distance(r_tracked_stracks, cls_detections)
            matches, u_track, u_detection = matching.linear_assignment(
                dists, thresh=0.5)  # thresh=0.5

            for i_tracked, i_det in matches:
                track = r_tracked_stracks[i_tracked]
                det = cls_detections[i_det]
                if track.state == TrackState.Tracked:
                    track.update(det, self.frame_id)
                    activated_starcks_dict[cls_id].append(track)
                else:
                    track.re_activate(det, self.frame_id, new_id=False)
                    refind_stracks_dict[cls_id].append(track)

            for it in u_track:
                track = r_tracked_stracks[it]
                if not track.state == TrackState.Lost:
                    track.mark_lost()
                    lost_stracks_dict[cls_id].append(track)
            '''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
            cls_detections = [cls_detections[i] for i in u_detection]
            dists = matching.iou_distance(unconfirmed_dict[cls_id],
                                          cls_detections)
            matches, u_unconfirmed, u_detection = matching.linear_assignment(
                dists, thresh=0.7)
            for i_tracked, i_det in matches:
                unconfirmed_dict[cls_id][i_tracked].update(
                    cls_detections[i_det], self.frame_id)
                activated_starcks_dict[cls_id].append(
                    unconfirmed_dict[cls_id][i_tracked])
            for it in u_unconfirmed:
                track = unconfirmed_dict[cls_id][it]
                track.mark_removed()
                removed_stracks_dict[cls_id].append(track)
            """ Step 4: Init new stracks"""
            for i_new in u_detection:
                track = cls_detections[i_new]

                if track.score < self.det_thresh:
                    continue

                track.activate(self.kalman_filter, self.frame_id)
                activated_starcks_dict[cls_id].append(track)
            """ Step 5: Update state"""
            for track in self.lost_stracks_dict[cls_id]:
                if self.frame_id - track.end_frame > self.max_time_lost:
                    track.mark_removed()
                    removed_stracks_dict[cls_id].append(track)

            # print('Ramained match {} s'.format(t4-t3))
            self.tracked_stracks_dict[cls_id] = [
                t for t in self.tracked_stracks_dict[cls_id]
                if t.state == TrackState.Tracked
            ]
            self.tracked_stracks_dict[cls_id] = joint_stracks(
                self.tracked_stracks_dict[cls_id],
                activated_starcks_dict[cls_id])
            self.tracked_stracks_dict[cls_id] = joint_stracks(
                self.tracked_stracks_dict[cls_id], refind_stracks_dict[cls_id])
            self.lost_stracks_dict[cls_id] = sub_stracks(
                self.lost_stracks_dict[cls_id],
                self.tracked_stracks_dict[cls_id])
            self.lost_stracks_dict[cls_id].extend(lost_stracks_dict[cls_id])
            self.lost_stracks_dict[cls_id] = sub_stracks(
                self.lost_stracks_dict[cls_id],
                self.removed_stracks_dict[cls_id])
            self.removed_stracks_dict[cls_id].extend(
                removed_stracks_dict[cls_id])
            self.tracked_stracks_dict[cls_id], self.lost_stracks_dict[
                cls_id] = remove_duplicate_stracks(
                    self.tracked_stracks_dict[cls_id],
                    self.lost_stracks_dict[cls_id])

            # get scores of lost tracks
            output_stracks_dict[cls_id] = [
                track for track in self.tracked_stracks_dict[cls_id]
                if track.is_activated
            ]

            logger.debug('===========Frame {}=========='.format(self.frame_id))
            logger.debug('Activated: {}'.format(
                [track.track_id for track in activated_starcks_dict[cls_id]]))
            logger.debug('Refind: {}'.format(
                [track.track_id for track in refind_stracks_dict[cls_id]]))
            logger.debug('Lost: {}'.format(
                [track.track_id for track in lost_stracks_dict[cls_id]]))
            logger.debug('Removed: {}'.format(
                [track.track_id for track in removed_stracks_dict[cls_id]]))

        return output_stracks_dict
예제 #3
0
def test_emb(
        opt,
        batch_size=16,
        img_size=(1088, 608),
        print_interval=40,
):
    data_cfg = opt.data_cfg
    f = open(data_cfg)
    data_cfg_dict = json.load(f)
    f.close()
    nC = 1
    test_paths = data_cfg_dict['test_emb']
    dataset_root = data_cfg_dict['root']
    if opt.gpus[0] >= 0:
        opt.device = torch.device('cuda')
    else:
        opt.device = torch.device('cpu')
    print('Creating model...')
    model, criterion, postprocessors = build_model(opt)
    # model = create_model(opt.arch, opt.heads, opt.head_conv)
    model = load_model(model, opt.load_model)
    # model = torch.nn.DataParallel(model)
    model = model.to(opt.device)
    model.eval()

    # Get dataloader
    # transforms = T.Compose([T.ToTensor()])
    # img_pil = Image.open(path).convert('RGB')
    normalize = T.Compose([
        T.ToTensor(),
        T.Normalize([0.408, 0.447, 0.470], [0.289, 0.274, 0.278])
    ])
    transforms = T.Compose([T.RandomResize([800], max_size=1333), normalize])
    # img_norm = img_norm(img_pil)

    dataset = JointDataset(opt,
                           dataset_root,
                           test_paths,
                           img_size,
                           augment=False,
                           transforms=transforms)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             num_workers=8,
                                             drop_last=False)
    embedding, id_labels = [], []
    print('Extracting pedestrain features...')
    for batch_i, batch in enumerate(dataloader):
        t = time.time()
        # output = model(batch['input'].cuda())[-1]
        output = model(batch['input'].cuda())[-1]
        id_head = _tranpose_and_gather_feat(output['id'], batch['ind'].cuda())
        id_head = id_head[batch['reg_mask'].cuda() > 0].contiguous()
        emb_scale = math.sqrt(2) * math.log(opt.nID - 1)
        id_head = emb_scale * F.normalize(id_head)
        id_target = batch['ids'].cuda()[batch['reg_mask'].cuda() > 0]

        for i in range(0, id_head.shape[0]):
            if len(id_head.shape) == 0:
                continue
            else:
                feat, label = id_head[i], id_target[i].long()
            if label != -1:
                embedding.append(feat)
                id_labels.append(label)

        if batch_i % print_interval == 0:
            print(
                'Extracting {}/{}, # of instances {}, time {:.2f} sec.'.format(
                    batch_i, len(dataloader), len(id_labels),
                    time.time() - t))

    print('Computing pairwise similairity...')
    if len(embedding) < 1:
        return None
    embedding = torch.stack(embedding, dim=0).cuda()
    id_labels = torch.LongTensor(id_labels)
    n = len(id_labels)
    print(n, len(embedding))
    assert len(embedding) == n

    embedding = F.normalize(embedding, dim=1)
    pdist = torch.mm(embedding, embedding.t()).cpu().numpy()
    gt = id_labels.expand(n, n).eq(id_labels.expand(n, n).t()).numpy()

    up_triangle = np.where(np.triu(pdist) - np.eye(n) * pdist != 0)
    pdist = pdist[up_triangle]
    gt = gt[up_triangle]

    far_levels = [1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1]
    far, tar, threshold = metrics.roc_curve(gt, pdist)
    interp = interpolate.interp1d(far, tar)
    tar_at_far = [interp(x) for x in far_levels]
    for f, fa in enumerate(far_levels):
        print('TPR@FAR={:.7f}: {:.4f}'.format(fa, tar_at_far[f]))
    return tar_at_far
예제 #4
0
def test_single(img_path, dev):
    """
    :param img_path:
    :param dev:
    :return:
    """
    if not os.path.isfile(img_path):
        print('[Err]: invalid image path.')
        return

    # Head dimensions of the net
    heads = {'hm': 5, 'reg': 2, 'wh': 2, 'id': 128}

    # Load model and put to device
    net = create_model(arch='resdcn_18', heads=heads, head_conv=256)
    model_path = '/mnt/diskb/even/MCMOT/exp/mot/default/mcmot_last_det_resdcn_18.pth'
    net = load_model(model=net, model_path=model_path)
    net = net.to(dev)
    net.eval()
    print(net)

    # Read image
    img_0 = cv2.imread(img_path)  # BGR
    assert img_0 is not None, 'Failed to load ' + img_path

    # Padded resize
    h_in, w_in = 608, 1088  # (608, 1088) (320, 640)
    img, _, _, _ = letterbox(img=img_0, height=h_in, width=w_in)

    # Preprocess image: BGR -> RGB and H×W×C -> C×H×W
    img = img[:, :, ::-1].transpose(2, 0, 1)
    img = np.ascontiguousarray(img, dtype=np.float32)
    img /= 255.0

    # Convert to tensor and put to device
    blob = torch.from_numpy(img).unsqueeze(0).to(dev)

    with torch.no_grad():
        # Network output
        output = net.forward(blob)[-1]

        # Tracking output
        hm = output['hm'].sigmoid_()
        reg = output['reg']
        wh = output['wh']
        id_feature = output['id']
        id_feature = F.normalize(id_feature,
                                 dim=1)  # L2 normalization for feature vector

        # Decode output
        dets, inds, cls_inds_mask = mot_decode(hm, wh, reg, 5, False, 128)

        # Get ReID feature vector by object class
        cls_id_feats = []  # topK feature vectors of each object class
        for cls_id in range(5):  # cls_id starts from 0
            # get inds of each object class
            cls_inds = inds[:, cls_inds_mask[cls_id]]

            # gather feats for each object class
            cls_id_feature = _tranpose_and_gather_feat(id_feature,
                                                       cls_inds)  # inds: 1×128
            cls_id_feature = cls_id_feature.squeeze(0)  # n × FeatDim
            if dev == 'cpu':
                cls_id_feature = cls_id_feature.numpy()
            else:
                cls_id_feature = cls_id_feature.cpu().numpy()
            cls_id_feats.append(cls_id_feature)

        # Convert back to original image coordinate system
        height_0, width_0 = img_0.shape[0], img_0.shape[
            1]  # H, W of original input image
        dets = map2orig(dets, h_in // 4, w_in // 4, height_0, width_0,
                        5)  # translate and scale

        # Parse detections of each class
        dets_dict = defaultdict(list)
        for cls_id in range(5):  # cls_id start from index 0
            cls_dets = dets[cls_id]

            # filter out low conf score dets
            remain_inds = cls_dets[:, 4] > 0.4
            cls_dets = cls_dets[remain_inds]
            # cls_id_feature = cls_id_feats[cls_id][remain_inds]  # if need re-id
            dets_dict[cls_id] = cls_dets

    # Visualize detection results
    img_draw = plot_detects(img_0, dets_dict, 5, frame_id=0, fps=30.0)
    # cv2.imshow('Detection', img_draw)
    # cv2.waitKey()
    cv2.imwrite('/mnt/diskb/even/MCMOT/results/00000.jpg', img_draw)
예제 #5
0
    def update(self, im_blob, img0):
        self.frame_id += 1
        activated_starcks = []
        refind_stracks = []
        lost_stracks = []
        removed_stracks = []

        width = img0.shape[1]
        height = img0.shape[0]
        inp_height = im_blob.shape[2]
        inp_width = im_blob.shape[3]
        c = np.array([width / 2., height / 2.], dtype=np.float32)
        s = max(float(inp_width) / float(inp_height) * height, width) * 1.0
        meta = {
            'c': c,
            's': s,
            'out_height': inp_height // self.opt.down_ratio,
            'out_width': inp_width // self.opt.down_ratio
        }
        ''' Step 1: Network forward, get detections & embeddings'''

        with torch.no_grad():
            output = self.model(im_blob)[-1]

            hm = output['hm'].sigmoid_()
            wh = output['wh']

            id_feature = output['id']
            id_feature = F.normalize(id_feature, dim=1)

            reg = output['reg'] if self.opt.reg_offset else None

            dets, inds = mot_decode(hm,
                                    wh,
                                    reg=reg,
                                    cat_spec_wh=self.opt.cat_spec_wh,
                                    K=self.opt.K)
            id_feature = _tranpose_and_gather_feat(id_feature, inds)
            id_feature = id_feature.squeeze(0)
            id_feature = id_feature.cpu().numpy()

        dets = self.post_process(dets, meta)
        dets = self.merge_outputs([dets])[1]

        remain_inds = dets[:, 4] > self.opt.conf_thres
        dets = dets[remain_inds]
        id_feature = id_feature[remain_inds]

        # vis
        '''
        for i in range(0, dets.shape[0]):
            bbox = dets[i][0:4]
            cv2.rectangle(img0, (bbox[0], bbox[1]),
                          (bbox[2], bbox[3]),
                          (0, 255, 0), 2)
        cv2.imshow('dets', img0)
        cv2.waitKey(0)
        id0 = id0-1
        '''

        if len(dets) > 0:
            '''Detections'''
            detections = [
                STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f, 30)
                for (tlbrs, f) in zip(dets[:, :5], id_feature)
            ]
        else:
            detections = []
        ''' Add newly detected tracklets to tracked_stracks'''
        unconfirmed = []
        tracked_stracks = []  # type: list[STrack]
        for track in self.tracked_stracks:
            if not track.is_activated:
                unconfirmed.append(track)
            else:
                tracked_stracks.append(track)
        ''' Step 2: First association, with embedding'''
        strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
        # Predict the current location with KF
        # for strack in strack_pool:
        # strack.predict()
        STrack.multi_predict(strack_pool)
        dists = matching.embedding_distance(strack_pool, detections)
        # dists = matching.gate_cost_matrix(self.kalman_filter, dists, strack_pool, detections)
        dists = matching.fuse_motion(self.kalman_filter, dists, strack_pool,
                                     detections)
        matches, u_track, u_detection = matching.linear_assignment(dists,
                                                                   thresh=0.7)

        for itracked, idet in matches:
            track = strack_pool[itracked]
            det = detections[idet]
            if track.state == TrackState.Tracked:
                track.update(detections[idet], self.frame_id)
                activated_starcks.append(track)
            else:
                track.re_activate(det, self.frame_id, new_id=False)
                refind_stracks.append(track)
        ''' Step 3: Second association, with IOU'''
        detections = [detections[i] for i in u_detection]
        r_tracked_stracks = [
            strack_pool[i] for i in u_track
            if strack_pool[i].state == TrackState.Tracked
        ]
        dists = matching.iou_distance(r_tracked_stracks, detections)
        matches, u_track, u_detection = matching.linear_assignment(dists,
                                                                   thresh=0.5)

        for itracked, idet in matches:
            track = r_tracked_stracks[itracked]
            det = detections[idet]
            if track.state == TrackState.Tracked:
                track.update(det, self.frame_id)
                activated_starcks.append(track)
            else:
                track.re_activate(det, self.frame_id, new_id=False)
                refind_stracks.append(track)

        for it in u_track:
            track = r_tracked_stracks[it]
            if not track.state == TrackState.Lost:
                track.mark_lost()
                lost_stracks.append(track)
        '''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
        detections = [detections[i] for i in u_detection]
        dists = matching.iou_distance(unconfirmed, detections)
        matches, u_unconfirmed, u_detection = matching.linear_assignment(
            dists, thresh=0.7)
        for itracked, idet in matches:
            unconfirmed[itracked].update(detections[idet], self.frame_id)
            activated_starcks.append(unconfirmed[itracked])
        for it in u_unconfirmed:
            track = unconfirmed[it]
            track.mark_removed()
            removed_stracks.append(track)
        """ Step 4: Init new stracks"""
        for inew in u_detection:
            track = detections[inew]
            if track.score < self.det_thresh:
                continue
            track.activate(self.kalman_filter, self.frame_id)
            activated_starcks.append(track)
        """ Step 5: Update state"""
        for track in self.lost_stracks:
            if self.frame_id - track.end_frame > self.max_time_lost:
                track.mark_removed()
                removed_stracks.append(track)

        # print('Ramained match {} s'.format(t4-t3))

        self.tracked_stracks = [
            t for t in self.tracked_stracks if t.state == TrackState.Tracked
        ]
        self.tracked_stracks = joint_stracks(self.tracked_stracks,
                                             activated_starcks)
        self.tracked_stracks = joint_stracks(self.tracked_stracks,
                                             refind_stracks)
        self.lost_stracks = sub_stracks(self.lost_stracks,
                                        self.tracked_stracks)
        self.lost_stracks.extend(lost_stracks)
        self.lost_stracks = sub_stracks(self.lost_stracks,
                                        self.removed_stracks)
        self.removed_stracks.extend(removed_stracks)
        self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(
            self.tracked_stracks, self.lost_stracks)
        # get scores of lost tracks
        output_stracks = [
            track for track in self.tracked_stracks if track.is_activated
        ]

        logger.debug('===========Frame {}=========='.format(self.frame_id))
        logger.debug('Activated: {}'.format(
            [track.track_id for track in activated_starcks]))
        logger.debug('Refind: {}'.format(
            [track.track_id for track in refind_stracks]))
        logger.debug('Lost: {}'.format(
            [track.track_id for track in lost_stracks]))
        logger.debug('Removed: {}'.format(
            [track.track_id for track in removed_stracks]))

        return output_stracks
예제 #6
0
    def update(self, im_blob, img0):
        self.frame_id += 1

        # 记录跟踪结果
        # 记录跟踪结果: 默认只有一类, 修改为多类别, 用defaultdict(list)代替list
        # 以class id为key
        activated_starcks_dict = defaultdict(list)
        refind_stracks_dict = defaultdict(list)
        lost_stracks_dict = defaultdict(list)
        removed_stracks_dict = defaultdict(list)
        output_stracks_dict = defaultdict(list)

        width = img0.shape[1]
        height = img0.shape[0]
        inp_height = im_blob.shape[2]
        inp_width = im_blob.shape[3]

        c = np.array([width / 2., height / 2.], dtype=np.float32)
        s = max(float(inp_width) / float(inp_height) * height, width) * 1.0
        meta = {
            'c': c,
            's': s,
            'out_height': inp_height // self.opt.down_ratio,
            'out_width': inp_width // self.opt.down_ratio
        }
        ''' Step 1: Network forward, get detections & embeddings'''
        with torch.no_grad():  # 前向推断过程不需要梯度反传
            output = self.model.forward(im_blob)[-1]

            hm = output['hm'].sigmoid_()
            # print("hm shape ", hm.shape, "hm:\n", hm)

            wh = output['wh']
            # print("wh shape ", wh.shape, "wh:\n", wh)

            id_feature = output['id']
            id_feature = F.normalize(id_feature, dim=1)

            reg = output['reg'] if self.opt.reg_offset else None
            # print("reg shape ", reg.shape, "reg:\n", reg)

            #  检测和分类结果解析
            dets, inds, cls_inds_mask = mot_decode(
                heatmap=hm,
                wh=wh,
                reg=reg,
                num_classes=self.opt.num_classes,
                cat_spec_wh=self.opt.cat_spec_wh,
                K=self.opt.K)

            # ----- 按照每一个检测类别解析输出并保存中间结果
            cls_id_feats = []  # topK每个类别的特征向量
            for cls_id in range(self.opt.num_classes):  # cls_id从0开始
                # 取每个检测类别
                cls_inds = inds[:, cls_inds_mask[cls_id]]

                # 组织用于Re-ID的特征向量
                cls_id_feature = _tranpose_and_gather_feat(
                    id_feature, cls_inds)  # inds: 1×128
                cls_id_feature = cls_id_feature.squeeze(0)  # n × FeatDim
                cls_id_feature = cls_id_feature.cpu().numpy()  # 最后传输到cpu端
                cls_id_feats.append(cls_id_feature)

        # 检测结果后处理
        dets = self.post_process(dets, meta)
        dets = self.merge_outputs([dets])
        # dets = self.merge_outputs(dets)[1]

        # ----- 解析每个检测类别
        for cls_id in range(self.opt.num_classes):  # cls_id从0开始
            cls_dets = dets[cls_id + 1]
            '''
            # 可视化中间的检测结果(每一类)
            for i in range(0, cls_dets.shape[0]):
                bbox = cls_dets[i][0:4]
                cv2.rectangle(img0,
                              (bbox[0], bbox[1]),  # left-top point
                              (bbox[2], bbox[3]),  # right-down point
                              [0, 255, 255],  # yellow
                              2)
                cv2.putText(img0,
                            id2cls[cls_id],
                            (bbox[0], bbox[1]),
                            cv2.FONT_HERSHEY_PLAIN,
                            1.3,
                            [0, 0, 255],  # red
                            2)
            cv2.imshow('{}'.format(id2cls[cls_id]), img0)
            cv2.waitKey(0)
            '''

            # 过滤掉score得分太低的dets
            remain_inds = cls_dets[:, 4] > self.opt.conf_thres
            cls_dets = cls_dets[remain_inds]
            cls_id_feature = cls_id_feats[cls_id][remain_inds]

            if len(cls_dets) > 0:
                '''Detections, tlbrs: top left bottom right score'''
                cls_detections = [
                    STrack(STrack.tlbr_to_tlwh(tlbrs[:4]),
                           tlbrs[4],
                           feat,
                           buff_size=30)
                    for (tlbrs, feat) in zip(cls_dets[:, :5], cls_id_feature)
                ]
            else:
                cls_detections = []

            # reset the track ids for a different object class
            for track in cls_detections:
                track.reset_track_id()
            ''' Add newly detected tracklets to tracked_stracks'''
            unconfirmed_dict = defaultdict(list)
            tracked_stracks_dict = defaultdict(
                list)  # type: key(cls_id), value: list[STrack]
            for track in self.tracked_stracks_dict[cls_id]:
                if not track.is_activated:
                    unconfirmed_dict[cls_id].append(track)
                else:
                    tracked_stracks_dict[cls_id].append(track)
            ''' Step 2: First association, with embedding'''
            strack_pool_dict = defaultdict(list)
            strack_pool_dict[cls_id] = joint_stracks(
                tracked_stracks_dict[cls_id], self.lost_stracks_dict[cls_id])

            # Predict the current location with KF
            # for strack in strack_pool:
            STrack.multi_predict(strack_pool_dict[cls_id])
            dists = matching.embedding_distance(strack_pool_dict[cls_id],
                                                cls_detections)
            dists = matching.fuse_motion(self.kalman_filter, dists,
                                         strack_pool_dict[cls_id],
                                         cls_detections)
            matches, u_track, u_detection = matching.linear_assignment(
                dists, thresh=0.7)

            for i_tracked, i_det in matches:
                track = strack_pool_dict[cls_id][i_tracked]
                det = cls_detections[i_det]
                if track.state == TrackState.Tracked:
                    track.update(cls_detections[i_det], self.frame_id)
                    activated_starcks_dict[cls_id].append(
                        track)  # for multi-class
                else:
                    track.re_activate(det, self.frame_id, new_id=False)
                    refind_stracks_dict[cls_id].append(track)
            ''' Step 3: Second association, with IOU'''
            cls_detections = [cls_detections[i] for i in u_detection]
            r_tracked_stracks = [
                strack_pool_dict[cls_id][i] for i in u_track
                if strack_pool_dict[cls_id][i].state == TrackState.Tracked
            ]
            dists = matching.iou_distance(r_tracked_stracks, cls_detections)
            matches, u_track, u_detection = matching.linear_assignment(
                dists, thresh=0.5)

            for i_tracked, i_det in matches:
                track = r_tracked_stracks[i_tracked]
                det = cls_detections[i_det]
                if track.state == TrackState.Tracked:
                    track.update(det, self.frame_id)
                    activated_starcks_dict[cls_id].append(track)
                else:
                    track.re_activate(det, self.frame_id, new_id=False)
                    refind_stracks_dict[cls_id].append(track)

            for it in u_track:
                track = r_tracked_stracks[it]
                if not track.state == TrackState.Lost:
                    track.mark_lost()
                    lost_stracks_dict[cls_id].append(track)
            '''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
            cls_detections = [cls_detections[i] for i in u_detection]
            dists = matching.iou_distance(unconfirmed_dict[cls_id],
                                          cls_detections)
            matches, u_unconfirmed, u_detection = matching.linear_assignment(
                dists, thresh=0.7)
            for i_tracked, i_det in matches:
                unconfirmed_dict[cls_id][i_tracked].update(
                    cls_detections[i_det], self.frame_id)
                activated_starcks_dict[cls_id].append(
                    unconfirmed_dict[cls_id][i_tracked])
            for it in u_unconfirmed:
                track = unconfirmed_dict[cls_id][it]
                track.mark_removed()
                removed_stracks_dict[cls_id].append(track)
            """ Step 4: Init new stracks"""
            for i_new in u_detection:
                track = cls_detections[i_new]

                if track.score < self.det_thresh:
                    continue

                track.activate(self.kalman_filter, self.frame_id)
                activated_starcks_dict[cls_id].append(track)
            """ Step 5: Update state"""
            for track in self.lost_stracks_dict[cls_id]:
                if self.frame_id - track.end_frame > self.max_time_lost:
                    track.mark_removed()
                    removed_stracks_dict[cls_id].append(track)

            # print('Ramained match {} s'.format(t4-t3))
            self.tracked_stracks_dict[cls_id] = [
                t for t in self.tracked_stracks_dict[cls_id]
                if t.state == TrackState.Tracked
            ]
            self.tracked_stracks_dict[cls_id] = joint_stracks(
                self.tracked_stracks_dict[cls_id],
                activated_starcks_dict[cls_id])
            self.tracked_stracks_dict[cls_id] = joint_stracks(
                self.tracked_stracks_dict[cls_id], refind_stracks_dict[cls_id])
            self.lost_stracks_dict[cls_id] = sub_stracks(
                self.lost_stracks_dict[cls_id],
                self.tracked_stracks_dict[cls_id])
            self.lost_stracks_dict[cls_id].extend(lost_stracks_dict[cls_id])
            self.lost_stracks_dict[cls_id] = sub_stracks(
                self.lost_stracks_dict[cls_id],
                self.removed_stracks_dict[cls_id])
            self.removed_stracks_dict[cls_id].extend(
                removed_stracks_dict[cls_id])
            self.tracked_stracks_dict[cls_id], self.lost_stracks_dict[
                cls_id] = remove_duplicate_stracks(
                    self.tracked_stracks_dict[cls_id],
                    self.lost_stracks_dict[cls_id])

            # get scores of lost tracks
            output_stracks_dict[cls_id] = [
                track for track in self.tracked_stracks_dict[cls_id]
                if track.is_activated
            ]

            logger.debug('===========Frame {}=========='.format(self.frame_id))
            logger.debug('Activated: {}'.format(
                [track.track_id for track in activated_starcks_dict[cls_id]]))
            logger.debug('Refind: {}'.format(
                [track.track_id for track in refind_stracks_dict[cls_id]]))
            logger.debug('Lost: {}'.format(
                [track.track_id for track in lost_stracks_dict[cls_id]]))
            logger.debug('Removed: {}'.format(
                [track.track_id for track in removed_stracks_dict[cls_id]]))

        return output_stracks_dict