def SUE_TTA(model, batch: torch.tensor, last_layer: bool) -> Tuple[np.ndarray, np.ndarray]: r"""Interface of Binary Segmentation Uncertainty Estimation with Test-Time Augmentations (TTA) method for 1 2D slice. Inputs supposed to be in range [0, data_range]. Args: model: Trained model. batch: Tensor with shape (1, C, H, W). last_layer: Flag whether there is Sigmoid as a last NN layer Returns: Aleatoric and epistemic uncertainty maps with shapes equal to batch shape """ model.eval() transforms = tta.Compose( [ tta.VerticalFlip(), tta.HorizontalFlip(), tta.Rotate90(angles=[0, 180]), tta.Scale(scales=[1, 2, 4]), tta.Multiply(factors=[0.9, 1, 1.1]), ] ) predicted = [] for transformer in transforms: augmented_image = transformer.augment_image(batch) model_output = model(augmented_image) deaug_mask = transformer.deaugment_mask(model_output) prediction = torch.sigmoid( deaug_mask).cpu().detach().numpy() if last_layer else deaug_mask.cpu().detach().numpy() predicted.append(prediction) p_hat = np.array(predicted) aleatoric = calc_aleatoric(p_hat) epistemic = calc_epistemic(p_hat) return aleatoric, epistemic
def predict(self, tta_aug=None, debug=None): transforms = tta_aug if tta_aug is None: import ttach as tta transforms = tta.Compose([ tta.Scale(scales=[0.95, 1, 1.05]), tta.HorizontalFlip(), ]) from torch.utils import data self.model.eval() if not isinstance(self.settings, PredictorSettings): logger.warning( 'Settings is of type: {}. Pass settings to network object of type Train to train' .format(str(type(self.settings)))) return predict_loader = data.DataLoader(dataset=self.settings.PREDICT_DATASET, batch_size=1, shuffle=False, num_workers=self.settings.PROCESSES) with torch.no_grad(): for idx, (data, target, id) in enumerate(predict_loader): data, target = data.to(self.device), target.to( self.device, dtype=torch.int64) outputs = [] o_shape = data.shape for transformer in transforms: augmented_image = transformer.augment_image(data) shape = list(augmented_image.shape)[2:] padded = pad(augmented_image, self.padding_value) ## 2**5 input = padded.float() output = self.model(input) output = unpad(output, shape) reversed = transformer.deaugment_mask(output) reversed = torch.nn.functional.interpolate( reversed, size=list(o_shape)[2:], mode="nearest") print( "original: {} input: {}, padded: {} unpadded {} output {}" .format(str(o_shape), str(shape), str(list(augmented_image.shape)), str(list(output.shape)), str(list(reversed.shape)))) outputs.append(reversed) stacked = torch.stack(outputs) output = torch.mean(stacked, dim=0) outputs.append(output) out = output.data.cpu().numpy() out = np.transpose(out, (0, 2, 3, 1)) out = np.squeeze(out) yield out
def tta_model_predict(X, model): tta_transforms = tta.Compose( [tta.HorizontalFlip(), tta.Scale(scales=[0.5, 1, 2])]) masks = [] for transformer in tta_transforms: augmented_image = transformer.augment_image(X) model_output = model(augmented_image)["out"] deaug_mask = transformer.deaugment_mask(model_output) masks.append(deaug_mask) mask = torch.sum(torch.stack(masks), dim=0) / len(masks) return mask
def test_compose_1(): transform = tta.Compose([ tta.HorizontalFlip(), tta.VerticalFlip(), tta.Rotate90(angles=[0, 90, 180, 270]), tta.Scale(scales=[1, 2, 4], interpolation="nearest"), ]) assert len( transform) == 2 * 2 * 4 * 3 # all combinations for aug parameters dummy_label = torch.ones(2).reshape(2, 1).float() dummy_image = torch.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5).float() dummy_model = lambda x: {"label": dummy_label, "mask": x} for augmenter in transform: augmented_image = augmenter.augment_image(dummy_image) model_output = dummy_model(augmented_image) deaugmented_mask = augmenter.deaugment_mask(model_output["mask"]) deaugmented_label = augmenter.deaugment_label(model_output["label"]) assert torch.allclose(deaugmented_mask, dummy_image) assert torch.allclose(deaugmented_label, dummy_label)
def test(model, data_loader, save_path=""): """ 为了计算方便,训练过程中的验证与测试都直接计算指标J和F,不再先生成再输出, 所以这里的指标仅作一个相对的参考,具体真实指标需要使用测试代码处理 """ model.eval() tqdm_iter = tqdm(enumerate(data_loader), total=len(data_loader), leave=False) if arg_config['use_tta']: construct_print("We will use Test Time Augmentation!") transforms = tta.Compose([ # 2*3 tta.HorizontalFlip(), tta.Scale(scales=[0.75, 1, 1.5], interpolation='bilinear', align_corners=False) ]) else: transforms = None results = defaultdict(list) for test_batch_id, test_data in tqdm_iter: tqdm_iter.set_description(f"te=>{test_batch_id + 1}") with torch.no_grad(): curr_jpegs = test_data["image"].to(DEVICES, non_blocking=True) curr_flows = test_data["flow"].to(DEVICES, non_blocking=True) preds_logits = tta_aug(model=model, transforms=transforms, data=dict(curr_jpeg=curr_jpegs, curr_flow=curr_flows)) preds_prob = preds_logits.sigmoid().squeeze().cpu().detach( ) # float32 for i, pred_prob in enumerate(preds_prob.numpy()): curr_mask_path = test_data["mask_path"][i] video_name, mask_name = curr_mask_path.split(os.sep)[-2:] mask = read_binary_array(curr_mask_path, thr=0) mask_h, mask_w = mask.shape pred_prob = cv2.resize(pred_prob, dsize=(mask_w, mask_h), interpolation=cv2.INTER_LINEAR) pred_prob = clip_to_normalize(data_array=pred_prob, clip_range=arg_config["clip_range"]) pred_seg = np.where(pred_prob > 0.5, 255, 0).astype(np.uint8) results[video_name].append( (jaccard.db_eval_iou(annotation=mask, segmentation=pred_seg), f_boundary.db_eval_boundary(annotation=mask, segmentation=pred_seg))) if save_path: pred_video_path = os.path.join(save_path, video_name) if not os.path.exists(pred_video_path): os.makedirs(pred_video_path) pred_frame_path = os.path.join(pred_video_path, mask_name) cv2.imwrite(pred_frame_path, pred_seg) j_f_collection = [] for video_name, video_scores in results.items(): j_f_for_video = np.mean(np.array(video_scores), axis=0).tolist() results[video_name] = j_f_for_video j_f_collection.append(j_f_for_video) results['average'] = np.mean(np.array(j_f_collection), axis=0).tolist() return pretty_print(results)
multilabel_resnet34 = models.resnet34(pretrained=False) multilabel_resnet34.fc = torch.nn.Linear(multilabel_resnet34.fc.in_features, len(CLASSES)) multilabel_resnet34 = load_dataparallel_model( multilabel_resnet34, torch.load(model_multilabel_path)) multilabel_resnet34 = torch.nn.DataParallel(multilabel_resnet34, device_ids=range( torch.cuda.device_count())) multilabel_resnet34.to(DEVICE) multilabel_resnet34.eval() multilabel_resnet34 = tta.ClassificationTTAWrapper(multilabel_resnet34, tta.Compose([ tta.HorizontalFlip(), tta.VerticalFlip(), tta.Scale([0.85, 1.15]), ]), merge_mode="mean") # #### Inference DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") class_th = torch.tensor([0.0005, 0.2, 0.005, 0.02, 0.005, 0.005, 0.008, 0.005]).to(DEVICE) total_preds, total_paths = [], [] @toma.batch(initial_batchsize=batch_size_multilabel) def run_multilabel(batch_size): inference_loader = DataLoader(inference_dataset, batch_size=batch_size,
def testing(num_split, class_params, encoder, decoder): """ 测试推理 """ import gc torch.cuda.empty_cache() gc.collect() sub = "./data/Clouds_Classify/sample_submission.csv" sub = pd.read_csv(open(sub)) sub.head() sub['label'] = sub['Image_Label'].apply(lambda x: x.split('_')[1]) sub['im_id'] = sub['Image_Label'].apply(lambda x: x.split('_')[0]) preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder, 'imagenet') if decoder == 'unet': model = smp.Unet( encoder_name=encoder, encoder_weights='imagenet', classes=4, activation=None, ) else: model = smp.FPN( encoder_name=encoder, encoder_weights='imagenet', classes=4, activation=None, ) test_ids = [id for id in os.listdir(test_imgs_folder)] test_dataset = CloudDataset( df=sub, transforms=get_validation_augmentation(), datatype='test', img_ids=test_ids, preprocessing=get_preprocessing(preprocessing_fn)) test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2) loaders = {"test": test_loader} logdir = "./logs/log_{}_{}/log_{}".format(encoder, decoder, num_split) encoded_pixels = [] ###############使用pytorch TTA预测#################### use_TTA = True checkpoint_path = logdir + '/checkpoints/best.pth' runner_out = [] model.load_state_dict(torch.load(checkpoint_path)['model_state_dict']) #使用tta预测 if use_TTA: transforms = tta.Compose([ tta.HorizontalFlip(), tta.VerticalFlip(), tta.Scale(scales=[5 / 6, 1, 7 / 6]), ]) tta_model = tta.SegmentationTTAWrapper(model, transforms, merge_mode='mean') else: tta_model = model tta_model = tta_model.cuda() tta_model.eval() with torch.no_grad(): for i, data in enumerate(tqdm.tqdm(loaders['test'])): img, _ = data img = img.cuda() batch_preds = tta_model(img).cpu().numpy() runner_out.extend(batch_preds) runner_out = np.array(runner_out) for i, output in tqdm.tqdm(enumerate(runner_out)): for j, probability in enumerate(output): if probability.shape != (350, 525): probability = cv2.resize(probability, dsize=(525, 350), interpolation=cv2.INTER_LINEAR) logit = sigmoid(probability) predict, num_predict = post_process(logit, class_params[j][0], class_params[j][1]) if num_predict == 0: encoded_pixels.append('') else: r = mask2rle(predict) encoded_pixels.append(r) sub['EncodedPixels'] = encoded_pixels sub.to_csv('./sub/{}_{}/tta_submission_{}.csv'.format( encoder, decoder, num_split), columns=['Image_Label', 'EncodedPixels'], index=False)
def validation(valid_ids, num_split, encoder, decoder): """ 模型验证,并选择后处理参数 """ train = "./data/Clouds_Classify/train.csv" # Data overview train = pd.read_csv(open(train)) train.head() train['label'] = train['Image_Label'].apply(lambda x: x.split('_')[1]) train['im_id'] = train['Image_Label'].apply(lambda x: x.split('_')[0]) ENCODER = encoder ENCODER_WEIGHTS = 'imagenet' if decoder == 'unet': model = smp.Unet( encoder_name=ENCODER, encoder_weights=ENCODER_WEIGHTS, classes=4, activation=None, ) else: model = smp.FPN( encoder_name=ENCODER, encoder_weights=ENCODER_WEIGHTS, classes=4, activation=None, ) preprocessing_fn = smp.encoders.get_preprocessing_fn( ENCODER, ENCODER_WEIGHTS) num_workers = 4 valid_bs = 32 valid_dataset = CloudDataset( df=train, transforms=get_validation_augmentation(), datatype='valid', img_ids=valid_ids, preprocessing=get_preprocessing(preprocessing_fn)) valid_loader = DataLoader(valid_dataset, batch_size=valid_bs, shuffle=False, num_workers=num_workers) loaders = {"valid": valid_loader} logdir = "./logs/log_{}_{}/log_{}".format(encoder, decoder, num_split) valid_masks = [] probabilities = np.zeros((len(valid_ids) * 4, 350, 525)) ############### TTA预测 #################### use_TTA = True checkpoint_path = logdir + '/checkpoints/best.pth' runner_out = [] model.load_state_dict(torch.load(checkpoint_path)['model_state_dict']) if use_TTA: transforms = tta.Compose([ tta.HorizontalFlip(), tta.VerticalFlip(), tta.Scale(scales=[5 / 6, 1, 7 / 6]), ]) tta_model = tta.SegmentationTTAWrapper(model, transforms, merge_mode='mean') else: tta_model = model tta_model = tta_model.cuda() tta_model.eval() with torch.no_grad(): for i, data in enumerate(tqdm.tqdm(loaders['valid'])): img, _ = data img = img.cuda() batch_preds = tta_model(img).cpu().numpy() runner_out.extend(batch_preds) runner_out = np.array(runner_out) ######################END########################## for i, ((_, mask), output) in enumerate(tqdm.tqdm(zip(valid_dataset, runner_out))): for m in mask: if m.shape != (350, 525): m = cv2.resize(m, dsize=(525, 350), interpolation=cv2.INTER_LINEAR) valid_masks.append(m) for j, probability in enumerate(output): if probability.shape != (350, 525): probability = cv2.resize(probability, dsize=(525, 350), interpolation=cv2.INTER_LINEAR) probabilities[i * 4 + j, :, :] = probability # Find optimal values print('searching for optimal param...') params_0 = [[35, 76], [12000, 19001]] params_1 = [[35, 76], [12000, 19001]] params_2 = [[35, 76], [12000, 19001]] params_3 = [[35, 76], [8000, 15001]] param = [params_0, params_1, params_2, params_3] for class_id in range(4): par = param[class_id] attempts = [] for t in range(par[0][0], par[0][1], 5): t /= 100 for ms in range(par[1][0], par[1][1], 2000): masks = [] print('==> searching [class_id:%d threshold:%.3f ms:%d]' % (class_id, t, ms)) for i in tqdm.tqdm(range(class_id, len(probabilities), 4)): probability = probabilities[i] predict, _ = post_process(sigmoid(probability), t, ms) masks.append(predict) d = [] for i, j in zip(masks, valid_masks[class_id::4]): if (i.sum() == 0) & (j.sum() == 0): d.append(1) else: d.append(dice(i, j)) attempts.append((t, ms, np.mean(d))) attempts_df = pd.DataFrame(attempts, columns=['threshold', 'size', 'dice']) attempts_df = attempts_df.sort_values('dice', ascending=False) attempts_df.to_csv( './params/{}_{}_par/params_{}/tta_params_{}.csv'.format( encoder, decoder, num_split, class_id), columns=['threshold', 'size', 'dice'], index=False)
import pytest import torch import ttach as tta @pytest.mark.parametrize( "transform", [ tta.HorizontalFlip(), tta.VerticalFlip(), tta.Rotate90(angles=[0, 90, 180, 270]), tta.Scale(scales=[1, 2, 4], interpolation="nearest"), tta.Resize(sizes=[(4, 5), (8, 10)], original_size=(4, 5), interpolation="nearest") ], ) def test_aug_deaug_mask(transform): a = torch.arange(20).reshape(1, 1, 4, 5).float() for p in transform.params: aug = transform.apply_aug_image(a, **{transform.pname: p}) deaug = transform.apply_deaug_mask(aug, **{transform.pname: p}) assert torch.allclose(a, deaug) @pytest.mark.parametrize( "transform", [ tta.HorizontalFlip(), tta.VerticalFlip(), tta.Rotate90(angles=[0, 90, 180, 270]), tta.Scale(scales=[1, 2, 4], interpolation="nearest"), tta.Add(values=[-1, 0, 1, 2]),
def get_val_logits(valid_ids, num_split, encoder, decoder): # valid train = "./data/Clouds_Classify/train.csv" # Data overview train = pd.read_csv(open(train)) train.head() train['label'] = train['Image_Label'].apply(lambda x: x.split('_')[1]) train['im_id'] = train['Image_Label'].apply(lambda x: x.split('_')[0]) ENCODER = encoder ENCODER_WEIGHTS = 'imagenet' if decoder == 'unet': #建立模型 model = smp.Unet( encoder_name=ENCODER, encoder_weights=ENCODER_WEIGHTS, classes=4, activation=None, ) else: model = smp.FPN( encoder_name=ENCODER, encoder_weights=ENCODER_WEIGHTS, classes=4, activation=None, ) preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS) num_workers = 4 valid_bs = 1 valid_dataset = CloudDataset(df=train, transforms = get_validation_augmentation(), datatype='valid', img_ids=valid_ids, preprocessing=get_preprocessing(preprocessing_fn)) valid_loader = DataLoader(valid_dataset, batch_size=valid_bs, shuffle=False, num_workers=num_workers) loaders = {"valid": valid_loader} logdir = "./logs/log_{}_{}/log_{}".format(encoder, decoder, num_split) print('predicting for validation data...') ###############使用pytorch TTA预测#################### use_TTA = True checkpoint_path = logdir + '/checkpoints/best.pth' model.load_state_dict(torch.load(checkpoint_path)['model_state_dict']) #使用tta预测 if use_TTA: transforms = tta.Compose([ tta.HorizontalFlip(), tta.VerticalFlip(), tta.Scale(scales=[5/6, 1, 7/6]), ]) tta_model = tta.SegmentationTTAWrapper(model, transforms, merge_mode='mean') else: tta_model = model tta_model = tta_model.cuda() tta_model.eval() with torch.no_grad(): for _, data in enumerate(tqdm.tqdm(loaders['valid'])): img, _, img_name = data img = img.cuda() batch_preds = tta_model(img).cpu().numpy() batch_preds = batch_preds.astype(np.float16) save_dir = './logits/valid/' + encoder + '_' + decoder + '/log_{}'.format(num_split) if not os.path.exists(save_dir): os.makedirs(save_dir) file_name = img_name[0].split('.')[0] + '.plk' file_path = os.path.join(save_dir, file_name) with open(file_path, 'wb') as wf: plk.dump(batch_preds, wf)
def get_test_logits(encoder, decoder): ''' 预测并保存测试集logits ''' sub = "./data/Clouds_Classify/sample_submission.csv" sub = pd.read_csv(open(sub)) sub['label'] = sub['Image_Label'].apply(lambda x: x.split('_')[1]) sub['im_id'] = sub['Image_Label'].apply(lambda x: x.split('_')[0]) #建立模型 preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder, 'imagenet') if decoder == 'unet': model = smp.Unet( encoder_name=encoder, encoder_weights='imagenet', classes=4, activation=None, ) else: model = smp.FPN( encoder_name=encoder, encoder_weights='imagenet', classes=4, activation=None, ) #载入模型参数 logdir = "./logs/log_{}_{}/log_{}".format(encoder, 'fpn', 4) checkpoint_path = logdir + '/checkpoints/best.pth' model.load_state_dict(torch.load(checkpoint_path)['model_state_dict']) #使用tta预测 use_TTA = True if use_TTA: print('using TTA!!!') transforms = tta.Compose([ tta.HorizontalFlip(), tta.VerticalFlip(), tta.Scale(scales=[5/6, 1, 7/6]), ]) tta_model = tta.SegmentationTTAWrapper(model, transforms, merge_mode='mean') else: tta_model = model test_ids = [id for id in os.listdir(test_imgs_folder)] test_dataset = CloudDataset(df=sub, transforms = get_validation_augmentation(), datatype='test', img_ids=test_ids, preprocessing=get_preprocessing(preprocessing_fn)) test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2) tta_model = tta_model.cuda() tta_model.eval() with torch.no_grad(): for i, data in enumerate(tqdm.tqdm(test_loader)): img, _, img_name = data img = img.cuda() batch_preds = tta_model(img).cpu().numpy() batch_preds = batch_preds.astype(np.float16) save_dir = './logits/test/' + encoder + '_fpn' + '/log_{}'.format(4) if not os.path.exists(save_dir): os.makedirs(save_dir) #保存为pickle file_name = img_name[0].split('.')[0] + '.plk' file_path = os.path.join(save_dir, file_name) with open(file_path, 'wb') as wf: plk.dump(batch_preds, wf)
def predict_single_image(self, image: np.array, rgb=True, preprocessing=True, tta_aug=None): from segmentation.dataset import process if not isinstance(self.settings, PredictorSettings): logger.warning( 'Settings is of type: {}. Pass settings to network object of type Train to train' .format(str(type(self.settings)))) return # from torch.utils import data transforms = tta_aug if tta_aug is None: import ttach as tta transforms = tta.Compose([ tta.Scale(scales=[0.95, 1, 1.05]), tta.HorizontalFlip(), ]) self.model.eval() preprocessing_fn = sm.encoders.get_preprocessing_fn(self.encoder) image, pseudo_mask = process(image=image, mask=image, rgb=rgb, preprocessing=preprocessing_fn, apply_preprocessing=preprocessing, augmentation=None, color_map=None, binary_augmentation=False) # data = image data = image.unsqueeze(0) with torch.no_grad(): data = data.to(self.device) outputs = [] o_shape = data.shape for transformer in transforms: augmented_image = transformer.augment_image(data) shape = list(augmented_image.shape)[2:] padded = pad(augmented_image, self.padding_value) ## 2**5 input = padded.float() output = self.model(input) output = unpad(output, shape) reversed = transformer.deaugment_mask(output) reversed = torch.nn.functional.interpolate( reversed, size=list(o_shape)[2:], mode="nearest") logger.debug( "original: {} input: {}, padded: {} unpadded {} output {} \n" .format(str(o_shape), str(shape), str(list(augmented_image.shape)), str(list(output.shape)), str(list(reversed.shape)))) outputs.append(reversed) stacked = torch.stack(outputs) output = torch.mean(stacked, dim=0) out = output.data.cpu().numpy() out = np.transpose(out, (0, 2, 3, 1)) out = np.squeeze(out) return out