def transform_vols(img1,img2,cls1,cls2,net): D,H,W = (img1.shape[0],img1.shape[1],img1.shape[2]) out_img = torch.zeros(img1.shape).float() out_cls = torch.zeros((nclasses,) + cls1.shape).float() basegrid_img = basegrid(torch.Size((1,1,H,W))) basegrid_cls = basegrid(torch.Size((1,nclasses,H,W))) for idx in range(D): imgslice1 = img1[idx] imgslice2 = img2[idx] clsslice1 = cls1[idx] clsslice2 = cls2[idx] clsslice1_oh = OneHotEncode()(clsslice1.unsqueeze(0)).float() clsslice2_oh = OneHotEncode()(clsslice2.unsqueeze(0)).float() if torch.cuda.is_available(): clsslice1_oh = clsslice1_oh.cuda() clsslice2_oh = clsslice2_oh.cuda() combslice = torch.cat((imgslice1.unsqueeze(0),imgslice2.unsqueeze(0)),dim=0) disp = generate_grid(net,combslice.unsqueeze(0)) grid_cls = basegrid_cls + disp grid_img = basegrid_img + disp imgslice1_t = F.grid_sample(imgslice1.unsqueeze(0).unsqueeze(0),grid_img)[0,0] # HxW clsslice1_oh_t = F.grid_sample(clsslice1_oh.unsqueeze(0),grid_cls)[0] # 4xHxW _,clsslice1_oh_t = torch.max(clsslice1_oh_t,dim=0) clsslice1_oh_t = OneHotEncode()(clsslice1_oh_t.data.unsqueeze(0)) out_img[idx] = imgslice1_t.data out_cls[:,idx] = clsslice1_oh_t return out_img,out_cls.byte()
def __getitem__(self, index): fname1, fname2 = self.pairlist[index] with open(os.path.join(self.datadir, "img", fname1 + ".tif"), 'rb') as f: img1 = Image.open(f).crop(crop_size) with open(os.path.join(self.datadir, "img", fname2 + ".tif"), 'rb') as f: img2 = Image.open(f).crop(crop_size) with open(os.path.join(self.datadir, "cls", fname1 + ".png"), 'rb') as f: label1 = Image.open(f).convert('P').crop(crop_size) with open(os.path.join(self.datadir, "cls", fname2 + ".png"), 'rb') as f: label2 = Image.open(f).convert('P').crop(crop_size) # import pdb; pdb.set_trace() img1, label1 = self.co_transform((img1, label1)) img2, label2 = self.co_transform((img2, label2)) img1 = self.img_transform(img1) img2 = self.img_transform(img2) label1 = self.label_transform(label1) label2 = self.label_transform(label2) ohlabel1 = OneHotEncode()(label1) return ((img1, label1, fname1), (img2, label2, fname2), ohlabel1)
def __getitem__(self, idx): new_id, new_image_path, new_seg_path = self.imgs[idx] img = load_image(new_image_path) label = load_image(new_seg_path) threshold = 0 table = [] for i in range(256): if i == threshold: table.append(0) else: table.append(1) label = label.point(table, '1') img, label = self.co_transform((img, label)) img = self.img_transform(img) # print(f'getmax(label)::{getmax(label)}') label = self.label_transform(label) # print(label.max()) # print(label.shape) ohlabel = OneHotEncode()(label) # if self.valid == True: # return img, label, ohlabel, slice_id # elif self.valid == False and self.labelled == False: # return img, label, ohlabel, slice_id return img, label, ohlabel, new_id
def __getitem__(self, index): filename = self.img_list[index] with open(os.path.join(self.images_root,filename+'.jpg'), 'rb') as f: image = load_image(f).convert('RGB') with open(os.path.join(self.labels_root,filename+'.png'), 'rb') as f: label = load_image(f).convert('P') image, label = self.co_transform((image,label)) image = self.img_transform(image) label = self.label_transform(label) ohlabel = OneHotEncode()(label) return image, label, ohlabel
def load_data(start, end, data_dir, img_transform, label_transform, co_transform): img_dir = os.path.join(data_dir, "img") cls_dir = os.path.join(data_dir, "cls") img_arr = [] cls_arr = [] cls_oh_arr = [] for idx in range(start, end + 1): img = Image.open(os.path.join(img_dir, str(idx) + '.tif')) cls = Image.open(os.path.join(cls_dir, str(idx) + '.png')).convert('P') img, cls = co_transform((img, cls)) img = img_transform(img) cls = label_transform(cls) cls_ohe = OneHotEncode()(cls) img_arr.append(img) cls_arr.append(cls) cls_oh_arr.append(cls_ohe) return img_arr, cls_arr, cls_oh_arr
def load_data(filename,datadir,co_transform,img_transform,label_transform): f = open(filename,'r') names = [name.strip() for name in f.readlines()] img_dict = {} label_dict = {} oh_label_dict = {} for name in names: assert(os.path.exists(os.path.join(datadir,"img",name+".tif"))) assert(os.path.exists(os.path.join(datadir,"cls",name+".png"))) img = Image.open(os.path.join(datadir,"img",name+".tif")) label = Image.open(os.path.join(datadir,"cls",name+".png")).convert('P') img,label = co_transform((img,label)) img = img_transform(img) label = label_transform(label) ohlabel = OneHotEncode()(label) img_dict[name] = img label_dict[name] = label oh_label_dict[name] = label return img_dict,label_dict,oh_label_dict