Пример #1
0
def test_net(tdcnn_demo, dataloader, args):

    start = time.time()
    # TODO: Add restriction for max_per_video
    max_per_video = 0

    if args.vis:
        thresh = 0.05
    else:
        thresh = 0.005
    
    all_twins = [[[] for _ in xrange(args.num_videos)]
               for _ in xrange(args.num_classes)]

    _t = {'im_detect': time.time(), 'misc': time.time()}

    tdcnn_demo.eval()
    empty_array = np.transpose(np.array([[],[],[]]), (1,0))
  
    data_tic = time.time()
    for i, (video_data, gt_twins, num_gt, video_info) in enumerate(dataloader):
        video_data = video_data.cuda()
        gt_twins = gt_twins.cuda()
        batch_size = video_data.shape[0]
        data_toc = time.time()
        data_time = data_toc - data_tic

        det_tic = time.time()
        rois, cls_prob, twin_pred = tdcnn_demo(video_data, gt_twins)
#        rpn_loss_cls, rpn_loss_twin, \
#        RCNN_loss_cls, RCNN_loss_twin, rois_label = tdcnn_demo(video_data, gt_twins)

        scores_all = cls_prob.data
        twins = rois.data[:, :, 1:3]

        if cfg.TEST.TWIN_REG:
            # Apply bounding-twin regression deltas
            twin_deltas = twin_pred.data
            if cfg.TRAIN.TWIN_NORMALIZE_TARGETS_PRECOMPUTED:
                # Optionally normalize targets by a precomputed mean and stdev
                twin_deltas = twin_deltas.view(-1, 2) * torch.FloatTensor(cfg.TRAIN.TWIN_NORMALIZE_STDS).type_as(twin_deltas) \
                       + torch.FloatTensor(cfg.TRAIN.TWIN_NORMALIZE_MEANS).type_as(twin_deltas)
                twin_deltas = twin_deltas.view(batch_size, -1, 2 * args.num_classes)

            pred_twins_all = twin_transform_inv(twins, twin_deltas)
            pred_twins_all = clip_twins(pred_twins_all, cfg.TRAIN.LENGTH[0])
        else:
            # Simply repeat the twins, once for each class
            pred_twins_all = np.tile(twins, (1, scores_all.shape[1]))
            
        det_toc = time.time()
        detect_time = det_toc - det_tic
        
        for b in range(batch_size):
            misc_tic = time.time()        
            print(video_info[b])        
            scores = scores_all[b] #scores.squeeze()
            pred_twins = pred_twins_all[b] #.squeeze()

            # skip j = 0, because it's the background class          
            for j in xrange(1, args.num_classes):
                inds = torch.nonzero(scores[:,j]>thresh).view(-1)
                # if there is det
                if inds.numel() > 0:
                    cls_scores = scores[:,j][inds]
                    _, order = torch.sort(cls_scores, 0, True)
                    cls_twins = pred_twins[inds][:, j * 2:(j + 1) * 2]
                    
                    cls_dets = torch.cat((cls_twins, cls_scores.unsqueeze(1)), 1)
                    # cls_dets = torch.cat((cls_twins, cls_scores), 1)
                    cls_dets = cls_dets[order]
                    keep = nms(cls_dets, cfg.TEST.NMS)
                    if ( len(keep)>0 ):
                          cls_dets = cls_dets[keep.view(-1).long()]
                          print ("activity: ", j)
                          print (cls_dets.cpu().numpy())
                      
                    all_twins[j][i*batch_size+b] = cls_dets.cpu().numpy()
                else:
                    all_twins[j][i*batch_size+b] = empty_array

            # Limit to max_per_video detections *over all classes*
            if max_per_video > 0:
                  video_scores = np.hstack([all_twins[j][i*batch_size+b][:, -1]
                                            for j in xrange(1, args.num_classes)])
                  if len(video_scores) > max_per_video:
                      video_thresh = np.sort(video_scores)[-max_per_video]
                      for j in xrange(1, args.num_classes):
                          keep = np.where(all_twins[j][i*batch_size+b][:, -1] >= video_thresh)[0]
                          all_twins[j][i*batch_size+b] = all_twins[j][i*batch_size+b][keep, :]
                          
            misc_toc = time.time()
            nms_time = misc_toc - misc_tic                          
            print ('im_detect: {:d}/{:d} {:.3f}s {:.3f}s {:.3f}s' \
              .format(i*batch_size+b+1, args.num_videos, data_time/batch_size, detect_time/batch_size, nms_time))              

        if args.vis:
          pass
          
        data_tic = time.time()
    end = time.time()
    print("test time: %0.4fs" % (end - start))
