Esempio n. 1
0
def evaluation(args):
    #-----------------load detection model -------------------------
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    dec_model = dec_net_seg.resnetssd50(pretrained=False,
                                        num_classes=args.num_classes)
    dec_model = load_dec_weights(dec_model, args.dec_weights)
    dec_model = dec_model.to(device)
    dec_model.eval()
    #-----------------load segmentation model -------------------------
    seg_model = seg_net.SEG_NET(num_classes=args.num_classes)
    seg_model.load_state_dict(torch.load(args.seg_weights))
    seg_model = seg_model.to(device)
    seg_model.eval()
    ##--------------------------------------------------------------
    data_transforms = seg_transforms.Compose([
        seg_transforms.ConvertImgFloat(),
        seg_transforms.Resize(args.img_height, args.img_width),
        seg_transforms.ToTensor()
    ])

    dsets = seg_dataset_kaggle.NucleiCell(args.testDir,
                                          args.annoDir,
                                          data_transforms,
                                          imgSuffix=args.imgSuffix,
                                          annoSuffix=args.annoSuffix)

    # for validation data -----------------------------------
    detector = Detect(num_classes=args.num_classes,
                      top_k=args.top_k,
                      conf_thresh=args.conf_thresh,
                      nms_thresh=args.nms_thresh,
                      variance=[0.1, 0.2])
    anchorGen = Anchors(args.img_height, args.img_width)
    anchors = anchorGen.forward()

    ap_05, ap_07 = seg_eval_kaggle.do_python_eval(dsets=dsets,
                                                  dec_model=dec_model,
                                                  seg_model=seg_model,
                                                  detector=detector,
                                                  anchors=anchors,
                                                  device=device,
                                                  args=args,
                                                  offline=True)

    print('Finish')
