Esempio n. 1
0
def SUE_TTA(model, batch: torch.tensor, last_layer: bool) -> Tuple[np.ndarray, np.ndarray]:
    r"""Interface of Binary Segmentation Uncertainty Estimation with Test-Time Augmentations (TTA) method for 1 2D slice.
            Inputs supposed to be in range [0, data_range].
            Args:
                model: Trained model.
                batch: Tensor with shape (1, C, H, W).
                last_layer: Flag whether there is Sigmoid as a last NN layer
            Returns:
                Aleatoric and epistemic uncertainty maps with shapes equal to batch shape
     """
    model.eval()
    transforms = tta.Compose(
        [
            tta.VerticalFlip(),
            tta.HorizontalFlip(),
            tta.Rotate90(angles=[0, 180]),
            tta.Scale(scales=[1, 2, 4]),
            tta.Multiply(factors=[0.9, 1, 1.1]),
        ]
    )
    predicted = []
    for transformer in transforms:
        augmented_image = transformer.augment_image(batch)
        model_output = model(augmented_image)
        deaug_mask = transformer.deaugment_mask(model_output)
        prediction = torch.sigmoid(
            deaug_mask).cpu().detach().numpy() if last_layer else deaug_mask.cpu().detach().numpy()
        predicted.append(prediction)

    p_hat = np.array(predicted)
    aleatoric = calc_aleatoric(p_hat)
    epistemic = calc_epistemic(p_hat)

    return aleatoric, epistemic
Esempio n. 2
0
    def predict(self, tta_aug=None, debug=None):
        transforms = tta_aug
        if tta_aug is None:
            import ttach as tta
            transforms = tta.Compose([
                tta.Scale(scales=[0.95, 1, 1.05]),
                tta.HorizontalFlip(),
            ])
        from torch.utils import data

        self.model.eval()

        if not isinstance(self.settings, PredictorSettings):
            logger.warning(
                'Settings is of type: {}. Pass settings to network object of type Train to train'
                .format(str(type(self.settings))))
            return
        predict_loader = data.DataLoader(dataset=self.settings.PREDICT_DATASET,
                                         batch_size=1,
                                         shuffle=False,
                                         num_workers=self.settings.PROCESSES)
        with torch.no_grad():
            for idx, (data, target, id) in enumerate(predict_loader):
                data, target = data.to(self.device), target.to(
                    self.device, dtype=torch.int64)
                outputs = []
                o_shape = data.shape
                for transformer in transforms:
                    augmented_image = transformer.augment_image(data)
                    shape = list(augmented_image.shape)[2:]
                    padded = pad(augmented_image, self.padding_value)  ## 2**5

                    input = padded.float()
                    output = self.model(input)
                    output = unpad(output, shape)
                    reversed = transformer.deaugment_mask(output)
                    reversed = torch.nn.functional.interpolate(
                        reversed, size=list(o_shape)[2:], mode="nearest")
                    print(
                        "original: {} input: {}, padded: {} unpadded {} output {}"
                        .format(str(o_shape), str(shape),
                                str(list(augmented_image.shape)),
                                str(list(output.shape)),
                                str(list(reversed.shape))))
                    outputs.append(reversed)
                stacked = torch.stack(outputs)
                output = torch.mean(stacked, dim=0)
                outputs.append(output)
                out = output.data.cpu().numpy()
                out = np.transpose(out, (0, 2, 3, 1))
                out = np.squeeze(out)
                yield out
def tta_model_predict(X, model):
    tta_transforms = tta.Compose(
        [tta.HorizontalFlip(),
         tta.Scale(scales=[0.5, 1, 2])])
    masks = []
    for transformer in tta_transforms:
        augmented_image = transformer.augment_image(X)
        model_output = model(augmented_image)["out"]

        deaug_mask = transformer.deaugment_mask(model_output)
        masks.append(deaug_mask)

    mask = torch.sum(torch.stack(masks), dim=0) / len(masks)
    return mask
