Пример #1
0
    def train_data_transform(self, index):
        img = self.read_image(self.images_dir, index)
        labels = self.all_labels[index]

        rotate_rand = random.random() if config.use_rotate else 0
        crop_rand = random.random() if config.use_crop else 0
        # rotate
        if rotate_rand > 0.5:
            labels, img, angle = ImgTransform.RotateImageWithLabel(labels,
                                                                   data=img)
        # crop
        if crop_rand > 0.5:
            scale = 0.1 + random.random() * 0.9
            labels, img, img_range = ImgTransform.CropImageWithLabel(
                labels, data=img, scale=scale)
            labels = PixelLinkIC15Dataset.filter_labels(labels, method="rai")
        # resize
        labels, img, size = ImgTransform.ResizeImageWithLabel(labels,
                                                              (512, 512),
                                                              data=img)
        # filter unsatifactory labels
        # labels = PixelLinkIC15Dataset.filter_labels(labels, method="msi")
        # zero mean
        img = ImgTransform.ZeroMeanImage(img, config.r_mean, config.g_mean,
                                         config.b_mean)
        # HWC to CHW
        img = img.transpose(2, 0, 1)
        return img, labels
Пример #2
0
 def test_data_transform(self, index):
     img = self.read_image(self.images_dir, index)
     labels = self.all_labels[index]
     labels, img, size = ImgTransform.ResizeImageWithLabel(labels,
                                                           (512, 512),
                                                           data=img)
     img = ImgTransform.ZeroMeanImage(img, config.r_mean, config.g_mean,
                                      config.b_mean)
     img = img.transpose(2, 0, 1)
     return img, labels
    def train_data_transform(self, index):
        img = self.read_image(self.images_dir, index)
        labels = self.all_labels[index]

        rotate_rand = random.random() if self.use_rotate else 0
        crop_rand = random.random() if self.use_crop else 0
        # rotate
        if rotate_rand > 0.5:
            labels, img, angle = ImgTransform.RotateImageWithLabel(labels,
                                                                   data=img)
        # crop
        if crop_rand > 0.5:
            labels, img, img_range = ImgTransform.CropImageWithLabel(labels,
                                                                     data=img)
            labels = PixelLinkIC15Dataset.filter_labels(labels, method="msi")
            labels = PixelLinkIC15Dataset.filter_labels(labels, method="rai")

        # resize
        labels, img, size = ImgTransform.ResizeImageWithLabel(
            labels, (self.image_size_train[1], self.image_size_train[0]),
            data=img)
        # filter unsatifactory labels
        # labels = PixelLinkIC15Dataset.filter_labels(labels, method="msi")

        img_np = None
        if DEBUG:
            img_np = img.copy()
            img_vis = img.copy()
            for pts, is_ignore in zip(labels['coor'], labels['ignore']):
                pts = np.array(pts, np.int32)
                pts = pts.reshape((-1, 1, 2))
                color = (255, 255, 0) if is_ignore else (255, 0, 0)
                cv2.polylines(img_vis, [pts], True, color)

            img_vis = img_vis.transpose(2, 0, 1)
            vis.image(img_vis)

        # zero mean
        img = ImgTransform.ZeroMeanImage(img, self.mean[0], self.mean[1],
                                         self.mean[2])
        # HWC to CHW
        img = img.transpose(2, 0, 1)
        return img, labels, img_np
