Beispiel #1
0
def transform_vols(img1,img2,cls1,cls2,net):

    D,H,W = (img1.shape[0],img1.shape[1],img1.shape[2])
    out_img = torch.zeros(img1.shape).float()
    out_cls = torch.zeros((nclasses,) + cls1.shape).float()
    basegrid_img = basegrid(torch.Size((1,1,H,W)))
    basegrid_cls = basegrid(torch.Size((1,nclasses,H,W)))
    for idx in range(D):
        imgslice1 = img1[idx]
        imgslice2  = img2[idx]
        clsslice1 = cls1[idx]
        clsslice2 = cls2[idx]

        clsslice1_oh = OneHotEncode()(clsslice1.unsqueeze(0)).float()
        clsslice2_oh = OneHotEncode()(clsslice2.unsqueeze(0)).float()
        if torch.cuda.is_available():
            clsslice1_oh = clsslice1_oh.cuda()
            clsslice2_oh = clsslice2_oh.cuda()
        combslice = torch.cat((imgslice1.unsqueeze(0),imgslice2.unsqueeze(0)),dim=0)

        disp = generate_grid(net,combslice.unsqueeze(0))

        grid_cls = basegrid_cls + disp
        grid_img = basegrid_img + disp

        imgslice1_t = F.grid_sample(imgslice1.unsqueeze(0).unsqueeze(0),grid_img)[0,0] # HxW
        clsslice1_oh_t = F.grid_sample(clsslice1_oh.unsqueeze(0),grid_cls)[0] # 4xHxW
        _,clsslice1_oh_t = torch.max(clsslice1_oh_t,dim=0)
        clsslice1_oh_t = OneHotEncode()(clsslice1_oh_t.data.unsqueeze(0))
        out_img[idx] = imgslice1_t.data
        out_cls[:,idx] = clsslice1_oh_t

    return out_img,out_cls.byte()
Beispiel #2
0
    def __getitem__(self, index):
        fname1, fname2 = self.pairlist[index]

        with open(os.path.join(self.datadir, "img", fname1 + ".tif"),
                  'rb') as f:
            img1 = Image.open(f).crop(crop_size)
        with open(os.path.join(self.datadir, "img", fname2 + ".tif"),
                  'rb') as f:
            img2 = Image.open(f).crop(crop_size)
        with open(os.path.join(self.datadir, "cls", fname1 + ".png"),
                  'rb') as f:
            label1 = Image.open(f).convert('P').crop(crop_size)
        with open(os.path.join(self.datadir, "cls", fname2 + ".png"),
                  'rb') as f:
            label2 = Image.open(f).convert('P').crop(crop_size)
        # import pdb; pdb.set_trace()
        img1, label1 = self.co_transform((img1, label1))
        img2, label2 = self.co_transform((img2, label2))
        img1 = self.img_transform(img1)
        img2 = self.img_transform(img2)
        label1 = self.label_transform(label1)
        label2 = self.label_transform(label2)

        ohlabel1 = OneHotEncode()(label1)
        return ((img1, label1, fname1), (img2, label2, fname2), ohlabel1)
Beispiel #3
0
    def __getitem__(self, idx):
        new_id, new_image_path, new_seg_path = self.imgs[idx]
        img = load_image(new_image_path)
        label = load_image(new_seg_path)

        threshold = 0
        table = []
        for i in range(256):
            if i == threshold:
                table.append(0)
            else:
                table.append(1)
        label = label.point(table, '1')

        img, label = self.co_transform((img, label))

        img = self.img_transform(img)
        # print(f'getmax(label)::{getmax(label)}')
        label = self.label_transform(label)
        # print(label.max())
        # print(label.shape)
        ohlabel = OneHotEncode()(label)
        # if self.valid == True:
        #     return img, label, ohlabel, slice_id
        # elif self.valid == False and self.labelled == False:
        #     return img, label, ohlabel, slice_id

        return img, label, ohlabel, new_id
Beispiel #4
0
    def __getitem__(self, index):
        filename = self.img_list[index]

        with open(os.path.join(self.images_root,filename+'.jpg'), 'rb') as f:
            image = load_image(f).convert('RGB')
        with open(os.path.join(self.labels_root,filename+'.png'), 'rb') as f:
            label = load_image(f).convert('P')

        image, label = self.co_transform((image,label))
        image = self.img_transform(image)
        label = self.label_transform(label)
        ohlabel = OneHotEncode()(label)

        return image, label, ohlabel
Beispiel #5
0
def load_data(start, end, data_dir, img_transform, label_transform,
              co_transform):
    img_dir = os.path.join(data_dir, "img")
    cls_dir = os.path.join(data_dir, "cls")
    img_arr = []
    cls_arr = []
    cls_oh_arr = []
    for idx in range(start, end + 1):
        img = Image.open(os.path.join(img_dir, str(idx) + '.tif'))
        cls = Image.open(os.path.join(cls_dir, str(idx) + '.png')).convert('P')
        img, cls = co_transform((img, cls))
        img = img_transform(img)
        cls = label_transform(cls)
        cls_ohe = OneHotEncode()(cls)

        img_arr.append(img)
        cls_arr.append(cls)
        cls_oh_arr.append(cls_ohe)
    return img_arr, cls_arr, cls_oh_arr
Beispiel #6
0
def load_data(filename,datadir,co_transform,img_transform,label_transform):
    f = open(filename,'r')
    names = [name.strip() for name in f.readlines()]
    img_dict = {}
    label_dict = {}
    oh_label_dict = {}
    for name in names:
        assert(os.path.exists(os.path.join(datadir,"img",name+".tif")))
        assert(os.path.exists(os.path.join(datadir,"cls",name+".png")))

        img = Image.open(os.path.join(datadir,"img",name+".tif"))
        label = Image.open(os.path.join(datadir,"cls",name+".png")).convert('P')
        img,label = co_transform((img,label))
        img = img_transform(img)
        label = label_transform(label)
        ohlabel = OneHotEncode()(label)

        img_dict[name] = img
        label_dict[name] = label
        oh_label_dict[name] = label
    return img_dict,label_dict,oh_label_dict