Пример #2
0
        twins = rois.data[:, :, 1:3]
        if cfg.TEST.TWIN_REG:
            # Apply bounding-twin regression deltas
            twin_deltas = twin_pred.data
            if cfg.TRAIN.TWIN_NORMALIZE_TARGETS_PRECOMPUTED:
                # Optionally normalize targets by a precomputed mean and stdev
                if cfg.AGNOSTIC:
                    twin_deltas = twin_deltas.view(-1, 2) * torch.FloatTensor(cfg.TRAIN.TWIN_NORMALIZE_STDS).cuda() \
                               + torch.FloatTensor(cfg.TRAIN.TWIN_NORMALIZE_MEANS).cuda()
                    twin_deltas = twin_deltas.view(1, -1, 2)
                else:
                    twin_deltas = twin_deltas.view(-1, 2) * torch.FloatTensor(cfg.TRAIN.TWIN_NORMALIZE_STDS).cuda() \
                               + torch.FloatTensor(cfg.TRAIN.TWIN_NORMALIZE_MEANS).cuda()
                    twin_deltas = twin_deltas.view(1, -1, 2 * args.num_classes)

            pred_twins = twin_transform_inv(twins, twin_deltas, 1)
            pred_twins = clip_twins(pred_twins, cfg.TRAIN.LENGTH[0], 1)
        else:
            # Simply repeat the twins, once for each class
            pred_twins = np.tile(twins, (1, scores.shape[1]))

        # pred_twins /= data[1][0][2]

        scores = scores.squeeze()
        pred_twins = pred_twins.squeeze()
        det_toc = time.time()
        detect_time = det_toc - det_tic
        misc_tic = time.time()
        if vis:
            #im = cv2.imread(imdb.video_path_at(i))
            #im2show = np.copy(im)
