def detection_evaluation(self, args, ov_thresh=0.5, use_07_metric=True):
        self.load_weights(resume=args.resume)
        self.model.eval()
        self.model = self.model.to(self.device)

        dsets = Kaggle(data_dir=args.data_dir, phase='test')

        all_tp = []
        all_fp = []
        all_scores = []
        npos = 0
        for index in range(len(dsets)):
            print('processing {}/{} images'.format(index, len(dsets)))
            img = dsets.load_image(index)
            height, width, c = img.shape

            bboxes = self.test_inference(args, img, bbox_flag=True)

            if bboxes is None:
                npos += len(dsets.load_annotation(index, type='bbox'))
                continue

            bboxes = np.asarray(bboxes, np.float32)

            bboxes[:, 0] = bboxes[:, 0] / args.input_h * height
            bboxes[:, 1] = bboxes[:, 1] / args.input_w * width
            bboxes[:, 2] = bboxes[:, 2] / args.input_h * height
            bboxes[:, 3] = bboxes[:, 3] / args.input_w * width

            fp, tp, all_scores, npos = evaluation.bbox_evaluation(
                index=index,
                dsets=dsets,
                BB_bboxes=bboxes,
                all_scores=all_scores,
                npos=npos,
                ov_thresh=ov_thresh)
            all_fp.extend(fp)
            all_tp.extend(tp)
        # step5: compute precision recall
        all_fp = np.asarray(all_fp)
        all_tp = np.asarray(all_tp)
        all_scores = np.asarray(all_scores)
        sorted_ind = np.argsort(-all_scores)
        all_fp = all_fp[sorted_ind]
        all_tp = all_tp[sorted_ind]
        all_fp = np.cumsum(all_fp)
        all_tp = np.cumsum(all_tp)
        rec = all_tp / float(npos)
        # avoid divide by zero in case the first detection matches a difficult
        # ground truth
        prec = all_tp / np.maximum(all_tp + all_fp, np.finfo(np.float64).eps)
        ap = evaluation.voc_ap(rec, prec, use_07_metric=use_07_metric)
        print("ap@{} is {}".format(ov_thresh, ap))
    def instance_segmentation_evaluation(self,
                                         args,
                                         ov_thresh=0.5,
                                         use_07_metric=True):
        self.load_weights(resume=args.resume)
        self.model.eval()
        self.model = self.model.to(self.device)

        dsets = Kaggle(data_dir=args.data_dir, phase='test')

        all_tp = []
        all_fp = []
        all_scores = []
        temp_overlaps = []
        npos = 0
        for index in range(len(dsets)):
            print('processing {}/{} images'.format(index, len(dsets)))
            img = dsets.load_image(index)
            predictions = self.test_inference(args, img)
            if predictions is None:
                npos += len(dsets.load_annotation(index, type='bbox'))
                continue
            pr_masks, pr_dets = predictions

            fp, tp, all_scores, npos, temp_overlaps = evaluation.seg_evaluation(
                index=index,
                dsets=dsets,
                BB_masks=pr_masks,
                BB_dets=pr_dets,
                all_scores=all_scores,
                npos=npos,
                temp_overlaps=temp_overlaps,
                ov_thresh=ov_thresh)

            all_fp.extend(fp)
            all_tp.extend(tp)
        # step5: compute precision recall
        all_fp = np.asarray(all_fp)
        all_tp = np.asarray(all_tp)
        all_scores = np.asarray(all_scores)
        sorted_ind = np.argsort(-all_scores)
        all_fp = all_fp[sorted_ind]
        all_tp = all_tp[sorted_ind]
        all_fp = np.cumsum(all_fp)
        all_tp = np.cumsum(all_tp)
        rec = all_tp / float(npos)
        # avoid divide by zero in case the first detection matches a difficult
        # ground truth
        prec = all_tp / np.maximum(all_tp + all_fp, np.finfo(np.float64).eps)
        ap = evaluation.voc_ap(rec, prec, use_07_metric=use_07_metric)
        print("ap@{} is {}".format(ov_thresh, ap))
        print("temp overlaps = {}".format(np.mean(temp_overlaps)))
    def test(self, args, save_flag=False):
        self.load_weights(resume=args.resume)
        self.model = self.model.to(self.device)
        self.model.eval()

        if not os.path.exists("save_result") and save_flag is True:
            os.mkdir("save_result")

        dsets = Kaggle(data_dir=args.data_dir, phase='test')

        for index in range(len(dsets)):
            img = dsets.load_image(index)
            predictions = self.test_inference(args, img)
            if predictions is None:
                continue
            mask_patches, mask_dets = predictions
            self.imshow_instance_segmentation(mask_patches,
                                              mask_dets,
                                              out_img=img.copy(),
                                              img_id=dsets.img_ids[index],
                                              save_flag=False)
Esempio n. 4
0
    def train(self, args):
        if not os.path.exists("weights"):
            os.mkdir("weights")

        self.model = self.model.to(self.device)

        self.model.train()

        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=args.lr)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.96, last_epoch=-1)

        loss_dec = DetectionLossAll(kp_radius=cfg.KP_RADIUS)
        loss_seg = seg_loss.SEG_loss(height=args.input_h, width=args.input_w)

        data_trans = {'train': transforms.Compose([transforms.ConvertImgFloat(),
                                                   transforms.PhotometricDistort(),
                                                   transforms.Expand(max_scale=2, mean=(0, 0, 0)),
                                                   transforms.RandomMirror_w(),
                                                   transforms.RandomMirror_h(),
                                                   transforms.Resize(h=args.input_h, w=args.input_w)]),

                      'val': transforms.Compose([transforms.ConvertImgFloat(),
                                                 transforms.Resize(h=args.input_h, w=args.input_w)])}

        dsets = {x: Kaggle(data_dir=args.data_dir,
                                   phase=x,
                                   transform=data_trans[x])
                 for x in ['train', 'val']}


        # for i in range(100):
        #     show_ground_truth.show_input(dsets.__getitem__(i))


        train_loader = torch.utils.data.DataLoader(dsets['train'],
                                                   batch_size=args.batch_size,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   shuffle=True,
                                                   collate_fn = collater)


        val_loader = torch.utils.data.DataLoader(dsets['val'],
                                                 batch_size=args.batch_size,
                                                 num_workers=args.workers,
                                                 pin_memory=True,
                                                 shuffle=False,
                                                 collate_fn = collater)


        train_loss_dict = []
        val_loss_dict = []
        for epoch in range(args.start_epoch, args.epochs):
            print('Epoch {}/{}'.format(epoch, args.epochs - 1))
            print('-' * 10)
            scheduler.step()

            train_epoch_loss = self.training(train_loader,loss_dec,loss_seg,optimizer,epoch, dsets['train'])
            train_loss_dict.append(train_epoch_loss)

            val_epoch_loss = self.validating(val_loader,loss_dec,loss_seg, epoch, dsets['val'])
            val_loss_dict.append(val_epoch_loss)

            np.savetxt('train_loss.txt', train_loss_dict, fmt='%.6f')
            np.savetxt('val_loss.txt', val_loss_dict, fmt='%.6f')

            if epoch % 5 == 0 and epoch >0:
                torch.save(self.model.state_dict(), os.path.join('weights', '{:d}_{:.4f}_model.pth'.format(epoch, train_epoch_loss)))
            torch.save(self.model.state_dict(), os.path.join('weights', args.resume))