def main(args):
    torch.manual_seed(1)

    # crop input image and ground truth and save on disk
    cropped_input_images_path = os.path.join(args.save_cropped, 'input_images')
    cropped_gt_images_path = os.path.join(args.save_cropped, 'gt_images')

    if args.crop_images:
        crop_and_save(args, cropped_input_images_path, cropped_gt_images_path)

    seg_dataset = SegmentationData(cropped_input_images_path, cropped_gt_images_path, args.n_classes, args.phase)
    train_loader = DataLoader(seg_dataset, shuffle=True, num_workers=4, batch_size=args.batch_size)

    model = FCN(args.n_classes)
    use_gpu = torch.cuda.is_available()
    num_gpu = list(range(torch.cuda.device_count()))
    if use_gpu :
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=num_gpu)

    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    criterion = nn.BCEWithLogitsLoss()
    losses = []
    for epoch in range(args.n_epoch):
        for i, (image, segement_im) in enumerate(train_loader):
            image = image.float()
            images = Variable(image.cuda())
            labels = Variable(segement_im.cuda())

            optimizer.zero_grad()

            outputs = model(images)

            loss = criterion(outputs, labels.float())
            loss.backward()
            optimizer.step()

            # add loss to a list for plotting it later
            if i == 0:
                losses.append(loss)
            print("epoch{} iteration {} loss: {}".format(epoch, i, loss.data.item()))

            if epoch%5 == 0:
                pred = outputs.data.max(1)[1].cpu().numpy()[0]

                decoded = decode_segmap(pred)
                decoded = Image.fromarray(decoded)

                path = os.path.join(args.output_path, 'output_%d_%d.png' % (epoch, i))

                decoded.save(path)

    # plot loss
    plot(losses, args)

    # save model
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)
    model_name = os.path.join(args.model_path, 'fcn.pt')
    torch.save(model, model_name)
Beispiel #2
0
def processImage(infile, args):
    n_classes = 12
    model = LinkNet(n_classes)
    model.load_state_dict(torch.load(args.model_path))
    if torch.cuda.is_available():
        model = model.cuda(0)
    model.eval()
    gif = cv2.VideoCapture(infile)
    cv2.namedWindow('camvid')
    while (gif.isOpened()):
        ret, frame = gif.read()
        frame = cv2.resize(frame, (768, 576))
        images = get_tensor(frame)
        if torch.cuda.is_available():
            images = Variable(images.cuda(0))
        else:
            images = Variable(images)
        outputs = model(images)
        pred = outputs.data.max(1)[1].cpu().numpy().reshape(576, 768)
        pred = decode_segmap(pred)
        vis = np.zeros((576, 1536, 3), np.uint8)
        vis[:576, :768, :3] = frame
        vis[:576, 768:1536, :3] = pred
        cv2.imshow('camvid', vis)
        cv2.waitKey(10)
    def save_mask(self):
        DATA_DIR = '/home/apex/chendixi/Experiment/data/CityScapes'
        dataset = CityscapesDataset(DATA_DIR, split='train')
        dataloader = DataLoader(dataset, batch_size=2, shuffle=False)
        it = iter(dataloader)
        images, masks = it.next()
        #masks 并没有被除以255

        mask = masks[1].numpy()

        tmp = np.array(mask).astype(np.uint8)
        segmap = decode_segmap(tmp, dataset='cityscapes')
        segmap = np.array(segmap * 255).astype(np.uint8)
        im = Image.fromarray(segmap)
        im.save("label.png")
