Esempio n. 1
0
def train(args):
    if not os.path.exists(args.weightDst):
        os.mkdir(args.weightDst)
    data_transforms = {
        'train':
        dec_transforms.Compose([
            dec_transforms.ConvertImgFloat(),
            dec_transforms.PhotometricDistort(),
            dec_transforms.Expand(),
            dec_transforms.RandomSampleCrop(),
            dec_transforms.RandomMirror_w(),
            dec_transforms.RandomMirror_h(),
            dec_transforms.Resize(args.img_height, args.img_width),
            dec_transforms.ToTensor()
        ]),
        'val':
        dec_transforms.Compose([
            dec_transforms.ConvertImgFloat(),
            dec_transforms.Resize(args.img_height, args.img_width),
            dec_transforms.ToTensor()
        ])
    }

    dsets = {
        'train':
        NucleiCell(args.trainDir,
                   args.annoDir,
                   data_transforms['train'],
                   imgSuffix=args.imgSuffix,
                   annoSuffix=args.annoSuffix),
        'val':
        NucleiCell(args.valDir,
                   args.annoDir,
                   data_transforms['val'],
                   imgSuffix=args.imgSuffix,
                   annoSuffix=args.annoSuffix)
    }

    dataloader = torch.utils.data.DataLoader(dsets['train'],
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=args.num_workers,
                                             collate_fn=collater,
                                             pin_memory=True)

    model = dec_net.resnetssd50(pretrained=True, num_classes=args.num_classes)
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    if args.multi_gpu:
        model = nn.DataParallel(model)
    model = model.to(device)

    optimizer = optim.SGD(model.parameters(), lr=args.init_lr, momentum=0.9)
    scheduler = lr_scheduler.MultiStepLR(
        optimizer, milestones=[args.decayEpoch, args.num_epochs], gamma=0.1)
    criterion = DecLoss(img_height=args.img_height,
                        img_width=args.img_width,
                        num_classes=args.num_classes,
                        variances=[0.1, 0.2])

    if args.vis:
        cv2.namedWindow('img')
        for idx in range(len(dsets['train'])):
            img, bboxes, labels = dsets['train'].__getitem__(idx)
            img = img.numpy().transpose(1, 2, 0) * 255
            bboxes = bboxes.numpy()
            labels = labels.numpy()
            for bbox in bboxes:
                y1, x1, y2, x2 = bbox
                cv2.rectangle(img, (x1, y1), (x2, y2), (255, 255, 255),
                              2,
                              lineType=1)
            cv2.imshow('img', np.uint8(img))
            k = cv2.waitKey(0)
            if k & 0xFF == ord('q'):
                cv2.destroyAllWindows()
                exit()
        cv2.destroyAllWindows()

    # for validation data -----------------------------------
    detector = Detect(num_classes=args.num_classes,
                      top_k=args.top_k,
                      conf_thresh=args.conf_thresh,
                      nms_thresh=args.nms_thresh,
                      variance=[0.1, 0.2])
    anchorGen = Anchors(args.img_height, args.img_width)
    anchors = anchorGen.forward()
    if not os.path.exists(args.cacheDir):
        os.mkdir(args.cacheDir)
    # --------------------------------------------------------
    train_loss_dict = []
    ap05_dict = []
    ap07_dict = []
    writer = SummaryWriter('/data2/coldplay/dsb_cell/scalar')
    for epoch in range(args.num_epochs):
        print('Epoch {}/{}'.format(epoch, args.num_epochs - 1))
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                model.train()
                running_loss = 0.0
                for inputs, bboxes, labels in dataloader:
                    inputs = inputs.to(device)
                    # zero the parameter gradients
                    optimizer.zero_grad()

                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(inputs)
                        loss_locs, loss_conf = criterion(
                            outputs, bboxes, labels)
                        loss = loss_locs + loss_conf
                        # backward + optimize only if in training phase
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()

                    # statistics
                    running_loss += loss.item() * inputs.size(0)

                epoch_loss = running_loss / len(dsets[phase])
                print('{} Loss: {:.4f}'.format(phase, epoch_loss))
                train_loss_dict.append(epoch_loss)
                np.savetxt('/data2/coldplay/dsb_cell/train_loss.txt',
                           train_loss_dict,
                           fmt='%.6f')
                writer.add_scalar('/data2/coldplay/dsb_cell/scalar/train',
                                  epoch_loss, epoch)
                if epoch % 5 == 0:
                    torch.save(
                        model.state_dict(),
                        os.path.join(
                            args.weightDst,
                            '{:d}_{:.4f}_model.pth'.format(epoch, epoch_loss)))
                torch.save(model.state_dict(),
                           os.path.join(args.weightDst, 'end_model.pth'))

            else:
                model.eval()  # Set model to evaluate mode
                model.eval()  # Set model to evaluate mode
                det_file = os.path.join(args.cacheDir, 'detections.pkl')
                all_boxes = [[[] for _ in range(len(dsets['val']))]
                             for _ in range(args.num_classes)]
                for img_idx in range(len(dsets['val'])):
                    ori_img = dsets['val'].load_img(img_idx)
                    h, w, c = ori_img.shape
                    inputs, gt_bboxes, gt_labels = dsets['val'].__getitem__(
                        img_idx)  # [3, 512, 640], [3, 4], [3, 1]
                    # run model
                    inputs = inputs.unsqueeze(0).to(device)
                    with torch.no_grad():
                        locs, conf = model(inputs)
                    detections = detector(locs, conf, anchors)
                    for cls_idx in range(1, detections.size(1)):
                        dets = detections[0, cls_idx, :]
                        mask = dets[:, 0].gt(0.).expand(5, dets.size(0)).t()
                        dets = torch.masked_select(dets, mask).view(-1, 5)
                        if dets.shape[0] == 0:
                            continue
                        pred_boxes = dets[:, 1:].cpu().numpy()
                        pred_score = dets[:, 0].cpu().numpy()
                        pred_boxes[:, 0] /= args.img_height
                        pred_boxes[:, 1] /= args.img_width
                        pred_boxes[:, 2] /= args.img_height
                        pred_boxes[:, 3] /= args.img_width
                        pred_boxes[:, 0] *= h
                        pred_boxes[:, 1] *= w
                        pred_boxes[:, 2] *= h
                        pred_boxes[:, 3] *= w
                        cls_dets = np.hstack(
                            (pred_boxes,
                             pred_score[:, np.newaxis])).astype(np.float32,
                                                                copy=False)
                        all_boxes[cls_idx][img_idx] = cls_dets

                with open(det_file, 'wb') as f:
                    pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL)
                    f.close()

                for cls_ind, cls in enumerate(dsets['val'].labelmap):
                    filename = dec_eval.get_voc_results_file_template(
                        'test', cls, args.cacheDir)
                    with open(filename, 'wt') as f:
                        for im_ind, index in enumerate(dsets['val'].img_files):
                            dets = all_boxes[cls_ind + 1][im_ind]
                            if dets == []:
                                continue
                            for k in range(dets.shape[0]):
                                # format: [img_file  confidence, y1, x1, y2, x2] save to call for multiple times
                                f.write(
                                    '{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'
                                    .format(index, dets[k, -1], dets[k, 0],
                                            dets[k, 1], dets[k, 2], dets[k,
                                                                         3]))
                ap05, ap07 = dec_eval.do_python_eval(dsets=dsets['val'],
                                                     output_dir=args.cacheDir,
                                                     offline=False,
                                                     use_07=True)
                print('ap05:{:.4f}, ap07:{:.4f}'.format(ap05, ap07))
                writer.add_scalar('/data2/coldplay/dsb_cell/scalar/val_ap05',
                                  ap05, epoch)
                writer.add_scalar('/data2/coldplay/dsb_cell/scalar/val_ap07',
                                  ap07, epoch)
                if ap05 > 0.71:
                    print('ap05:{:.4f}'.format(ap05))
                    torch.save(
                        model.state_dict(),
                        os.path.join(
                            args.weightDst,
                            '{:d}_{:.4f}_model.pth'.format(epoch, epoch_loss)))
                ap05_dict.append(ap05)
                np.savetxt('/data2/coldplay/dsb_cell/ap_05.txt',
                           ap05_dict,
                           fmt='%.6f')
                ap07_dict.append(ap07)
                np.savetxt('/data2/coldplay/dsb_cell/ap_07.txt',
                           ap07_dict,
                           fmt='%.6f')
    print('Finish')
