示例#1
0
def write_detection_results(detector, model, dsets, device, output_dir,
                            datatype):
    num_imgs = len(dsets)

    det_file = os.path.join(output_dir, 'detections.pkl')
    all_boxes = [[[] for _ in range(num_imgs)] for _ in range(cfg.num_classes)]

    anchorGen = Anchors(cfg.img_sizes)
    anchors = anchorGen.forward()
    evaluate_time = []
    for i in range(num_imgs):
        img, bboxes, labels = dsets.__getitem__(i)

        x = img.unsqueeze(0)
        x = x.to(device)

        begin = time.time()
        bboxes, conf = model(x)
        detect_time = time.time() - begin

        if i:
            evaluate_time.append(detect_time)

        # bboxes: torch.Size([1, 30080, 4])   device(type='cuda', index=0)
        # conf:   torch.Size([1, 30080, 2])   device(type='cuda', index=0)
        detections = detector(bboxes, conf, anchors)

        for j in range(1, detections.size(1)):
            dets = detections[0, j, :]
            mask = dets[:, 0].gt(0.).expand(5, dets.size(0)).t()
            dets = torch.masked_select(dets, mask).view(-1, 5)
            if dets.shape[0] == 0:
                continue
            boxes = dets[:, 1:]
            scores = dets[:, 0].cpu().numpy()
            cls_dets = np.hstack((boxes.cpu().numpy(), scores[:, np.newaxis])) \
                .astype(np.float32, copy=False)
            all_boxes[j][i] = cls_dets

        # print('img-detect: {:d}/{:d} {:.3f}s'.format(i + 1, num_imgs, detect_time))

    print("average time is {:.4f}".format(np.mean(evaluate_time)))

    with open(det_file, 'wb') as f:
        pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL)
        f.close()

    for cls_ind, cls in enumerate(cfg.labelmap):
        # print('Writing {:s} VOC results file'.format(cls))
        filename = python_evaluation.get_voc_results_file_template(
            datatype, cls)
        with open(filename, 'wt') as f:
            for im_ind, index in enumerate(dsets.img_ids):
                dets = all_boxes[cls_ind + 1][im_ind]
                if dets == []:
                    continue
                for k in range(dets.shape[0]):
                    f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.format(
                        index[1], dets[k, -1], dets[k, 0] + 1, dets[k, 1] + 1,
                        dets[k, 2] + 1, dets[k, 3] + 1))
示例#2
0
 def __init__(self, img_size, num_classes, device):
     super(LossModule, self).__init__()
     self.device = device
     self.num_classes = num_classes
     self.neg_pos_ratio = 3
     anchorGen = Anchors(img_size)
     self.anchors = anchorGen.forward()
     self.num_anchors = self.anchors.shape[0]
     self.match_thresh = 0.5
     self.variances = [0.1, 0.2]
