コード例 #1
0
def generate_det(args):
    ckpt_path = args.checkpoint_path
    try:
        names = os.listdir(ckpt_path)
        for name in names:
            out = re.findall("ResNet_.*", name)
            if out != []:
                ckpt_path = out[0]
                break
        ckpt_path = os.path.join(args.checkpoint_path, ckpt_path)
    except Exception:
        print("There is no checkpoint in ", args.checkpoint)
        exit
    model = RC3D_resnet.RC3D(num_classes, cfg.Test.Image_shape,
                             args.feature_path)
    model = model.cuda()
    model.zero_grad()
    model.load(ckpt_path)
    test_batch = utils.new_Batch_Generator(name_to_id, num_classes,
                                           args.image_path,
                                           args.annotation_path, 'test')
    fp = []
    det = []
    for i in range(1, num_classes):
        f = open(
            os.path.join(args.json_path, "detection_{}.json".format(str(i))),
            'w')
        fp.append(f)
        det.append({})
        det[i - 1]['object'] = []
    try:
        while True:
            with torch.no_grad():
                data, gt = next(test_batch)
                _, _, object_cls_score, object_offset = model.forward(data)
                #bbox 是按照score降序排列的
                bbox = utils.nms(model.proposal_bbox, object_cls_score,
                                 object_offset, model.num_classes,
                                 model.im_info)
                if bbox is None:
                    continue
                #pdb.set_trace()
                for _cls, score, proposal in zip(bbox['cls'], bbox['score'],
                                                 bbox['bbox']):
                    if proposal[:, 0] == proposal[:, 1]:
                        continue
                    temp_dict = {}
                    temp_dict['file_name'] = data
                    temp_dict['start'] = float(proposal[:, 0])
                    temp_dict['end'] = float(proposal[:, 1])
                    temp_dict['score'] = float(score)
                    det[int(_cls[0]) - 1]['object'].append(temp_dict)
                torch.cuda.empty_cache()
    except StopIteration:
        for i in range(num_classes - 1):
            json.dump(det[i], fp[i])
            fp[i].close()
    print("generate_gt Done!")
コード例 #2
0
def test(args):
    runtime = AverageMeter()
    ckpt_path = args.checkpoint_path
    try:
        names = os.listdir(ckpt_path)
        for name in names:
            out = re.findall("ResNetNMS_.*", name)
            if out != []:
                ckpt_path = out[0]
                break
        ckpt_path = os.path.join(args.checkpoint_path, ckpt_path)
    except Exception:
        print("There is no checkpoint in ", args.checkpoint)
        exit
    model = RC3D_resnet_learn_nms.RC3D(num_classes, cfg.Test.Image_shape,
                                       args.feature_path)
    model = model.cuda()
    model.zero_grad()
    model.load(ckpt_path)
    #test_batch = utils.Batch_Generator(name_to_id, num_classes, args.image_path, args.annotation_path, mode = 'test')
    test_batch = utils.new_Batch_Generator(name_to_id, num_classes,
                                           args.image_path,
                                           args.annotation_path)
    tic = time.time()
    data, gt = next(test_batch)
    with torch.no_grad():
        #pdb.set_trace()
        print(gt)
        _, _, _, _, nms_score = model.forward(data)
        #bbox = utils.nms(model.proposal_bbox, object_cls_score, object_offset, model.num_classes, model.im_info)
        pdb.set_trace()
        num_bbox = nms_score.shape[0]
        label = torch.arange(1, num_classes).cuda()
        label = label.repeat(num_bbox, 1)
        idx = torch.nonzero(nms_score > cfg.Network.nms_threshold[0])
        if idx.shape[0] == 0:
            exit
        bbox = utils.subscript_index(model.sorted_bbox, idx)
        cls_score = utils.subscript_index(nms_score, idx)
        cls_label = utils.subscript_index(label, idx)
        toc = time.time()
        torch.cuda.empty_cache()
        runtime.update(toc - tic)
        print('Time {runtime.val:.3f} ({runtime.avg:.3f})\t'.format(
            runtime=runtime))
        for _cls, score, proposal in zip(cls_label, cls_score, bbox):
            print(
                "class:{:}({:})\t   score:{:.6f}\t   start:{:.2f}\t  end:{:.2f}\t"
                .format(id_to_name[int(_cls)], _cls, score, proposal[0],
                        proposal[1]))
