Exemple #1
0
def train(dec_weights, seg_weights=None):

    if not os.path.exists('weights'):
        os.mkdir('weights')

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    #-----------------load detection model -------------------------
    dec_model = dec_net.resnetssd50(pretrained=False,
                                    num_classes=cfg.num_classes)
    resume_dict = torch.load(dec_weights)
    resume_dict = {k[7:]: v for k, v in resume_dict.items()}
    dec_model.load_state_dict(resume_dict)
    dec_model.to(device)
    #-----------------load segmentation model -------------------------
    seg_model = SEG_NET(num_classes=cfg.num_classes)
    if seg_weights:
        seg_model.load_state_dict(torch.load(seg_weights))
    seg_model.to(device)
    ##--------------------------------------------------------------

    data_transforms = {
        'train':
        seg_transforms.Compose([
            seg_transforms.ConvertImgFloat(),
            seg_transforms.PhotometricDistort(),
            seg_transforms.Expand(max_scale=2, mean=(0.485, 0.456, 0.406)),
            seg_transforms.RandomSampleCrop(),
            seg_transforms.RandomMirror_w(),
            seg_transforms.RandomMirror_h(),
            seg_transforms.Resize(cfg.img_sizes),
            seg_transforms.ToTensor()
        ]),
        'val':
        seg_transforms.Compose(
            [seg_transforms.ConvertImgFloat(),
             seg_transforms.ToTensor()])
    }

    dsets = {
        x: CellDataset(root=cfg.root, datatype=x, transform=data_transforms[x])
        for x in ['train', 'val']
    }

    ## Visualization of input data and GT ###################
    # viewDatasets.view_dataset(dsets['train'])
    #########################################################

    dataloader_train = torch.utils.data.DataLoader(
        dsets['train'],
        batch_size=cfg.batch_size,
        shuffle=True,
        num_workers=4,
        collate_fn=seg_collater.collater,
        pin_memory=True)

    optimizer = optim.Adam(params=filter(lambda p: p.requires_grad,
                                         seg_model.parameters()),
                           lr=cfg.init_lr)
    scheduler = lr_scheduler.ExponentialLR(optimizer,
                                           gamma=0.98,
                                           last_epoch=-1)
    criterion = SEG_loss()

    detector = Detect(num_classes=cfg.num_classes,
                      top_k=cfg.top_k,
                      conf_thresh=cfg.conf_thresh,
                      nms_thresh=cfg.nms_thresh,
                      variance=cfg.variances)

    anchorGen = Anchors(cfg.img_sizes)
    anchors = anchorGen.forward()

    #-------------------------------------------------------------------
    dec_model.eval()  # detector set to 'evaluation' mode
    for param in dec_model.parameters():
        param.requires_grad = False
    #-------------------------------------------------------------------
    train_loss_dict = []
    ap05_dict = []
    ap07_dict = []
    for epoch in range(cfg.num_epochs):
        print('Epoch {}/{}'.format(epoch, cfg.num_epochs - 1))
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                seg_model.train()
                running_loss = 0.0
                for inputs, bboxes, labels, masks in dataloader_train:
                    inputs = inputs.to(device)
                    with torch.no_grad():
                        locs, conf, feat_seg = dec_model(inputs)
                        detections = detector(locs, conf, anchors)
                        # viewDatasets.view_detections(inputs, detections)

                    optimizer.zero_grad()
                    with torch.enable_grad():
                        outputs = seg_model(detections, feat_seg)
                        loss = criterion(outputs, bboxes, labels, masks)
                        loss.backward()
                        optimizer.step()

                    # statistics
                    running_loss += loss.item() * inputs.size(0)

                epoch_loss = running_loss / len(dsets[phase])

                print('{} Loss: {:.4f}'.format(phase, epoch_loss))
                train_loss_dict.append(epoch_loss)
                np.savetxt('train_loss.txt', train_loss_dict, fmt='%.6f')
                if epoch % 5 == 0:
                    torch.save(
                        seg_model.state_dict(),
                        os.path.join(
                            'weights',
                            '{:d}_{:.4f}_model.pth'.format(epoch, epoch_loss)))
                torch.save(seg_model.state_dict(),
                           os.path.join('weights', 'end_model.pth'))

            else:
                seg_model.eval()  # Set model to evaluate mode
                with torch.no_grad():
                    ap05, ap07 = seg_eval.eval(dec_model=dec_model,
                                               seg_model=seg_model,
                                               dsets=dsets[phase],
                                               device=device,
                                               detector=detector,
                                               anchors=anchors)
                    ap05_dict.append(ap05)
                    np.savetxt('ap_05.txt', ap05_dict, fmt='%.6f')
                    ap07_dict.append(ap07)
                    np.savetxt('ap_07.txt', ap07_dict, fmt='%.6f')