Esempio n. 2
0
def train(args):
    if not os.path.exists(args.weightDst):
        os.mkdir(args.weightDst)

    #-----------------load detection model -------------------------
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    dec_model = dec_net_seg.resnetssd50(pretrained=False, num_classes=args.num_classes)
    resume_dict = torch.load(args.dec_weights,map_location='cpu')#map_location='cpu'
    # resume_dict = {k[7:]: v for k, v in resume_dict.items()}
    dec_model.load_state_dict(resume_dict,strict=False)#strict=False
    dec_model = dec_model.to(device)
    #-------------------------------------------------------------------
    dec_model.eval()        # detector set to 'evaluation' mode
    for param in dec_model.parameters():
        param.requires_grad = False
    #-----------------load segmentation model -------------------------
    seg_model = seg_net.SEG_NET(num_classes=args.num_classes)
    seg_model = seg_model.to(device)
    ##--------------------------------------------------------------
    data_transforms = {
        'train': seg_transforms.Compose([seg_transforms.ConvertImgFloat(),
                                         seg_transforms.PhotometricDistort(),
                                         seg_transforms.Expand(),
                                         seg_transforms.RandomSampleCrop(),
                                         seg_transforms.RandomMirror_w(),
                                         seg_transforms.RandomMirror_h(),
                                         seg_transforms.Resize(args.img_height, args.img_width),
                                         seg_transforms.ToTensor()]),

        'val': seg_transforms.Compose([seg_transforms.ConvertImgFloat(),
                                       seg_transforms.Resize(args.img_height, args.img_width),
                                       seg_transforms.ToTensor()])
    }

    dsets = {'train': seg_dataset_kaggle.NucleiCell(args.trainDir, args.annoDir, data_transforms['train'],
                                 imgSuffix=args.imgSuffix, annoSuffix=args.annoSuffix),
             'val': seg_dataset_kaggle.NucleiCell(args.valDir, args.annoDir, data_transforms['val'],
                                 imgSuffix=args.imgSuffix, annoSuffix=args.annoSuffix)}

    dataloader = torch.utils.data.DataLoader(dsets['train'],
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=args.num_workers,
                                             collate_fn=collater,
                                             pin_memory=True)

    optimizer = optim.Adam(params=filter(lambda p: p.requires_grad, seg_model.parameters()), lr=args.init_lr)
    scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.98, last_epoch=-1)
    criterion = SEG_loss(height=args.img_height, width=args.img_width)

    if args.vis:
        cv2.namedWindow('img')
        for idx in range(len(dsets['train'])):
            img, bboxes, labels, masks = dsets['train'].__getitem__(idx)
            img = img.numpy().transpose(1, 2, 0).copy()*255
            print(img.shape)
            bboxes = bboxes.numpy()
            labels = labels.numpy()
            masks = masks.numpy()
            for idx in range(bboxes.shape[0]):
                y1, x1, y2, x2 = bboxes[idx, :]
                y1 = int(y1)
                x1 = int(x1)
                y2 = int(y2)
                x2 = int(x2)
                cv2.rectangle(img, (x1, y1), (x2, y2), (255, 255, 255), 2, lineType=1)
                mask = masks[idx, :, :]
                img = map_mask_to_image(mask, img, color=np.random.rand(3))
            cv2.imshow('img', img)
            k = cv2.waitKey(0)
            if k & 0xFF == ord('q'):
                cv2.destroyAllWindows()
                exit()
        cv2.destroyAllWindows()

    # for validation data -----------------------------------
    detector = Detect(num_classes=args.num_classes,
                      top_k=args.top_k,
                      conf_thresh=args.conf_thresh,
                      nms_thresh=args.nms_thresh,
                      variance=[0.1, 0.2])
    anchorGen = Anchors(args.img_height, args.img_width)
    anchors = anchorGen.forward()
    # --------------------------------------------------------
    train_loss_dict = []
    ap05_dict = []
    ap07_dict = []
    for epoch in range(args.num_epochs):
        print('Epoch {}/{}'.format(epoch, args.num_epochs - 1))
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                seg_model.train()
                running_loss = 0.0
                for inputs, bboxes, labels, masks in dataloader:
                    inputs = inputs.to(device)
                    with torch.no_grad():
                        locs, conf, feat_seg = dec_model(inputs)
                        detections = detector(locs, conf, anchors)

                    optimizer.zero_grad()
                    with torch.enable_grad():
                        outputs = seg_model(detections, feat_seg)
                        loss = criterion(outputs, bboxes, labels, masks)
                        loss.backward()
                        optimizer.step()

                    # statistics
                    running_loss += loss.item() * inputs.size(0)

                epoch_loss = running_loss / len(dsets[phase])

                print('{} Loss: {:.4f}'.format(phase, epoch_loss))
                train_loss_dict.append(epoch_loss)
                np.savetxt('seg_train_loss.txt', train_loss_dict, fmt='%.6f')
                #if epoch % 5 == 0:
                #    torch.save(seg_model.state_dict(),
                #               os.path.join(args.weightDst, '{:d}_{:.4f}_model.pth'.format(epoch, epoch_loss)))
                #torch.save(seg_model.state_dict(), os.path.join(args.weightDst, 'end_model.pth'))

            else:
                #if epoch % 9 == 0:
                seg_model.eval()   # Set model to evaluate mode
                ap_05, ap_07 = seg_eval_kaggle.do_python_eval(dsets=dsets[phase], dec_model=dec_model, seg_model=seg_model,
                                                           detector=detector, anchors=anchors, device=device,
                                                           args=args, offline=False)
                    # print('ap05:{:.4f}, ap07:{:.4f}'.format(ap05, ap07))
                if ap_05>0.7:
                    torch.save(seg_model.state_dict(),
                                   os.path.join(args.weightDst, '{:d}_{:.4f}_model.pth'.format(epoch, epoch_loss)))
                ap05_dict.append(ap_05)
                np.savetxt('seg_ap_05.txt', ap05_dict, fmt='%.6f')
                ap07_dict.append(ap_07)
                np.savetxt('seg_ap_07.txt', ap07_dict, fmt='%.6f')

    print('Finish')