def main(args):
    # The input is URL
    if args.in_path.startswith("http"):
        dl_request = requests.get(args.in_path, stream=True)
        dl_request.raise_for_status()
        input_image = dl_request.content
    # Input image
    elif os.path.exists(args.in_path):
        input_image = open(args.in_path, "rb").read()
    else:
        print("File doesn't exists")
        raise Exception("No such file or url")

    # Compose a JSON Predict request (send JPEG image in base64).
    # According to https://www.tensorflow.org/tfx/serving/api_rest
    # JSON uses UTF-8 encoding. If you have input feature or tensor values that need to be binary
    # (like image bytes), you must Base64 encode the data and encapsulate it in a JSON object
    # having b64 as the key as follows:
    # { "b64": <base64 encoded string> }
    jpeg_bytes = base64.b64encode(input_image).decode("utf-8")
    predict_request = '{"instances" : [{"b64": "%s"}]}' % jpeg_bytes

    # # Send few requests to warm-up the model.
    # for _ in range(3):
    #     response = requests.post(SERVER_URL, data=predict_request)
    #     response.raise_for_status()
    SERVER_URL = "http://{}:{}/v1/models/{}:predict".format(
        args.host, args.port, args.model)
    response = requests.post(SERVER_URL, data=predict_request)
    response.raise_for_status()

    # Extract text from JSON
    response = json.loads(response.text)

    # Interpret bitstring output
    response_string = response["predictions"][0]["b64"]

    # Decode bitstring
    encoded_response_string = response_string.encode("utf-8")
    image_bitstring = base64.b64decode(encoded_response_string)

    img_np = np.frombuffer(image_bitstring, dtype=np.uint8)
    img_np = cv.imdecode(img_np, flags=1)[:, :, 0]
    img_bgr = utils.decode_segmap(img_np)
    cv.imwrite(args.out_path, img_bgr)
    def save_img8bit_mask(self):
        DATA_DIR = '/home/apex/chendixi/Experiment/data/CityScapes'
        dataset = CityscapesDataset(DATA_DIR, split='train')
        dataloader = DataLoader(dataset, batch_size=2, shuffle=False)
        it = iter(dataloader)
        images, masks = it.next()
        #masks 并没有被除以255
        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)
        for i in range(images.size()[0]):
            mean = torch.tensor(mean, dtype=torch.float32)
            std = torch.tensor(std, dtype=torch.float32)
            image = images[i].mul(std[:, None, None]).add(mean[:, None, None])
            vutils.save_image(image, "image" + str(i) + ".png")
            mask = masks[i].numpy()

            tmp = np.array(mask).astype(np.uint8)
            segmap = decode_segmap(tmp, dataset='cityscapes')
            segmap = torch.from_numpy(segmap).float().permute(2, 0, 1)
            vutils.save_image(segmap, "label" + str(i) + ".png")
def main(args):
    # crop input image and save on disk
    cropped_input_images_path = os.path.join(args.save_cropped,
                                             'input_images_test')
    crop_and_save(args, cropped_input_images_path)

    seg_dataset = SegmentationData(cropped_input_images_path, phase=args.phase)
    test_loader = DataLoader(seg_dataset, shuffle=False)

    # load model
    model = torch.load(args.model_path)

    # create temp folder for saving prediction for each cropped input images
    temp_name = 'temp'
    if not os.path.exists(os.path.join(args.output_path, temp_name)):
        os.makedirs(os.path.join(args.output_path, temp_name))

    for i, (image, im_name) in enumerate(test_loader):
        image = image.float()

        image = Variable(image.cuda())

        # predict
        output = model(image)
        pred = np.squeeze(output.data.max(1)[1].cpu().numpy(), axis=0)
        decoded_im = decode_segmap(pred)

        # save image
        output_name = im_name[0].split('/')[-1]

        output_name = os.path.join(args.output_path, temp_name, output_name)
        decoded_im = Image.fromarray(decoded_im)
        decoded_im.save(output_name)

    stitch_predicted(args)
    shutil.rmtree(os.path.join(args.output_path, 'temp'))
