def main(args): print('Loading model') if args.model == 'vnet': model = vnet.VNet.load_from_checkpoint(args.checkpoint) else: # args.model == 'unet' model = unet.UNet3dTrainer.load_from_checkpoint(args.checkpoint) size = np.asarray(args.crop_size) files = [] for f in listdir(args.data_dir): if isfile(join(args.data_dir, f)): files.append(f) if torch.cuda.is_available(): dev = "cuda:0" else: dev = "cpu" device = torch.device(dev) model = model.to(device) model.eval() for i, filename in enumerate(files): print('File', i + 1, ':', filename) data = torch.from_numpy(np.load(join(args.data_dir, filename))) vol_size = data.shape data = data.unsqueeze(0).unsqueeze(0) pred = torch.zeros((2, ) + vol_size, dtype=torch.float32) mask = torch.zeros((1, ) + vol_size, dtype=torch.bool) for outer_corner, outer_size, inner_corner in tqdm( SubvolCorners(vol_size, size, border=args.border)): sub_data = F.crop(data, outer_corner, outer_size) sub_pred = F.crop(pred, outer_corner + inner_corner, size) sub_mask = F.crop(mask, outer_corner + inner_corner, size) if torch.cuda.is_available(): sub_data = sub_data.cuda() with torch.no_grad(): res = model(sub_data).cpu() sub_pred[:] = F.crop(res, inner_corner, size) sub_mask[:] = sub_pred.argmax(dim=0) == 1 pred = pred.squeeze() mask = mask.squeeze() out_filename = join(args.save_dir, basename(args.checkpoint) + '.' + filename) if args.file_type == 'npy': np.save(out_filename + '.pred', pred.numpy()) np.save(out_filename + '.mask', mask.numpy()) elif args.file_type == 'raw': pred.numpy().tofile(out_filename + '.pred.raw') mask.numpy().tofile(out_filename + '.mask.raw') else: raise ValueError('Invalid file type: {}'.format(args.file_type))
def get_crop(self, index): vol_index = index // self.samples_per_volume sample_index = index % self.samples_per_volume corner = self.subvol_corners[sample_index] data_and_mask = self.get_volume(vol_index) return F.crop(data_and_mask, corner, self.size)
def _compute_supported_crop_corners(self, labels): """ Returns corners for all crops that contain labels """ support = labels > 0 corners = [] for c in tqdm(SubvolCorners(self.vol_size, self.size)): if (F.crop(support, c, self.size) > 0).any(): corners.append(c) return corners
def get_crop(self, index): vol_index = index // self.samples_per_volume data_and_mask = self.get_volume(vol_index) # For now, just keep sampling random subvolumes until we find one with # labels. Since F.random_crop is fast, this is okay. while True: sample, corner = F.random_crop(data_and_mask, self.size, self.dist, return_corner=True) if torch.any(sample[1] > 0): corner = self._move_to_center_of_mass(corner, sample[1], data_and_mask[1].size()) return F.crop(data_and_mask, corner, self.size)
def crop(data: dict, region: tuple): image = data["data"] cropped_image = RF.crop(image, *region) target = data.copy() i, j, h, w = region # should we do something wrt the original size? target["size"] = torch.tensor([h, w]) fields = ["labels", "area", "iscrowd"] if "boxes" in target: boxes = target["boxes"] max_size = torch.as_tensor([w, h], dtype=torch.float32) cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) cropped_boxes = cropped_boxes.clamp(min=0) area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) target["boxes"] = cropped_boxes.reshape(-1, 4) target["area"] = area fields.append("boxes") if "masks" in target: target["masks"] = target["masks"][:, i:i + h, j:j + w] fields.append("masks") # remove elements for which the boxes or masks that have zero area if "boxes" in target or "masks" in target: # favor boxes selection when defining which elements to keep # this is compatible with previous implementation if "boxes" in target: cropped_boxes = target["boxes"].reshape(-1, 2, 2) keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) else: keep = target["masks"].flatten(1).any(1) for field in fields: target[field] = target[field][keep] target["data"] = cropped_image return target
def get_crop(self, index): vol_index = self.index_to_vol[index] corner_index = index - self.corner_index_offsets[vol_index] corner = self.corner_lists[vol_index][corner_index] data_and_mask = self.get_volume(vol_index) return F.crop(data_and_mask, corner, self.size)