Esempio n. 3
0
def test(args):
    #-----------------load detection model -------------------------
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    dec_model = dec_net_seg.resnetssd50(pretrained=False,
                                        num_classes=args.num_classes)
    dec_model = load_dec_weights(dec_model, args.dec_weights)
    dec_model = dec_model.to(device)
    dec_model.eval()
    #-----------------load segmentation model -------------------------
    seg_model = seg_net.SEG_NET(num_classes=args.num_classes)
    seg_model.load_state_dict(torch.load(args.seg_weights))
    seg_model = seg_model.to(device)
    seg_model.eval()
    ##--------------------------------------------------------------
    data_transforms = seg_transforms.Compose([
        seg_transforms.ConvertImgFloat(),
        seg_transforms.Resize(args.img_height, args.img_width),
        seg_transforms.ToTensor()
    ])

    dsets = seg_dataset_kaggle.NucleiCell(args.testDir,
                                          args.annoDir,
                                          data_transforms,
                                          imgSuffix=args.imgSuffix,
                                          annoSuffix=args.annoSuffix)

    # for validation data -----------------------------------
    detector = Detect(num_classes=args.num_classes,
                      top_k=args.top_k,
                      conf_thresh=args.conf_thresh,
                      nms_thresh=args.nms_thresh,
                      variance=[0.1, 0.2])
    anchorGen = Anchors(args.img_height, args.img_width)
    anchors = anchorGen.forward()
    # for img_idx in [1,55,57,72,78,123]:
    for img_idx in range(len(dsets)):
        print('loading {}/{} image'.format(img_idx, len(dsets)))
        inputs, gt_boxes, gt_classes, gt_masks = dsets.__getitem__(img_idx)
        ori_img = dsets.load_img(img_idx)
        #ori_img_copy = ori_img.copy()
        #bboxes, labels, masks = dsets.load_annotation(dsets.img_files[img_idx])
        #for mask in masks:
        #    ori_img = map_mask_to_image(mask, ori_img, color=np.random.rand(3))
        h, w, c = ori_img.shape
        x = inputs.unsqueeze(0)
        x = x.to(device)
        locs, conf, feat_seg = dec_model(x)
        detections = detector(locs, conf, anchors)
        outputs = seg_model(detections, feat_seg)
        mask_patches, mask_dets = outputs
        # For batches
        for b_mask_patches, b_mask_dets in zip(mask_patches, mask_dets):
            nd = len(b_mask_dets)
            # Step1: rearrange mask_patches and mask_dets
            for d in range(nd):
                d_mask = np.zeros((args.img_height, args.img_width),
                                  dtype=np.float32)
                d_mask_det = b_mask_dets[d].data.cpu().numpy()
                d_mask_patch = b_mask_patches[d].data.cpu().numpy()
                d_bbox = d_mask_det[0:4]
                d_conf = d_mask_det[4]
                d_class = d_mask_det[5]
                if d_conf < args.conf_thresh:
                    continue
                [y1, x1, y2, x2] = d_bbox
                y1 = np.maximum(0, np.int32(np.round(y1)))
                x1 = np.maximum(0, np.int32(np.round(x1)))
                y2 = np.minimum(np.int32(np.round(y2)), args.img_height - 1)
                x2 = np.minimum(np.int32(np.round(x2)), args.img_width - 1)
                d_mask_patch = cv2.resize(d_mask_patch,
                                          (x2 - x1 + 1, y2 - y1 + 1))
                d_mask_patch = np.where(d_mask_patch >= args.seg_thresh, 1.,
                                        0.)
                d_mask[y1:y2 + 1, x1:x2 + 1] = d_mask_patch
                d_mask = cv2.resize(d_mask,
                                    dsize=(w, h),
                                    interpolation=cv2.INTER_NEAREST)
                ori_img = map_mask_to_image(d_mask,
                                            ori_img,
                                            color=np.random.rand(3))

        cv2.imshow('img', ori_img)
        k = cv2.waitKey(0)
        if k & 0xFF == ord('q'):
            cv2.destroyAllWindows()
            exit()
        elif k & 0xFF == ord('s'):
            # cv2.imwrite('kaggle_imgs/{}_ori.png'.format(img_idx), ori_img_copy)
            # cv2.imwrite('kaggle_imgs/{}_final.png'.format(img_idx), ori_img)
            cv2.imwrite('kaggle_imgs/img/{}_gt.png'.format(img_idx), ori_img)
    cv2.destroyAllWindows()
    print('Finish')
