示例#1
0
def get_predictions(model_chosen, tta = False):

    model_chosen.cuda.eval()
    actual_values, predicted_values = [], []

    if tta == True:
        transformation = ttach.Compose(
            [
                ttach.HorizontalFlip(),
                ttach.VerticalFlip(),
                ttach.Rotate90(angles=[0, 90, 180, 270])
            ]
        )
        test_time_augmentation_wrapper = ttach.ClassificationTTAWrapper(model_chosen, transformation)
        with torch.no_grad():
            for batch in loader_of_test:
                test_image, test_label = batch
                predicted_value = test_time_augmentation_wrapper(test_image.cuda())
                predicted_value = torch.argmax(predicted_value, dim=1).detach().cpu().numpy()
                actual_values.append(test_label.cpu().numpy())
                predicted_values.append(predicted_value)
    else:
        with torch.no_grad():
            for batch in loader_of_test:
                test_image, test_label = batch
                predicted_value = model_chosen(test_image.cuda())
                predicted_value = torch.argmax(predicted_value, dim=1).detach().cpu().numpy()
                actual_values.append(test_label.cpu().numpy())
                predicted_values.append(predicted_value)

    return predicted_values
示例#2
0
def SUE_TTA(model, batch: torch.tensor, last_layer: bool) -> Tuple[np.ndarray, np.ndarray]:
    r"""Interface of Binary Segmentation Uncertainty Estimation with Test-Time Augmentations (TTA) method for 1 2D slice.
            Inputs supposed to be in range [0, data_range].
            Args:
                model: Trained model.
                batch: Tensor with shape (1, C, H, W).
                last_layer: Flag whether there is Sigmoid as a last NN layer
            Returns:
                Aleatoric and epistemic uncertainty maps with shapes equal to batch shape
     """
    model.eval()
    transforms = tta.Compose(
        [
            tta.VerticalFlip(),
            tta.HorizontalFlip(),
            tta.Rotate90(angles=[0, 180]),
            tta.Scale(scales=[1, 2, 4]),
            tta.Multiply(factors=[0.9, 1, 1.1]),
        ]
    )
    predicted = []
    for transformer in transforms:
        augmented_image = transformer.augment_image(batch)
        model_output = model(augmented_image)
        deaug_mask = transformer.deaugment_mask(model_output)
        prediction = torch.sigmoid(
            deaug_mask).cpu().detach().numpy() if last_layer else deaug_mask.cpu().detach().numpy()
        predicted.append(prediction)

    p_hat = np.array(predicted)
    aleatoric = calc_aleatoric(p_hat)
    epistemic = calc_epistemic(p_hat)

    return aleatoric, epistemic
示例#3
0
def evaluate(val_loader, model, loss_fn, device, use_tta=False):
    model.eval()

    if use_tta:
        transformations = tta.Compose([
            tta.Rotate90(angles=[0, 90, 180, 270]),
            tta.HorizontalFlip(),
            tta.VerticalFlip()
        ])

        tta_model = tta.ClassificationTTAWrapper(model, transformations)

    correct = 0
    total = 0
    total_loss = 0
    for i, batch in enumerate(val_loader):
        input_data, labels = batch
        input_data, labels = input_data.to(device), labels.to(device)
        with torch.no_grad():
            if use_tta:
                predictions = tta_model(input_data)
            else:
                predictions = model(input_data)
            total_loss += loss_fn(predictions, labels).item()
            correct += (predictions.argmax(axis=1) == labels).sum().item()
            total += len(labels)
            torch.cuda.empty_cache()

    model.train()
    return total_loss / total, correct / total
 def build_tta_model(self, model, config, device):
     tta_model = getattr(tta, config["tta"])(
         model,
         tta.Compose([tta.HorizontalFlip(),
                      tta.VerticalFlip()]),
         merge_mode="mean",
     )
     tta_model.to(device)
     return tta_model
示例#5
0
def predict(model_path, test_loader, saveFileName, iftta):

    ## predict
    model = initialize_model(num_classes=176)

    # create model and load weights from checkpoint
    model = model.to(device)
    model.load_state_dict(torch.load(model_path))

    if iftta:
        print("Using TTA")
        transforms = tta.Compose(
            [
                tta.HorizontalFlip(),
                tta.VerticalFlip(),
                tta.Rotate90(angles=[0, 180]),
                # tta.Scale(scales=[1, 0.3]), 
            ]
        )
        model = tta.ClassificationTTAWrapper(model, transforms)

    # Make sure the model is in eval mode.
    # Some modules like Dropout or BatchNorm affect if the model is in training mode.
    model.eval()
    
    # Initialize a list to store the predictions.
    predictions = []
    # Iterate the testing set by batches.
    for batch in tqdm(test_loader):

        imgs = batch
        with torch.no_grad():
            logits = model(imgs.to(device))
        
        # Take the class with greatest logit as prediction and record it.
        predictions.extend(logits.argmax(dim=-1).cpu().numpy().tolist())

    preds = []
    for i in predictions:
        preds.append(num_to_class[i])

    test_data = pd.read_csv('leaves_data/test.csv')
    test_data['label'] = pd.Series(preds)
    submission = pd.concat([test_data['image'], test_data['label']], axis=1)
    submission.to_csv(saveFileName, index=False)
    print("Done!!!!!!!!!!!!!!!!!!!!!!!!!!!")
