def predict_on_image(model, args, data, label_file, img_info):
    # read csv label path
    label_info = get_label_info(args.csv_path)

    # pre-processing on image
    label = Image.open(label_file)
    label = np.array(label)
    label = one_hot_it_v11_dice(label, label_info).astype(np.uint8)
    label = np.transpose(label, [2, 0, 1]).astype(np.float32)
    label = label.squeeze()
    label = np.argmax(label, axis=0)

    image = cv2.imread(data, -1)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    resize = iaa.Scale({'height': args.crop_height, 'width': args.crop_width})
    resize_det = resize.to_deterministic()
    image = resize_det.augment_image(image)
    image = Image.fromarray(image).convert('RGB')
    image = transforms.ToTensor()(image)
    image = transforms.Normalize((0.485, 0.456, 0.406),
                                 (0.229, 0.224, 0.225))(image).unsqueeze(0)
    # predict
    model.eval()
    predict = model(image).squeeze()
    # 512 * 512
    predict = reverse_one_hot(predict)

    predict_ = colour_code_segmentation(np.array(predict), label_info)
    predict_ = cv2.resize(np.uint8(predict_), (512, 512))
    cv2.imwrite('res/pred_' + 'img_info' + '.png',
                cv2.cvtColor(np.uint8(predict_), cv2.COLOR_RGB2BGR))
    diff = plot_diff(np.array(predict), label)
    cv2.imwrite('res/diff_' + 'img_info' + '.png',
                cv2.cvtColor(np.uint8(diff), cv2.COLOR_RGB2BGR))
Example #2
0
    def __getitem__(self, index):
        #open image and label
        img = Image.open(self.images[index])
        label = Image.open(self.labels[index]).convert("RGB")

        #resize image and label, then crop them
        scale = random.choice(self.scale)
        scale = (int(self.shape[0] * scale), int(self.shape[1] * scale))

        seed = random.random()
        img = transforms.Resize(scale, Image.BILINEAR)(img)
        img = RandomCrop(self.shape, seed, pad_if_needed=True)(img)
        img = np.array(img)

        label = transforms.Resize(scale, Image.NEAREST)(label)
        label = RandomCrop(self.shape, seed, pad_if_needed=True)(label)
        label = np.array(label)

        #translete to CamVid color palette
        label = self.__toCamVid(label)

        #apply augmentation
        img, label = augmentation(img, label)
        if random.randint(0, 1) == 1:
            img = augmentation_pixel(img)

        img = Image.fromarray(img)
        img = self.to_tensor(img).float()

        #computing losses
        if self.loss == 'dice':
            # label -> [num_classes, H, W]
            label = one_hot_it_v11_dice(label,
                                        self.label_info).astype(np.uint8)

            label = np.transpose(label, [2, 0, 1]).astype(np.float32)
            label = torch.from_numpy(label)

            return img, label

        elif self.loss == 'crossentropy':
            label = one_hot_it_v11(label, self.label_info).astype(np.uint8)
            label = torch.from_numpy(label).long()

            return img, label
Example #3
0
    def __getitem__(self, index):
        # load image and crop
        seed = random.random()
        img = Image.open(self.image_list[index])
        # random crop image
        # =====================================
        # w,h = img.size
        # th, tw = self.scale
        # i = random.randint(0, h - th)
        # j = random.randint(0, w - tw)
        # img = F.crop(img, i, j, th, tw)
        # =====================================

        scale = random.choice(self.scale)
        scale = (int(self.image_size[0] * scale),
                 int(self.image_size[1] * scale))

        # randomly resize image and random crop
        # =====================================
        if self.mode == 'train' or self.mode == 'adversarial_train':
            img = transforms.Resize(scale, Image.BILINEAR)(img)
            img = RandomCrop(self.image_size, seed, pad_if_needed=True)(img)
        # =====================================

        img = np.array(img)
        if self.mode != 'adversarial_train':
            label = Image.open(self.label_list[index])
        else:
            imarray = np.zeros(shape=(2, 2, 4)) * 255
            label = Image.fromarray(imarray.astype('uint8')).convert('RGBA')

        # crop the corresponding label
        # =====================================
        # label = F.crop(label, i, j, th, tw)
        # =====================================

        # randomly resize label and random crop
        # =====================================
        if self.mode == 'train' or self.mode == 'adversarial_train':
            label = transforms.Resize(scale, Image.NEAREST)(label)
            label = RandomCrop(self.image_size, seed,
                               pad_if_needed=True)(label)

        label = np.array(label)

        # augment image and label
        if self.mode == 'train' or self.mode == 'adversarial_train':
            if random.random() < 0.5:
                img, label = augmentation(img, label)

        # augment pixel image
        if self.mode == 'train' or self.mode == 'adversarial_train':
            # set a probability of 0.5
            if random.random() < 0.5:
                img = augmentation_pixel(img)

        # image -> [C, H, W]

        img = Image.fromarray(img)
        img = self.to_tensor(img).float()

        if self.loss == 'dice':
            # label -> [num_classes, H, W]
            if self.mode != 'adversarial_train':
                label = one_hot_it_v11_dice(label,
                                            self.label_info).astype(np.uint8)

            label = np.transpose(label, [2, 0, 1]).astype(np.float32)
            label = torch.from_numpy(label)

            return img, label

        elif self.loss == 'crossentropy':
            label = one_hot_it_v11(label, self.label_info).astype(np.uint8)
            # label = label.astype(np.float32)
            label = torch.from_numpy(label).long()

            return img, label
Example #4
0
    def __getitem__(self, index):
        # load image and crop
        seed = random.random()
        img = Image.open(self.image_list[index])
        # random crop image
        # =====================================
        # w,h = img.size
        # th, tw = self.scale
        # i = random.randint(0, h - th)
        # j = random.randint(0, w - tw)
        # img = F.crop(img, i, j, th, tw)
        # =====================================

        scale = random.choice(self.scale)
        scale = (int(self.image_size[0] * scale), int(self.image_size[1] * scale))

        # randomly resize image and random crop
        # =====================================
        if self.mode == 'train':
            img = transforms.Resize(scale, Image.BILINEAR)(img)
            img = RandomCrop(self.image_size, seed, pad_if_needed=True)(img)
        # =====================================

        img = np.array(img)
        # load label
        label = Image.open(self.label_list[index])


        # crop the corresponding label
        # =====================================
        # label = F.crop(label, i, j, th, tw)
        # =====================================

        # randomly resize label and random crop
        # =====================================
        if self.mode == 'train':
            label = transforms.Resize(scale, Image.NEAREST)(label)
            label = RandomCrop(self.image_size, seed, pad_if_needed=True)(label)
        # =====================================

        label = np.array(label)


        # augment image and label
        if self.mode == 'train':
            seq_det = self.fliplr.to_deterministic()
            img = seq_det.augment_image(img)
            label = seq_det.augment_image(label)


        # image -> [C, H, W]
        img = Image.fromarray(img)
        img = self.to_tensor(img).float()

        if self.loss == 'dice':
            # label -> [num_classes, H, W]
            label = one_hot_it_v11_dice(label, self.label_info).astype(np.uint8)

            label = np.transpose(label, [2, 0, 1]).astype(np.float32)
            # label = label.astype(np.float32)
            label = torch.from_numpy(label)

            return img, label

        elif self.loss == 'crossentropy':
            label = one_hot_it_v11(label, self.label_info).astype(np.uint8)
            # label = label.astype(np.float32)
            label = torch.from_numpy(label).long()

            return img, label