Esempio n. 4
0
def test_compose_1():
    transform = tta.Compose([
        tta.HorizontalFlip(),
        tta.VerticalFlip(),
        tta.Rotate90(angles=[0, 90, 180, 270]),
        tta.Scale(scales=[1, 2, 4], interpolation="nearest"),
    ])

    assert len(
        transform) == 2 * 2 * 4 * 3  # all combinations for aug parameters

    dummy_label = torch.ones(2).reshape(2, 1).float()
    dummy_image = torch.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5).float()
    dummy_model = lambda x: {"label": dummy_label, "mask": x}

    for augmenter in transform:
        augmented_image = augmenter.augment_image(dummy_image)
        model_output = dummy_model(augmented_image)
        deaugmented_mask = augmenter.deaugment_mask(model_output["mask"])
        deaugmented_label = augmenter.deaugment_label(model_output["label"])
        assert torch.allclose(deaugmented_mask, dummy_image)
        assert torch.allclose(deaugmented_label, dummy_label)
Esempio n. 5
0
def test(model, data_loader, save_path=""):
    """
    为了计算方便,训练过程中的验证与测试都直接计算指标J和F,不再先生成再输出,
    所以这里的指标仅作一个相对的参考,具体真实指标需要使用测试代码处理
    """
    model.eval()
    tqdm_iter = tqdm(enumerate(data_loader),
                     total=len(data_loader),
                     leave=False)

    if arg_config['use_tta']:
        construct_print("We will use Test Time Augmentation!")
        transforms = tta.Compose([  # 2*3
            tta.HorizontalFlip(),
            tta.Scale(scales=[0.75, 1, 1.5],
                      interpolation='bilinear',
                      align_corners=False)
        ])
    else:
        transforms = None

    results = defaultdict(list)
    for test_batch_id, test_data in tqdm_iter:
        tqdm_iter.set_description(f"te=>{test_batch_id + 1}")

        with torch.no_grad():
            curr_jpegs = test_data["image"].to(DEVICES, non_blocking=True)
            curr_flows = test_data["flow"].to(DEVICES, non_blocking=True)
            preds_logits = tta_aug(model=model,
                                   transforms=transforms,
                                   data=dict(curr_jpeg=curr_jpegs,
                                             curr_flow=curr_flows))
            preds_prob = preds_logits.sigmoid().squeeze().cpu().detach(
            )  # float32

        for i, pred_prob in enumerate(preds_prob.numpy()):
            curr_mask_path = test_data["mask_path"][i]
            video_name, mask_name = curr_mask_path.split(os.sep)[-2:]
            mask = read_binary_array(curr_mask_path, thr=0)
            mask_h, mask_w = mask.shape

            pred_prob = cv2.resize(pred_prob,
                                   dsize=(mask_w, mask_h),
                                   interpolation=cv2.INTER_LINEAR)
            pred_prob = clip_to_normalize(data_array=pred_prob,
                                          clip_range=arg_config["clip_range"])
            pred_seg = np.where(pred_prob > 0.5, 255, 0).astype(np.uint8)

            results[video_name].append(
                (jaccard.db_eval_iou(annotation=mask, segmentation=pred_seg),
                 f_boundary.db_eval_boundary(annotation=mask,
                                             segmentation=pred_seg)))

            if save_path:
                pred_video_path = os.path.join(save_path, video_name)
                if not os.path.exists(pred_video_path):
                    os.makedirs(pred_video_path)
                pred_frame_path = os.path.join(pred_video_path, mask_name)
                cv2.imwrite(pred_frame_path, pred_seg)

    j_f_collection = []
    for video_name, video_scores in results.items():
        j_f_for_video = np.mean(np.array(video_scores), axis=0).tolist()
        results[video_name] = j_f_for_video
        j_f_collection.append(j_f_for_video)
    results['average'] = np.mean(np.array(j_f_collection), axis=0).tolist()
    return pretty_print(results)
Esempio n. 6
0
multilabel_resnet34 = models.resnet34(pretrained=False)
multilabel_resnet34.fc = torch.nn.Linear(multilabel_resnet34.fc.in_features,
                                         len(CLASSES))
multilabel_resnet34 = load_dataparallel_model(
    multilabel_resnet34, torch.load(model_multilabel_path))
multilabel_resnet34 = torch.nn.DataParallel(multilabel_resnet34,
                                            device_ids=range(
                                                torch.cuda.device_count()))
multilabel_resnet34.to(DEVICE)
multilabel_resnet34.eval()