示例#6
0
def test_compose_1():
    transform = tta.Compose([
        tta.HorizontalFlip(),
        tta.VerticalFlip(),
        tta.Rotate90(angles=[0, 90, 180, 270]),
        tta.Scale(scales=[1, 2, 4], interpolation="nearest"),
    ])

    assert len(
        transform) == 2 * 2 * 4 * 3  # all combinations for aug parameters

    dummy_label = torch.ones(2).reshape(2, 1).float()
    dummy_image = torch.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5).float()
    dummy_model = lambda x: {"label": dummy_label, "mask": x}

    for augmenter in transform:
        augmented_image = augmenter.augment_image(dummy_image)
        model_output = dummy_model(augmented_image)
        deaugmented_mask = augmenter.deaugment_mask(model_output["mask"])
        deaugmented_label = augmenter.deaugment_label(model_output["label"])
        assert torch.allclose(deaugmented_mask, dummy_image)
        assert torch.allclose(deaugmented_label, dummy_label)
示例#7
0
    def test_tta(self, mode='train', unet_path=None):
        """Test model & Calculate performances."""
        print(char_color('@,,@   %s with TTA' % (mode)))
        if not unet_path is None:
            if os.path.isfile(unet_path):
                checkpoint = torch.load(unet_path)
                self.unet.load_state_dict(checkpoint['state_dict'])
                self.myprint('Successfully Loaded from %s' % (unet_path))

        self.unet.train(False)
        self.unet.eval()

        if mode == 'train':
            data_lodear = self.train_loader
        elif mode == 'test':
            data_lodear = self.test_loader
        elif mode == 'valid':
            data_lodear = self.valid_loader

        acc = 0.  # Accuracy
        SE = 0.  # Sensitivity (Recall)
        SP = 0.  # Specificity
        PC = 0.  # Precision
        DC = 0.  # Dice Coefficient
        IOU = 0.  # IOU
        length = 0

        # model pre for each image
        detail_result = []  # detail_result = [id, acc, SE, SP, PC, dsc, IOU]
        with torch.no_grad():
            for i, sample in enumerate(data_lodear):
                (image_paths, images, GT) = sample
                images_path = list(image_paths)
                images = images.to(self.device)
                GT = GT.to(self.device)

                tta_trans = tta.Compose([
                    tta.VerticalFlip(),
                    tta.HorizontalFlip(),
                    tta.Rotate90(angles=[0, 180])
                ])

                tta_model = tta.SegmentationTTAWrapper(self.unet, tta_trans)
                SR = tta_model(images)

                # SR = self.unet(images)
                SR = F.sigmoid(SR)

                if self.save_image:
                    images_all = torch.cat((images, SR, GT), 0)
                    torchvision.utils.save_image(
                        images_all.data.cpu(),
                        os.path.join(self.result_path, 'images',
                                     '%s_%d_image.png' % (mode, i)),
                        nrow=self.batch_size)

                SR = SR.data.cpu().numpy()
                GT = GT.data.cpu().numpy()

                for ii in range(SR.shape[0]):
                    SR_tmp = SR[ii, :].reshape(-1)
                    GT_tmp = GT[ii, :].reshape(-1)
                    tmp_index = images_path[ii].split(sep)[-1]
                    tmp_index = int(tmp_index.split('.')[0][:])

                    SR_tmp = torch.from_numpy(SR_tmp).to(self.device)
                    GT_tmp = torch.from_numpy(GT_tmp).to(self.device)

                    result_tmp = np.array([
                        tmp_index,
                        get_accuracy(SR_tmp, GT_tmp),
                        get_sensitivity(SR_tmp, GT_tmp),
                        get_specificity(SR_tmp, GT_tmp),
                        get_precision(SR_tmp, GT_tmp),
                        get_DC(SR_tmp, GT_tmp),
                        get_IOU(SR_tmp, GT_tmp)
                    ])

                    acc += result_tmp[1]
                    SE += result_tmp[2]
                    SP += result_tmp[3]
                    PC += result_tmp[4]
                    DC += result_tmp[5]
                    IOU += result_tmp[6]
                    detail_result.append(result_tmp)

                    length += 1

        accuracy = acc / length
        sensitivity = SE / length
        specificity = SP / length
        precision = PC / length
        disc = DC / length
        iou = IOU / length
        detail_result = np.array(detail_result)

        if (self.save_detail_result
            ):  # detail_result = [id, acc, SE, SP, PC, dsc, IOU]
            excel_save_path = os.path.join(self.result_path,
                                           mode + '_pre_detial_result.xlsx')
            writer = pd.ExcelWriter(excel_save_path)
            detail_result = pd.DataFrame(detail_result)
            detail_result.to_excel(writer, mode, float_format='%.5f')
            writer.save()
            writer.close()

        return accuracy, sensitivity, specificity, precision, disc, iou
