Exemple #1
0
def evaluate(val_loader, model, loss_fn, device, use_tta=False):
    model.eval()

    if use_tta:
        transformations = tta.Compose([
            tta.Rotate90(angles=[0, 90, 180, 270]),
            tta.HorizontalFlip(),
            tta.VerticalFlip()
        ])

        tta_model = tta.ClassificationTTAWrapper(model, transformations)

    correct = 0
    total = 0
    total_loss = 0
    for i, batch in enumerate(val_loader):
        input_data, labels = batch
        input_data, labels = input_data.to(device), labels.to(device)
        with torch.no_grad():
            if use_tta:
                predictions = tta_model(input_data)
            else:
                predictions = model(input_data)
            total_loss += loss_fn(predictions, labels).item()
            correct += (predictions.argmax(axis=1) == labels).sum().item()
            total += len(labels)
            torch.cuda.empty_cache()

    model.train()
    return total_loss / total, correct / total
    def forward_augmentation_smoothing(
            self,
            input_tensor: torch.Tensor,
            targets: List[torch.nn.Module],
            eigen_smooth: bool = False) -> np.ndarray:
        transforms = tta.Compose([
            tta.HorizontalFlip(),
            tta.Multiply(factors=[0.9, 1, 1.1]),
        ])
        cams = []
        for transform in transforms:
            augmented_tensor = transform.augment_image(input_tensor)
            cam = self.forward(augmented_tensor, targets, eigen_smooth)

            # The ttach library expects a tensor of size BxCxHxW
            cam = cam[:, None, :, :]
            cam = torch.from_numpy(cam)
            cam = transform.deaugment_mask(cam)

            # Back to numpy float32, HxW
            cam = cam.numpy()
            cam = cam[:, 0, :, :]
            cams.append(cam)

        cam = np.mean(np.float32(cams), axis=0)
        return cam
Exemple #3
0
def process_folder():
    image_folder = add_backslash_if_needed(IMAGE_FOLDER)
    model_folder = add_backslash_if_needed(MODEL_FOLDER)
    image_names = list(os.listdir(image_folder))
    image_names.sort()

    transforms = tta.Compose(
                [
                    tta.HorizontalFlip(),
                ]
            )
    all_thrs = [0.18, 0.22, 0.22, 0.2, 0.2]
    all_preds = []
    for ind in range(5):
        system = AppleClassification(model_path=f'{model_folder}fold{ind}.ckpt',
                                     device='cuda:0',
                                     transforms=transforms,
                                     th=all_thrs[ind])
        labels, probs = system.process_folder(image_folder, image_names, num_workers=NUM_WORKERS)
        # float, will multiply by weights
        labels = np.array(labels, dtype=float)
        all_preds.append(labels)

    weights = np.array([1, 1, 1, 1, 1], dtype=float)
    weighted_values = weights[0] * all_preds[0]
    for ind in range(1, 5):
        weighted_values += weights[ind] * all_preds[ind]
    weighted_values = weighted_values / np.sum(weights)
    final_preds = (weighted_values > 0.5).astype(int)

    df = pd.DataFrame({'name': image_names,
                       'disease_flag': final_preds})
    df.to_csv(OUTPUT_FILE, index=False)

    return
Exemple #4
0
def get_predictions(model_chosen, tta = False):

    model_chosen.cuda.eval()
    actual_values, predicted_values = [], []

    if tta == True:
        transformation = ttach.Compose(
            [
                ttach.HorizontalFlip(),
                ttach.VerticalFlip(),
                ttach.Rotate90(angles=[0, 90, 180, 270])
            ]
        )
        test_time_augmentation_wrapper = ttach.ClassificationTTAWrapper(model_chosen, transformation)
        with torch.no_grad():
            for batch in loader_of_test:
                test_image, test_label = batch
                predicted_value = test_time_augmentation_wrapper(test_image.cuda())
                predicted_value = torch.argmax(predicted_value, dim=1).detach().cpu().numpy()
                actual_values.append(test_label.cpu().numpy())
                predicted_values.append(predicted_value)
    else:
        with torch.no_grad():
            for batch in loader_of_test:
                test_image, test_label = batch
                predicted_value = model_chosen(test_image.cuda())
                predicted_value = torch.argmax(predicted_value, dim=1).detach().cpu().numpy()
                actual_values.append(test_label.cpu().numpy())
                predicted_values.append(predicted_value)

    return predicted_values
Exemple #5
0
def single_model_predict_tta():
    assert len(model_name_list) == 1
    model_name = model_name_list[0]
    model = Net(model_name).to(device)
    model_save_path = os.path.join(
        config.dir_weight, '{}.bin'.format(model_name))
    model.load_state_dict(torch.load(model_save_path))

    transforms = tta.Compose([
        tta.HorizontalFlip(),
        # tta.Rotate90(angles=[0, 180]),
        # tta.Scale(scales=[1, 2, 4]),
        # tta.Multiply(factors=[0.9, 1, 1.1]),
        tta.FiveCrops(224, 224)
    ])

    tta_model = tta.ClassificationTTAWrapper(model, transforms)

    pred_list = []
    with torch.no_grad():
        for batch_x, _ in tqdm(test_loader):
            batch_x = batch_x.to(device)
            probs = tta_model(batch_x)
            probs = torch.max(torch.softmax(probs, dim=1), dim=1)
            probs = probs[1].cpu().numpy()
            pred_list += probs.tolist()

    submission = pd.DataFrame({
        "id": range(len(pred_list)),
        "label": [int2label(x) for x in pred_list]
    })
    submission.to_csv(config.dir_csv_test, index=False, header=False)
Exemple #6
0
def multi_model_predict_tta():
    preds_dict = dict()
    for model_name in model_name_list:
        for fold_idx in range(5):
            model = Net(model_name).to(device)
            model_save_path = os.path.join(
                config.dir_weight, '{}_fold{}.bin'.format(model_name, fold_idx))
            model.load_state_dict(torch.load(model_save_path))
            '/home/muyun99/data/dataset/AIyanxishe/Image_Classification/weight/resnet18_train_size_256_fold0.bin'
            transforms = tta.Compose([
                tta.Resize([int(config.size_test_image), int(config.size_test_image)]),
                tta.HorizontalFlip(),
                # tta.Rotate90(angles=[0, 180]),
                # tta.Scale(scales=[1, 2, 4]),
                # tta.Multiply(factors=[0.9, 1, 1.1]),
                tta.FiveCrops(config.size_test_image, config.size_test_image)
            ])
            tta_model = tta.ClassificationTTAWrapper(model, transforms)

            pred_list = predict(tta_model)
            submission = pd.DataFrame(pred_list)
            submission.to_csv(
                '{}/{}_fold{}_submission.csv'.format(config.dir_submission, config.save_model_name, fold_idx),
                index=False,
                header=False
            )
            preds_dict['{}_{}'.format(model_name, fold_idx)] = pred_list

    pred_list = get_pred_list(preds_dict)
    submission = pd.DataFrame(
        {"id": range(len(pred_list)), "label": [int2label(x) for x in pred_list]})
    submission.to_csv(config.dir_csv_test, index=False, header=False)