Beispiel #7
0
    composed_transforms_tr = transforms.Compose([tr.RandomHorizontalFlip(),
                                                 tr.RandomCrop(512),
                                                 tr.RandomRotate(15),
                                                 tr.ToTensor()]
                                                )

    isprs_train = ISPRSSegmentation(split='train', transform=composed_transforms_tr)

    dataloader = DataLoader(isprs_train, batch_size=5, shuffle=True, num_workers=2)

    for ii, sample in enumerate(dataloader):
        for jj in range(sample["image"].size()[0]):
            img = sample['image'].numpy()
            gt = sample['label'].numpy()
            tmp = np.array(gt[jj]).astype(np.uint8)
            tmp = np.squeeze(tmp, axis=0)
            segmap = decode_segmap(tmp, dataset='ISPRS')
            img_tmp = np.transpose(img[jj], axes=[1, 2, 0]).astype(np.uint8)
            plt.figure()
            plt.title('display')
            plt.subplot(211)
            plt.imshow(img_tmp)
            plt.subplot(212)
            plt.imshow(segmap)

        if ii == 1:
            break
    plt.show(block=True)


import utils

if __name__ == '__main__':
    MovingObjSavePath = '/tmp/trainannot_moving/'
    if os.path.isdir(MovingObjSavePath):
        pass
    else:
        exit(0)
    CamVidRootTrainAnnotList = os.listdir(MovingObjSavePath)
    for CamVidRootTrainAnnotItem in CamVidRootTrainAnnotList:
        print(CamVidRootTrainAnnotItem)
        CamVidRootTrainAnnotItemFullPath = os.path.join(MovingObjSavePath, CamVidRootTrainAnnotItem)
        img = misc.imread(CamVidRootTrainAnnotItemFullPath)
        height, width = img.shape
        img = utils.decode_segmap(img)
        img_lbl, regions = selectivesearch.selective_search(img, scale=500, sigma=0.9, min_size=10)

        candidates = set()
        for r in regions:
            if r['rect'] in candidates:
                pass
                continue
            if r['size'] < 2000:
                pass
                continue
            x, y, w, h = r['rect']
            if h == 0 or w == 0:
                continue
            elif w / h > 1.2 or h / w > 1.2:
                pass
    # voc_train = VOCSegmentation(split='train',
    #                             transform=composed_transforms_tr)

    voc_train = VOCSegmentation(split='train',
                                transform=composed_transforms_tr)

    dataloader = DataLoader(voc_train,
                            batch_size=5,
                            shuffle=True,
                            num_workers=2)

    for ii, sample in enumerate(dataloader):
        for jj in range(sample["image"].size()[0]):
            img = sample['image'].numpy()
            gt = sample['label'].numpy()
            tmp = np.array(gt[jj]).astype(np.uint8)
            tmp = np.squeeze(tmp, axis=0)
            segmap = decode_segmap(tmp, dataset='pascal')
            img_tmp = np.transpose(img[jj], axes=[1, 2, 0]).astype(np.uint8)
            plt.figure()
            plt.title('display')
            plt.subplot(211)
            plt.imshow(img_tmp)
            plt.subplot(212)
            plt.imshow(segmap)

        if ii == 1:
            break
    plt.show(block=True)