Esempio n. 2
0
def evaluation(args):

    data_transforms = dec_transforms.Compose([dec_transforms.ConvertImgFloat(),
                                              dec_transforms.Resize(args.img_height, args.img_width),
                                              dec_transforms.ToTensor()])

    dsets = dec_dataset_kaggle.NucleiCell(args.testDir, args.annoDir, data_transforms,
                       imgSuffix=args.imgSuffix, annoSuffix=args.annoSuffix)


    model = dec_net.resnetssd50(pretrained=True, num_classes=args.num_classes)
    model = load_dec_weights(model, args.resume)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()
    detector = Detect(num_classes=args.num_classes,
                      top_k=args.top_k,
                      conf_thresh=args.conf_thresh,
                      nms_thresh=args.nms_thresh,
                      variance=[0.1, 0.2])
    anchorGen = Anchors(args.img_height, args.img_width)
    anchors = anchorGen.forward()

    det_file = os.path.join(args.cacheDir, 'detections.pkl')
    if not os.path.exists(args.cacheDir):
        os.mkdir(args.cacheDir)

    all_boxes = [[[] for _ in range(len(dsets))] for _ in range(args.num_classes)]
    for img_idx in range(len(dsets)):
        print('loading {}/{} image'.format(img_idx, len(dsets)))
        ori_img = dsets.load_img(img_idx)
        h,w,c = ori_img.shape
        inputs, gt_bboxes, gt_labels = dsets.__getitem__(img_idx)  # [3, 512, 640], [3, 4], [3, 1]
        inputs = inputs.unsqueeze(0).to(device)
        with torch.no_grad():
            locs, conf = model(inputs)
        detections = detector(locs, conf, anchors)
        for cls_idx in range(1, detections.size(1)):
            dets = detections[0, cls_idx, :]
            mask = dets[:, 0].gt(0.).expand(5, dets.size(0)).t()
            dets = torch.masked_select(dets, mask).view(-1, 5)
            if dets.shape[0] == 0:
                continue
            pred_boxes = dets[:, 1:].cpu().numpy().astype(np.float32)
            pred_score = dets[:, 0].cpu().numpy()

            pred_boxes[:,0] /= args.img_height
            pred_boxes[:,1] /= args.img_width
            pred_boxes[:,2] /= args.img_height
            pred_boxes[:,3] /= args.img_width
            pred_boxes[:,0] *= h
            pred_boxes[:,1] *= w
            pred_boxes[:,2] *= h
            pred_boxes[:,3] *= w

            cls_dets = np.hstack((pred_boxes, pred_score[:, np.newaxis])).astype(np.float32, copy=False)
            all_boxes[cls_idx][img_idx] = cls_dets

    with open(det_file, 'wb') as f:
        pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL)
        f.close()

    for cls_ind, cls in enumerate(dsets.labelmap):
        filename = dec_eval.get_voc_results_file_template('test', cls, args.cacheDir)
        with open(filename, 'wt') as f:
            for im_ind, index in enumerate(dsets.img_files):
                dets = all_boxes[cls_ind + 1][im_ind]
                if dets == []:
                    continue
                for k in range(dets.shape[0]):
                    # format: [img_file  confidence, y1, x1, y2, x2] save to call for multiple times
                    f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.format(index,
                                                                               dets[k, -1],
                                                                               dets[k, 0],
                                                                               dets[k, 1],
                                                                               dets[k, 2],
                                                                               dets[k, 3]))
    ap05, ap07 = dec_eval.do_python_eval(dsets=dsets,
                                         output_dir=args.cacheDir,
                                         offline=True,
                                         use_07=True)