Exemple #7
0
def SUE_TTA(model, batch: torch.tensor, last_layer: bool) -> Tuple[np.ndarray, np.ndarray]:
    r"""Interface of Binary Segmentation Uncertainty Estimation with Test-Time Augmentations (TTA) method for 1 2D slice.
            Inputs supposed to be in range [0, data_range].
            Args:
                model: Trained model.
                batch: Tensor with shape (1, C, H, W).
                last_layer: Flag whether there is Sigmoid as a last NN layer
            Returns:
                Aleatoric and epistemic uncertainty maps with shapes equal to batch shape
     """
    model.eval()
    transforms = tta.Compose(
        [
            tta.VerticalFlip(),
            tta.HorizontalFlip(),
            tta.Rotate90(angles=[0, 180]),
            tta.Scale(scales=[1, 2, 4]),
            tta.Multiply(factors=[0.9, 1, 1.1]),
        ]
    )
    predicted = []
    for transformer in transforms:
        augmented_image = transformer.augment_image(batch)
        model_output = model(augmented_image)
        deaug_mask = transformer.deaugment_mask(model_output)
        prediction = torch.sigmoid(
            deaug_mask).cpu().detach().numpy() if last_layer else deaug_mask.cpu().detach().numpy()
        predicted.append(prediction)

    p_hat = np.array(predicted)
    aleatoric = calc_aleatoric(p_hat)
    epistemic = calc_epistemic(p_hat)

    return aleatoric, epistemic
 def build_tta_model(self, model, config, device):
     tta_model = getattr(tta, config["tta"])(
         model,
         tta.Compose([tta.HorizontalFlip(),
                      tta.VerticalFlip()]),
         merge_mode="mean",
     )
     tta_model.to(device)
     return tta_model
Exemple #9
0
 def __init__(self):
     super(Net, self).__init__()
     self.transforms = ttach.Compose([
         ttach.HorizontalFlip(),
         # ttach.Scale(scales=[1, 1.05], interpolation="linear"),
         ttach.Multiply(factors=[0.95, 1, 1.05]),
     ])
     self.model = ttach.ClassificationTTAWrapper(InnerNet(),
                                                 transforms=self.transforms,
                                                 merge_mode="mean")
Exemple #10
0
def main():
    device = torch.device(f"cuda" if torch.cuda.is_available() else 'cpu')

    transforms = tta.Compose([ tta.HorizontalFlip() ])

    #best_threshold, best_min_size_threshold = search_threshold(device, transforms)
    best_threshold = [0.8, 0.7, 0.8, 0.7]
    best_min_size_threshold = 0

    predict(best_threshold, best_min_size_threshold, device, transforms)
Exemple #11
0
def init_model(model_path):
    transforms = tta.Compose([
        tta.HorizontalFlip(),
    ])
    system = AppleClassification(model_path=model_path,
                                 device='cuda:0',
                                 transforms=None,
                                 th=0.2,
                                 gradcam=True)
    return system
Exemple #12
0
def init_model():
    transforms = tta.Compose(
                [
                    tta.HorizontalFlip(),
                ]
            )
    system = AppleClassification(model_path=MODEL_PATH,
                                 device='cuda:0',
                                 transforms=transforms,
                                 th=0.2)
    return system
    def predict(self, tta_aug=None, debug=None):
        transforms = tta_aug
        if tta_aug is None:
            import ttach as tta
            transforms = tta.Compose([
                tta.Scale(scales=[0.95, 1, 1.05]),
                tta.HorizontalFlip(),
            ])
        from torch.utils import data

        self.model.eval()

        if not isinstance(self.settings, PredictorSettings):
            logger.warning(
                'Settings is of type: {}. Pass settings to network object of type Train to train'
                .format(str(type(self.settings))))
            return
        predict_loader = data.DataLoader(dataset=self.settings.PREDICT_DATASET,
                                         batch_size=1,
                                         shuffle=False,
                                         num_workers=self.settings.PROCESSES)
        with torch.no_grad():
            for idx, (data, target, id) in enumerate(predict_loader):
                data, target = data.to(self.device), target.to(
                    self.device, dtype=torch.int64)
                outputs = []
                o_shape = data.shape
                for transformer in transforms:
                    augmented_image = transformer.augment_image(data)
                    shape = list(augmented_image.shape)[2:]
                    padded = pad(augmented_image, self.padding_value)  ## 2**5

                    input = padded.float()
                    output = self.model(input)
                    output = unpad(output, shape)
                    reversed = transformer.deaugment_mask(output)
                    reversed = torch.nn.functional.interpolate(
                        reversed, size=list(o_shape)[2:], mode="nearest")
                    print(
                        "original: {} input: {}, padded: {} unpadded {} output {}"
                        .format(str(o_shape), str(shape),
                                str(list(augmented_image.shape)),
                                str(list(output.shape)),
                                str(list(reversed.shape))))
                    outputs.append(reversed)
                stacked = torch.stack(outputs)
                output = torch.mean(stacked, dim=0)
                outputs.append(output)
                out = output.data.cpu().numpy()
                out = np.transpose(out, (0, 2, 3, 1))
                out = np.squeeze(out)
                yield out
Exemple #14
0
def test_time_aug(net, merge_mode='mean'):
    """
    More operations please assess to this url: https://github.com/qubvel/ttach
    """
    print("Using the test time augmentation! [Default: HorizontalFlip]")
    trans = tta.Compose([
        tta.HorizontalFlip(),
        # tta.Rotate90(angles=[0, 180]),
        # tta.Scale(scales=[1, 2]),
        # tta.Multiply(factors=[0.9, 1, 1.1]),
    ])
    net = tta.SegmentationTTAWrapper(net, trans, merge_mode=merge_mode)
    return net
def tta_model_predict(X, model):
    tta_transforms = tta.Compose(
        [tta.HorizontalFlip(),
         tta.Scale(scales=[0.5, 1, 2])])
    masks = []
    for transformer in tta_transforms:
        augmented_image = transformer.augment_image(X)
        model_output = model(augmented_image)["out"]

        deaug_mask = transformer.deaugment_mask(model_output)
        masks.append(deaug_mask)

    mask = torch.sum(torch.stack(masks), dim=0) / len(masks)
    return mask