Beispiel #10
0
    def sample_syn(self, x_a, x_b, m_a, m_b):
        """ 
        Infer the model on a batch of image
        
        Arguments:
            x_a {torch.Tensor} -- batch of image from domain A
            x_b {[type]} -- batch of image from domain B
        
        Returns:
            A list of torch images -- columnwise :x_a, autoencode(x_a), x_ab_1, x_ab_2
            Or if self.semantic_w is true: x_a, autoencode(x_a), Semantic segmentation x_a, 
            x_ab_1,semantic segmentation x_ab_1, x_ab_2
        """
        self.eval()

        x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
        x_a_augment = torch.cat([x_a, m_a], dim=1)
        x_b_augment = torch.cat([x_b, m_b], dim=1)

        for i in range(x_a.size(0)):
            c_a = self.gen.encode(x_a[i].unsqueeze(0), 1)
            c_b = self.gen.encode(x_b[i].unsqueeze(0), 2)
            x_a_recon.append(self.gen.decode(c_a, m_a[i].unsqueeze(0), 1))
            x_b_recon.append(self.gen.decode(c_b, m_b[i].unsqueeze(0), 2))

            x_ba1.append(self.gen.decode(c_b, m_b[i].unsqueeze(0), 1))
            x_ba2.append(self.gen.decode(c_b, m_b[i].unsqueeze(0), 1))
            x_ab1.append(self.gen.decode(c_a, m_a[i].unsqueeze(0), 2))
            x_ab2.append(self.gen.decode(c_a, m_a[i].unsqueeze(0), 2))

        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
        x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)

        if self.semantic_w:
            rgb_a_list, rgb_b_list, rgb_ab_list, rgb_ba_list = [], [], [], []

            for i in range(x_a.size(0)):

                # Inference semantic segmentation on original images
                im_a = (x_a[i].squeeze() + 1) / 2.0
                im_b = (x_b[i].squeeze() + 1) / 2.0

                input_transformed_a = seg_transform()(im_a).unsqueeze(0)
                input_transformed_b = seg_transform()(im_b).unsqueeze(0)
                output_a = self.segmentation_model(
                    input_transformed_a).squeeze().max(0)[1]
                output_b = self.segmentation_model(
                    input_transformed_b).squeeze().max(0)[1]

                rgb_a = decode_segmap(output_a.cpu().numpy())
                rgb_b = decode_segmap(output_b.cpu().numpy())
                rgb_a = Image.fromarray(rgb_a).resize(
                    (x_a.size(3), x_a.size(3)))
                rgb_b = Image.fromarray(rgb_b).resize(
                    (x_a.size(3), x_a.size(3)))

                rgb_a_list.append(transforms.ToTensor()(rgb_a).unsqueeze(0))
                rgb_b_list.append(transforms.ToTensor()(rgb_b).unsqueeze(0))

                # Inference semantic segmentation on fake images
                image_ab = (x_ab1[i].squeeze() + 1) / 2.0
                image_ba = (x_ba1[i].squeeze() + 1) / 2.0

                input_transformed_ab = seg_transform()(image_ab).unsqueeze(
                    0).to("cuda")
                input_transformed_ba = seg_transform()(image_ba).unsqueeze(
                    0).to("cuda")

                output_ab = self.segmentation_model(
                    input_transformed_ab).squeeze().max(0)[1]
                output_ba = self.segmentation_model(
                    input_transformed_ba).squeeze().max(0)[1]

                rgb_ab = decode_segmap(output_ab.cpu().numpy())
                rgb_ba = decode_segmap(output_ba.cpu().numpy())

                rgb_ab = Image.fromarray(rgb_ab).resize(
                    (x_a.size(3), x_a.size(3)))
                rgb_ba = Image.fromarray(rgb_ba).resize(
                    (x_a.size(3), x_a.size(3)))

                rgb_ab_list.append(transforms.ToTensor()(rgb_ab).unsqueeze(0))
                rgb_ba_list.append(transforms.ToTensor()(rgb_ba).unsqueeze(0))

            rgb1_a, rgb1_b, rgb1_ab, rgb1_ba = (
                torch.cat(rgb_a_list).cuda(),
                torch.cat(rgb_b_list).cuda(),
                torch.cat(rgb_ab_list).cuda(),
                torch.cat(rgb_ba_list).cuda(),
            )

        self.train()
        if self.semantic_w:
            self.segmentation_model.eval()
            return (
                x_a,
                x_a_recon,
                rgb1_a,
                x_ab1,
                rgb1_ab,
                x_ab2,
                x_b,
                x_b_recon,
                rgb1_b,
                x_ba1,
                rgb1_ba,
                x_ba2,
            )
        else:
            return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2