示例#8
0
def testing(num_split, class_params, encoder, decoder):
    """
    测试推理
    """
    import gc
    torch.cuda.empty_cache()
    gc.collect()

    sub = "./data/Clouds_Classify/sample_submission.csv"
    sub = pd.read_csv(open(sub))
    sub.head()

    sub['label'] = sub['Image_Label'].apply(lambda x: x.split('_')[1])
    sub['im_id'] = sub['Image_Label'].apply(lambda x: x.split('_')[0])

    preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder, 'imagenet')
    if decoder == 'unet':
        model = smp.Unet(
            encoder_name=encoder,
            encoder_weights='imagenet',
            classes=4,
            activation=None,
        )
    else:
        model = smp.FPN(
            encoder_name=encoder,
            encoder_weights='imagenet',
            classes=4,
            activation=None,
        )
    test_ids = [id for id in os.listdir(test_imgs_folder)]

    test_dataset = CloudDataset(
        df=sub,
        transforms=get_validation_augmentation(),
        datatype='test',
        img_ids=test_ids,
        preprocessing=get_preprocessing(preprocessing_fn))
    test_loader = DataLoader(test_dataset,
                             batch_size=1,
                             shuffle=False,
                             num_workers=2)

    loaders = {"test": test_loader}
    logdir = "./logs/log_{}_{}/log_{}".format(encoder, decoder, num_split)

    encoded_pixels = []

    ###############使用pytorch TTA预测####################
    use_TTA = True
    checkpoint_path = logdir + '/checkpoints/best.pth'
    runner_out = []
    model.load_state_dict(torch.load(checkpoint_path)['model_state_dict'])
    #使用tta预测
    if use_TTA:
        transforms = tta.Compose([
            tta.HorizontalFlip(),
            tta.VerticalFlip(),
            tta.Scale(scales=[5 / 6, 1, 7 / 6]),
        ])
        tta_model = tta.SegmentationTTAWrapper(model,
                                               transforms,
                                               merge_mode='mean')
    else:
        tta_model = model

    tta_model = tta_model.cuda()
    tta_model.eval()

    with torch.no_grad():
        for i, data in enumerate(tqdm.tqdm(loaders['test'])):
            img, _ = data
            img = img.cuda()
            batch_preds = tta_model(img).cpu().numpy()
            runner_out.extend(batch_preds)
    runner_out = np.array(runner_out)

    for i, output in tqdm.tqdm(enumerate(runner_out)):
        for j, probability in enumerate(output):
            if probability.shape != (350, 525):
                probability = cv2.resize(probability,
                                         dsize=(525, 350),
                                         interpolation=cv2.INTER_LINEAR)
            logit = sigmoid(probability)
            predict, num_predict = post_process(logit, class_params[j][0],
                                                class_params[j][1])

            if num_predict == 0:
                encoded_pixels.append('')
            else:
                r = mask2rle(predict)
                encoded_pixels.append(r)

    sub['EncodedPixels'] = encoded_pixels
    sub.to_csv('./sub/{}_{}/tta_submission_{}.csv'.format(
        encoder, decoder, num_split),
               columns=['Image_Label', 'EncodedPixels'],
               index=False)
示例#9
0
def validation(valid_ids, num_split, encoder, decoder):
    """
    模型验证,并选择后处理参数
    """
    train = "./data/Clouds_Classify/train.csv"

    # Data overview
    train = pd.read_csv(open(train))
    train.head()

    train['label'] = train['Image_Label'].apply(lambda x: x.split('_')[1])
    train['im_id'] = train['Image_Label'].apply(lambda x: x.split('_')[0])

    ENCODER = encoder
    ENCODER_WEIGHTS = 'imagenet'
    if decoder == 'unet':
        model = smp.Unet(
            encoder_name=ENCODER,
            encoder_weights=ENCODER_WEIGHTS,
            classes=4,
            activation=None,
        )
    else:
        model = smp.FPN(
            encoder_name=ENCODER,
            encoder_weights=ENCODER_WEIGHTS,
            classes=4,
            activation=None,
        )
    preprocessing_fn = smp.encoders.get_preprocessing_fn(
        ENCODER, ENCODER_WEIGHTS)

    num_workers = 4
    valid_bs = 32
    valid_dataset = CloudDataset(
        df=train,
        transforms=get_validation_augmentation(),
        datatype='valid',
        img_ids=valid_ids,
        preprocessing=get_preprocessing(preprocessing_fn))
    valid_loader = DataLoader(valid_dataset,
                              batch_size=valid_bs,
                              shuffle=False,
                              num_workers=num_workers)

    loaders = {"valid": valid_loader}
    logdir = "./logs/log_{}_{}/log_{}".format(encoder, decoder, num_split)

    valid_masks = []
    probabilities = np.zeros((len(valid_ids) * 4, 350, 525))

    ############### TTA预测 ####################
    use_TTA = True
    checkpoint_path = logdir + '/checkpoints/best.pth'
    runner_out = []
    model.load_state_dict(torch.load(checkpoint_path)['model_state_dict'])

    if use_TTA:
        transforms = tta.Compose([
            tta.HorizontalFlip(),
            tta.VerticalFlip(),
            tta.Scale(scales=[5 / 6, 1, 7 / 6]),
        ])
        tta_model = tta.SegmentationTTAWrapper(model,
                                               transforms,
                                               merge_mode='mean')
    else:
        tta_model = model

    tta_model = tta_model.cuda()
    tta_model.eval()

    with torch.no_grad():
        for i, data in enumerate(tqdm.tqdm(loaders['valid'])):
            img, _ = data
            img = img.cuda()
            batch_preds = tta_model(img).cpu().numpy()
            runner_out.extend(batch_preds)
    runner_out = np.array(runner_out)
    ######################END##########################

    for i, ((_, mask),
            output) in enumerate(tqdm.tqdm(zip(valid_dataset, runner_out))):
        for m in mask:
            if m.shape != (350, 525):
                m = cv2.resize(m,
                               dsize=(525, 350),
                               interpolation=cv2.INTER_LINEAR)
            valid_masks.append(m)

        for j, probability in enumerate(output):
            if probability.shape != (350, 525):
                probability = cv2.resize(probability,
                                         dsize=(525, 350),
                                         interpolation=cv2.INTER_LINEAR)
            probabilities[i * 4 + j, :, :] = probability

    # Find optimal values
    print('searching for optimal param...')
    params_0 = [[35, 76], [12000, 19001]]
    params_1 = [[35, 76], [12000, 19001]]
    params_2 = [[35, 76], [12000, 19001]]
    params_3 = [[35, 76], [8000, 15001]]
    param = [params_0, params_1, params_2, params_3]

    for class_id in range(4):
        par = param[class_id]
        attempts = []
        for t in range(par[0][0], par[0][1], 5):
            t /= 100
            for ms in range(par[1][0], par[1][1], 2000):
                masks = []
                print('==> searching [class_id:%d threshold:%.3f ms:%d]' %
                      (class_id, t, ms))
                for i in tqdm.tqdm(range(class_id, len(probabilities), 4)):
                    probability = probabilities[i]
                    predict, _ = post_process(sigmoid(probability), t, ms)
                    masks.append(predict)

                d = []
                for i, j in zip(masks, valid_masks[class_id::4]):
                    if (i.sum() == 0) & (j.sum() == 0):
                        d.append(1)
                    else:
                        d.append(dice(i, j))

                attempts.append((t, ms, np.mean(d)))

        attempts_df = pd.DataFrame(attempts,
                                   columns=['threshold', 'size', 'dice'])

        attempts_df = attempts_df.sort_values('dice', ascending=False)
        attempts_df.to_csv(
            './params/{}_{}_par/params_{}/tta_params_{}.csv'.format(
                encoder, decoder, num_split, class_id),
            columns=['threshold', 'size', 'dice'],
            index=False)