def predict(model_path, test_loader, saveFileName, iftta):

    ## predict
    model = initialize_model(num_classes=176)

    # create model and load weights from checkpoint
    model = model.to(device)
    model.load_state_dict(torch.load(model_path))

    if iftta:
        print("Using TTA")
        transforms = tta.Compose(
            [
                tta.HorizontalFlip(),
                tta.VerticalFlip(),
                tta.Rotate90(angles=[0, 180]),
                # tta.Scale(scales=[1, 0.3]), 
            ]
        )
        model = tta.ClassificationTTAWrapper(model, transforms)

    # Make sure the model is in eval mode.
    # Some modules like Dropout or BatchNorm affect if the model is in training mode.
    model.eval()
    
    # Initialize a list to store the predictions.
    predictions = []
    # Iterate the testing set by batches.
    for batch in tqdm(test_loader):

        imgs = batch
        with torch.no_grad():
            logits = model(imgs.to(device))
        
        # Take the class with greatest logit as prediction and record it.
        predictions.extend(logits.argmax(dim=-1).cpu().numpy().tolist())

    preds = []
    for i in predictions:
        preds.append(num_to_class[i])

    test_data = pd.read_csv('leaves_data/test.csv')
    test_data['label'] = pd.Series(preds)
    submission = pd.concat([test_data['image'], test_data['label']], axis=1)
    submission.to_csv(saveFileName, index=False)
    print("Done!!!!!!!!!!!!!!!!!!!!!!!!!!!")
Exemple #17
0
def init_model():
    path = os.path.join(os.path.dirname(__file__), 'net.pth')
    model = DeepLab(output_stride=16,
                    class_num=17,
                    pretrained=False,
                    bn_momentum=0.1,
                    freeze_bn=False)
    model.load_state_dict(torch.load(path))

    transforms = tta.Compose([
        tta.HorizontalFlip(),
        tta.Rotate90(angles=[0, 180]),
    ])

    model = tta.SegmentationTTAWrapper(model, transforms)
    model = model.cuda()

    return model
Exemple #18
0
def get_tta_model(model: nn.Module, crop_method: str,
                  input_size: List[int]) -> nn.Module:
    """Wraps input model to TTA model.

    Args:
        model: input model without TTA
        crop_method: one of {'resize', 'crop'}. Cropping method of the input images
        input_size: model's input size

    Returns:
        Model with TTA
    """

    transforms = [ttach.HorizontalFlip()]
    if crop_method == "crop":
        transforms.append(
            ThreeCrops(crop_height=input_size[0], crop_width=input_size[1]))
    transforms = ttach.Compose(transforms)
    model = ttach.ClassificationTTAWrapper(model, transforms)

    return model
Exemple #19
0
def test_compose_1():
    transform = tta.Compose([
        tta.HorizontalFlip(),
        tta.VerticalFlip(),
        tta.Rotate90(angles=[0, 90, 180, 270]),
        tta.Scale(scales=[1, 2, 4], interpolation="nearest"),
    ])

    assert len(
        transform) == 2 * 2 * 4 * 3  # all combinations for aug parameters

    dummy_label = torch.ones(2).reshape(2, 1).float()
    dummy_image = torch.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5).float()
    dummy_model = lambda x: {"label": dummy_label, "mask": x}

    for augmenter in transform:
        augmented_image = augmenter.augment_image(dummy_image)
        model_output = dummy_model(augmented_image)
        deaugmented_mask = augmenter.deaugment_mask(model_output["mask"])
        deaugmented_label = augmenter.deaugment_label(model_output["label"])
        assert torch.allclose(deaugmented_mask, dummy_image)
        assert torch.allclose(deaugmented_label, dummy_label)
