def SUE_TTA(model, batch: torch.tensor, last_layer: bool) -> Tuple[np.ndarray, np.ndarray]: r"""Interface of Binary Segmentation Uncertainty Estimation with Test-Time Augmentations (TTA) method for 1 2D slice. Inputs supposed to be in range [0, data_range]. Args: model: Trained model. batch: Tensor with shape (1, C, H, W). last_layer: Flag whether there is Sigmoid as a last NN layer Returns: Aleatoric and epistemic uncertainty maps with shapes equal to batch shape """ model.eval() transforms = tta.Compose( [ tta.VerticalFlip(), tta.HorizontalFlip(), tta.Rotate90(angles=[0, 180]), tta.Scale(scales=[1, 2, 4]), tta.Multiply(factors=[0.9, 1, 1.1]), ] ) predicted = [] for transformer in transforms: augmented_image = transformer.augment_image(batch) model_output = model(augmented_image) deaug_mask = transformer.deaugment_mask(model_output) prediction = torch.sigmoid( deaug_mask).cpu().detach().numpy() if last_layer else deaug_mask.cpu().detach().numpy() predicted.append(prediction) p_hat = np.array(predicted) aleatoric = calc_aleatoric(p_hat) epistemic = calc_epistemic(p_hat) return aleatoric, epistemic
def forward_augmentation_smoothing( self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool = False) -> np.ndarray: transforms = tta.Compose([ tta.HorizontalFlip(), tta.Multiply(factors=[0.9, 1, 1.1]), ]) cams = [] for transform in transforms: augmented_tensor = transform.augment_image(input_tensor) cam = self.forward(augmented_tensor, targets, eigen_smooth) # The ttach library expects a tensor of size BxCxHxW cam = cam[:, None, :, :] cam = torch.from_numpy(cam) cam = transform.deaugment_mask(cam) # Back to numpy float32, HxW cam = cam.numpy() cam = cam[:, 0, :, :] cams.append(cam) cam = np.mean(np.float32(cams), axis=0) return cam
def __init__(self): super(Net, self).__init__() self.transforms = ttach.Compose([ ttach.HorizontalFlip(), # ttach.Scale(scales=[1, 1.05], interpolation="linear"), ttach.Multiply(factors=[0.95, 1, 1.05]), ]) self.model = ttach.ClassificationTTAWrapper(InnerNet(), transforms=self.transforms, merge_mode="mean")
face_model4 = get_face_model(coeff=1) text_model4 = get_text_model(nhead=8, num_layers=12) model4 = multimodal_model(speech_model4, face_model4, text_model4) model4.load_state_dict(torch.load(weight_path4)) model4.cuda() model4 = nn.DataParallel(model4) model4.eval() Emotions = [] Emotions3 = [] temperature = 0.1 ## 1 for arithmetic mean, >1 for temparature sharpening, < 1 for temparature smoothing transforms = tta.Compose([ tta.HorizontalFlip(), tta.VerticalFlip(), # tta.Multiply(factors=[0.8, 0.9, 1, 1.1, 1.2]), tta.Multiply(factors=[0.9, 1, 1.1]) ]) with torch.no_grad(): for idx, data in enumerate(tqdm(test_loader)): speech = data['speech'].cuda() text = data['text'].cuda() inte_pred = None for transformer in transforms: face = torch.zeros_like(data['face']) for i in range(face.shape[2]): face[:, :, i, :, :] = transformer.augment_image( data['face'][:, :, i, :, :]) face = face.cuda() _inte_pred, _, _, _ = model1(speech, face, text)
def test_multiply_transform(): transform = tta.Multiply(factors=[-1, 0, 1]) a = torch.arange(20).reshape(1, 1, 4, 5).float() for p in transform.params: aug = transform.apply_aug_image(a, **{transform.pname: p}) assert torch.allclose(aug, a * p)
a = torch.arange(20).reshape(1, 1, 4, 5).float() for p in transform.params: aug = transform.apply_aug_image(a, **{transform.pname: p}) deaug = transform.apply_deaug_mask(aug, **{transform.pname: p}) assert torch.allclose(a, deaug) @pytest.mark.parametrize( "transform", [ tta.HorizontalFlip(), tta.VerticalFlip(), tta.Rotate90(angles=[0, 90, 180, 270]), tta.Scale(scales=[1, 2, 4], interpolation="nearest"), tta.Add(values=[-1, 0, 1, 2]), tta.Multiply(factors=[-1, 0, 1, 2]), tta.FiveCrops(crop_height=3, crop_width=5), tta.Resize(sizes=[(4, 5), (8, 10), (2, 2)], interpolation="nearest") ], ) def test_label_is_same(transform): a = torch.arange(20).reshape(1, 1, 4, 5).float() for p in transform.params: aug = transform.apply_aug_image(a, **{transform.pname: p}) deaug = transform.apply_deaug_label(aug, **{transform.pname: p}) assert torch.allclose(aug, deaug) def test_add_transform(): transform = tta.Add(values=[-1, 0, 1]) a = torch.arange(20).reshape(1, 1, 4, 5).float()