示例#10
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--encoder', type=str, default='efficientnet-b0')
    parser.add_argument('--model', type=str, default='unet')
    parser.add_argument('--loc', type=str)
    parser.add_argument('--data_folder', type=str, default='../input/')
    parser.add_argument('--batch_size', type=int, default=2)
    parser.add_argument('--optimize', type=bool, default=False)
    parser.add_argument('--tta_pre', type=bool, default=False)
    parser.add_argument('--tta_post', type=bool, default=False)
    parser.add_argument('--merge', type=str, default='mean')
    parser.add_argument('--min_size', type=int, default=10000)
    parser.add_argument('--thresh', type=float, default=0.5)
    parser.add_argument('--name', type=str)

    args = parser.parse_args()
    encoder = args.encoder
    model = args.model
    loc = args.loc
    data_folder = args.data_folder
    bs = args.batch_size
    optimize = args.optimize
    tta_pre = args.tta_pre
    tta_post = args.tta_post
    merge = args.merge
    min_size = args.min_size
    thresh = args.thresh
    name = args.name

    if model == 'unet':
        model = smp.Unet(encoder_name=encoder,
                         encoder_weights='imagenet',
                         classes=4,
                         activation=None)
    if model == 'fpn':
        model = smp.FPN(
            encoder_name=encoder,
            encoder_weights='imagenet',
            classes=4,
            activation=None,
        )
    if model == 'pspnet':
        model = smp.PSPNet(
            encoder_name=encoder,
            encoder_weights='imagenet',
            classes=4,
            activation=None,
        )
    if model == 'linknet':
        model = smp.Linknet(
            encoder_name=encoder,
            encoder_weights='imagenet',
            classes=4,
            activation=None,
        )

    preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder, 'imagenet')

    test_df = get_dataset(train=False)
    test_df = prepare_dataset(test_df)
    test_ids = test_df['Image_Label'].apply(
        lambda x: x.split('_')[0]).drop_duplicates().values
    test_dataset = CloudDataset(
        df=test_df,
        datatype='test',
        img_ids=test_ids,
        transforms=valid1(),
        preprocessing=get_preprocessing(preprocessing_fn))
    test_loader = DataLoader(test_dataset, batch_size=bs, shuffle=False)

    val_df = get_dataset(train=True)
    val_df = prepare_dataset(val_df)
    _, val_ids = get_train_test(val_df)
    valid_dataset = CloudDataset(
        df=val_df,
        datatype='train',
        img_ids=val_ids,
        transforms=valid1(),
        preprocessing=get_preprocessing(preprocessing_fn))
    valid_loader = DataLoader(valid_dataset, batch_size=bs, shuffle=False)

    model.load_state_dict(torch.load(loc)['model_state_dict'])

    class_params = {
        0: (thresh, min_size),
        1: (thresh, min_size),
        2: (thresh, min_size),
        3: (thresh, min_size)
    }

    if optimize:
        print("OPTIMIZING")
        print(tta_pre)
        if tta_pre:
            opt_model = tta.SegmentationTTAWrapper(
                model,
                tta.Compose([
                    tta.HorizontalFlip(),
                    tta.VerticalFlip(),
                    tta.Rotate90(angles=[0, 180])
                ]),
                merge_mode=merge)
        else:
            opt_model = model
        tta_runner = SupervisedRunner()
        print("INFERRING ON VALID")
        tta_runner.infer(
            model=opt_model,
            loaders={'valid': valid_loader},
            callbacks=[InferCallback()],
            verbose=True,
        )

        valid_masks = []
        probabilities = np.zeros((4 * len(valid_dataset), 350, 525))
        for i, (batch, output) in enumerate(
                tqdm(
                    zip(valid_dataset,
                        tta_runner.callbacks[0].predictions["logits"]))):
            _, mask = batch
            for m in mask:
                if m.shape != (350, 525):
                    m = cv2.resize(m,
                                   dsize=(525, 350),
                                   interpolation=cv2.INTER_LINEAR)
                valid_masks.append(m)

            for j, probability in enumerate(output):
                if probability.shape != (350, 525):
                    probability = cv2.resize(probability,
                                             dsize=(525, 350),
                                             interpolation=cv2.INTER_LINEAR)
                probabilities[(i * 4) + j, :, :] = probability

        print("RUNNING GRID SEARCH")
        for class_id in range(4):
            print(class_id)
            attempts = []
            for t in range(30, 70, 5):
                t /= 100
                for ms in [7500, 10000, 12500, 15000, 175000]:
                    masks = []
                    for i in range(class_id, len(probabilities), 4):
                        probability = probabilities[i]
                        predict, num_predict = post_process(
                            sigmoid(probability), t, ms)
                        masks.append(predict)

                    d = []
                    for i, j in zip(masks, valid_masks[class_id::4]):
                        if (i.sum() == 0) & (j.sum() == 0):
                            d.append(1)
                        else:
                            d.append(dice(i, j))

                    attempts.append((t, ms, np.mean(d)))

            attempts_df = pd.DataFrame(attempts,
                                       columns=['threshold', 'size', 'dice'])

            attempts_df = attempts_df.sort_values('dice', ascending=False)
            print(attempts_df.head())
            best_threshold = attempts_df['threshold'].values[0]
            best_size = attempts_df['size'].values[0]

            class_params[class_id] = (best_threshold, best_size)

        del opt_model
        del tta_runner
        del valid_masks
        del probabilities
    gc.collect()

    if tta_post:
        model = tta.SegmentationTTAWrapper(model,
                                           tta.Compose([
                                               tta.HorizontalFlip(),
                                               tta.VerticalFlip(),
                                               tta.Rotate90(angles=[0, 180])
                                           ]),
                                           merge_mode=merge)
    else:
        model = model
    print(tta_post)

    runner = SupervisedRunner()
    runner.infer(
        model=model,
        loaders={'test': test_loader},
        callbacks=[InferCallback()],
        verbose=True,
    )

    encoded_pixels = []
    image_id = 0

    for i, image in enumerate(tqdm(runner.callbacks[0].predictions['logits'])):
        for i, prob in enumerate(image):
            if prob.shape != (350, 525):
                prob = cv2.resize(prob,
                                  dsize=(525, 350),
                                  interpolation=cv2.INTER_LINEAR)
            predict, num_predict = post_process(sigmoid(prob),
                                                class_params[image_id % 4][0],
                                                class_params[image_id % 4][1])
            if num_predict == 0:
                encoded_pixels.append('')
            else:
                r = mask2rle(predict)
                encoded_pixels.append(r)
            image_id += 1

    test_df['EncodedPixels'] = encoded_pixels
    test_df.to_csv(name, columns=['Image_Label', 'EncodedPixels'], index=False)