Esempio n. 4
0
def test(args):
    #-----------------load detection model -------------------------
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    dec_model = dec_net_seg.resnetssd50(pretrained=False,
                                        num_classes=args.num_classes)
    resume_dict = torch.load(args.dec_weights, map_location='cpu')
    resume_dict = {k: v for k, v in resume_dict.items()}
    dec_model.load_state_dict(resume_dict, strict=False)
    dec_model = dec_model.to(device)
    dec_model.eval()
    #-----------------load segmentation model -------------------------
    seg_model = seg_net.SEG_NET(num_classes=args.num_classes)
    seg_model.load_state_dict(torch.load(args.seg_weights, map_location='cpu'),
                              strict=False)

    seg_model = seg_model.to(device)
    seg_model.eval()
    ##--------------------------------------------------------------
    data_transforms = seg_transforms.Compose([
        seg_transforms.ConvertImgFloat(),
        seg_transforms.Resize(args.img_height, args.img_width),
        seg_transforms.ToTensor()
    ])

    dsets = seg_dataset_kaggle.NucleiCell(args.testDir,
                                          args.annoDir,
                                          data_transforms,
                                          imgSuffix=args.imgSuffix,
                                          annoSuffix=args.annoSuffix)

    # for validation data -----------------------------------
    detector = Detect(num_classes=args.num_classes,
                      top_k=args.top_k,
                      conf_thresh=args.conf_thresh,
                      nms_thresh=args.nms_thresh,
                      variance=[0.1, 0.2])
    anchorGen = Anchors(args.img_height, args.img_width)
    anchors = anchorGen.forward()
    all_time = []
    for img_idx in range(len(dsets)):
        time_begin = time.time()
        print('loading {}/{} image'.format(img_idx, len(dsets)))
        ori_img = dsets.load_img(img_idx)
        black = cv2.cvtColor(
            np.zeros((ori_img.shape[0], ori_img.shape[1]), dtype=np.uint8),
            cv2.COLOR_GRAY2BGR)
        img = ori_img.astype(np.float32)
        img = cv2.resize(img, dsize=(512, 512))
        img = torch.from_numpy(img.copy().transpose((2, 0, 1)))
        inputs = img / 255
        h, w, c = ori_img.shape
        x = inputs.unsqueeze(0)

        x = x.to(device)
        locs, conf, feat_seg = dec_model(x)
        detections = detector(locs, conf, anchors)
        outputs = seg_model(detections, feat_seg)
        mask_patches, mask_dets = outputs
        all_time.append(time.time() - time_begin)
        # For batches
        for b_mask_patches, b_mask_dets in zip(mask_patches, mask_dets):

            nd = len(b_mask_dets)
            # Step1: rearrange mask_patches and mask_dets
            for d in range(nd):
                d_mask = np.zeros((args.img_height, args.img_width),
                                  dtype=np.float32)
                d_mask_det = b_mask_dets[d].data.cpu().numpy()
                d_mask_patch = b_mask_patches[d].data.cpu().numpy()
                d_bbox = d_mask_det[0:4]
                d_conf = d_mask_det[4]
                if d_conf < args.conf_thresh:
                    continue
                [y1, x1, y2, x2] = d_bbox
                y1 = np.maximum(0, np.int32(np.round(y1)))
                x1 = np.maximum(0, np.int32(np.round(x1)))
                y2 = np.minimum(np.int32(np.round(y2)), args.img_height - 1)
                x2 = np.minimum(np.int32(np.round(x2)), args.img_width - 1)
                d_mask_patch = cv2.resize(d_mask_patch,
                                          (x2 - x1 + 1, y2 - y1 + 1))
                d_mask_patch = np.where(d_mask_patch >= args.seg_thresh, 1.,
                                        0.)
                d_mask[y1:y2 + 1, x1:x2 + 1] = d_mask_patch
                d_mask = cv2.resize(d_mask,
                                    dsize=(w, h),
                                    interpolation=cv2.INTER_NEAREST)

                #ori_img = map_mask_to_image(d_mask, ori_img, color=np.random.rand(3))
                black = map_mask_to_image(d_mask,
                                          black,
                                          color=np.random.rand(3))
        cv2.imwrite('TCGA_imgs/{}_gt.png'.format(img_idx), black)
    all_time = all_time[1:]
    print('avg time is {}'.format(np.mean(all_time)))
    print('FPS is {}'.format(1. / np.mean(all_time)))
    print('Finish')