Beispiel #11
0
    def sample(self, x_a, x_b):
        """ 
        Infer the model on a batch of image
        
        Arguments:
            x_a {torch.Tensor} -- batch of image from domain A
            x_b {[type]} -- batch of image from domain B
        
        Returns:
            A list of torch images -- columnwise :x_a, autoencode(x_a), x_ab_1, x_ab_2
            Or if self.semantic_w is true: x_a, autoencode(x_a), Semantic segmentation x_a, 
            x_ab_1,semantic segmentation x_ab_1, x_ab_2
        """
        self.eval()
        s_a1 = Variable(self.s_a)
        s_b1 = Variable(self.s_b)
        s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []

        if self.gen_state == 0:
            for i in range(x_a.size(0)):
                c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0))
                c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))
                x_a_recon.append(self.gen_a.decode(c_a, s_a_fake))
                x_b_recon.append(self.gen_b.decode(c_b, s_b_fake))
                if self.guided == 0:
                    x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0)))
                    x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0)))
                    x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0)))
                    x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0)))
                elif self.guided == 1:
                    x_ba1.append(self.gen_a.decode(
                        c_b, s_a_fake))  # s_a1[i].unsqueeze(0)))
                    x_ba2.append(self.gen_a.decode(
                        c_b, s_a_fake))  # s_a2[i].unsqueeze(0)))
                    x_ab1.append(self.gen_b.decode(
                        c_a, s_b_fake))  # s_b1[i].unsqueeze(0)))
                    x_ab2.append(self.gen_b.decode(
                        c_a, s_b_fake))  # s_b2[i].unsqueeze(0)))
                else:
                    print("self.guided unknown value:", self.guided)

        elif self.gen_state == 1:
            for i in range(x_a.size(0)):
                c_a, s_a_fake = self.gen.encode(x_a[i].unsqueeze(0), 1)
                c_b, s_b_fake = self.gen.encode(x_b[i].unsqueeze(0), 2)
                x_a_recon.append(self.gen.decode(c_a, s_a_fake, 1))
                x_b_recon.append(self.gen.decode(c_b, s_b_fake, 2))
                if self.guided == 0:
                    x_ba1.append(self.gen.decode(c_b, s_a1[i].unsqueeze(0), 1))
                    x_ba2.append(self.gen.decode(c_b, s_a2[i].unsqueeze(0), 1))
                    x_ab1.append(self.gen.decode(c_a, s_b1[i].unsqueeze(0), 2))
                    x_ab2.append(self.gen.decode(c_a, s_b2[i].unsqueeze(0), 2))
                elif self.guided == 1:
                    x_ba1.append(self.gen.decode(c_b, s_a_fake,
                                                 1))  # s_a1[i].unsqueeze(0)))
                    x_ba2.append(self.gen.decode(c_b, s_a_fake,
                                                 1))  # s_a2[i].unsqueeze(0)))
                    x_ab1.append(self.gen.decode(c_a, s_b_fake,
                                                 2))  # s_b1[i].unsqueeze(0)))
                    x_ab2.append(self.gen.decode(c_a, s_b_fake,
                                                 2))  # s_b2[i].unsqueeze(0)))
                else:
                    print("self.guided unknown value:", self.guided)

        else:
            print("self.gen_state unknown value:", self.gen_state)

        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
        x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)

        if self.semantic_w:
            rgb_a_list, rgb_b_list, rgb_ab_list, rgb_ba_list = [], [], [], []

            for i in range(x_a.size(0)):

                # Inference semantic segmentation on original images
                im_a = (x_a[i].squeeze() + 1) / 2.0
                im_b = (x_b[i].squeeze() + 1) / 2.0

                input_transformed_a = seg_transform()(im_a).unsqueeze(0)
                input_transformed_b = seg_transform()(im_b).unsqueeze(0)
                output_a = (self.segmentation_model(
                    input_transformed_a).squeeze().max(0)[1])
                output_b = (self.segmentation_model(
                    input_transformed_b).squeeze().max(0)[1])

                rgb_a = decode_segmap(output_a.cpu().numpy())
                rgb_b = decode_segmap(output_b.cpu().numpy())
                rgb_a = Image.fromarray(rgb_a).resize(
                    (x_a.size(3), x_a.size(3)))
                rgb_b = Image.fromarray(rgb_b).resize(
                    (x_a.size(3), x_a.size(3)))

                rgb_a_list.append(transforms.ToTensor()(rgb_a).unsqueeze(0))
                rgb_b_list.append(transforms.ToTensor()(rgb_b).unsqueeze(0))

                # Inference semantic segmentation on fake images
                image_ab = (x_ab1[i].squeeze() + 1) / 2.0
                image_ba = (x_ba1[i].squeeze() + 1) / 2.0

                input_transformed_ab = seg_transform()(image_ab).unsqueeze(
                    0).to("cuda")
                input_transformed_ba = seg_transform()(image_ba).unsqueeze(
                    0).to("cuda")

                output_ab = (self.segmentation_model(
                    input_transformed_ab).squeeze().max(0)[1])
                output_ba = (self.segmentation_model(
                    input_transformed_ba).squeeze().max(0)[1])

                rgb_ab = decode_segmap(output_ab.cpu().numpy())
                rgb_ba = decode_segmap(output_ba.cpu().numpy())

                rgb_ab = Image.fromarray(rgb_ab).resize(
                    (x_a.size(3), x_a.size(3)))
                rgb_ba = Image.fromarray(rgb_ba).resize(
                    (x_a.size(3), x_a.size(3)))

                rgb_ab_list.append(transforms.ToTensor()(rgb_ab).unsqueeze(0))
                rgb_ba_list.append(transforms.ToTensor()(rgb_ba).unsqueeze(0))

            rgb1_a, rgb1_b, rgb1_ab, rgb1_ba = (
                torch.cat(rgb_a_list).cuda(),
                torch.cat(rgb_b_list).cuda(),
                torch.cat(rgb_ab_list).cuda(),
                torch.cat(rgb_ba_list).cuda(),
            )

        self.train()
        if self.semantic_w:
            self.segmentation_model.eval()
            return (
                x_a,
                x_a_recon,
                rgb1_a,
                x_ab1,
                rgb1_ab,
                x_ab2,
                x_b,
                x_b_recon,
                rgb1_b,
                x_ba1,
                rgb1_ba,
                x_ba2,
            )
        else:
            return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2
    while True:
        ret, frame = cam1.read()
        if not ret:
            print("failed to grab frame")
            break
        cv2.imshow("test", frame)

        tmg_ = cv2.resize(frame, (512, 512), cv2.INTER_NEAREST)
        tmg = torch.tensor(tmg_).unsqueeze(0).float()
        tmg = tmg.transpose(2, 3).transpose(1, 2).to(device)

        with torch.no_grad():
            out1 = enet(tmg.float()).squeeze(0)

        out2 = out1.cpu().detach().numpy()
        segmentated = decode_segmap(out2)

        final = cv2.vconcat([frame, segmentated])
        imshow("Comparison", final)

        k = cv2.waitKey(10)
        if k % 256 == 27:
            # ESC pressed
            print("Escape hit, closing...")
            break
        elif k % 256 == 32:
            # SPACE pressed
            img_name = "opencv_frame_{}.png".format(img_counter)
            cv2.imwrite(img_name, frame)
            print("{} written!".format(img_name))
            img_counter += 1