multilabel_resnet34 = tta.ClassificationTTAWrapper(multilabel_resnet34,
                                                   tta.Compose([
                                                       tta.HorizontalFlip(),
                                                       tta.VerticalFlip(),
                                                       tta.Scale([0.85, 1.15]),
                                                   ]),
                                                   merge_mode="mean")

# #### Inference
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class_th = torch.tensor([0.0005, 0.2, 0.005, 0.02, 0.005, 0.005, 0.008,
                         0.005]).to(DEVICE)

total_preds, total_paths = [], []


@toma.batch(initial_batchsize=batch_size_multilabel)
def run_multilabel(batch_size):
    inference_loader = DataLoader(inference_dataset,
                                  batch_size=batch_size,
Esempio n. 7
0
def testing(num_split, class_params, encoder, decoder):
    """
    测试推理
    """
    import gc
    torch.cuda.empty_cache()
    gc.collect()

    sub = "./data/Clouds_Classify/sample_submission.csv"
    sub = pd.read_csv(open(sub))
    sub.head()

    sub['label'] = sub['Image_Label'].apply(lambda x: x.split('_')[1])
    sub['im_id'] = sub['Image_Label'].apply(lambda x: x.split('_')[0])

    preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder, 'imagenet')
    if decoder == 'unet':
        model = smp.Unet(
            encoder_name=encoder,
            encoder_weights='imagenet',
            classes=4,
            activation=None,
        )
    else:
        model = smp.FPN(
            encoder_name=encoder,
            encoder_weights='imagenet',
            classes=4,
            activation=None,
        )
    test_ids = [id for id in os.listdir(test_imgs_folder)]

    test_dataset = CloudDataset(
        df=sub,
        transforms=get_validation_augmentation(),
        datatype='test',
        img_ids=test_ids,
        preprocessing=get_preprocessing(preprocessing_fn))
    test_loader = DataLoader(test_dataset,
                             batch_size=1,
                             shuffle=False,
                             num_workers=2)

    loaders = {"test": test_loader}
    logdir = "./logs/log_{}_{}/log_{}".format(encoder, decoder, num_split)

    encoded_pixels = []

    ###############使用pytorch TTA预测####################
    use_TTA = True
    checkpoint_path = logdir + '/checkpoints/best.pth'
    runner_out = []
    model.load_state_dict(torch.load(checkpoint_path)['model_state_dict'])
    #使用tta预测
    if use_TTA:
        transforms = tta.Compose([
            tta.HorizontalFlip(),
            tta.VerticalFlip(),
            tta.Scale(scales=[5 / 6, 1, 7 / 6]),
        ])
        tta_model = tta.SegmentationTTAWrapper(model,
                                               transforms,
                                               merge_mode='mean')
    else:
        tta_model = model

    tta_model = tta_model.cuda()
    tta_model.eval()

    with torch.no_grad():
        for i, data in enumerate(tqdm.tqdm(loaders['test'])):
            img, _ = data
            img = img.cuda()
            batch_preds = tta_model(img).cpu().numpy()
            runner_out.extend(batch_preds)
    runner_out = np.array(runner_out)

    for i, output in tqdm.tqdm(enumerate(runner_out)):
        for j, probability in enumerate(output):
            if probability.shape != (350, 525):
                probability = cv2.resize(probability,
                                         dsize=(525, 350),
                                         interpolation=cv2.INTER_LINEAR)
            logit = sigmoid(probability)
            predict, num_predict = post_process(logit, class_params[j][0],
                                                class_params[j][1])

            if num_predict == 0:
                encoded_pixels.append('')
            else:
                r = mask2rle(predict)
                encoded_pixels.append(r)

    sub['EncodedPixels'] = encoded_pixels
    sub.to_csv('./sub/{}_{}/tta_submission_{}.csv'.format(
        encoder, decoder, num_split),
               columns=['Image_Label', 'EncodedPixels'],
               index=False)