Exemple #2
0
def test(dec_weights, seg_weights):
    np.random.seed(0)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transforms = seg_transforms.Compose(
        [seg_transforms.ConvertImgFloat(),
         seg_transforms.ToTensor()])

    datatype = 'test'

    dsets = CellDataset(root=cfg.root,
                        datatype=datatype,
                        transform=data_transforms)

    #-----------------load detection model -------------------------
    dec_model = dec_net.resnetssd50(pretrained=False,
                                    num_classes=cfg.num_classes)
    resume_dict = torch.load(dec_weights)
    resume_dict = {k[7:]: v for k, v in resume_dict.items()}
    dec_model.load_state_dict(resume_dict)
    dec_model.to(device)
    #-----------------load segmentation model -------------------------
    seg_model = SEG_NET(num_classes=cfg.num_classes)
    if seg_weights:
        seg_model.load_state_dict(torch.load(seg_weights))
    seg_model.to(device)

    dec_model.eval()
    seg_model.eval()

    detector = Detect(num_classes=cfg.num_classes,
                      top_k=cfg.top_k,
                      conf_thresh=cfg.conf_thresh,
                      nms_thresh=cfg.nms_thresh,
                      variance=cfg.variances)

    anchorGen = Anchors(cfg.img_sizes)
    anchors = anchorGen.forward()

    # print(dec_model)
    # print(seg_model)

    for idx_img in range(len(dsets)):
        inputs, gt_boxes, gt_classes, gt_masks = dsets.__getitem__(idx_img)
        ori_img = dsets.load_image(idx_img)
        img_copy = ori_img.copy()
        h, w, c = ori_img.shape

        x = inputs.unsqueeze(0)
        x = x.to(device)
        with torch.no_grad():
            locs, conf, feat_seg = dec_model(x)
            detections = detector(locs, conf, anchors)
            outputs = seg_model(detections, feat_seg)

        mask_patches, mask_dets = outputs

        for idx in range(len(mask_patches)):
            batch_mask_patches = mask_patches[idx]
            batch_mask_dets = mask_dets[idx]
            # For obj
            for idx_obj in range(len(batch_mask_patches)):
                # ori_img = img_copy
                dets = batch_mask_dets[idx_obj].data.cpu().numpy()
                box = dets[0:4]
                conf = dets[4]
                if conf < cfg.conf_thresh:
                    continue
                class_obj = dets[5]

                mask_patch = batch_mask_patches[idx_obj].data.cpu().numpy()

                [y1, x1, y2, x2] = box
                y1 = np.maximum(0, np.int32(np.round(y1)))
                x1 = np.maximum(0, np.int32(np.round(x1)))
                y2 = np.minimum(np.int32(np.round(y2)), h - 1)
                x2 = np.minimum(np.int32(np.round(x2)), w - 1)

                mask = np.zeros((h, w), dtype=np.float32)
                mask_patch = cv2.resize(mask_patch, (x2 - x1, y2 - y1))

                mask_patch = np.where(mask_patch >= cfg.seg_thresh, 1, 0)

                mask[y1:y2, x1:x2] = mask_patch
                color = np.random.rand(3)
                mask = np.repeat(mask[:, :, np.newaxis], 3, axis=2)
                mskd = ori_img * mask

                clmsk = np.ones(mask.shape) * mask
                clmsk[:, :, 0] = clmsk[:, :, 0] * color[0] * 256
                clmsk[:, :, 1] = clmsk[:, :, 1] * color[1] * 256
                clmsk[:, :, 2] = clmsk[:, :, 2] * color[2] * 256
                ori_img = ori_img + 0.7 * clmsk - 0.7 * mskd
                cv2.rectangle(ori_img, (x1, y1), (x2, y2), (255, 0, 0), 2, 1)
                cv2.putText(ori_img,
                            dsets.classes[int(class_obj)] + "%.2f" % conf,
                            (x1, y1 + 20), cv2.FONT_HERSHEY_SIMPLEX, 0.6,
                            (255, 255, 255))
        cv2.imwrite("{}.jpg".format(dsets.img_ids[idx_img][1]),
                    np.uint8(ori_img))
        cv2.imshow('img', np.uint8(ori_img))
        k = cv2.waitKey(0)
        if k & 0xFF == ord('q'):
            cv2.destroyAllWindows()
            exit(1)

    cv2.destroyAllWindows()
    exit(1)