示例#3
0
def train(dec_weights, seg_weights=None):

    if not os.path.exists('weights'):
        os.mkdir('weights')

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    #-----------------load detection model -------------------------
    dec_model = dec_net.resnetssd50(pretrained=False,
                                    num_classes=cfg.num_classes)
    resume_dict = torch.load(dec_weights)
    resume_dict = {k[7:]: v for k, v in resume_dict.items()}
    dec_model.load_state_dict(resume_dict)
    dec_model.to(device)
    #-----------------load segmentation model -------------------------
    seg_model = SEG_NET(num_classes=cfg.num_classes)
    if seg_weights:
        seg_model.load_state_dict(torch.load(seg_weights))
    seg_model.to(device)
    ##--------------------------------------------------------------

    data_transforms = {
        'train':
        seg_transforms.Compose([
            seg_transforms.ConvertImgFloat(),
            seg_transforms.PhotometricDistort(),
            seg_transforms.Expand(max_scale=2, mean=(0.485, 0.456, 0.406)),
            seg_transforms.RandomSampleCrop(),
            seg_transforms.RandomMirror_w(),
            seg_transforms.RandomMirror_h(),
            seg_transforms.Resize(cfg.img_sizes),
            seg_transforms.ToTensor()
        ]),
        'val':
        seg_transforms.Compose(
            [seg_transforms.ConvertImgFloat(),
             seg_transforms.ToTensor()])
    }

    dsets = {
        x: CellDataset(root=cfg.root, datatype=x, transform=data_transforms[x])
        for x in ['train', 'val']
    }

    ## Visualization of input data and GT ###################
    # viewDatasets.view_dataset(dsets['train'])
    #########################################################

    dataloader_train = torch.utils.data.DataLoader(
        dsets['train'],
        batch_size=cfg.batch_size,
        shuffle=True,
        num_workers=4,
        collate_fn=seg_collater.collater,
        pin_memory=True)

    optimizer = optim.Adam(params=filter(lambda p: p.requires_grad,
                                         seg_model.parameters()),
                           lr=cfg.init_lr)
    scheduler = lr_scheduler.ExponentialLR(optimizer,
                                           gamma=0.98,
                                           last_epoch=-1)
    criterion = SEG_loss()

    detector = Detect(num_classes=cfg.num_classes,
                      top_k=cfg.top_k,
                      conf_thresh=cfg.conf_thresh,
                      nms_thresh=cfg.nms_thresh,
                      variance=cfg.variances)

    anchorGen = Anchors(cfg.img_sizes)
    anchors = anchorGen.forward()

    #-------------------------------------------------------------------
    dec_model.eval()  # detector set to 'evaluation' mode
    for param in dec_model.parameters():
        param.requires_grad = False
    #-------------------------------------------------------------------
    train_loss_dict = []
    ap05_dict = []
    ap07_dict = []
    for epoch in range(cfg.num_epochs):
        print('Epoch {}/{}'.format(epoch, cfg.num_epochs - 1))
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                seg_model.train()
                running_loss = 0.0
                for inputs, bboxes, labels, masks in dataloader_train:
                    inputs = inputs.to(device)
                    with torch.no_grad():
                        locs, conf, feat_seg = dec_model(inputs)
                        detections = detector(locs, conf, anchors)
                        # viewDatasets.view_detections(inputs, detections)

                    optimizer.zero_grad()
                    with torch.enable_grad():
                        outputs = seg_model(detections, feat_seg)
                        loss = criterion(outputs, bboxes, labels, masks)
                        loss.backward()
                        optimizer.step()

                    # statistics
                    running_loss += loss.item() * inputs.size(0)

                epoch_loss = running_loss / len(dsets[phase])

                print('{} Loss: {:.4f}'.format(phase, epoch_loss))
                train_loss_dict.append(epoch_loss)
                np.savetxt('train_loss.txt', train_loss_dict, fmt='%.6f')
                if epoch % 5 == 0:
                    torch.save(
                        seg_model.state_dict(),
                        os.path.join(
                            'weights',
                            '{:d}_{:.4f}_model.pth'.format(epoch, epoch_loss)))
                torch.save(seg_model.state_dict(),
                           os.path.join('weights', 'end_model.pth'))

            else:
                seg_model.eval()  # Set model to evaluate mode
                with torch.no_grad():
                    ap05, ap07 = seg_eval.eval(dec_model=dec_model,
                                               seg_model=seg_model,
                                               dsets=dsets[phase],
                                               device=device,
                                               detector=detector,
                                               anchors=anchors)
                    ap05_dict.append(ap05)
                    np.savetxt('ap_05.txt', ap05_dict, fmt='%.6f')
                    ap07_dict.append(ap07)
                    np.savetxt('ap_07.txt', ap07_dict, fmt='%.6f')
示例#4
0
def test(dec_weights, seg_weights):
    np.random.seed(0)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transforms = seg_transforms.Compose(
        [seg_transforms.ConvertImgFloat(),
         seg_transforms.ToTensor()])

    datatype = 'test'

    dsets = CellDataset(root=cfg.root,
                        datatype=datatype,
                        transform=data_transforms)

    #-----------------load detection model -------------------------
    dec_model = dec_net.resnetssd50(pretrained=False,
                                    num_classes=cfg.num_classes)
    resume_dict = torch.load(dec_weights)
    resume_dict = {k[7:]: v for k, v in resume_dict.items()}
    dec_model.load_state_dict(resume_dict)
    dec_model.to(device)
    #-----------------load segmentation model -------------------------
    seg_model = SEG_NET(num_classes=cfg.num_classes)
    if seg_weights:
        seg_model.load_state_dict(torch.load(seg_weights))
    seg_model.to(device)

    dec_model.eval()
    seg_model.eval()

    detector = Detect(num_classes=cfg.num_classes,
                      top_k=cfg.top_k,
                      conf_thresh=cfg.conf_thresh,
                      nms_thresh=cfg.nms_thresh,
                      variance=cfg.variances)

    anchorGen = Anchors(cfg.img_sizes)
    anchors = anchorGen.forward()

    # print(dec_model)
    # print(seg_model)

    for idx_img in range(len(dsets)):
        inputs, gt_boxes, gt_classes, gt_masks = dsets.__getitem__(idx_img)
        ori_img = dsets.load_image(idx_img)
        img_copy = ori_img.copy()
        h, w, c = ori_img.shape

        x = inputs.unsqueeze(0)
        x = x.to(device)
        with torch.no_grad():
            locs, conf, feat_seg = dec_model(x)
            detections = detector(locs, conf, anchors)
            outputs = seg_model(detections, feat_seg)

        mask_patches, mask_dets = outputs

        for idx in range(len(mask_patches)):
            batch_mask_patches = mask_patches[idx]
            batch_mask_dets = mask_dets[idx]
            # For obj
            for idx_obj in range(len(batch_mask_patches)):
                # ori_img = img_copy
                dets = batch_mask_dets[idx_obj].data.cpu().numpy()
                box = dets[0:4]
                conf = dets[4]
                if conf < cfg.conf_thresh:
                    continue
                class_obj = dets[5]

                mask_patch = batch_mask_patches[idx_obj].data.cpu().numpy()

                [y1, x1, y2, x2] = box
                y1 = np.maximum(0, np.int32(np.round(y1)))
                x1 = np.maximum(0, np.int32(np.round(x1)))
                y2 = np.minimum(np.int32(np.round(y2)), h - 1)
                x2 = np.minimum(np.int32(np.round(x2)), w - 1)

                mask = np.zeros((h, w), dtype=np.float32)
                mask_patch = cv2.resize(mask_patch, (x2 - x1, y2 - y1))

                mask_patch = np.where(mask_patch >= cfg.seg_thresh, 1, 0)

                mask[y1:y2, x1:x2] = mask_patch
                color = np.random.rand(3)
                mask = np.repeat(mask[:, :, np.newaxis], 3, axis=2)
                mskd = ori_img * mask

                clmsk = np.ones(mask.shape) * mask
                clmsk[:, :, 0] = clmsk[:, :, 0] * color[0] * 256
                clmsk[:, :, 1] = clmsk[:, :, 1] * color[1] * 256
                clmsk[:, :, 2] = clmsk[:, :, 2] * color[2] * 256
                ori_img = ori_img + 0.7 * clmsk - 0.7 * mskd
                cv2.rectangle(ori_img, (x1, y1), (x2, y2), (255, 0, 0), 2, 1)
                cv2.putText(ori_img,
                            dsets.classes[int(class_obj)] + "%.2f" % conf,
                            (x1, y1 + 20), cv2.FONT_HERSHEY_SIMPLEX, 0.6,
                            (255, 255, 255))
        cv2.imwrite("{}.jpg".format(dsets.img_ids[idx_img][1]),
                    np.uint8(ori_img))
        cv2.imshow('img', np.uint8(ori_img))
        k = cv2.waitKey(0)
        if k & 0xFF == ord('q'):
            cv2.destroyAllWindows()
            exit(1)

    cv2.destroyAllWindows()
    exit(1)