Esempio n. 8
0
def validation(valid_ids, num_split, encoder, decoder):
    """
    模型验证,并选择后处理参数
    """
    train = "./data/Clouds_Classify/train.csv"

    # Data overview
    train = pd.read_csv(open(train))
    train.head()

    train['label'] = train['Image_Label'].apply(lambda x: x.split('_')[1])
    train['im_id'] = train['Image_Label'].apply(lambda x: x.split('_')[0])

    ENCODER = encoder
    ENCODER_WEIGHTS = 'imagenet'
    if decoder == 'unet':
        model = smp.Unet(
            encoder_name=ENCODER,
            encoder_weights=ENCODER_WEIGHTS,
            classes=4,
            activation=None,
        )
    else:
        model = smp.FPN(
            encoder_name=ENCODER,
            encoder_weights=ENCODER_WEIGHTS,
            classes=4,
            activation=None,
        )
    preprocessing_fn = smp.encoders.get_preprocessing_fn(
        ENCODER, ENCODER_WEIGHTS)

    num_workers = 4
    valid_bs = 32
    valid_dataset = CloudDataset(
        df=train,
        transforms=get_validation_augmentation(),
        datatype='valid',
        img_ids=valid_ids,
        preprocessing=get_preprocessing(preprocessing_fn))
    valid_loader = DataLoader(valid_dataset,
                              batch_size=valid_bs,
                              shuffle=False,
                              num_workers=num_workers)

    loaders = {"valid": valid_loader}
    logdir = "./logs/log_{}_{}/log_{}".format(encoder, decoder, num_split)

    valid_masks = []
    probabilities = np.zeros((len(valid_ids) * 4, 350, 525))

    ############### TTA预测 ####################
    use_TTA = True
    checkpoint_path = logdir + '/checkpoints/best.pth'
    runner_out = []
    model.load_state_dict(torch.load(checkpoint_path)['model_state_dict'])

    if use_TTA:
        transforms = tta.Compose([
            tta.HorizontalFlip(),
            tta.VerticalFlip(),
            tta.Scale(scales=[5 / 6, 1, 7 / 6]),
        ])
        tta_model = tta.SegmentationTTAWrapper(model,
                                               transforms,
                                               merge_mode='mean')
    else:
        tta_model = model

    tta_model = tta_model.cuda()
    tta_model.eval()

    with torch.no_grad():
        for i, data in enumerate(tqdm.tqdm(loaders['valid'])):
            img, _ = data
            img = img.cuda()
            batch_preds = tta_model(img).cpu().numpy()
            runner_out.extend(batch_preds)
    runner_out = np.array(runner_out)
    ######################END##########################

    for i, ((_, mask),
            output) in enumerate(tqdm.tqdm(zip(valid_dataset, runner_out))):
        for m in mask:
            if m.shape != (350, 525):
                m = cv2.resize(m,
                               dsize=(525, 350),
                               interpolation=cv2.INTER_LINEAR)
            valid_masks.append(m)

        for j, probability in enumerate(output):
            if probability.shape != (350, 525):
                probability = cv2.resize(probability,
                                         dsize=(525, 350),
                                         interpolation=cv2.INTER_LINEAR)
            probabilities[i * 4 + j, :, :] = probability

    # Find optimal values
    print('searching for optimal param...')
    params_0 = [[35, 76], [12000, 19001]]
    params_1 = [[35, 76], [12000, 19001]]
    params_2 = [[35, 76], [12000, 19001]]
    params_3 = [[35, 76], [8000, 15001]]
    param = [params_0, params_1, params_2, params_3]

    for class_id in range(4):
        par = param[class_id]
        attempts = []
        for t in range(par[0][0], par[0][1], 5):
            t /= 100
            for ms in range(par[1][0], par[1][1], 2000):
                masks = []
                print('==> searching [class_id:%d threshold:%.3f ms:%d]' %
                      (class_id, t, ms))
                for i in tqdm.tqdm(range(class_id, len(probabilities), 4)):
                    probability = probabilities[i]
                    predict, _ = post_process(sigmoid(probability), t, ms)
                    masks.append(predict)

                d = []
                for i, j in zip(masks, valid_masks[class_id::4]):
                    if (i.sum() == 0) & (j.sum() == 0):
                        d.append(1)
                    else:
                        d.append(dice(i, j))

                attempts.append((t, ms, np.mean(d)))

        attempts_df = pd.DataFrame(attempts,
                                   columns=['threshold', 'size', 'dice'])

        attempts_df = attempts_df.sort_values('dice', ascending=False)
        attempts_df.to_csv(
            './params/{}_{}_par/params_{}/tta_params_{}.csv'.format(
                encoder, decoder, num_split, class_id),
            columns=['threshold', 'size', 'dice'],
            index=False)