Пример #3
0
def test_net(tdcnn_demo, dataloader, args):

    start = time.time()
    # TODO: Add restriction for max_per_video
    max_per_video = 0

    if args.vis:
        thresh = 0.05
    else:
        thresh = 0.005

    # 构建一个维度为21*10563的数组
    all_twins = [[[] for _ in xrange(args.num_videos)]
                 for _ in xrange(args.num_classes)]

    _t = {'im_detect': time.time(), 'misc': time.time()}

    tdcnn_demo.eval()
    empty_array = np.transpose(np.array([[], [], []]), (1, 0))

    data_tic = time.time()  # 开始检测
    for i, (video_data, gt_twins, num_gt, video_info) in enumerate(dataloader):
        video_data = video_data.cuda()
        gt_twins = gt_twins.cuda()
        batch_size = video_data.shape[0]
        data_toc = time.time()
        data_time = data_toc - data_tic  # 数据加载时间

        det_tic = time.time()  # 模型预测开始
        # rois : [1,900,3]
        # cls_prob : [1,900,21]
        # twin_pred : [1,900,42]
        rois, cls_prob, twin_pred = tdcnn_demo(video_data, gt_twins)
        #        rpn_loss_cls, rpn_loss_twin, \
        #        RCNN_loss_cls, RCNN_loss_twin, rois_label = tdcnn_demo(video_data, gt_twins)

        scores_all = cls_prob.data
        twins = rois.data[:, :, 1:3]

        if cfg.TEST.TWIN_REG:
            # Apply bounding-twin regression deltas
            twin_deltas = twin_pred.data
            if cfg.TRAIN.TWIN_NORMALIZE_TARGETS_PRECOMPUTED:
                # Optionally normalize targets by a precomputed mean and stdev
                twin_deltas = twin_deltas.view(-1, 2) * torch.FloatTensor(cfg.TRAIN.TWIN_NORMALIZE_STDS).type_as(twin_deltas) \
                       + torch.FloatTensor(cfg.TRAIN.TWIN_NORMALIZE_MEANS).type_as(twin_deltas)
                twin_deltas = twin_deltas.view(batch_size, -1,
                                               2 * args.num_classes)

            pred_twins_all = twin_transform_inv(twins, twin_deltas,
                                                batch_size)  # [1,900,42]
            pred_twins_all = clip_twins(pred_twins_all, cfg.TRAIN.LENGTH[0],
                                        batch_size)
        else:
            # Simply repeat the twins, once for each class
            pred_twins_all = np.tile(twins, (1, scores_all.shape[1]))

        det_toc = time.time()  # 模型输出结果时间
        detect_time = det_toc - det_tic  # 模型预测时间

        # 开始对一个batch数据操作获取预测结果
        for b in range(batch_size):
            misc_tic = time.time()  # 获取结果开始
            print(video_info[b])
            # 相当于去掉最前面的1,只保留需要的信息
            scores = scores_all[b]  #scores.squeeze()
            pred_twins = pred_twins_all[b]  #.squeeze()

            # 开始对每一类遍历,获取满足条件的结果
            # skip j = 0, because it's the background class
            for j in xrange(1, args.num_classes):
                inds = torch.nonzero(scores[:, j] > thresh).view(
                    -1)  # 获取得分满足阈值的索引
                # if there is det
                if inds.numel() > 0:
                    # 获取指定索引,指定类别的得分
                    cls_scores = scores[:, j][inds]  # 根据索引获取本类满足条件的得分
                    _, order = torch.sort(cls_scores, 0, True)  # 对得分排序

                    # 获取指定索引,指定类别的预测值
                    cls_twins = pred_twins[inds][:, j * 2:(j + 1) *
                                                 2]  # pred_twins [900,42]

                    # 将得分和索引拼接在一起
                    cls_dets = torch.cat((cls_twins, cls_scores.unsqueeze(1)),
                                         1)
                    # cls_dets = torch.cat((cls_twins, cls_scores), 1)
                    cls_dets = cls_dets[order]
                    keep = nms(cls_dets, cfg.TEST.NMS)
                    if (len(keep) > 0):
                        cls_dets = cls_dets[keep.view(-1).long()]
                        print("activity: ", j)
                        print(cls_dets.cpu().numpy()
                              )  # 这个打印函数很重要,后续的日志分析文件就指望这个输出获取检测结果

                    all_twins[j][i * batch_size + b] = cls_dets.cpu().numpy(
                    )  # 将得到的结果保存到all_twins中,形式如下[start,end,score]
                else:
                    all_twins[j][i * batch_size + b] = empty_array

            # Limit to max_per_video detections *over all classes*
            if max_per_video > 0:
                video_scores = np.hstack([
                    all_twins[j][i * batch_size + b][:, -1]
                    for j in xrange(1, args.num_classes)
                ])
                if len(video_scores) > max_per_video:
                    video_thresh = np.sort(video_scores)[-max_per_video]
                    for j in xrange(1, args.num_classes):
                        keep = np.where(
                            all_twins[j][i * batch_size +
                                         b][:, -1] >= video_thresh)[0]
                        all_twins[j][i * batch_size +
                                     b] = all_twins[j][i * batch_size +
                                                       b][keep, :]

            misc_toc = time.time()
            nms_time = misc_toc - misc_tic  # 获取结果结束
            # 当对一个batch_size和所以类别遍历完
            # 分别输出一个batch——size数据加载时间,模型预测时间,获取检测结果时间
            print ('im_detect: {:d}/{:d} {:.3f}s {:.3f}s {:.3f}s' \
              .format(i*batch_size+b+1, args.num_videos, data_time/batch_size, detect_time/batch_size, nms_time))

        if args.vis:
            pass

        data_tic = time.time()  # 检测结束
    end = time.time()
    print("test time: %0.4fs" % (end - start))  #输出总的测试时间