示例#5
0
    datatype = 'test'

    dsets = CellDataset(root=cfg.root,
                        datatype=datatype,
                        transform=data_transforms)

    model = dec_net.resnetssd50(pretrained=True, num_classes=cfg.num_classes)
    if resume is not None:
        resume_dict = torch.load(resume)
        resume_dict = {k[7:]: v for k, v in resume_dict.items()}
        model.load_state_dict(resume_dict)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    detector = Detect(num_classes=cfg.num_classes,
                      top_k=cfg.top_k,
                      conf_thresh=cfg.conf_thresh,
                      nms_thresh=cfg.nms_thresh,
                      variance=cfg.variances)

    num_imgs = len(dsets)
    anchorGen = Anchors(cfg.img_sizes)
    anchors = anchorGen.forward()

    eval(model=model,
         dsets=dsets,
         device=device,
         detector=detector,
         datatype=datatype)
示例#6
0
def test(resume=None):
    data_transforms = dec_transform.Compose([dec_transform.ConvertImgFloat(),
                                             dec_transform.ToTensor()])

    datatype = 'test'

    dsets = CellDataset(root=cfg.root,
                        datatype=datatype,
                        transform=data_transforms)


    model = dec_net.resnetssd50(pretrained=True, num_classes=cfg.num_classes)
    if resume is not None:
        resume_dict = torch.load(resume)
        resume_dict = {k[7:]: v for k, v in resume_dict.items()}
        model.load_state_dict(resume_dict)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()


    detector = Detect(num_classes=cfg.num_classes,
                      top_k=cfg.top_k,
                      conf_thresh=cfg.conf_thresh,
                      nms_thresh=cfg.nms_thresh,
                      variance=cfg.variances)

    num_imgs = len(dsets)
    anchorGen = Anchors(cfg.img_sizes)
    anchors = anchorGen.forward()

    cv2.namedWindow('img')
    for i in range(num_imgs):
        img, bboxes, labels = dsets.__getitem__(i)
        ori_img = dsets.load_image(i)
        x = img.unsqueeze(0)
        x = x.to(device)

        bboxes, conf = model(x)

        detections = detector(bboxes, conf, anchors)
        for j in range(1, detections.size(1)):
            dets = detections[0, j, :]
            mask = dets[:,0].gt(0.).expand(5,dets.size(0)).t()
            dets = torch.masked_select(dets, mask).view(-1,5)
            if dets.dim()==0:
                continue
            if j:
                boxes = dets[:,1:]
                scores = dets[:,0].cpu().numpy()
                for box, score in zip(boxes,scores):
                    y1,x1,y2,x2 = box
                    y1 = int(y1)
                    x1 = int(x1)
                    y2 = int(y2)
                    x2 = int(x2)

                    cv2.rectangle(ori_img, (x1, y1), (x2, y2), (0, 255, 0), 2, 2)
                    cv2.putText(ori_img,
                                dsets.classes[int(j)] + "%.2f" % score,
                                (x1, y1 + 20),
                                cv2.FONT_HERSHEY_SIMPLEX,
                                0.6,
                                (255, 255, 255))
        cv2.imshow('img', ori_img)
        k = cv2.waitKey(0)
        if k & 0xFF == ord('q'):
            cv2.destroyAllWindows()
            exit()
    cv2.destroyAllWindows()
    exit()