Esempio n. 9
0
import pytest
import torch
import ttach as tta


@pytest.mark.parametrize(
    "transform",
    [
        tta.HorizontalFlip(),
        tta.VerticalFlip(),
        tta.Rotate90(angles=[0, 90, 180, 270]),
        tta.Scale(scales=[1, 2, 4], interpolation="nearest"),
        tta.Resize(sizes=[(4, 5), (8, 10)], original_size=(4, 5), interpolation="nearest")
    ],
)
def test_aug_deaug_mask(transform):
    a = torch.arange(20).reshape(1, 1, 4, 5).float()
    for p in transform.params:
        aug = transform.apply_aug_image(a, **{transform.pname: p})
        deaug = transform.apply_deaug_mask(aug, **{transform.pname: p})
        assert torch.allclose(a, deaug)


@pytest.mark.parametrize(
    "transform",
    [
        tta.HorizontalFlip(),
        tta.VerticalFlip(),
        tta.Rotate90(angles=[0, 90, 180, 270]),
        tta.Scale(scales=[1, 2, 4], interpolation="nearest"),
        tta.Add(values=[-1, 0, 1, 2]),
def get_val_logits(valid_ids, num_split, encoder, decoder):
    # valid
    train = "./data/Clouds_Classify/train.csv"

    # Data overview
    train = pd.read_csv(open(train))
    train.head()

    train['label'] = train['Image_Label'].apply(lambda x: x.split('_')[1])
    train['im_id'] = train['Image_Label'].apply(lambda x: x.split('_')[0])

    ENCODER = encoder
    ENCODER_WEIGHTS = 'imagenet'

    if decoder == 'unet':
        #建立模型
        model = smp.Unet(
            encoder_name=ENCODER, 
            encoder_weights=ENCODER_WEIGHTS, 
            classes=4, 
            activation=None,
        )
    else:
        model = smp.FPN(
            encoder_name=ENCODER, 
            encoder_weights=ENCODER_WEIGHTS, 
            classes=4, 
            activation=None,
        )
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

    num_workers = 4
    valid_bs = 1
    valid_dataset = CloudDataset(df=train, transforms = get_validation_augmentation(), datatype='valid', img_ids=valid_ids, preprocessing=get_preprocessing(preprocessing_fn))
    valid_loader = DataLoader(valid_dataset, batch_size=valid_bs, shuffle=False, num_workers=num_workers)    

    loaders = {"valid": valid_loader}
    logdir = "./logs/log_{}_{}/log_{}".format(encoder, decoder, num_split)

    print('predicting for validation data...')
    ###############使用pytorch TTA预测####################
    use_TTA = True
    checkpoint_path = logdir + '/checkpoints/best.pth'
    model.load_state_dict(torch.load(checkpoint_path)['model_state_dict'])
    #使用tta预测
    if use_TTA:
        transforms = tta.Compose([
            tta.HorizontalFlip(),
            tta.VerticalFlip(),
            tta.Scale(scales=[5/6, 1, 7/6]),
        ])
        tta_model = tta.SegmentationTTAWrapper(model, transforms, merge_mode='mean')
    else:
        tta_model = model
    
    tta_model = tta_model.cuda()
    tta_model.eval()

    with torch.no_grad(): 
        for _, data in enumerate(tqdm.tqdm(loaders['valid'])):
            img, _, img_name = data
            img = img.cuda()
            batch_preds = tta_model(img).cpu().numpy()
            batch_preds = batch_preds.astype(np.float16)
            
            save_dir = './logits/valid/' + encoder + '_' + decoder + '/log_{}'.format(num_split)
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)

            file_name = img_name[0].split('.')[0] + '.plk'
            file_path = os.path.join(save_dir, file_name)

            with open(file_path, 'wb') as wf:
                plk.dump(batch_preds, wf)
def get_test_logits(encoder, decoder):
    '''
    预测并保存测试集logits
    '''
    sub = "./data/Clouds_Classify/sample_submission.csv"
    sub = pd.read_csv(open(sub))
    
    sub['label'] = sub['Image_Label'].apply(lambda x: x.split('_')[1])
    sub['im_id'] = sub['Image_Label'].apply(lambda x: x.split('_')[0])
   
    #建立模型
    preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder, 'imagenet')
    if decoder == 'unet':
        model = smp.Unet(
            encoder_name=encoder, 
            encoder_weights='imagenet', 
            classes=4, 
            activation=None,
        )
    else:
        model = smp.FPN(
            encoder_name=encoder, 
            encoder_weights='imagenet', 
            classes=4, 
            activation=None,
        )
    
    #载入模型参数
    logdir = "./logs/log_{}_{}/log_{}".format(encoder, 'fpn', 4)
    checkpoint_path = logdir + '/checkpoints/best.pth'
    model.load_state_dict(torch.load(checkpoint_path)['model_state_dict'])
    
    #使用tta预测
    use_TTA = True
    if use_TTA:
        print('using TTA!!!')
        transforms = tta.Compose([
            tta.HorizontalFlip(),
            tta.VerticalFlip(),
            tta.Scale(scales=[5/6, 1, 7/6]),
            ])
        tta_model = tta.SegmentationTTAWrapper(model, transforms, merge_mode='mean')
    else:
        tta_model = model

    test_ids = [id for id in os.listdir(test_imgs_folder)]

    test_dataset = CloudDataset(df=sub, transforms = get_validation_augmentation(), datatype='test', img_ids=test_ids, preprocessing=get_preprocessing(preprocessing_fn))
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2)

    tta_model = tta_model.cuda()
    tta_model.eval() 

    with torch.no_grad(): 
        for i, data in enumerate(tqdm.tqdm(test_loader)):
            img, _, img_name = data
            img = img.cuda()
            batch_preds = tta_model(img).cpu().numpy()
            batch_preds = batch_preds.astype(np.float16)


            save_dir = './logits/test/' + encoder + '_fpn' + '/log_{}'.format(4)
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)

            #保存为pickle
            file_name = img_name[0].split('.')[0] + '.plk'
            file_path = os.path.join(save_dir, file_name)

            with open(file_path, 'wb') as wf:
                plk.dump(batch_preds, wf)
