コード例 #1
0
ファイル: eval.py プロジェクト: zobeirraisi/stela
def do_eval(args):
    model = STELA(backbone=args.backbone, num_classes=2)
    model.load_state_dict(torch.load(args.weights))
    model.eval()
    if torch.cuda.is_available():
        model.cuda()
    evaluate(model, args)
コード例 #2
0
def demo(args):
    #
    model = STELA(backbone=args.backbone, num_classes=2)
    model.load_state_dict(torch.load(args.weights))
    model.eval()

    ims_list = [x for x in os.listdir(args.ims_dir) if is_image(x)]

    for _, im_name in enumerate(ims_list):
        im_path = os.path.join(args.ims_dir, im_name)
        src = cv2.imread(im_path, cv2.IMREAD_COLOR)
        im = cv2.cvtColor(src, cv2.COLOR_BGR2RGB)
        cls_dets = im_detect(model, im, target_sizes=args.target_size)
        for j in range(len(cls_dets)):
            cls, scores = cls_dets[j, 0], cls_dets[j, 1]
            bbox = cls_dets[j, 2:]
            if len(bbox) == 4:
                draw_caption(src, bbox, '{:1.3f}'.format(scores))
                cv2.rectangle(src, (int(bbox[0]), int(bbox[1])),
                              (int(bbox[2]), int(bbox[3])),
                              color=(0, 0, 255),
                              thickness=2)
            else:
                pts = np.array([rbox_2_quad(bbox[:5]).reshape((4, 2))],
                               dtype=np.int32)
                cv2.drawContours(src, pts, 0, color=(0, 255, 0), thickness=2)
                # display original anchors
                # if len(bbox) > 5:
                #     pts = np.array([rbox_2_quad(bbox[5:]).reshape((4, 2))], dtype=np.int32)
                #     cv2.drawContours(src, pts, 0, color=(0, 0, 255), thickness=2)
        # resize for better shown
        im = cv2.resize(src, (800, 800), interpolation=cv2.INTER_LINEAR)
        cv2.imshow('Detection Results', im)
        cv2.waitKey(0)
コード例 #3
0
def train_model(args):
    # train
    train_dataset = CustomDataset(args.train_img, args.train_gt, args.gt_type_train)
    print('Number of Training Images is: {}'.format(len(train_dataset)))
    scales = args.training_size + 32 * np.array([x for x in range(-5, 6)])
    collater = Collater(scales=scales, keep_ratio=False, multiple=32)
    train_loader = data.DataLoader(
        dataset=train_dataset,
        batch_size=args.batch_size,
        num_workers=8,
        collate_fn=collater,
        shuffle=True,
        drop_last=True
    )
    
    os.makedirs('./weights', exist_ok=True)
    
    if args.gt_type_test == 'json':
        parse_gt_json(args)
    elif args.gt_type_test == 'txt':
        parse_gt_txt(args)
        
    model = STELA(backbone=args.backbone, num_classes=2)
    if os.path.exists(args.pretrained):
        model.load_state_dict(torch.load(args.pretrained))
        print('Load pretrained model from {}.'.format(args.pretrained))
    if torch.cuda.is_available():
        model.cuda()
    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model).cuda()

    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.1)
    iters_per_epoch = np.floor((len(train_dataset) / float(args.batch_size)))
    num_epochs = int(np.ceil(args.max_iter / iters_per_epoch))
    iter_idx = 0
    best_loss = sys.maxsize
    for _ in range(num_epochs):
        for _, batch in enumerate(train_loader):
            iter_idx += 1
            if iter_idx > args.max_iter:
                break
            _t.tic()
            model.train()
            
            if args.freeze_bn:
                if torch.cuda.device_count() > 1:
                    model.module.freeze_bn()
                else:
                    model.freeze_bn()

            optimizer.zero_grad()
            ims, gt_boxes = batch['image'], batch['boxes']
            if torch.cuda.is_available():
                ims, gt_boxes = ims.cuda(), gt_boxes.cuda()
            losses = model(ims, gt_boxes)
            loss_cls, loss_reg = losses['loss_cls'].mean(), losses['loss_reg'].mean()
            if losses.__contains__('loss_ref'):
                loss_ref = losses['loss_ref'].mean()
                loss = loss_cls + (loss_reg + loss_ref) * 0.5
            else:
                loss = loss_cls + loss_reg
            if bool(loss == 0):
                continue
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 0.1)
            optimizer.step()
            scheduler.step()
            
            if iter_idx % args.display == 0:
                info = 'iter: [{}/{}], time: {:1.3f}'.format(iter_idx, args.max_iter, _t.toc())
                if losses.__contains__('loss_ref'):
                    info = info + ', ref: {:1.3f}'.format(loss_ref.item())
                    mlflow_rest.log_metric(metric = {'key': 'loss_ref', 'value': loss_ref.item(), 'step': iter_idx})
                print(info + ', loss_cls: {:1.3f}, loss_reg: {:1.3f}, total_loss: {:1.3f}'.format(loss_cls.item(), loss_reg.item(), loss.item()))                
                if loss.item() < best_loss:
                    best_loss = loss.item()
                    if torch.cuda.device_count() > 1:
                        torch.save(model.module.state_dict(), 'weights/weight_{}_{:1.3f}.pth'.format(iter_idx, loss.item()))
                    else:
                        torch.save(model.state_dict(), 'weights/weight_{}_{:1.3f}.pth'.format(iter_idx, loss.item()))
                mlflow_rest.log_metric(metric = {'key': 'loss_cls', 'value': loss_cls.item(), 'step': iter_idx})
                mlflow_rest.log_metric(metric = {'key': 'loss_reg', 'value': loss_reg.item(), 'step': iter_idx})
                mlflow_rest.log_metric(metric = {'key': 'total_loss', 'value': loss.item(), 'step': iter_idx})
            
