def get_ims(slice_idx, vol, spacing, slice_intv, mask_list): num_slice = cfg.INPUT.NUM_SLICES * cfg.INPUT.NUM_IMAGES_3DCE im_np, im_scale, crop, mask_list = load_prep_img(vol, slice_idx, spacing, slice_intv, mask_list, cfg.INPUT.IMG_DO_CLIP, num_slice=num_slice) im = im_np - cfg.INPUT.PIXEL_MEAN im = torch.from_numpy(im.transpose((2, 0, 1))).to(dtype=torch.float) ims = im.split(cfg.INPUT.NUM_IMAGES_3DCE) return ims, im_np[:, :, int(num_slice/2)+1], im_scale, crop, mask_list
def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target, info). """ image_fn = self.image_fn_list[index] lesion_idx_grouped = self.lesion_idx_grouped[index] boxes0 = self.boxes[lesion_idx_grouped] # slice_no = self.slice_idx[lesion_idx_grouped][0] slice_intv = self.slice_intv[lesion_idx_grouped][0] spacing = self.spacing[lesion_idx_grouped][0] recists = self.d_coordinate[lesion_idx_grouped] diameters = self.diameter[lesion_idx_grouped] window = self.DICOM_window[lesion_idx_grouped][0] gender = float(self.gender[lesion_idx_grouped][0] == 'M') age = self.age[lesion_idx_grouped][0]/100 if np.isnan(age) or age == 0: age = .5 z_coord = self.norm_location[lesion_idx_grouped[0], 2] num_slice = cfg.INPUT.NUM_SLICES * cfg.INPUT.NUM_IMAGES_3DCE is_train = self.split=='train' if is_train and cfg.INPUT.DATA_AUG_3D is not False: slice_radius = diameters.min() / 2 * spacing / slice_intv * abs(cfg.INPUT.DATA_AUG_3D) # lesion should not be too small slice_radius = int(slice_radius) if slice_radius > 0: if cfg.INPUT.DATA_AUG_3D > 0: delta = np.random.randint(0, slice_radius+1) else: # give central slice higher prob ar = np.arange(slice_radius+1) p = slice_radius-ar.astype(float) delta = np.random.choice(ar, p=p/p.sum()) if np.random.rand(1) > .5: delta = -delta dirname, slicename = image_fn.split(os.sep) slice_idx = int(slicename[:-4]) image_fn1 = '%s%s%03d.png' % (dirname, os.sep, slice_idx + delta) if os.path.exists(os.path.join(self.data_path, image_fn1)): image_fn = image_fn1 im, im_scale, crop = load_prep_img(self.data_path, image_fn, spacing, slice_intv, cfg.INPUT.IMG_DO_CLIP, num_slice=num_slice, is_train=is_train) im -= cfg.INPUT.PIXEL_MEAN im = torch.from_numpy(im.transpose((2, 0, 1))).to(dtype=torch.float) boxes_new = boxes0.copy() if cfg.INPUT.IMG_DO_CLIP: offset = [crop[2], crop[0]] boxes_new -= offset*2 boxes_new *= im_scale boxes = torch.as_tensor(boxes_new).reshape(-1, 4) # guard against no boxes target = BoxList(boxes, (im.shape[2], im.shape[1]), mode="xyxy") num_boxes = boxes.shape[0] classes = torch.ones(num_boxes, dtype=torch.int) # lesion/nonlesion target.add_field("labels", classes) if cfg.MODEL.TAG_ON: tags = torch.zeros(num_boxes, self.num_tags, dtype=torch.int) reliable_neg_tags = torch.zeros(num_boxes, self.num_tags, dtype=torch.int) for p in range(num_boxes): if lesion_idx_grouped[p] in self.lesion_tags.keys(): pos_tags = self.lesion_tags[lesion_idx_grouped[p]] tags[p, pos_tags] = 1 ex_tags = [e for l in pos_tags for e in self.exclusive_list[l] if e not in pos_tags] reliable_neg_tags[p, ex_tags] = 1 else: tags[p] = -1 # no tag exist for this lesion, the loss weights for this lesion should be zero reliable_neg_tags[p] = -1 # no tag exist for this lesion, the loss weights for this lesion should be zero target.add_field("tags", tags) target.add_field("reliable_neg_tags", reliable_neg_tags) if self.split == 'test': tags = torch.zeros(num_boxes, self.num_tags, dtype=torch.int) for p in range(num_boxes): if lesion_idx_grouped[p] in self.manual_annot_test_tags.keys(): tags[p, self.manual_annot_test_tags[lesion_idx_grouped[p]]] = 1 else: tags[p] = -1 # no tag exist for this lesion, the loss weights for this lesion should be zero target.add_field("manual_annot_test_tags", tags) if cfg.INPUT.IMG_DO_CLIP: recists -= offset * 4 recists *= im_scale if cfg.MODEL.MASK_ON: masks = [] for recist in recists: mask = gen_mask_polygon_from_recist(recist) masks.append([mask]) masks = SegmentationMask(masks, (im.shape[-1], im.shape[-2])) target.add_field("masks", masks) target = target.clip_to_image(remove_empty=False) if self.transforms is not None: im, target = self.transforms(im, target) infos = {'im_index': index, 'lesion_idxs': lesion_idx_grouped, 'image_fn': image_fn, 'diameters': diameters*spacing, 'crop': crop, 'recists': recists, 'window': window, 'spacing': spacing, 'im_scale': im_scale, 'gender': gender, 'age': age, 'z_coord': z_coord} return im, target, infos