Пример #4
0
def test_net(tdcnn_demo, dataloader, args, split, max_per_video=0, thresh=0):
    np.random.seed(cfg.RNG_SEED)
    total_video_num = len(dataloader) * args.batch_size

    all_twins = [[[] for _ in range(total_video_num)]
                 for _ in range(args.num_classes)
                 ]  # class_num,video_num,proposal_num
    tdcnn_demo.eval()
    empty_array = np.transpose(np.array([[], [], []]), (1, 0))

    for data_idx, (support_data, video_data, gt_twins, num_gt,
                   video_info) in tqdm(enumerate(dataloader),
                                       desc="evaluation"):
        if is_debug and data_idx > fast_eval_samples:
            break

        video_data = video_data.cuda()
        for i in range(args.shot):
            support_data[i] = support_data[i].cuda()
        gt_twins = gt_twins.cuda()
        batch_size = video_data.shape[0]
        rois, cls_prob, twin_pred = tdcnn_demo(
            video_data, gt_twins, support_data
        )  ##torch.Size([1, 300, 3]),torch.Size([1, 300, 2]),torch.Size([1, 300, 4])
        scores_all = cls_prob.data
        twins = rois.data[:, :, 1:3]

        if cfg.TEST.TWIN_REG:  # True
            # Apply bounding-twin regression deltas
            twin_deltas = twin_pred.data
            if cfg.TRAIN.TWIN_NORMALIZE_TARGETS_PRECOMPUTED:  # True
                # Optionally normalize targets by a precomputed mean and stdev
                twin_deltas = twin_deltas.view(-1, 2) * torch.FloatTensor(
                    cfg.TRAIN.TWIN_NORMALIZE_STDS
                ).type_as(twin_deltas) + torch.FloatTensor(
                    cfg.TRAIN.TWIN_NORMALIZE_MEANS).type_as(twin_deltas)
                twin_deltas = twin_deltas.view(
                    batch_size, -1,
                    2 * args.num_classes)  # torch.Size([1, 300, 4])

            pred_twins_all = twin_transform_inv(
                twins, twin_deltas, batch_size)  # torch.Size([1, 300, 4])
            pred_twins_all = clip_twins(pred_twins_all, cfg.TRAIN.LENGTH[0],
                                        batch_size)  # torch.Size([1, 300, 4])
        else:
            # Simply repeat the twins, once for each class
            pred_twins_all = np.tile(twins, (1, scores_all.shape[1]))

        for b in range(batch_size):
            if is_debug:
                logger.info(video_info)
            scores = scores_all[b]  # scores.squeeze()
            pred_twins = pred_twins_all[b]  # .squeeze()

            # skip j = 0, because it's the background class
            for j in range(1, args.num_classes):
                inds = torch.nonzero(scores[:, j] > thresh).view(-1)
                # if there is det
                if inds.numel() > 0:
                    cls_scores = scores[:, j][inds]
                    _, order = torch.sort(cls_scores, 0, True)
                    cls_twins = pred_twins[inds][:, j * 2:(j + 1) * 2]

                    cls_dets = torch.cat((cls_twins, cls_scores.unsqueeze(1)),
                                         1)
                    # cls_dets = torch.cat((cls_twins, cls_scores), 1)
                    cls_dets = cls_dets[order]
                    keep = nms_cpu(cls_dets.cpu(), args.test_nms)
                    if (len(keep) > 0):
                        if is_debug:
                            print("after nms, keep {}".format(len(keep)))
                        cls_dets = cls_dets[keep.view(-1).long()]
                    else:
                        print(
                            "warning, after nms, none of the rois is kept!!!")
                    all_twins[j][data_idx * batch_size +
                                 b] = cls_dets.cpu().numpy()
                else:
                    all_twins[j][data_idx * batch_size + b] = empty_array

            # Limit to max_per_video detections *over all classes*, useless code here, default max_per_video = 0
            if max_per_video > 0:
                video_scores = np.hstack([
                    all_twins[j][data_idx * batch_size + b][:, -1]
                    for j in range(1, args.num_classes)
                ])
                if len(video_scores) > max_per_video:
                    video_thresh = np.sort(video_scores)[-max_per_video]
                    for j in range(1, args.num_classes):
                        keep = np.where(
                            all_twins[j][data_idx * batch_size +
                                         b][:, -1] >= video_thresh)[0]
                        all_twins[j][data_idx * batch_size +
                                     b] = all_twins[j][data_idx * batch_size +
                                                       b][keep, :]

            # logger.info('im_detect: {:d}/{:d}'.format(i * batch_size + b + 1, len(dataloader)))

    pred = dict()
    pred['external_data'] = ''
    pred['version'] = ''
    pred['results'] = dict()
    for i_video in tqdm(range(total_video_num),
                        desc="generating prediction json.."):
        if is_debug and i_video > fast_eval_samples * batch_size - 2:
            break
        item_pre = []
        for j_roi in range(
                0, len(all_twins[1][i_video])
        ):  # binary class problem, here we only consider class_num=1, ignoring background class
            _d = dict()
            _d['score'] = all_twins[1][i_video][j_roi][2].item()
            _d['label'] = 'c1'
            _d['segment'] = [
                all_twins[1][i_video][j_roi][0].item(),
                all_twins[1][i_video][j_roi][1].item()
            ]
            item_pre.append(_d)
        pred['results']["query_%05d" % i_video] = item_pre

    predict_filename = os.path.join(logger.get_logger_dir(),
                                    '{}_pred.json'.format(split))
    ground_truth_filename = os.path.join('preprocess/{}'.format(args.dataset),
                                         '{}_gt.json'.format(split))

    with open(predict_filename, 'w') as f:
        json.dump(pred, f)
        logger.info('dump pred.json complete..')

    sys.path.insert(0, "evaluation")
    from eval_detection import ANETdetection

    anet_detection = ANETdetection(ground_truth_filename,
                                   predict_filename,
                                   subset="test",
                                   tiou_thresholds=tiou_thresholds,
                                   verbose=True,
                                   check_status=False)
    anet_detection.evaluate()
    ap = anet_detection.mAP
    mAP = ap[0]
    return mAP, ap