def get_test_logits(encoder, decoder):
    '''
    预测并保存测试集logits
    '''
    sub = "./data/Clouds_Classify/sample_submission.csv"
    sub = pd.read_csv(open(sub))
    
    sub['label'] = sub['Image_Label'].apply(lambda x: x.split('_')[1])
    sub['im_id'] = sub['Image_Label'].apply(lambda x: x.split('_')[0])
   
    #建立模型
    preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder, 'imagenet')
    if decoder == 'unet':
        model = smp.Unet(
            encoder_name=encoder, 
            encoder_weights='imagenet', 
            classes=4, 
            activation=None,
        )
    else:
        model = smp.FPN(
            encoder_name=encoder, 
            encoder_weights='imagenet', 
            classes=4, 
            activation=None,
        )
    
    #载入模型参数
    logdir = "./logs/log_{}_{}/log_{}".format(encoder, 'fpn', 4)
    checkpoint_path = logdir + '/checkpoints/best.pth'
    model.load_state_dict(torch.load(checkpoint_path)['model_state_dict'])
    
    #使用tta预测
    use_TTA = True
    if use_TTA:
        print('using TTA!!!')
        transforms = tta.Compose([
            tta.HorizontalFlip(),
            tta.VerticalFlip(),
            tta.Scale(scales=[5/6, 1, 7/6]),
            ])
        tta_model = tta.SegmentationTTAWrapper(model, transforms, merge_mode='mean')
    else:
        tta_model = model

    test_ids = [id for id in os.listdir(test_imgs_folder)]

    test_dataset = CloudDataset(df=sub, transforms = get_validation_augmentation(), datatype='test', img_ids=test_ids, preprocessing=get_preprocessing(preprocessing_fn))
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2)

    tta_model = tta_model.cuda()
    tta_model.eval() 

    with torch.no_grad(): 
        for i, data in enumerate(tqdm.tqdm(test_loader)):
            img, _, img_name = data
            img = img.cuda()
            batch_preds = tta_model(img).cpu().numpy()
            batch_preds = batch_preds.astype(np.float16)


            save_dir = './logits/test/' + encoder + '_fpn' + '/log_{}'.format(4)
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)

            #保存为pickle
            file_name = img_name[0].split('.')[0] + '.plk'
            file_path = os.path.join(save_dir, file_name)

            with open(file_path, 'wb') as wf:
                plk.dump(batch_preds, wf)