Esempio n. 3
0
def test(args):

    data_transforms = dec_transforms.Compose([
        dec_transforms.ConvertImgFloat(),
        dec_transforms.Resize(args.img_height, args.img_width),
        dec_transforms.ToTensor()
    ])

    dsets = dec_dataset_kaggle.NucleiCell(args.testDir,
                                          args.annoDir,
                                          data_transforms,
                                          imgSuffix=args.imgSuffix,
                                          annoSuffix=args.annoSuffix)

    model = dec_net.resnetssd50(pretrained=True, num_classes=args.num_classes)
    print('Resuming training weights from {} ...'.format(args.resume))
    pretrained_dict = torch.load(args.resume)
    model_dict = model.state_dict()
    trained_dict = {k[7:]: v for k, v in pretrained_dict.items()}
    model_dict.update(trained_dict)
    model.load_state_dict(model_dict)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()
    detector = Detect(num_classes=args.num_classes,
                      top_k=args.top_k,
                      conf_thresh=args.conf_thresh,
                      nms_thresh=args.nms_thresh,
                      variance=[0.1, 0.2])
    anchorGen = Anchors(args.img_height, args.img_width)
    anchors = anchorGen.forward()
    cv2.namedWindow('img')
    for img_idx in range(len(dsets)):
        ori_img = dsets.load_img(img_idx)
        h, w, c = ori_img.shape
        inputs, gt_bboxes, gt_labels = dsets.__getitem__(
            img_idx)  # [3, 512, 640], [3, 4], [3, 1]
        inputs = inputs.unsqueeze(0).to(device)
        with torch.no_grad():
            locs, conf = model(inputs)
        detections = detector(locs, conf, anchors)
        for cls_idx in range(1, detections.size(1)):
            dets = detections[0, cls_idx, :]
            mask = dets[:, 0].gt(0.).expand(5, dets.size(0)).t()
            dets = torch.masked_select(dets, mask).view(-1, 5)
            if dets.shape[0] == 0:
                continue
            dets = dets.cpu().numpy()
            for i in range(dets.shape[0]):
                box = dets[i, 1:]
                score = dets[i, 0]
                y1, x1, y2, x2 = box
                y1 = float(y1) / args.img_height
                x1 = float(x1) / args.img_width
                y2 = float(y2) / args.img_height
                x2 = float(x2) / args.img_width
                y1 = int(float(y1) * h)
                x1 = int(float(x1) * w)
                y2 = int(float(y2) * h)
                x2 = int(float(x2) * w)
                cv2.rectangle(ori_img, (x1, y1), (x2, y2), (0, 255, 0), 2, 2)
                cv2.putText(ori_img, "%.2f" % score, (x1, y1 + 20),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 255))
        cv2.imshow('img', ori_img)
        k = cv2.waitKey(0)
        if k & 0xFF == ord('q'):
            cv2.destroyAllWindows()
            exit()
    cv2.destroyAllWindows()
    exit()