Пример #5
0
def test_net(tdcnn_demo,
             dataloader,
             args,
             trimmed_support_set_roidb,
             thresh=0.7,
             use_softmax_fewshot_score=False):
    FEWSHOT_FEATURES_PATH = '/home/vltava/fewshot_features_5_shot.pkl'

    start = time.time()

    _t = {'im_detect': time.time(), 'misc': time.time()}

    tdcnn_demo.eval()

    if exists(FEWSHOT_FEATURES_PATH):
        all_fewshot_features = pickle.load(open(FEWSHOT_FEATURES_PATH, 'rb'))
    else:
        support_set_dataloader = create_sampled_support_set_dataset(
            trimmed_support_set_roidb,
            [2, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 18, 19],
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            samples_per_class=5)
        print('Loading fewshot trimmed videos')
        feature_vectors = []
        feature_labels = []
        for _, (video_data, gt_twins, _) in enumerate(support_set_dataloader):
            support_set_features, support_set_labels = \
                tdcnn_demo(video_data, None, None, gt_twins, True)
            feature_vectors.append(support_set_features)
            feature_labels.append(support_set_labels)
            del video_data
        all_fewshot_features = (torch.cat(feature_vectors),
                                torch.cat(feature_labels))

        pickle.dump(all_fewshot_features, open(FEWSHOT_FEATURES_PATH, 'wb'))

    support_set_features = all_fewshot_features[0]
    support_set_labels = all_fewshot_features[1]

    # print(f'Got {support_set_features.shape[0]} few shot features')

    unique_support_set_labels = support_set_labels.cpu().unique(
        sorted=True).cuda()

    data_tic = time.time()
    for i, (video_data, gt_twins, num_gt, video_info,
            fewshot_label) in enumerate(dataloader):
        video_data = video_data.cuda()
        gt_twins = gt_twins.cuda()
        batch_size = video_data.shape[0]
        data_toc = time.time()
        data_time = data_toc - data_tic

        batch_support_set_size = 5

        unique_labels_in_test = fewshot_label.cpu().unique().numpy().tolist()
        batch_support_set_labels = unique_labels_in_test
        if len(unique_labels_in_test) < batch_support_set_size:
            other_labels = torch.cat([
                l.unsqueeze(0) for l in unique_support_set_labels
                if l not in unique_labels_in_test
            ]).cpu().numpy().tolist()
            batch_support_set_labels = unique_labels_in_test + random.sample(
                other_labels,
                batch_support_set_size - len(unique_labels_in_test))
        batch_support_set_indices = [
            i for i, v in enumerate(support_set_labels)
            if v in batch_support_set_labels
        ]
        batch_support_set_features = support_set_features[
            batch_support_set_indices]
        batch_support_set_labels = support_set_labels[
            batch_support_set_indices]
        unique_batch_support_set_labels = batch_support_set_labels.cpu(
        ).unique(sorted=True).cuda()

        det_tic = time.time()
        rois, cls_prob, twin_pred, fewshot_scores, fewshot_scores_softmax = \
            tdcnn_demo(video_data,
                       torch.cat(args.batch_size * [batch_support_set_features.unsqueeze(0)]),
                       torch.cat(args.batch_size * [batch_support_set_labels.unsqueeze(0)]),
                       gt_twins)

        scores_all = fewshot_scores.data
        twins = rois.data[:, :, 1:3]

        if cfg.TEST.TWIN_REG:
            # Apply bounding-twin regression deltas
            twin_deltas = twin_pred.data
            if cfg.TRAIN.TWIN_NORMALIZE_TARGETS_PRECOMPUTED:
                # Optionally normalize targets by a precomputed mean and stdev
                twin_deltas = twin_deltas.view(-1, 2) * torch.FloatTensor(cfg.TRAIN.TWIN_NORMALIZE_STDS).type_as(
                    twin_deltas) \
                              + torch.FloatTensor(cfg.TRAIN.TWIN_NORMALIZE_MEANS).type_as(twin_deltas)
                twin_deltas = twin_deltas.view(batch_size, -1,
                                               2 * args.num_classes)

            pred_twins_all = twin_transform_inv(twins, twin_deltas, batch_size)
            pred_twins_all = clip_twins(pred_twins_all, cfg.TRAIN.LENGTH[0],
                                        batch_size)
        else:
            # Simply repeat the twins, once for each class
            pred_twins_all = np.tile(twins, (1, scores_all.shape[1]))

        det_toc = time.time()
        detect_time = det_toc - det_tic

        for b in range(batch_size):
            misc_tic = time.time()
            print(video_info[b])
            # cls_prob scores are not helpful for fewshot since these are 21-class output (0 = background)
            # and the new example doesn't appear in the training data
            pred_twins = pred_twins_all[b]  # .squeeze()

            fewshot_scores_of_batch = torch.zeros(
                fewshot_scores.shape[1], batch_support_set_size).cuda()
            fewshot_scores_softmax_of_batch = torch.zeros(
                fewshot_scores.shape[1], batch_support_set_size).cuda()

            has_detections = False

            for j in range(batch_support_set_size):
                idx_of_label_with_id = (
                    batch_support_set_labels ==
                    unique_batch_support_set_labels[j]).nonzero().squeeze()
                fewshot_scores_of_batch[:, j] = fewshot_scores[
                    b, :, idx_of_label_with_id].mean(
                        dim=1)  # Average scores of the same label
                fewshot_scores_softmax_of_batch[:, j] = fewshot_scores_softmax[
                    b, :, idx_of_label_with_id].sum(dim=1)

                if use_softmax_fewshot_score:
                    inds = torch.nonzero(
                        fewshot_scores_softmax_of_batch[:,
                                                        j] > thresh).view(-1)
                else:
                    inds = torch.nonzero(
                        fewshot_scores_of_batch[:, j] > thresh).view(-1)

                label_id = unique_batch_support_set_labels[j].item()

                # if there is detection
                if inds.numel() > 0:
                    has_detections = True
                    cls_scores = fewshot_scores_of_batch[:, j][inds]
                    _, order = torch.sort(cls_scores, 0, True)
                    # This doesn't quite make sense because label_id column has no meaning to the network
                    cls_twins = pred_twins[inds][:, label_id *
                                                 2:(label_id + 1) * 2]

                    cls_dets = torch.cat((cls_twins, cls_scores.unsqueeze(1)),
                                         1)
                    cls_dets = cls_dets[order]
                    keep = nms(cls_dets, cfg.TEST.NMS)
                    if len(keep) > 0:
                        cls_dets = cls_dets[keep.view(-1).long()]
                        print("activity: ", label_id)
                        print(cls_dets.cpu().numpy())
                else:
                    pass
                    # DEBUGGING ONLY. If this is the correct label but no detection, return scores
                    # if label_id == fewshot_label[b].item():
                    # print(f'**** FAILED TO DETECT CLASS: {fewshot_scores_softmax_of_batch[:, j].mean()}')

            most_likely_labels = unique_batch_support_set_labels[torch.sort(
                fewshot_scores_of_batch, descending=True)[1]][:50]
            most_likely_labels_softmax = unique_batch_support_set_labels[
                torch.sort(fewshot_scores_softmax_of_batch,
                           descending=True)[1]][:50]

            if not has_detections:
                sorted_scores, sorted_scores_idx = torch.sort(
                    fewshot_scores_softmax_of_batch, descending=True)
                sorted_scores = sorted_scores[:, 0]
                sorted_scores_idx = sorted_scores_idx[:, 0]
                sorted_scores_rows, sorted_scores_rows_idx = torch.sort(
                    sorted_scores, descending=True)
                sorted_scores_rows = sorted_scores_rows[:10]
                sorted_scores_rows_idx = sorted_scores_rows_idx[:10]
                sorted_scores_cols_idx = sorted_scores_idx[
                    sorted_scores_rows_idx]
                unique_cols_idx = sorted_scores_cols_idx.cpu().unique(
                    sorted=True).cuda()
                for label_idx in unique_cols_idx:
                    cls_twins = pred_twins[sorted_scores_rows_idx][:,
                                                                   label_idx *
                                                                   2:
                                                                   (label_idx +
                                                                    1) * 2]
                    cls_dets = torch.cat(
                        (cls_twins, sorted_scores_rows.unsqueeze(1)), 1)
                # print(f'No detections. Most likely labels are {most_likely_labels}, {most_likely_labels_softmax}')

            misc_toc = time.time()
            nms_time = misc_toc - misc_tic
            print('im_detect: {:d}/{:d} {:.3f}s {:.3f}s {:.3f}s' \
                  .format(i * batch_size + b + 1, args.num_videos, data_time / batch_size, detect_time / batch_size,
                          nms_time))

        if args.vis:
            pass

        data_tic = time.time()
    end = time.time()
    print("test time: %0.4fs" % (end - start))