def main():
    fold_path = args.fold_path
    fold_num = args.fold_num
    model_name = args.model_name
    train_csv = args.train_csv
    sub_csv = args.sub_csv
    encoder = args.encoder
    num_workers = args.num_workers
    batch_size = args.batch_size
    log_path = args.log_path
    is_tta = args.is_tta
    test_batch_size = args.test_batch_size
    attention_type = args.attention_type
    print(log_path)

    train = pd.read_csv(train_csv)
    train['label'] = train['Image_Label'].apply(lambda x: x.split('_')[-1])
    train['im_id'] = train['Image_Label'].apply(lambda x: x.replace('_' + x.split('_')[-1], ''))

    val_fold = pd.read_csv(f'{fold_path}/valid_file_fold_{fold_num}.csv')
    valid_ids = np.array(val_fold.file_name)

    attention_type = None if attention_type == 'None' else attention_type

    encoder_weights = 'imagenet'

    if model_name == 'Unet':
        model = smp.Unet(
            encoder_name=encoder,
            encoder_weights=encoder_weights,
            classes=CLASS,
            activation='softmax',
            attention_type=attention_type,
        )
    if model_name == 'Linknet':
        model = smp.Linknet(
            encoder_name=encoder,
            encoder_weights=encoder_weights,
            classes=CLASS,
            activation='softmax',
        )
    if model_name == 'FPN':
        model = smp.FPN(
            encoder_name=encoder,
            encoder_weights=encoder_weights,
            classes=CLASS,
            activation='softmax',
        )
    if model_name == 'ORG':
        model = Linknet_resnet18_ASPP(
        )


    preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder, encoder_weights)

    valid_dataset = CloudDataset(df=train,
                                 datatype='valid',
                                 img_ids=valid_ids,
                                 transforms=get_validation_augmentation(),
                                 preprocessing=get_preprocessing(preprocessing_fn))

    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    loaders = {"infer": valid_loader}
    runner = SupervisedRunner()

    checkpoint = torch.load(f"{log_path}/checkpoints/best.pth")
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    transforms = tta.Compose(
        [
            tta.HorizontalFlip(),
            tta.VerticalFlip(),
        ]
    )
    model = tta.SegmentationTTAWrapper(model, transforms)
    runner.infer(
        model=model,
        loaders=loaders,
        callbacks=[InferCallback()],
    )
    callbacks_num = 0

    valid_masks = []
    probabilities = np.zeros((valid_dataset.__len__() * CLASS, IMG_SIZE[0], IMG_SIZE[1]))

    # ========
    # val predict
    #

    for batch in tqdm(valid_dataset):  # クラスごとの予測値
        _, mask = batch
        for m in mask:
            m = resize_img(m)
            valid_masks.append(m)

    for i, output in enumerate(tqdm(runner.callbacks[callbacks_num].predictions["logits"])):
        for j, probability in enumerate(output):
            probability = resize_img(probability)  # 各クラスごとにprobability(予測値)が取り出されている。jは0~3だと思う。
            probabilities[i * CLASS + j, :, :] = probability

    # ========
    # search best size and threshold
    #

    class_params = {}
    for class_id in range(CLASS):
        attempts = []
        for threshold in range(20, 90, 5):
            threshold /= 100
            for min_size in [10000, 15000, 20000]:
                masks = class_masks(class_id, probabilities, threshold, min_size)
                dices = class_dices(class_id, masks, valid_masks)
                attempts.append((threshold, min_size, np.mean(dices)))

        attempts_df = pd.DataFrame(attempts, columns=['threshold', 'size', 'dice'])
        attempts_df = attempts_df.sort_values('dice', ascending=False)

        print(attempts_df.head())

        best_threshold = attempts_df['threshold'].values[0]
        best_size = attempts_df['size'].values[0]

        class_params[class_id] = (best_threshold, best_size)

    # ========
    # gc
    #
    torch.cuda.empty_cache()
    gc.collect()

    # ========
    # predict
    #
    sub = pd.read_csv(sub_csv)
    sub['label'] = sub['Image_Label'].apply(lambda x: x.split('_')[-1])
    sub['im_id'] = sub['Image_Label'].apply(lambda x: x.replace('_' + x.split('_')[-1], ''))

    test_ids = sub['Image_Label'].apply(lambda x: x.split('_')[0]).drop_duplicates().values

    test_dataset = CloudDataset(df=sub,
                                datatype='test',
                                img_ids=test_ids,
                                transforms=get_validation_augmentation(),
                                preprocessing=get_preprocessing(preprocessing_fn))

    encoded_pixels = get_test_encoded_pixels(test_dataset, runner, class_params, test_batch_size)
    sub['EncodedPixels'] = encoded_pixels

    # ========
    # val dice
    #

    val_Image_Label = []
    for i, row in val_fold.iterrows():
        val_Image_Label.append(row.file_name + '_Fish')
        val_Image_Label.append(row.file_name + '_Flower')
        val_Image_Label.append(row.file_name + '_Gravel')
        val_Image_Label.append(row.file_name + '_Sugar')

    val_encoded_pixels = get_test_encoded_pixels(valid_dataset, runner, class_params, test_batch_size)
    val = pd.DataFrame(val_encoded_pixels, columns=['EncodedPixels'])
    val['Image_Label'] = val_Image_Label

    sub.to_csv(f'./sub/sub_{model_name}_fold_{fold_num}_{encoder}.csv', columns=['Image_Label', 'EncodedPixels'], index=False)
    val.to_csv(f'./val/val_{model_name}_fold_{fold_num}_{encoder}.csv', columns=['Image_Label', 'EncodedPixels'], index=False)
示例#13
0
import pytest
import torch
import ttach as tta