Exemple #3
0
def test(resume=None):
    data_transforms = dec_transform.Compose([dec_transform.ConvertImgFloat(),
                                             dec_transform.ToTensor()])

    datatype = 'test'

    dsets = CellDataset(root=cfg.root,
                        datatype=datatype,
                        transform=data_transforms)


    model = dec_net.resnetssd50(pretrained=True, num_classes=cfg.num_classes)
    if resume is not None:
        resume_dict = torch.load(resume)
        resume_dict = {k[7:]: v for k, v in resume_dict.items()}
        model.load_state_dict(resume_dict)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()


    detector = Detect(num_classes=cfg.num_classes,
                      top_k=cfg.top_k,
                      conf_thresh=cfg.conf_thresh,
                      nms_thresh=cfg.nms_thresh,
                      variance=cfg.variances)

    num_imgs = len(dsets)
    anchorGen = Anchors(cfg.img_sizes)
    anchors = anchorGen.forward()

    cv2.namedWindow('img')
    for i in range(num_imgs):
        img, bboxes, labels = dsets.__getitem__(i)
        ori_img = dsets.load_image(i)
        x = img.unsqueeze(0)
        x = x.to(device)

        bboxes, conf = model(x)

        detections = detector(bboxes, conf, anchors)
        for j in range(1, detections.size(1)):
            dets = detections[0, j, :]
            mask = dets[:,0].gt(0.).expand(5,dets.size(0)).t()
            dets = torch.masked_select(dets, mask).view(-1,5)
            if dets.dim()==0:
                continue
            if j:
                boxes = dets[:,1:]
                scores = dets[:,0].cpu().numpy()
                for box, score in zip(boxes,scores):
                    y1,x1,y2,x2 = box
                    y1 = int(y1)
                    x1 = int(x1)
                    y2 = int(y2)
                    x2 = int(x2)

                    cv2.rectangle(ori_img, (x1, y1), (x2, y2), (0, 255, 0), 2, 2)
                    cv2.putText(ori_img,
                                dsets.classes[int(j)] + "%.2f" % score,
                                (x1, y1 + 20),
                                cv2.FONT_HERSHEY_SIMPLEX,
                                0.6,
                                (255, 255, 255))
        cv2.imshow('img', ori_img)
        k = cv2.waitKey(0)
        if k & 0xFF == ord('q'):
            cv2.destroyAllWindows()
            exit()
    cv2.destroyAllWindows()
    exit()
Exemple #4
0
    return ap05, ap07


if __name__ == '__main__':
    resume = "end_model.pth"
    data_transforms = dec_transform.Compose(
        [dec_transform.ConvertImgFloat(),
         dec_transform.ToTensor()])

    datatype = 'test'

    dsets = CellDataset(root=cfg.root,
                        datatype=datatype,
                        transform=data_transforms)

    model = dec_net.resnetssd50(pretrained=True, num_classes=cfg.num_classes)
    if resume is not None:
        resume_dict = torch.load(resume)
        resume_dict = {k[7:]: v for k, v in resume_dict.items()}
        model.load_state_dict(resume_dict)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    detector = Detect(num_classes=cfg.num_classes,
                      top_k=cfg.top_k,
                      conf_thresh=cfg.conf_thresh,
                      nms_thresh=cfg.nms_thresh,
                      variance=cfg.variances)