Beispiel #13
0
from utils import decode_segmap
import numpy as np
from PIL import Image

mask = Image.open('mask.png')
mask_np = np.asarray(mask, dtype=np.uint8)
mask_color = decode_segmap(mask_np, 'pascal') * 255
mask_color = mask_color.astype(np.uint8)
mask_c_img = Image.fromarray(mask_color, mode='RGB')
mask_c_img.save('color.png')
        image, mask = safe_crop(image, mask, size=512)
        image = transforms.ToPILImage()(image.copy().astype(np.uint8))
        image = self.transformer(image)

        mask = torch.from_numpy(mask)

        return image, mask

    def __len__(self):
        return len(self.images)
        # return len(self.names)


if __name__ == "__main__":
    dltrain = DLDataset('val', "./data/pascal_voc_seg/tfrecord/")
    # dltrain = DLDataset('trainval', "./data/pascal_voc_seg/VOCdevkit/VOC2012/")
    dataloader = DataLoader(dltrain, batch_size=1, num_workers=8, shuffle=True)
    for image, mask in dataloader:
        image = image.numpy()
        image = image[0]
        # print(image.shape)
        image = np.transpose(image, (1, 2, 0))
        plt.imshow(image)
        plt.show()

        mask = mask.numpy()
        mask = mask[0]
        print(type(mask))
        segmap = decode_segmap(mask, 'pascal', plot=True)