#             if (arg.eval_iter > 0) and (iter_idx % arg.eval_iter) == 0:
                
                ## mlflow_rest.log_metric(metric = {'key': 'accuracy', 'value': accuracy, 'step': iter_idx})
                ## mlflow_rest.log_metric(metric = {'key': 'IOU', 'value': avg_iou, 'step': iter_idx})
                ## mlflow_rest.log_metric(metric = {'key': 'confidence', 'value': confidence, 'step': iter_idx})
                ## print('IOU: {}, Score: {}'.format(avg_iou, confidence))
                ## print(f"precision: {precision*100}, recall: {recall*100}, f1: {f1*100}, accuracy: {accuracy*100}")
                
            if iter_idx % args.save_interval == 0:
                if torch.cuda.device_count() > 1:
                    torch.save(model.module.state_dict(), f'weights/check_{iter_idx}.pth')
                else:
                    torch.save(model.state_dict(), f'weights/check_{iter_idx}.pth')
    
    if torch.cuda.device_count() > 1:
        torch.save(model.module.state_dict(), f'weights/final_{args.max_iter}.pth')
    else:
        torch.save(model.state_dict(), f'weights/final_{args.max_iter}.pth')

    model.eval()
    if torch.cuda.device_count() > 1:
        result = evaluate(model.module, args)
    else:
        result = evaluate(model, args)

    mlflow_rest.log_metric(metric = {'key': 'precision', 'value': result['precision'], 'step': iter_idx})
    mlflow_rest.log_metric(metric = {'key': 'recall', 'value': result['recall'], 'step': iter_idx})
    mlflow_rest.log_metric(metric = {'key': 'hmean', 'value': result['hmean'], 'step': iter_idx})
コード例 #4
0
def train_model(args):
    #
    ds = VOCDataset(root_dir=args.train_dir)
    print('Number of Training Images is: {}'.format(len(ds)))
    scales = args.training_size + 32 * np.array([x for x in range(-5, 6)])
    collater = Collater(scales=scales, keep_ratio=False, multiple=32)
    loader = data.DataLoader(dataset=ds,
                             batch_size=args.batch_size,
                             num_workers=8,
                             collate_fn=collater,
                             shuffle=True,
                             drop_last=True)
    #
    model = STELA(backbone=args.backbone, num_classes=2)
    if os.path.exists(args.pretrained):
        model.load_state_dict(torch.load(args.pretrained))
        print('Load pretrained model from {}.'.format(args.pretrained))
    if torch.cuda.is_available():
        model.cuda()
    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model).cuda()

    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=args.step_size,
                                          gamma=0.1)
    iters_per_epoch = np.floor((len(ds) / float(args.batch_size)))
    num_epochs = int(np.ceil(args.max_iter / iters_per_epoch))
    iter_idx = 0
    for _ in range(num_epochs):
        for _, batch in enumerate(loader):
            iter_idx += 1
            if iter_idx > args.max_iter:
                break
            _t.tic()
            scheduler.step(epoch=iter_idx)
            model.train()

            if args.freeze_bn:
                if torch.cuda.device_count() > 1:
                    model.module.freeze_bn()
                else:
                    model.freeze_bn()

            optimizer.zero_grad()
            ims, gt_boxes = batch['image'], batch['boxes']
            if torch.cuda.is_available():
                ims, gt_boxes = ims.cuda(), gt_boxes.cuda()
            losses = model(ims, gt_boxes)
            loss_cls, loss_reg = losses['loss_cls'].mean(
            ), losses['loss_reg'].mean()
            if losses.__contains__('loss_ref'):
                loss_ref = losses['loss_ref'].mean()
                loss = loss_cls + (loss_reg + loss_ref) * 0.5
            else:
                loss = loss_cls + loss_reg
            if bool(loss == 0):
                continue
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 0.1)
            optimizer.step()
            if iter_idx % args.display == 0:
                info = 'iter: [{}/{}], time: {:1.3f}'.format(
                    iter_idx, args.max_iter, _t.toc())
                if losses.__contains__('loss_ref'):
                    info = info + ', ref: {:1.3f}'.format(loss_ref.item())
                info = info + ', cls: {:1.3f}, reg: {:1.3f}'.format(
                    loss_cls.item(), loss_reg.item())
                print(info)
            #
            if (arg.eval_iter > 0) and (iter_idx % arg.eval_iter) == 0:
                model.eval()
                if torch.cuda.device_count() > 1:
                    evaluate(model.module, args)
                else:
                    evaluate(model, args)
    #
    if not os.path.exists('./weights'):
        os.mkdir('./weights')
    if torch.cuda.device_count() > 1:
        torch.save(model.module.state_dict(), './weights/deploy.pth')
    else:
        torch.save(model.state_dict(), './weights/deploy.pth')