Exemple #5
0
def train(resume=None):

    if not os.path.exists('weights'):
        os.mkdir('weights')

    data_transforms = {
        'train':
        dec_transform.Compose([
            dec_transform.ConvertImgFloat(),
            dec_transform.PhotometricDistort(),
            dec_transform.Expand(max_scale=2, mean=(0.485, 0.456, 0.406)),
            dec_transform.RandomSampleCrop(),
            dec_transform.RandomMirror_w(),
            dec_transform.RandomMirror_h(),
            dec_transform.Resize(cfg.img_sizes),
            dec_transform.ToTensor()
        ]),
        'val':
        dec_transform.Compose(
            [dec_transform.ConvertImgFloat(),
             dec_transform.ToTensor()])
    }

    dsets = {
        x: CellDataset(root=cfg.root, datatype=x, transform=data_transforms[x])
        for x in ['train', 'val']
    }

    # view_dataset(dsets['train'])

    dataloader_train = torch.utils.data.DataLoader(
        dsets['train'],
        batch_size=cfg.batch_size,
        shuffle=True,
        num_workers=4,
        collate_fn=dec_collater.collater,
        pin_memory=True)

    model = dec_net.resnetssd50(pretrained=True, num_classes=cfg.num_classes)

    if resume is not None:
        resume_dict = torch.load(resume)
        resume_dict = {k[7:]: v for k, v in resume_dict.items()}
        model.load_state_dict(resume_dict)

    # data parallel
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)
    model = model.to(device)

    optimizer = optim.SGD(model.parameters(), lr=cfg.init_lr, momentum=0.9)
    scheduler = lr_scheduler.MultiStepLR(optimizer,
                                         milestones=cfg.milestones,
                                         gamma=0.1)

    criterion = LossModule(img_size=cfg.img_sizes,
                           num_classes=cfg.num_classes,
                           device=device)

    # for validation data
    detector = Detect(num_classes=cfg.num_classes,
                      top_k=cfg.top_k,
                      conf_thresh=cfg.conf_thresh,
                      nms_thresh=cfg.nms_thresh,
                      variance=cfg.variances)

    train_loss_dict = []
    ap05_dict = []
    ap07_dict = []
    for epoch in range(cfg.num_epochs):
        print('Epoch {}/{}'.format(epoch, cfg.num_epochs - 1))
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                model.train()
                running_loss = 0.0
                for inputs, bboxes, labels in dataloader_train:
                    inputs = inputs.to(device)
                    # labels = labels.to(device)
                    # zero the parameter gradients
                    optimizer.zero_grad()

                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(inputs)

                        loss_locs, loss_conf = criterion(
                            outputs, bboxes, labels)
                        loss = loss_locs + loss_conf
                        # backward + optimize only if in training phase
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()

                    # statistics
                    running_loss += loss.item() * inputs.size(0)

                epoch_loss = running_loss / len(dsets[phase])

                print('{} Loss: {:.4f}'.format(phase, epoch_loss))
                train_loss_dict.append(epoch_loss)
                np.savetxt('train_loss.txt', train_loss_dict, fmt='%.6f')
                if epoch % 5 == 0:
                    torch.save(
                        model.state_dict(),
                        os.path.join(
                            'weights',
                            '{:d}_{:.4f}_model.pth'.format(epoch, epoch_loss)))
                torch.save(model.state_dict(),
                           os.path.join('weights', 'end_model.pth'))

            else:
                model.eval()  # Set model to evaluate mode
                with torch.no_grad():
                    ap05, ap07 = eval(model=model,
                                      dsets=dsets[phase],
                                      device=device,
                                      detector=detector,
                                      datatype=phase)
                    ap05_dict.append(ap05)
                    np.savetxt('ap_05.txt', ap05_dict, fmt='%.6f')
                    ap07_dict.append(ap07)
                    np.savetxt('ap_07.txt', ap07_dict, fmt='%.6f')