Пример #4
0
def test_on_train_dataset(vis_per_img=10):
    dataset = datasets.PixelLinkIC15Dataset(config.train_images_dir,
                                            config.train_labels_dir,
                                            train=False)
    # dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=False)
    my_net = net.Net()
    if config.gpu:
        device = torch.device("cuda:0")
        my_net = my_net.cuda()
        if config.multi_gpu:
            my_net = nn.DataParallel(my_net)
    else:
        device = torch.device("cpu")
    my_net.load_state_dict(
        torch.load(config.saving_model_dir +
                   '%d.mdl' % config.test_model_index))
    true_pos, true_neg, false_pos, false_neg = [0] * 4
    for i in range(len(dataset)):
        sample = dataset[i]
        image = sample['image'].to(device)
        image = image.unsqueeze(0)
        my_labels = cal_label_on_batch(my_net, image)[0]
        # print("my labels num: %d" % len(my_labels))
        res = comp_gt_and_output(my_labels, sample["label"], 0.5)
        if i % vis_per_img == 0:
            image = image.squeeze(0).cpu().numpy()
            image = ImgFormat.ImgOrderFormat(image,
                                             from_order="CHW",
                                             to_order="HWC")
            image = ImgTransform.UnzeroMeanImage(image, config.r_mean,
                                                 config.g_mean, config.b_mean)
            image = ImgFormat.ImgColorFormat(image,
                                             from_color="RGB",
                                             to_color="BGR")
            image = visualize_label(image, my_labels, color=(0, 255, 0))
            image = visualize_label(image,
                                    sample["label"]["coor"],
                                    color=(255, 0, 0))
            cv2.imwrite("test_output/img_%d.jpg" % i, image)
        true_pos += res[0]
        false_pos += res[1]
        false_neg += res[2]
        if (true_pos + false_pos) > 0:
            precision = true_pos / (true_pos + false_pos)
        else:
            precision = 0
        if (true_pos + false_neg) > 0:
            recall = true_pos / (true_pos + false_neg)
        else:
            recall = 0
        print("i: %d, TP: %d, FP: %d, FN: %d, precision: %f, recall: %f" %
              (i, true_pos, false_pos, false_neg, precision, recall))
    def test_data_transform(self, index):
        img = self.read_image(self.images_dir, index)
        labels = self.all_labels[index]
        labels, img, size = ImgTransform.ResizeImageWithLabel(
            labels, (self.image_size_test[1], self.image_size_test[0]),
            data=img)

        img_np = None
        if DEBUG:
            img_np = img.copy()
            img_vis = img.copy()
            for pts, is_ignore in zip(labels['coor'], labels['ignore']):
                pts = np.array(pts, np.int32)
                pts = pts.reshape((-1, 1, 2))
                color = (255, 255, 0) if is_ignore else (255, 0, 0)
                cv2.polylines(img_vis, [pts], True, color)

            img_vis = img_vis.transpose(2, 0, 1)
            vis.image(img_vis)

        img = ImgTransform.ZeroMeanImage(img, self.mean[0], self.mean[1],
                                         self.mean[2])
        img = img.transpose(2, 0, 1)
        return img, labels, img_np
 def read_image(self, dir, index):
     index += 1
     filename = os.path.join(dir, "img_" + str(index) + ".jpg")
     image = ImgTransform.ReadImage(filename)
     return image
def test(my_net,
         dataset,
         epoch,
         exp_dir,
         results_dir,
         test_file,
         gpu=True,
         multi_gpu=False,
         vis_per_img=10,
         weights_preloaded=False):
    # dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=False)
    if gpu:
        device = torch.device("cuda:0")
    else:
        device = torch.device("cpu")

    if not weights_preloaded:
        if gpu:
            my_net = my_net.cuda()
            if multi_gpu:
                my_net = nn.DataParallel(my_net)
        checkpoint = torch.load(
            os.path.join(exp_dir, 'snapshots', 'epoch_%08d.mdl' % epoch))
        my_net.load_state_dict(checkpoint['state_dict'])

    my_net.eval()

    if not os.path.exists(results_dir):
        os.makedirs(results_dir)

    true_pos, true_neg, false_pos, false_neg = [0] * 4
    for i in range(len(dataset)):
        sample = dataset[i]
        image = sample['image'].to(device)
        image = image.unsqueeze(0)
        my_labels = cal_label_on_batch(my_net, image)[0]
        # print("my labels num: %d" % len(my_labels))
        res = comp_gt_and_output(my_labels, sample["label"], 0.5)
        if i % vis_per_img == 0:
            image = image.squeeze(0).cpu().numpy()
            image = ImgFormat.ImgOrderFormat(image,
                                             from_order="CHW",
                                             to_order="HWC")
            image = ImgTransform.UnzeroMeanImage(image, dataset.mean[0],
                                                 dataset.mean[1],
                                                 dataset.mean[2])
            image = ImgFormat.ImgColorFormat(image,
                                             from_color="RGB",
                                             to_color="BGR")
            # color : gt = red, ignore = yellow, detection = blue
            image = visualize_label(image,
                                    sample["label"]["coor"],
                                    color=(0, 0, 255),
                                    ignore=sample["label"]["ignore"])
            image = visualize_label(image,
                                    my_labels,
                                    color=(255, 0, 0),
                                    thickness=2)
            cv2.imwrite("%s/img_%d.jpg" % (results_dir, i), image)
        true_pos += res[0]
        false_pos += res[1]
        false_neg += res[2]
        if (true_pos + false_pos) > 0:
            precision = true_pos / (true_pos + false_pos)
        else:
            precision = 0
        if (true_pos + false_neg) > 0:
            recall = true_pos / (true_pos + false_neg)
        else:
            recall = 0
        F = 0
        if (precision + recall) > 0:
            F = 2 * precision * recall / (precision + recall)
        print(
            "i: %d, TP: %d, FP: %d, FN: %d, precision: %f, recall: %f, F=%f" %
            (i, true_pos, false_pos, false_neg, precision, recall, F))

    perf_str2 = "%d, %d,%d,%d,%f,%f,%f" % (epoch, true_pos, false_pos,
                                           false_neg, precision, recall, F)
    if not os.path.exists(test_file):
        os.system('echo "epoch,TP,FP,FN,precision,recall,F1" > %s' % test_file)
    os.system('echo "%s" >> %s' % (perf_str2, test_file))