@pytest.mark.parametrize(
    "transform",
    [
        tta.HorizontalFlip(),
        tta.VerticalFlip(),
        tta.Rotate90(angles=[0, 90, 180, 270]),
        tta.Scale(scales=[1, 2, 4], interpolation="nearest"),
        tta.Resize(sizes=[(4, 5), (8, 10)],
                   original_size=(4, 5),
                   interpolation="nearest")
    ],
)
def test_aug_deaug_mask(transform):
    a = torch.arange(20).reshape(1, 1, 4, 5).float()
    for p in transform.params:
        aug = transform.apply_aug_image(a, **{transform.pname: p})
        deaug = transform.apply_deaug_mask(aug, **{transform.pname: p})
        assert torch.allclose(a, deaug)


@pytest.mark.parametrize(
    "transform",
    [
        tta.HorizontalFlip(),
        tta.VerticalFlip(),
        tta.Rotate90(angles=[0, 90, 180, 270]),
示例#14
0
def main(MODEL_TYPE, LEARNING_RATE, CROP_SIZE, UPSCALE_FACTOR, NUM_EPOCHS, BATCH_SIZE, IMAGE_DIR, LAST_EPOCH, MODEL_NAME='', TV_LOSS_RATE=1e-3):
    
    global net, history, lr_img, hr_img, test_lr_img, test_sr_img, valid_hr_img, valid_sr_img, valid_lr_img, scheduler, optimizer
    
    train_set    = TrainDatasetFromFolder(IMAGE_DIR, crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)
    train_loader = DataLoader(dataset=train_set, num_workers=8, batch_size=BATCH_SIZE, shuffle=True)
    n_iter = (len(train_set) // BATCH_SIZE + 1) * NUM_EPOCHS

    net = eval(f"{MODEL_TYPE}({UPSCALE_FACTOR})")
    criterion = TotalLoss(TV_LOSS_RATE)
    optimizer = optim.RAdam(net.parameters(), lr=LEARNING_RATE)
    scheduler = schedule.StepLR(optimizer, int(n_iter * 0.3), gamma=0.5, last_epoch=LAST_EPOCH)
    if LAST_EPOCH == -1:
        scheduler = GradualWarmupScheduler(optimizer, 1, n_iter // 50, after_scheduler=scheduler)
        
    # plot_scheduler(scheduler, n_iter)

    if MODEL_NAME:
        net.load_state_dict(torch.load('epochs/' + MODEL_NAME))
        print(f"# Loaded model: [{MODEL_NAME}]")
    
    print(f'# {MODEL_TYPE} parameters:', sum(param.numel() for param in net.parameters()))
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    net.to(device)
    criterion.to(device)
    
    tta_transform = tta.Compose([
        tta.HorizontalFlip(),
        tta.VerticalFlip()
    ])
    
    # Train model:
    tta_net = tta.SegmentationTTAWrapper(net, tta_transform)
    history = []
    img_test  = plt.imread(r'data\testing_lr_images\09.png')
    img_valid = plt.imread(r'data\valid_hr_images\t20.png')
    test_lr_img = torch.from_numpy(img_test.transpose(2, 0, 1)).unsqueeze(0).to(device)
    valid_hr_img = ToTensor()(img_valid)
    valid_lr_img = valid_hr_transform(img_valid.shape, UPSCALE_FACTOR)(img_valid).to(device)
    for epoch in range(1, NUM_EPOCHS + 1):
        train_bar = tqdm(train_loader)
        running_results = {'batch_sizes': 0, 'loss': 0, 'psnr': 0}
        
        # Train a epoch:
        net.train()
        for lr_img, hr_img in train_bar:
            running_results['batch_sizes'] += BATCH_SIZE
            optimizer.zero_grad()
            
            lr_img, hr_img = MixUp()(lr_img, hr_img, lr_img, hr_img)
            lr_img = lr_img.type(torch.FloatTensor).to(device)
            hr_img = hr_img.type(torch.FloatTensor).to(device)
            sr_img = tta_net(lr_img)
            
            loss = criterion(sr_img, hr_img)
            loss.backward()
            optimizer.step()
            scheduler.step()

            running_results['loss'] += loss.item() * BATCH_SIZE
            running_results['psnr'] += psnr(hr_img, sr_img) * BATCH_SIZE
            train_bar.set_description(desc='[%d/%d] Loss: %.4f, PSNR: %.4f' % (
                epoch, NUM_EPOCHS,
                running_results['loss'] / running_results['batch_sizes'],
                running_results['psnr'] / running_results['batch_sizes']
            ))
    
        # Save model parameters:
        psnr_now = running_results['psnr'] / running_results['batch_sizes']
        filename = f'epochs/{MODEL_TYPE}_x%d_epoch=%d_PSNR=%.4f.pth' % (UPSCALE_FACTOR, epoch, psnr_now)
        torch.save(net.state_dict(), filename)
        history.append(running_results)
        
        # Test model:
        if epoch % 5 == 0:
            with torch.no_grad():
                net.eval()
                
                # Plot up-scaled testing image:
                test_sr_img = net(test_lr_img)
                plot_hr_lr(test_sr_img, test_lr_img)
                
                # Compute PSNR of validation image:
                valid_sr_img = net(valid_lr_img)
                psnr_valid = psnr(valid_hr_img, valid_sr_img)
                
                # Print PSNR:
                print('\n' + '-' * 50)
                print(f"PSNR of Validation = {psnr_valid}")
                print('-' * 50 + '\n')
                
    torch.save(optimizer.state_dict(), f'optimizer_{MODEL_TYPE}_epoch={NUM_EPOCHS}.pth')
    torch.save(scheduler.state_dict(), f'scheduler_{MODEL_TYPE}_epoch={NUM_EPOCHS}.pth')

for i in range(5):
    train(i)


def load_model(fold: int, epoch: int, device: torch.device = 'cuda'):

    model = EfficientNetModel().to(device)
    model.load_state_dict(
        torch.load(f'models/effinet_b4_SAM_CosLR-f{fold}-{epoch}.pth'))

    return model


transforms = tta.Compose([tta.HorizontalFlip(), tta.VerticalFlip()])
tta_model = tta.ClassificationTTAWrapper(model, transforms)


def test(device: torch.device = 'cuda'):
    submit = pd.read_csv('data/sample_submission.csv')

    model1 = load_model(0, 19)
    model2 = load_model(1, 19)
    model3 = load_model(2, 19)
    model4 = load_model(3, 19)
    model5 = load_model(4, 19)

    tta_model1 = tta.ClassificationTTAWrapper(model1, transforms)
    tta_model2 = tta.ClassificationTTAWrapper(model2, transforms)
    tta_model3 = tta.ClassificationTTAWrapper(model3, transforms)
    # 读取测试集
    if fold_flag:
        _, test = get_fold_filelist(csv_file, K=fold_K, fold=fold_index)
        test_img_list = [img_path+sep+i[0] for i in test]
        if mask_path is not None:
            test_mask_list = [mask_path+sep+i[0] for i in test]
    else:
        test_img_list = get_filelist_frompath(img_path,'PNG')
        if mask_path is not None:
            test_mask_list = [mask_path + sep + i.split(sep)[-1] for i in test_img_list]

    # 构建两个模型
    with torch.no_grad():
        # tta设置
        tta_trans = tta.Compose([
            tta.VerticalFlip(),
            tta.HorizontalFlip(),
            tta.Rotate90(angles=[0,180]),
        ])
        # 构建模型
        # cascade1
        model_cascade1 = smp.DeepLabV3Plus(encoder_name="efficientnet-b6", encoder_weights=None, in_channels=1, classes=1)
        model_cascade1.to(device)
        model_cascade1.load_state_dict(torch.load(weight_c1))
        if c1_tta:
            model_cascade1 = tta.SegmentationTTAWrapper(model_cascade1, tta_trans,merge_mode='mean')
        model_cascade1.eval()
        # cascade2
        model_cascade2 = smp.DeepLabV3Plus(encoder_name="efficientnet-b6", encoder_weights=None, in_channels=1, classes=1)
        # model_cascade2 = smp.Unet(encoder_name="efficientnet-b6", encoder_weights=None, in_channels=1, classes=1, encoder_depth=5, decoder_attention_type='scse')
        # model_cascade2 = smp.PAN(encoder_name="efficientnet-b6",encoder_weights='imagenet',	in_channels=1, classes=1)
def get_val_logits(valid_ids, num_split, encoder, decoder):
    # valid
    train = "./data/Clouds_Classify/train.csv"

    # Data overview
    train = pd.read_csv(open(train))
    train.head()

    train['label'] = train['Image_Label'].apply(lambda x: x.split('_')[1])
    train['im_id'] = train['Image_Label'].apply(lambda x: x.split('_')[0])

    ENCODER = encoder
    ENCODER_WEIGHTS = 'imagenet'

    if decoder == 'unet':
        #建立模型
        model = smp.Unet(
            encoder_name=ENCODER, 
            encoder_weights=ENCODER_WEIGHTS, 
            classes=4, 
            activation=None,
        )
    else:
        model = smp.FPN(
            encoder_name=ENCODER, 
            encoder_weights=ENCODER_WEIGHTS, 
            classes=4, 
            activation=None,
        )
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

    num_workers = 4
    valid_bs = 1
    valid_dataset = CloudDataset(df=train, transforms = get_validation_augmentation(), datatype='valid', img_ids=valid_ids, preprocessing=get_preprocessing(preprocessing_fn))
    valid_loader = DataLoader(valid_dataset, batch_size=valid_bs, shuffle=False, num_workers=num_workers)    

    loaders = {"valid": valid_loader}
    logdir = "./logs/log_{}_{}/log_{}".format(encoder, decoder, num_split)

    print('predicting for validation data...')
    ###############使用pytorch TTA预测####################
    use_TTA = True
    checkpoint_path = logdir + '/checkpoints/best.pth'
    model.load_state_dict(torch.load(checkpoint_path)['model_state_dict'])
    #使用tta预测
    if use_TTA:
        transforms = tta.Compose([
            tta.HorizontalFlip(),
            tta.VerticalFlip(),
            tta.Scale(scales=[5/6, 1, 7/6]),
        ])
        tta_model = tta.SegmentationTTAWrapper(model, transforms, merge_mode='mean')
    else:
        tta_model = model
    
    tta_model = tta_model.cuda()
    tta_model.eval()

    with torch.no_grad(): 
        for _, data in enumerate(tqdm.tqdm(loaders['valid'])):
            img, _, img_name = data
            img = img.cuda()
            batch_preds = tta_model(img).cpu().numpy()
            batch_preds = batch_preds.astype(np.float16)
            
            save_dir = './logits/valid/' + encoder + '_' + decoder + '/log_{}'.format(num_split)
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)

            file_name = img_name[0].split('.')[0] + '.plk'
            file_path = os.path.join(save_dir, file_name)

            with open(file_path, 'wb') as wf:
                plk.dump(batch_preds, wf)