コード例 #5
0
def demo(backbone='eb2',
         weights='weights/deploy_eb_ship_15.pth',
         ims_dir='sample',
         target_size=768):
    #
    model = STELA(backbone=backbone, num_classes=2)
    model.load_state_dict(torch.load(weights))
    # model.eval()
    # print(model)

    classifier = EfficientNet.from_name(net_name)
    num_ftrs = classifier._fc.in_features
    classifier._fc = nn.Linear(num_ftrs, class_num)

    classifier = classifier.cuda()
    best_model_wts = 'dataset/weismoke/model/efficientnet-b0.pth'
    classifier.load_state_dict(torch.load(best_model_wts))

    ims_list = [x for x in os.listdir(ims_dir) if is_image(x)]
    import shutil
    shutil.rmtree('output/')
    os.mkdir('output/')
    for _, im_name in enumerate(ims_list):
        im_path = os.path.join(ims_dir, im_name)
        src = cv2.imread(im_path, cv2.IMREAD_COLOR)
        im = cv2.cvtColor(src, cv2.COLOR_BGR2RGB)
        import time
        # start=time.clock()
        cls_dets = im_detect(model, im, target_sizes=target_size)
        end = time.clock()
        # print('********time*********',end-start)
        # val='/home/jd/projects/haha/chosename/val_plane_split/label_new/'
        """
        if(len(cls_dets)==0):
            print('*********no********',im_name)
            #image_path = os.path.join(img_path, name + ext) #样本图片的名称
            shutil.move(val+im_name[0:-4]+'.txt', 'hard')  #移动该样本图片到blank_img_path
            shutil.move(im_path, 'hard/')     #移动该样本图片的标签到blank_label_path
            continue
        """
        fw = open('output/' + im_name[:-4] + '.txt', 'a')
        fw.truncate()
        for j in range(len(cls_dets)):
            cls, scores = cls_dets[j, 0], cls_dets[j, 1]
            # print ('cls,score',cls,scores)

            bbox = cls_dets[j, 2:]
            # print(bbox)
            if len(bbox) == 4:
                draw_caption(src, bbox, '{:1.3f}'.format(scores))
                cv2.rectangle(src, (int(bbox[0]), int(bbox[1])),
                              (int(bbox[2]), int(bbox[3])),
                              color=(0, 0, 255),
                              thickness=2)
            else:
                pts = np.array([rbox_2_quad(bbox[:5]).reshape((4, 2))],
                               dtype=np.int32)
                # print('####pts####',pts)
                cv2.drawContours(src, pts, 0, color=(0, 255, 0), thickness=2)
                # display original anchors
                # if len(bbox) > 5:
                #     pts = np.array([rbox_2_quad(bbox[5:]).reshape((4, 2))], dtype=np.int32)
                #     cv2.drawContours(src, pts, 0, color=(0, 0, 255), thickness=2)
                patch = crop_image(im, pts)
                pred = classify(classifier, patch)

            fw.write(
                str(pts.flatten()[0]) + ' ' + str(pts.flatten()[1]) + ' ' +
                str(pts.flatten()[2]) + ' ' + str(pts.flatten()[3]) + ' ' +
                str(pts.flatten()[4]) + ' ' + str(pts.flatten()[5]) + ' ' +
                str(pts.flatten()[6]) + ' ' + str(pts.flatten()[7]) + ' ' +
                classes[pred] + '\n')
        fw.close()

        # resize for better shown
        im = cv2.resize(src, (768, 768), interpolation=cv2.INTER_LINEAR)
        cv2.imwrite('output_img/' + im_name, im)

    train_img_dir = '/home/jd/projects/bifpn/sample_plane/'
    groundtruth_txt_dir = '/home/jd/projects/haha/chosename/val_plane_split/label_new/'
    detect_txt_dir = '/home/jd/projects/bifpn/output/'
    Recall, Precision, mAP = compute_acc(train_img_dir, groundtruth_txt_dir,
                                         detect_txt_dir)
    print('*******', Recall)
    return Recall, Precision, mAP