Exemple #20
0
def segmentation(
    param,
    input_image,
    label_arr,
    num_classes: int,
    gpkg_name,
    model,
    chunk_size: int,
    device,
    scale: List,
    BGR_to_RGB: bool,
    tp_mem,
    debug=False,
):
    """

    Args:
        param: parameter dict
        input_image: opened image (rasterio object)
        label_arr: numpy array of label if available
        num_classes: number of classes
        gpkg_name: geo-package name if available
        model: model weights
        chunk_size: image tile size
        device: cuda/cpu device
        scale: scale range
        BGR_to_RGB: True/False
        tp_mem: memory temp file for saving numpy array to disk
        debug: True/False

    Returns:

    """
    xmin, ymin, xmax, ymax = (input_image.bounds.left,
                              input_image.bounds.bottom,
                              input_image.bounds.right, input_image.bounds.top)
    xres, yres = (abs(input_image.transform.a), abs(input_image.transform.e))
    mx = chunk_size * xres
    my = chunk_size * yres
    padded = chunk_size * 2
    h = input_image.height
    w = input_image.width
    h_ = h + padded
    w_ = w + padded
    dist_samples = int(round(chunk_size * (1 - 1.0 / 2.0)))

    # switch to evaluate mode
    model.eval()

    # initialize test time augmentation
    transforms = tta.Compose([
        tta.HorizontalFlip(),
    ])
    # construct window for smoothing
    WINDOW_SPLINE_2D = _window_2D(window_size=padded, power=2.0)
    WINDOW_SPLINE_2D = torch.as_tensor(np.moveaxis(WINDOW_SPLINE_2D, 2,
                                                   0), ).type(torch.float)
    WINDOW_SPLINE_2D = WINDOW_SPLINE_2D.to(device)

    fp = np.memmap(tp_mem,
                   dtype='float16',
                   mode='w+',
                   shape=(h_, w_, num_classes))
    sample = {'sat_img': None, 'map_img': None, 'metadata': None}
    cnt = 0
    img_gen = gen_img_samples(input_image, chunk_size)
    start_seg = time.time()
    for img in tqdm(img_gen,
                    position=1,
                    leave=False,
                    desc='inferring on window slices'):
        row = img[1]
        col = img[2]
        sub_image = img[0]
        image_metadata = add_metadata_from_raster_to_sample(
            sat_img_arr=sub_image,
            raster_handle=input_image,
            meta_map={},
            raster_info={})

        sample['metadata'] = image_metadata
        totensor_transform = augmentation.compose_transforms(
            param,
            dataset="tst",
            input_space=BGR_to_RGB,
            scale=scale,
            aug_type='totensor')
        sample['sat_img'] = sub_image
        sample = totensor_transform(sample)
        inputs = sample['sat_img'].unsqueeze_(0)
        inputs = inputs.to(device)
        if inputs.shape[1] == 4 and any("module.modelNIR" in s
                                        for s in model.state_dict().keys()):
            ############################
            # Test Implementation of the NIR
            ############################
            # Init NIR   TODO: make a proper way to read the NIR channel
            #                  and put an option to be able to give the idex of the NIR channel
            # Extract the NIR channel -> [batch size, H, W] since it's only one channel
            inputs_NIR = inputs[:, -1, ...]
            # add a channel to get the good size -> [:, 1, :, :]
            inputs_NIR.unsqueeze_(1)
            # take out the NIR channel and take only the RGB for the inputs
            inputs = inputs[:, :-1, ...]
            # Suggestion of implementation
            # inputs_NIR = data['NIR'].to(device)
            inputs = [inputs, inputs_NIR]
            # outputs = model(inputs, inputs_NIR)
            ############################
            # End of the test implementation module
            ############################
        output_lst = []
        for transformer in transforms:
            # augment inputs
            augmented_input = transformer.augment_image(inputs)
            augmented_output = model(augmented_input)
            if isinstance(augmented_output,
                          OrderedDict) and 'out' in augmented_output.keys():
                augmented_output = augmented_output['out']
            logging.debug(
                f'Shape of augmented output: {augmented_output.shape}')
            # reverse augmentation for outputs
            deaugmented_output = transformer.deaugment_mask(augmented_output)
            deaugmented_output = F.softmax(deaugmented_output,
                                           dim=1).squeeze(dim=0)
            output_lst.append(deaugmented_output)
        outputs = torch.stack(output_lst)
        outputs = torch.mul(outputs, WINDOW_SPLINE_2D)
        outputs, _ = torch.max(outputs, dim=0)
        outputs = outputs.permute(1, 2, 0)
        outputs = outputs.reshape(padded, padded,
                                  num_classes).cpu().numpy().astype('float16')
        outputs = outputs[dist_samples:-dist_samples,
                          dist_samples:-dist_samples, :]
        fp[row:row + chunk_size, col:col + chunk_size, :] = \
            fp[row:row + chunk_size, col:col + chunk_size, :] + outputs
        cnt += 1
    fp.flush()
    del fp

    fp = np.memmap(tp_mem,
                   dtype='float16',
                   mode='r',
                   shape=(h_, w_, num_classes))
    subdiv = 2.0
    step = int(chunk_size / subdiv)
    pred_img = np.zeros((h_, w_), dtype=np.uint8)
    for row in tqdm(range(0, input_image.height, step),
                    position=2,
                    leave=False):
        for col in tqdm(range(0, input_image.width, step),
                        position=3,
                        leave=False):
            arr1 = fp[row:row + chunk_size, col:col + chunk_size, :] / (2**2)
            arr1 = arr1.argmax(axis=-1).astype('uint8')
            pred_img[row:row + chunk_size, col:col + chunk_size] = arr1
    pred_img = pred_img[:h, :w]
    end_seg = time.time() - start_seg
    logging.info('Segmentation operation completed in {:.0f}m {:.0f}s'.format(
        end_seg // 60, end_seg % 60))

    if debug:
        logging.debug(
            f'Bin count of final output: {np.unique(pred_img, return_counts=True)}'
        )
    gdf = None
    if label_arr is not None:
        start_seg_ = time.time()
        feature = defaultdict(list)
        cnt = 0
        for row in tqdm(range(0, h, chunk_size), position=2, leave=False):
            for col in tqdm(range(0, w, chunk_size), position=3, leave=False):
                label = label_arr[row:row + chunk_size, col:col + chunk_size]
                pred = pred_img[row:row + chunk_size, col:col + chunk_size]
                pixelMetrics = ComputePixelMetrics(label.flatten(),
                                                   pred.flatten(), num_classes)
                eval = pixelMetrics.update(pixelMetrics.iou)
                feature['id_image'].append(gpkg_name)
                for c_num in range(num_classes):
                    feature['L_count_' + str(c_num)].append(
                        int(np.count_nonzero(label == c_num)))
                    feature['P_count_' + str(c_num)].append(
                        int(np.count_nonzero(pred == c_num)))
                    feature['IoU_' + str(c_num)].append(eval['iou_' +
                                                             str(c_num)])
                feature['mIoU'].append(eval['macro_avg_iou'])
                x_1, y_1 = (xmin + (col * xres)), (ymax - (row * yres))
                x_2, y_2 = (xmin + ((col * xres) + mx)), y_1
                x_3, y_3 = x_2, (ymax - ((row * yres) + my))
                x_4, y_4 = x_1, y_3
                geom = Polygon([(x_1, y_1), (x_2, y_2), (x_3, y_3),
                                (x_4, y_4)])
                feature['geometry'].append(geom)
                feature['length'].append(geom.length)
                feature['pointx'].append(geom.centroid.x)
                feature['pointy'].append(geom.centroid.y)
                feature['area'].append(geom.area)
                cnt += 1
        gdf = gpd.GeoDataFrame(feature, crs=input_image.crs)
        gdf.to_crs(crs="EPSG:4326", inplace=True)
        end_seg_ = time.time() - start_seg_
        logging.info('Benchmark operation completed in {:.0f}m {:.0f}s'.format(
            end_seg_ // 60, end_seg_ % 60))
    input_image.close()
    return pred_img, gdf
Exemple #21
0
def test(model, data_loader, save_path=""):
    """
    为了计算方便,训练过程中的验证与测试都直接计算指标J和F,不再先生成再输出,
    所以这里的指标仅作一个相对的参考,具体真实指标需要使用测试代码处理
    """
    model.eval()
    tqdm_iter = tqdm(enumerate(data_loader),
                     total=len(data_loader),
                     leave=False)

    if arg_config['use_tta']:
        construct_print("We will use Test Time Augmentation!")
        transforms = tta.Compose([  # 2*3
            tta.HorizontalFlip(),
            tta.Scale(scales=[0.75, 1, 1.5],
                      interpolation='bilinear',
                      align_corners=False)
        ])
    else:
        transforms = None

    results = defaultdict(list)
    for test_batch_id, test_data in tqdm_iter:
        tqdm_iter.set_description(f"te=>{test_batch_id + 1}")

        with torch.no_grad():
            curr_jpegs = test_data["image"].to(DEVICES, non_blocking=True)
            curr_flows = test_data["flow"].to(DEVICES, non_blocking=True)
            preds_logits = tta_aug(model=model,
                                   transforms=transforms,
                                   data=dict(curr_jpeg=curr_jpegs,
                                             curr_flow=curr_flows))
            preds_prob = preds_logits.sigmoid().squeeze().cpu().detach(
            )  # float32

        for i, pred_prob in enumerate(preds_prob.numpy()):
            curr_mask_path = test_data["mask_path"][i]
            video_name, mask_name = curr_mask_path.split(os.sep)[-2:]
            mask = read_binary_array(curr_mask_path, thr=0)
            mask_h, mask_w = mask.shape

            pred_prob = cv2.resize(pred_prob,
                                   dsize=(mask_w, mask_h),
                                   interpolation=cv2.INTER_LINEAR)
            pred_prob = clip_to_normalize(data_array=pred_prob,
                                          clip_range=arg_config["clip_range"])
            pred_seg = np.where(pred_prob > 0.5, 255, 0).astype(np.uint8)

            results[video_name].append(
                (jaccard.db_eval_iou(annotation=mask, segmentation=pred_seg),
                 f_boundary.db_eval_boundary(annotation=mask,
                                             segmentation=pred_seg)))

            if save_path:
                pred_video_path = os.path.join(save_path, video_name)
                if not os.path.exists(pred_video_path):
                    os.makedirs(pred_video_path)
                pred_frame_path = os.path.join(pred_video_path, mask_name)
                cv2.imwrite(pred_frame_path, pred_seg)

    j_f_collection = []
    for video_name, video_scores in results.items():
        j_f_for_video = np.mean(np.array(video_scores), axis=0).tolist()
        results[video_name] = j_f_for_video
        j_f_collection.append(j_f_for_video)
    results['average'] = np.mean(np.array(j_f_collection), axis=0).tolist()
    return pretty_print(results)
Exemple #22
0
def predict_test(CropStage=False,
                 TestStage=True,
                 toMask=True,
                 toZip=True,
                 newTH=0.05):
    # os.environ["CUDA_VISIBLE_DIVICES"] ="1"

    root_path = '/media/totem_disk/totem/weitang/project'
    # model = smp.Unet('se_resnext101_32x4d', activation=None).cuda()
    # i_size=512
    # i_scale=0.25
    # dir_model = root_path + '/model/unet_se_resnext101_32x4d_2_1_best.pth'
    # model.load_state_dict(torch.load(dir_model)['state_dict'])

    model = smp.Unet('densenet161', activation=None).cuda()
    i_size = 512
    i_scale = 0.25
    dir_model = root_path + '/model/unet_densenet161_2_1_best_0.73.pth'
    model.load_state_dict(torch.load(dir_model)['state_dict'])
    tta_transforms = tta.Compose([
        tta.HorizontalFlip(),
        # tta.Scale(scales=[1,2,4])
        # tta.Rotate90(angles=[0,180])
    ])
    tta_model = tta.SegmentationTTAWrapper(model,
                                           tta_transforms,
                                           merge_mode='mean')

    # model = smp.Unet('resnet34', activation=None).cuda()
    # dir_model = root_path + '/model/unet_resnet34_1_1_best.pth'

    # 裁切测试集路径
    # crop_test_images_path = root_path + '/temp_data_test/0.4crop_test_set1024'
    crop_test_images_path = root_path + '/temp_data_test/crop_test_set'

    test_path_list = glob.glob(
        '/media/totem_disk/totem/weitang/competition/test2/test/*jpg')
    print("Total {} images for testing.".format(len(test_path_list)))
    # 裁切程序
    if CropStage == True:
        print("Stage 1: ")
        #crop images
        crop(test_path_list,
             crop_test_images_path,
             scale=i_scale,
             image_size=i_size,
             mode="test")

    # prob_save_path = root_path + '/temp_data_test/resprob1024'
    prob_save_path = root_path + '/temp_data_test/dense101_t'
    crop_predict = root_path + '/temp_data_test/predict512_resnet101_t'

    # crop_predict = root_path + '/temp_data_test/crop_predict_1024'
    if TestStage == True:
        print("Stage 2: ")
        #predict cropped images
        test_images_path_list = glob.glob(crop_test_images_path + '/*.jpg')
        os.makedirs(crop_predict, exist_ok=True)
        test_loader = get_test_loader(test_images_path_list,
                                      image_size=i_size,
                                      batch_size=2)
        test(test_loader, crop_predict, model=tta_model)

        print("Stage 3: ")
        #merge predicted images
        # prob_save_path = root_path + '/temp_data_test/prob'
        # 缩放倍率:0.25,即将原图*0.25再进行裁切
        os.makedirs(prob_save_path, exist_ok=True)
        merge_hot_pic(test_path_list, crop_predict, i_scale, prob_save_path)

    mask_save_path = root_path + '/temp_data_test/mask/'
    if toMask == True:
        print("Stage 4: ")
        #convert probs to masks
        os.makedirs(mask_save_path, exist_ok=True)
        prob_to_mask(prob_save_path,
                     mask_save_dir=mask_save_path,
                     th=newTH,
                     pad_white=True)

    if toZip == True:
        print("Stage 5: ")
        #zip masks
        zf = zipfile.ZipFile(f'{root_path}/result/result.zip', 'w')
        for i in glob.glob(f"{mask_save_path}/*.png"):
            basename = os.path.split(i)[1]
            zf.write(i, f'result/{basename}')
        zf.close()
Exemple #23
0
    def test_tta(self, mode='train', unet_path=None):
        """Test model & Calculate performances."""
        print(char_color('@,,@   %s with TTA' % (mode)))
        if not unet_path is None:
            if os.path.isfile(unet_path):
                checkpoint = torch.load(unet_path)
                self.unet.load_state_dict(checkpoint['state_dict'])
                self.myprint('Successfully Loaded from %s' % (unet_path))

        self.unet.train(False)
        self.unet.eval()

        if mode == 'train':
            data_lodear = self.train_loader
        elif mode == 'test':
            data_lodear = self.test_loader
        elif mode == 'valid':
            data_lodear = self.valid_loader

        acc = 0.  # Accuracy
        SE = 0.  # Sensitivity (Recall)
        SP = 0.  # Specificity
        PC = 0.  # Precision
        DC = 0.  # Dice Coefficient
        IOU = 0.  # IOU
        length = 0

        # model pre for each image
        detail_result = []  # detail_result = [id, acc, SE, SP, PC, dsc, IOU]
        with torch.no_grad():
            for i, sample in enumerate(data_lodear):
                (image_paths, images, GT) = sample
                images_path = list(image_paths)
                images = images.to(self.device)
                GT = GT.to(self.device)

                tta_trans = tta.Compose([
                    tta.VerticalFlip(),
                    tta.HorizontalFlip(),
                    tta.Rotate90(angles=[0, 180])
                ])

                tta_model = tta.SegmentationTTAWrapper(self.unet, tta_trans)
                SR = tta_model(images)

                # SR = self.unet(images)
                SR = F.sigmoid(SR)

                if self.save_image:
                    images_all = torch.cat((images, SR, GT), 0)
                    torchvision.utils.save_image(
                        images_all.data.cpu(),
                        os.path.join(self.result_path, 'images',
                                     '%s_%d_image.png' % (mode, i)),
                        nrow=self.batch_size)

                SR = SR.data.cpu().numpy()
                GT = GT.data.cpu().numpy()

                for ii in range(SR.shape[0]):
                    SR_tmp = SR[ii, :].reshape(-1)
                    GT_tmp = GT[ii, :].reshape(-1)
                    tmp_index = images_path[ii].split(sep)[-1]
                    tmp_index = int(tmp_index.split('.')[0][:])

                    SR_tmp = torch.from_numpy(SR_tmp).to(self.device)
                    GT_tmp = torch.from_numpy(GT_tmp).to(self.device)

                    result_tmp = np.array([
                        tmp_index,
                        get_accuracy(SR_tmp, GT_tmp),
                        get_sensitivity(SR_tmp, GT_tmp),
                        get_specificity(SR_tmp, GT_tmp),
                        get_precision(SR_tmp, GT_tmp),
                        get_DC(SR_tmp, GT_tmp),
                        get_IOU(SR_tmp, GT_tmp)
                    ])

                    acc += result_tmp[1]
                    SE += result_tmp[2]
                    SP += result_tmp[3]
                    PC += result_tmp[4]
                    DC += result_tmp[5]
                    IOU += result_tmp[6]
                    detail_result.append(result_tmp)

                    length += 1

        accuracy = acc / length
        sensitivity = SE / length
        specificity = SP / length
        precision = PC / length
        disc = DC / length
        iou = IOU / length
        detail_result = np.array(detail_result)

        if (self.save_detail_result
            ):  # detail_result = [id, acc, SE, SP, PC, dsc, IOU]
            excel_save_path = os.path.join(self.result_path,
                                           mode + '_pre_detial_result.xlsx')
            writer = pd.ExcelWriter(excel_save_path)
            detail_result = pd.DataFrame(detail_result)
            detail_result.to_excel(writer, mode, float_format='%.5f')
            writer.save()
            writer.close()

        return accuracy, sensitivity, specificity, precision, disc, iou

for i in range(5):
    train(i)


def load_model(fold: int, epoch: int, device: torch.device = 'cuda'):

    model = EfficientNetModel().to(device)
    model.load_state_dict(
        torch.load(f'models/effinet_b4_SAM_CosLR-f{fold}-{epoch}.pth'))

    return model


transforms = tta.Compose([tta.HorizontalFlip(), tta.VerticalFlip()])
tta_model = tta.ClassificationTTAWrapper(model, transforms)


def test(device: torch.device = 'cuda'):
    submit = pd.read_csv('data/sample_submission.csv')

    model1 = load_model(0, 19)
    model2 = load_model(1, 19)
    model3 = load_model(2, 19)
    model4 = load_model(3, 19)
    model5 = load_model(4, 19)

    tta_model1 = tta.ClassificationTTAWrapper(model1, transforms)
    tta_model2 = tta.ClassificationTTAWrapper(model2, transforms)
    tta_model3 = tta.ClassificationTTAWrapper(model3, transforms)
def validation(valid_ids, num_split, encoder, decoder):
    """
    模型验证,并选择后处理参数
    """
    train = "./data/Clouds_Classify/train.csv"

    # Data overview
    train = pd.read_csv(open(train))
    train.head()

    train['label'] = train['Image_Label'].apply(lambda x: x.split('_')[1])
    train['im_id'] = train['Image_Label'].apply(lambda x: x.split('_')[0])

    ENCODER = encoder
    ENCODER_WEIGHTS = 'imagenet'
    if decoder == 'unet':
        model = smp.Unet(
            encoder_name=ENCODER,
            encoder_weights=ENCODER_WEIGHTS,
            classes=4,
            activation=None,
        )
    else:
        model = smp.FPN(
            encoder_name=ENCODER,
            encoder_weights=ENCODER_WEIGHTS,
            classes=4,
            activation=None,
        )
    preprocessing_fn = smp.encoders.get_preprocessing_fn(
        ENCODER, ENCODER_WEIGHTS)

    num_workers = 4
    valid_bs = 32
    valid_dataset = CloudDataset(
        df=train,
        transforms=get_validation_augmentation(),
        datatype='valid',
        img_ids=valid_ids,
        preprocessing=get_preprocessing(preprocessing_fn))
    valid_loader = DataLoader(valid_dataset,
                              batch_size=valid_bs,
                              shuffle=False,
                              num_workers=num_workers)

    loaders = {"valid": valid_loader}
    logdir = "./logs/log_{}_{}/log_{}".format(encoder, decoder, num_split)

    valid_masks = []
    probabilities = np.zeros((len(valid_ids) * 4, 350, 525))

    ############### TTA预测 ####################
    use_TTA = True
    checkpoint_path = logdir + '/checkpoints/best.pth'
    runner_out = []
    model.load_state_dict(torch.load(checkpoint_path)['model_state_dict'])

    if use_TTA:
        transforms = tta.Compose([
            tta.HorizontalFlip(),
            tta.VerticalFlip(),
            tta.Scale(scales=[5 / 6, 1, 7 / 6]),
        ])
        tta_model = tta.SegmentationTTAWrapper(model,
                                               transforms,
                                               merge_mode='mean')
    else:
        tta_model = model

    tta_model = tta_model.cuda()
    tta_model.eval()

    with torch.no_grad():
        for i, data in enumerate(tqdm.tqdm(loaders['valid'])):
            img, _ = data
            img = img.cuda()
            batch_preds = tta_model(img).cpu().numpy()
            runner_out.extend(batch_preds)
    runner_out = np.array(runner_out)
    ######################END##########################

    for i, ((_, mask),
            output) in enumerate(tqdm.tqdm(zip(valid_dataset, runner_out))):
        for m in mask:
            if m.shape != (350, 525):
                m = cv2.resize(m,
                               dsize=(525, 350),
                               interpolation=cv2.INTER_LINEAR)
            valid_masks.append(m)

        for j, probability in enumerate(output):
            if probability.shape != (350, 525):
                probability = cv2.resize(probability,
                                         dsize=(525, 350),
                                         interpolation=cv2.INTER_LINEAR)
            probabilities[i * 4 + j, :, :] = probability

    # Find optimal values
    print('searching for optimal param...')
    params_0 = [[35, 76], [12000, 19001]]
    params_1 = [[35, 76], [12000, 19001]]
    params_2 = [[35, 76], [12000, 19001]]
    params_3 = [[35, 76], [8000, 15001]]
    param = [params_0, params_1, params_2, params_3]

    for class_id in range(4):
        par = param[class_id]
        attempts = []
        for t in range(par[0][0], par[0][1], 5):
            t /= 100
            for ms in range(par[1][0], par[1][1], 2000):
                masks = []
                print('==> searching [class_id:%d threshold:%.3f ms:%d]' %
                      (class_id, t, ms))
                for i in tqdm.tqdm(range(class_id, len(probabilities), 4)):
                    probability = probabilities[i]
                    predict, _ = post_process(sigmoid(probability), t, ms)
                    masks.append(predict)

                d = []
                for i, j in zip(masks, valid_masks[class_id::4]):
                    if (i.sum() == 0) & (j.sum() == 0):
                        d.append(1)
                    else:
                        d.append(dice(i, j))

                attempts.append((t, ms, np.mean(d)))

        attempts_df = pd.DataFrame(attempts,
                                   columns=['threshold', 'size', 'dice'])

        attempts_df = attempts_df.sort_values('dice', ascending=False)
        attempts_df.to_csv(
            './params/{}_{}_par/params_{}/tta_params_{}.csv'.format(
                encoder, decoder, num_split, class_id),
            columns=['threshold', 'size', 'dice'],
            index=False)
def main():
    DATA_DIR = './input/test_images_png/'
    output_dir = './output/'
    x_valid_dir = DATA_DIR
    y_valid_dir = DATA_DIR

    ENCODER = 'inceptionv4'
    ENCODER_WEIGHTS = 'imagenet'
    CLASSES = ['coastline']
    ACTIVATION = 'sigmoid'  # could be None for logits or 'softmax2d' for multicalss segmentation
    DEVICE = 'cuda'

    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

    # load best saved checkpoint
    best_model = torch.load('./best_model_Unet_resnet18.pth')
    # tta_model = tta.SegmentationTTAWrapper(best_model, tta.aliases.d4_transform(), merge_mode='mean')
    transforms = tta.Compose(
        [
            tta.HorizontalFlip(),
            tta.Rotate90(angles=[0, 180])
        ]
    )
    tta_model = tta.SegmentationTTAWrapper(best_model, transforms)

    # create test dataset
    test_dataset = Dataset(
        x_valid_dir,  # x_test_dir
        y_valid_dir,  # y_test_dir
        augmentation=get_validation_augmentation(),
        preprocessing=get_preprocessing(preprocessing_fn),
        classes=CLASSES,
    )

    # test_dataloader = DataLoader(test_dataset)
    #
    # # evaluate model on test set
    # test_epoch = smp.utils.train.ValidEpoch(
    #     model=best_model,
    #     loss=loss,
    #     metrics=metrics,
    #     device=DEVICE,
    # )
    #
    # logs = test_epoch.run(test_dataloader)

    # test dataset without transformations for image visualization
    test_dataset_vis2 = Dataset(
        x_valid_dir, y_valid_dir,  # x_test_dir, y_test_dir,
        augmentation=get_validation_augmentation(),
        classes=CLASSES,
    )

    for i in range(len(test_dataset)):
        n = np.random.choice(len(test_dataset))

        image_vis = test_dataset_vis2[i][0].astype('uint8')
        image, gt_mask = test_dataset[i]
        file_name = test_dataset.images_fps[i]
        base_name = os.path.basename(file_name)
        gt_mask = gt_mask.squeeze()

        x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
        # pr_mask = best_model.predict(x_tensor)
        pr_mask = tta_model.forward(x_tensor)
        # pr_mask = (pr_mask.squeeze().cpu().numpy().round())
        pr_mask = pr_mask.squeeze().to('cpu').detach().numpy().copy()
        pr_mask = (pr_mask*255).astype(np.uint8)
        ret, pr_mask = cv.threshold(pr_mask, 1, 255, cv.THRESH_BINARY)

        visualize(
            image=image_vis,
            ground_truth_mask=gt_mask,
            predicted_mask=pr_mask
        )

        org_image = cv.imread(file_name)
        h = org_image.shape[0]
        w = org_image.shape[1]
        pr_mask = cv.resize(pr_mask, (w, h))
        cv.imwrite(output_dir + base_name, pr_mask)
Exemple #27
0
def six_crop_transform(crop_height, crop_width):
    return Compose([tta.HorizontalFlip(), ThreeCrops(crop_height, crop_width)])
def testing(num_split, class_params, encoder, decoder):
    """
    测试推理
    """
    import gc
    torch.cuda.empty_cache()
    gc.collect()

    sub = "./data/Clouds_Classify/sample_submission.csv"
    sub = pd.read_csv(open(sub))
    sub.head()

    sub['label'] = sub['Image_Label'].apply(lambda x: x.split('_')[1])
    sub['im_id'] = sub['Image_Label'].apply(lambda x: x.split('_')[0])

    preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder, 'imagenet')
    if decoder == 'unet':
        model = smp.Unet(
            encoder_name=encoder,
            encoder_weights='imagenet',
            classes=4,
            activation=None,
        )
    else:
        model = smp.FPN(
            encoder_name=encoder,
            encoder_weights='imagenet',
            classes=4,
            activation=None,
        )
    test_ids = [id for id in os.listdir(test_imgs_folder)]

    test_dataset = CloudDataset(
        df=sub,
        transforms=get_validation_augmentation(),
        datatype='test',
        img_ids=test_ids,
        preprocessing=get_preprocessing(preprocessing_fn))
    test_loader = DataLoader(test_dataset,
                             batch_size=1,
                             shuffle=False,
                             num_workers=2)

    loaders = {"test": test_loader}
    logdir = "./logs/log_{}_{}/log_{}".format(encoder, decoder, num_split)

    encoded_pixels = []

    ###############使用pytorch TTA预测####################
    use_TTA = True
    checkpoint_path = logdir + '/checkpoints/best.pth'
    runner_out = []
    model.load_state_dict(torch.load(checkpoint_path)['model_state_dict'])
    #使用tta预测
    if use_TTA:
        transforms = tta.Compose([
            tta.HorizontalFlip(),
            tta.VerticalFlip(),
            tta.Scale(scales=[5 / 6, 1, 7 / 6]),
        ])
        tta_model = tta.SegmentationTTAWrapper(model,
                                               transforms,
                                               merge_mode='mean')
    else:
        tta_model = model

    tta_model = tta_model.cuda()
    tta_model.eval()

    with torch.no_grad():
        for i, data in enumerate(tqdm.tqdm(loaders['test'])):
            img, _ = data
            img = img.cuda()
            batch_preds = tta_model(img).cpu().numpy()
            runner_out.extend(batch_preds)
    runner_out = np.array(runner_out)

    for i, output in tqdm.tqdm(enumerate(runner_out)):
        for j, probability in enumerate(output):
            if probability.shape != (350, 525):
                probability = cv2.resize(probability,
                                         dsize=(525, 350),
                                         interpolation=cv2.INTER_LINEAR)
            logit = sigmoid(probability)
            predict, num_predict = post_process(logit, class_params[j][0],
                                                class_params[j][1])

            if num_predict == 0:
                encoded_pixels.append('')
            else:
                r = mask2rle(predict)
                encoded_pixels.append(r)

    sub['EncodedPixels'] = encoded_pixels
    sub.to_csv('./sub/{}_{}/tta_submission_{}.csv'.format(
        encoder, decoder, num_split),
               columns=['Image_Label', 'EncodedPixels'],
               index=False)
    if fold_flag:
        _, test = get_fold_filelist(csv_file, K=fold_K, fold=fold_index)
        test_img_list = [img_path+sep+i[0] for i in test]
        if mask_path is not None:
            test_mask_list = [mask_path+sep+i[0] for i in test]
    else:
        test_img_list = get_filelist_frompath(img_path,'PNG')
        if mask_path is not None:
            test_mask_list = [mask_path + sep + i.split(sep)[-1] for i in test_img_list]

    # 构建两个模型
    with torch.no_grad():
        # tta设置
        tta_trans = tta.Compose([
            tta.VerticalFlip(),
            tta.HorizontalFlip(),
            tta.Rotate90(angles=[0,180]),
        ])
        # 构建模型
        # cascade1
        model_cascade1 = smp.DeepLabV3Plus(encoder_name="efficientnet-b6", encoder_weights=None, in_channels=1, classes=1)
        model_cascade1.to(device)
        model_cascade1.load_state_dict(torch.load(weight_c1))
        if c1_tta:
            model_cascade1 = tta.SegmentationTTAWrapper(model_cascade1, tta_trans,merge_mode='mean')
        model_cascade1.eval()
        # cascade2
        model_cascade2 = smp.DeepLabV3Plus(encoder_name="efficientnet-b6", encoder_weights=None, in_channels=1, classes=1)
        # model_cascade2 = smp.Unet(encoder_name="efficientnet-b6", encoder_weights=None, in_channels=1, classes=1, encoder_depth=5, decoder_attention_type='scse')
        # model_cascade2 = smp.PAN(encoder_name="efficientnet-b6",encoder_weights='imagenet',	in_channels=1, classes=1)
        model_cascade2.to(device)
Exemple #30
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--encoder', type=str, default='efficientnet-b0')
    parser.add_argument('--model', type=str, default='unet')
    parser.add_argument('--loc', type=str)
    parser.add_argument('--data_folder', type=str, default='../input/')
    parser.add_argument('--batch_size', type=int, default=2)
    parser.add_argument('--optimize', type=bool, default=False)
    parser.add_argument('--tta_pre', type=bool, default=False)
    parser.add_argument('--tta_post', type=bool, default=False)
    parser.add_argument('--merge', type=str, default='mean')
    parser.add_argument('--min_size', type=int, default=10000)
    parser.add_argument('--thresh', type=float, default=0.5)
    parser.add_argument('--name', type=str)

    args = parser.parse_args()
    encoder = args.encoder
    model = args.model
    loc = args.loc
    data_folder = args.data_folder
    bs = args.batch_size
    optimize = args.optimize
    tta_pre = args.tta_pre
    tta_post = args.tta_post
    merge = args.merge
    min_size = args.min_size
    thresh = args.thresh
    name = args.name

    if model == 'unet':
        model = smp.Unet(encoder_name=encoder,
                         encoder_weights='imagenet',
                         classes=4,
                         activation=None)
    if model == 'fpn':
        model = smp.FPN(
            encoder_name=encoder,
            encoder_weights='imagenet',
            classes=4,
            activation=None,
        )
    if model == 'pspnet':
        model = smp.PSPNet(
            encoder_name=encoder,
            encoder_weights='imagenet',
            classes=4,
            activation=None,
        )
    if model == 'linknet':
        model = smp.Linknet(
            encoder_name=encoder,
            encoder_weights='imagenet',
            classes=4,
            activation=None,
        )

    preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder, 'imagenet')

    test_df = get_dataset(train=False)
    test_df = prepare_dataset(test_df)
    test_ids = test_df['Image_Label'].apply(
        lambda x: x.split('_')[0]).drop_duplicates().values
    test_dataset = CloudDataset(
        df=test_df,
        datatype='test',
        img_ids=test_ids,
        transforms=valid1(),
        preprocessing=get_preprocessing(preprocessing_fn))
    test_loader = DataLoader(test_dataset, batch_size=bs, shuffle=False)

    val_df = get_dataset(train=True)
    val_df = prepare_dataset(val_df)
    _, val_ids = get_train_test(val_df)
    valid_dataset = CloudDataset(
        df=val_df,
        datatype='train',
        img_ids=val_ids,
        transforms=valid1(),
        preprocessing=get_preprocessing(preprocessing_fn))
    valid_loader = DataLoader(valid_dataset, batch_size=bs, shuffle=False)

    model.load_state_dict(torch.load(loc)['model_state_dict'])

    class_params = {
        0: (thresh, min_size),
        1: (thresh, min_size),
        2: (thresh, min_size),
        3: (thresh, min_size)
    }

    if optimize:
        print("OPTIMIZING")
        print(tta_pre)
        if tta_pre:
            opt_model = tta.SegmentationTTAWrapper(
                model,
                tta.Compose([
                    tta.HorizontalFlip(),
                    tta.VerticalFlip(),
                    tta.Rotate90(angles=[0, 180])
                ]),
                merge_mode=merge)
        else:
            opt_model = model
        tta_runner = SupervisedRunner()
        print("INFERRING ON VALID")
        tta_runner.infer(
            model=opt_model,
            loaders={'valid': valid_loader},
            callbacks=[InferCallback()],
            verbose=True,
        )

        valid_masks = []
        probabilities = np.zeros((4 * len(valid_dataset), 350, 525))
        for i, (batch, output) in enumerate(
                tqdm(
                    zip(valid_dataset,
                        tta_runner.callbacks[0].predictions["logits"]))):
            _, mask = batch
            for m in mask:
                if m.shape != (350, 525):
                    m = cv2.resize(m,
                                   dsize=(525, 350),
                                   interpolation=cv2.INTER_LINEAR)
                valid_masks.append(m)

            for j, probability in enumerate(output):
                if probability.shape != (350, 525):
                    probability = cv2.resize(probability,
                                             dsize=(525, 350),
                                             interpolation=cv2.INTER_LINEAR)
                probabilities[(i * 4) + j, :, :] = probability

        print("RUNNING GRID SEARCH")
        for class_id in range(4):
            print(class_id)
            attempts = []
            for t in range(30, 70, 5):
                t /= 100
                for ms in [7500, 10000, 12500, 15000, 175000]:
                    masks = []
                    for i in range(class_id, len(probabilities), 4):
                        probability = probabilities[i]
                        predict, num_predict = post_process(
                            sigmoid(probability), t, ms)
                        masks.append(predict)

                    d = []
                    for i, j in zip(masks, valid_masks[class_id::4]):
                        if (i.sum() == 0) & (j.sum() == 0):
                            d.append(1)
                        else:
                            d.append(dice(i, j))

                    attempts.append((t, ms, np.mean(d)))

            attempts_df = pd.DataFrame(attempts,
                                       columns=['threshold', 'size', 'dice'])

            attempts_df = attempts_df.sort_values('dice', ascending=False)
            print(attempts_df.head())
            best_threshold = attempts_df['threshold'].values[0]
            best_size = attempts_df['size'].values[0]

            class_params[class_id] = (best_threshold, best_size)

        del opt_model
        del tta_runner
        del valid_masks
        del probabilities
    gc.collect()

    if tta_post:
        model = tta.SegmentationTTAWrapper(model,
                                           tta.Compose([
                                               tta.HorizontalFlip(),
                                               tta.VerticalFlip(),
                                               tta.Rotate90(angles=[0, 180])
                                           ]),
                                           merge_mode=merge)
    else:
        model = model
    print(tta_post)

    runner = SupervisedRunner()
    runner.infer(
        model=model,
        loaders={'test': test_loader},
        callbacks=[InferCallback()],
        verbose=True,
    )

    encoded_pixels = []
    image_id = 0

    for i, image in enumerate(tqdm(runner.callbacks[0].predictions['logits'])):
        for i, prob in enumerate(image):
            if prob.shape != (350, 525):
                prob = cv2.resize(prob,
                                  dsize=(525, 350),
                                  interpolation=cv2.INTER_LINEAR)
            predict, num_predict = post_process(sigmoid(prob),
                                                class_params[image_id % 4][0],
                                                class_params[image_id % 4][1])
            if num_predict == 0:
                encoded_pixels.append('')
            else:
                r = mask2rle(predict)
                encoded_pixels.append(r)
            image_id += 1

    test_df['EncodedPixels'] = encoded_pixels
    test_df.to_csv(name, columns=['Image_Label', 'EncodedPixels'], index=False)