def get_data(data_info, data_path=args.data_path, image_folder=args.images_folder, mask_folder=args.masks_folder, image_type=args.image_type, mask_type=args.mask_type): x = [] y = [] for _, row in data_info.iterrows(): filename = get_fullname(row['name'], row['position']) image_path = get_filepath(data_path, row['name'], image_folder, filename, file_type=image_type) mask_path = get_filepath(data_path, row['name'], mask_folder, filename, file_type=mask_type) x.append(read_tensor(image_path)) y.append(read_tensor(mask_path)) x = np.array(x) y = np.array(y) y = y.reshape([*y.shape, 1]) return x, y
def get_input_pair(self, data_info_row): if len(self.channels) == 0: raise Exception('You have to specify at least one channel.') instance_name = '_'.join( [data_info_row['name'], str(data_info_row['position'])]) image_path = get_filepath(self.dataset_path, data_info_row['name'], self.images_folder, instance_name, file_type=self.image_type) mask_path = get_filepath(self.dataset_path, data_info_row['name'], self.masks_folder, instance_name, file_type=self.mask_type) images_array = filter_by_channels(read_tensor(image_path), self.channels, self.neighbours) if images_array.ndim == 2: images_array = np.expand_dims(images_array, -1) masks_array = read_tensor(mask_path) aug = Compose([ RandomRotate90(), Flip(), OneOf( [ RandomSizedCrop(min_max_height=(int( self.image_size * 0.7), self.image_size), height=self.image_size, width=self.image_size), RandomBrightnessContrast(brightness_limit=0.15, contrast_limit=0.15), #MedianBlur(blur_limit=3, p=0.2), MaskDropout(p=0.6), ElasticTransform(alpha=15, sigma=5, alpha_affine=5), GridDistortion(p=0.6) ], p=0.8), ToTensor() ]) augmented = aug(image=images_array, mask=masks_array) augmented_images = augmented['image'] augmented_masks = augmented['mask'] if self.classification_head: masks_class = ((augmented_masks.sum() > 0) * 1).unsqueeze(-1).float() #.type(torch.FloatTensor) return augmented_images, [augmented_masks, masks_class] else: return {'features': augmented_images, 'targets': augmented_masks}
def predict(data_path, model_weights_path, network, test_df_path, save_path, size, channels, neighbours, classification_head): model = get_model(network, classification_head) model.encoder.conv1 = nn.Conv2d(count_channels(channels) * neighbours, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) model, device = UtilsFactory.prepare_model(model) if classification_head: model.load_state_dict(torch.load(model_weights_path)) else: checkpoint = torch.load(model_weights_path, map_location='cpu') model.load_state_dict(checkpoint['model_state_dict']) test_df = pd.read_csv(test_df_path) predictions_path = os.path.join(save_path, "predictions") if not os.path.exists(predictions_path): os.makedirs(predictions_path, exist_ok=True) print("Prediction directory created.") for _, image_info in tqdm(test_df.iterrows()): filename = '_'.join([image_info['name'], image_info['position']]) image_path = get_filepath(data_path, image_info['dataset_folder'], 'images', filename, file_type='tiff') image_tensor = filter_by_channels(read_tensor(image_path), channels, neighbours) if image_tensor.ndim == 2: image_tensor = np.expand_dims(image_tensor, -1) image = transforms.ToTensor()(image_tensor) if classification_head: prediction, label = model.predict( image.view(1, count_channels(channels) * neighbours, size, size).to(device, dtype=torch.float)) else: prediction = model.predict( image.view(1, count_channels(channels) * neighbours, size, size).to(device, dtype=torch.float)) result = prediction.view(size, size).detach().cpu().numpy() cv.imwrite(get_filepath(predictions_path, filename, file_type='png'), result * 255)
explainer = RISE(model, args.input_size, args.gpu_batch) # Generate masks for RISE or use the saved ones. maskspath = 'masks.npy' generate_new = True if generate_new or not os.path.isfile(maskspath): explainer.generate_masks(N=6000, s=8, p1=0.1, savepath=maskspath) else: explainer.load_masks(maskspath) print('Masks are loaded.') # ## Explaining one instance # Producing saliency maps for top $k$ predicted classes. example(read_tensor('catdog.png'), 5) explanations = explain_all(data_loader, explainer) # Save explanations if needed. # explanations.tofile('exp_{:05}-{:05}.npy'.format(args.range[0], args.range[-1])) for i, (img, _) in enumerate(data_loader): p, c = torch.max(model(img.to(device)), dim=-1) p, c = p[0].item(), c[0].item() prob = torch.softmax(model(img.to(device)), dim=-1) pred_prob = prob[0][c] plt.figure(figsize=(10, 5)) plt.suptitle('RISE Explanation for model {}'.format(args.model))
def __getitem__(self, idx): if len(self.channels) < 2: raise Exception('You have to specify at least two channels.') data_info_row = self.df.iloc[idx] instance_name = '_'.join( [data_info_row['name'], data_info_row['position']]) images_array, masks_array = [], [] #for k in range(1,self.num_images+1): for k in range(self.num_images, 0, -1): image_path = get_filepath(self.dataset_path, data_info_row['dataset_folder'], self.images_folder, instance_name + f'_{k}', file_type=self.image_type) img = filter_by_channels(read_tensor(image_path), self.channels, 1) images_array.append(img) mask_path = get_filepath(self.dataset_path, data_info_row['dataset_folder'], self.masks_folder, instance_name + f'_{k}', file_type=self.mask_type) msk = read_tensor(mask_path) masks_array.append(np.expand_dims(msk, axis=-1)) aug = Compose([ RandomRotate90(), Flip(), OneOf([ RandomSizedCrop(min_max_height=(int( self.image_size * 0.7), self.image_size), height=self.image_size, width=self.image_size), RandomBrightnessContrast(brightness_limit=0.15, contrast_limit=0.15), ElasticTransform(alpha=15, sigma=5, alpha_affine=5), GridDistortion(p=0.6) ], p=0.8), ToTensor() ]) augmented = aug(image=np.concatenate(images_array, axis=-1), mask=np.concatenate(masks_array, axis=-1)) augmented_images = torch.stack([ augmented['image'][num_img * count_channels(self.channels):(num_img + 1) * count_channels(self.channels), :, :] for num_img in range(self.num_images) ]) if self.all_masks: augmented_masks = torch.stack([ augmented['mask'][:, :, :, i] for i in range(augmented['mask'].shape[-1]) ]).squeeze() else: augmented_masks = torch.stack([augmented['mask'][:, :, :, -1]]) return { 'features': augmented_images, 'targets': augmented_masks, 'name': data_info_row['name'], 'position': data_info_row['position'] }
def __getitem__(self, idx): if len(self.channels) < 2: raise Exception('You have to specify at least two channels.') data_info_row = self.df.iloc[idx] instance_name = '_'.join( [data_info_row['name'], data_info_row['position']]) images_array, masks_array = [], [] for k in range(1, self.num_images + 1): image_path = get_filepath(self.dataset_path, data_info_row['dataset_folder'], self.images_folder, instance_name + f'_{k}', file_type=self.image_type) img = filter_by_channels(read_tensor(image_path), self.channels, 1) images_array.append(img) mask_path = get_filepath(self.dataset_path, data_info_row['dataset_folder'], self.masks_folder, instance_name, file_type=self.mask_type) masks_array = read_tensor(mask_path) if self.phase == 'train': aug = Compose([ RandomRotate90(), Flip(), OneOf( [ RandomSizedCrop(min_max_height=(int( self.image_size * 0.7), self.image_size), height=self.image_size, width=self.image_size), RandomBrightnessContrast(brightness_limit=0.15, contrast_limit=0.15), #MedianBlur(blur_limit=3, p=0.2), MaskDropout(p=0.6), ElasticTransform(alpha=15, sigma=5, alpha_affine=5), GridDistortion(p=0.6) ], p=0.8), ToTensor() ]) else: aug = ToTensor() ''' keys = ['image'] values = [images_array[0]] for k in range(self.num_images-1): keys.append(f'image{k}') values.append(images_array[k+1]) keys.append('mask') values.append(masks_array) #{"image" : images_array[0], "image2" : images_array[1], ..., "mask": masks_array, ...} aug_input = { keys[i] : values[i] for i in range(len(keys)) } augmented = aug(**aug_input) augmented_images = [augmented['image']] for k in range(self.num_images-1): augmented_images.append(np.transpose(augmented[f'image{k}'], ( 2, 0, 1))/255) augmented_masks = [augmented['mask']] return {'features': augmented_images, 'targets': augmented_masks, 'name': data_info_row['name'], 'position': data_info_row['position']} ''' augmented = aug(image=np.concatenate( (images_array[0], images_array[1]), axis=-1), mask=masks_array) augmented_images = [ augmented['image'][:count_channels(self.channels), :, :], augmented['image'][count_channels(self.channels):, :, :] ] augmented_masks = [augmented['mask']] return { 'features': augmented_images, 'targets': augmented_masks, 'name': data_info_row['name'], 'position': data_info_row['position'] }