Esempio n. 12
0
    def predict_single_image(self,
                             image: np.array,
                             rgb=True,
                             preprocessing=True,
                             tta_aug=None):
        from segmentation.dataset import process
        if not isinstance(self.settings, PredictorSettings):
            logger.warning(
                'Settings is of type: {}. Pass settings to network object of type Train to train'
                .format(str(type(self.settings))))
            return
        # from torch.utils import data
        transforms = tta_aug
        if tta_aug is None:
            import ttach as tta
            transforms = tta.Compose([
                tta.Scale(scales=[0.95, 1, 1.05]),
                tta.HorizontalFlip(),
            ])
        self.model.eval()
        preprocessing_fn = sm.encoders.get_preprocessing_fn(self.encoder)
        image, pseudo_mask = process(image=image,
                                     mask=image,
                                     rgb=rgb,
                                     preprocessing=preprocessing_fn,
                                     apply_preprocessing=preprocessing,
                                     augmentation=None,
                                     color_map=None,
                                     binary_augmentation=False)
        # data = image
        data = image.unsqueeze(0)

        with torch.no_grad():
            data = data.to(self.device)

            outputs = []
            o_shape = data.shape
            for transformer in transforms:
                augmented_image = transformer.augment_image(data)
                shape = list(augmented_image.shape)[2:]
                padded = pad(augmented_image, self.padding_value)  ## 2**5

                input = padded.float()
                output = self.model(input)
                output = unpad(output, shape)
                reversed = transformer.deaugment_mask(output)
                reversed = torch.nn.functional.interpolate(
                    reversed, size=list(o_shape)[2:], mode="nearest")
                logger.debug(
                    "original: {} input: {}, padded: {} unpadded {} output {} \n"
                    .format(str(o_shape), str(shape),
                            str(list(augmented_image.shape)),
                            str(list(output.shape)),
                            str(list(reversed.shape))))
                outputs.append(reversed)
            stacked = torch.stack(outputs)
            output = torch.mean(stacked, dim=0)
            out = output.data.cpu().numpy()
            out = np.transpose(out, (0, 2, 3, 1))
            out = np.squeeze(out)

            return out