コード例 #3
0
ファイル: test_resnet_Net.py プロジェクト: peterzpy/TAL
def test(args):
    runtime = AverageMeter()
    ckpt_path = args.checkpoint_path
    try:
        names = os.listdir(ckpt_path)
        for name in names:
            out = re.findall("ResNet_.*", name)
            if out != []:
                ckpt_path = out[0]
                break
        ckpt_path = os.path.join(args.checkpoint_path, ckpt_path)
    except Exception:
        print("There is no checkpoint in ", args.checkpoint)
        exit
    model = RC3D_resnet.RC3D(num_classes, cfg.Test.Image_shape,
                             args.feature_path)
    model = model.cuda()
    model.zero_grad()
    model.load(ckpt_path)
    #test_batch = utils.Batch_Generator(name_to_id, num_classes, args.image_path, args.annotation_path, mode = 'test')
    test_batch = utils.new_Batch_Generator(name_to_id, num_classes,
                                           args.image_path,
                                           args.annotation_path)
    tic = time.time()
    data, gt = next(test_batch)
    with torch.no_grad():
        pdb.set_trace()
        print(gt)
        _, _, object_cls_score, object_offset = model.forward(data)
        bbox = utils.nms(model.proposal_bbox, object_cls_score, object_offset,
                         model.num_classes, model.im_info)
        toc = time.time()
        torch.cuda.empty_cache()
        runtime.update(toc - tic)
        print('Time {runtime.val:.3f} ({runtime.avg:.3f})\t'.format(
            runtime=runtime))
        for _cls, score, proposal in zip(bbox['cls'], bbox['score'],
                                         bbox['bbox']):
            print(
                "class:{:}({:})\t   score:{:.6f}\t   start:{:.2f}\t  end:{:.2f}\t"
                .format(id_to_name[int(_cls[0])], _cls[0], score[0],
                        proposal[0, 0], proposal[0, 1]))
コード例 #4
0
def train(args):
    if args.focal_loss == 'False':
        focal_loss = False
    else:
        focal_loss = True
    ckpt_path = args.checkpoint_path
    cost = AverageMeter()
    cost1 = AverageMeter()
    cost2 = AverageMeter()
    cost3 = AverageMeter()
    cost4 = AverageMeter()
    cost5 = AverageMeter()
    runtime = AverageMeter()
    if args.preprocess == 'False':
        utils.ner_preprocess(args.video_path, args.image_path,
                             args.video_annotation_path, args.annotation_path)
    if args.feature_preprocess == 'True':
        extract_feature(args.image_path, args.feature_path, num_classes,
                        args.pth_path)
    model = RC3D_resnet.RC3D(num_classes, cfg.Train.Image_shape,
                             args.feature_path)
    model = model.cuda()
    model.zero_grad()
    if args.use_resnet_pth == 'True':
        #model.load(args.pth_path, 0)
        step = 0
        ckpt_path = os.path.join(args.checkpoint_path, "ResNetNMS.ckpt")
    else:
        try:
            names = os.listdir(ckpt_path)
            for name in names:
                out = re.findall("ResNetNMS_.*.ckpt", name)
                if out != []:
                    ckpt_path = out[0]
                    break
            step = int(re.findall(r".*_(.*).ckpt", ckpt_path)[0])
            ckpt_path = os.path.join(args.checkpoint_path, ckpt_path)
        except Exception:
            step = 0
            ckpt_path = os.path.join(args.checkpoint_path, "ResNetNMS.ckpt")
        if step or args.pretrained == 'True':
            model.load(ckpt_path)
    train_batch = utils.new_Batch_Generator(name_to_id, num_classes,
                                            args.image_path,
                                            args.annotation_path)
    optimizer = MyOptim(model.parameters())
    while step < args.iters:
        optimizer.zero_grad()
        tic = time.time()
        data, gt_boxes = next(train_batch)
        if gt_boxes.shape[0] == 0:
            continue
        gt_boxes = torch.tensor(gt_boxes, device='cuda', dtype=torch.float32)
        cls_score, proposal_offset, object_cls_score, object_offset, nms_score = model.forward(
            data)
        loss, loss1, loss2, loss3, loss4, loss5 = model.get_loss(
            cls_score, proposal_offset, object_cls_score, object_offset,
            nms_score, gt_boxes, focal_loss)
        cost.update(loss)
        cost1.update(loss1)
        cost2.update(loss2)
        cost3.update(loss3)
        cost4.update(loss4)
        cost5.update(loss5)
        toc = time.time()
        runtime.update(toc - tic)
        loss.backward()
        optimizer.step()
        torch.cuda.empty_cache()
        step += 1
        if step % args.display_per_iters == 0 and step:
            print(
                'iter: [{0}]\t'
                'Loss {loss.avg:.4f}\t'
                'Time {runtime.val:.3f} ({runtime.avg:.3f})\n'
                'RPN:\nCls_Loss {loss1.avg:.4f}\t Bbox_Loss {loss2.avg:.4f}\nProposal:\nCls_Loss {loss3.avg:.4f}\t Bbox_Loss {loss4.avg:.4f}\n'
                'NMS:\nLoss {loss5.avg:.4f}\n'.format(step,
                                                      runtime=runtime,
                                                      loss=cost,
                                                      loss1=cost1,
                                                      loss2=cost2,
                                                      loss3=cost3,
                                                      loss4=cost4,
                                                      loss5=cost5))
        if step % args.snapshot_per_iters == 0 and step:
            try:
                os.remove(ckpt_path)
            except Exception:
                pass
            ckpt_path = os.path.join(args.checkpoint_path,
                                     "ResNetNMS_{:05d}.ckpt".format(step))
            model.save(ckpt_path)
        if step % args.clear_per_iters == 0:
            cost.reset()
            cost1.reset()
            cost2.reset()
            cost3.reset()
            cost4.reset()
            cost5.reset()
            runtime.reset()