Beispiel #1
0
 def __init__(self):
     super().__init__()
     model = smp.FPN(encoder_name='se_resnet50',
                     encoder_weights='imagenet', classes=10, activation=None)
     self.model = model
Beispiel #2
0
    ## seg model
    parser.add_argument('--gpu_ids',
                        type=str,
                        default='0',
                        help="device id to run")
    parser.add_argument(
        '--seg_model_path',
        type=str,
        default='results/random_take_8obj_1e-4_ok/checkpoint_3.pth')

    args = parser.parse_args()

    if args.use_typeVector:
        args.network = 'cnn_type'
    print('Load RL model.')
    model, env = train(args)

    print('Load Segmentation model.')
    seg_models = []
    test_loaders = []
    gpus = args.gpu_ids.split(',')
    ckpt = torch.load(args.seg_model_path)
    for i in range(4):
        seg_model = smp.FPN('resnet50', in_channels=4, classes=3).cuda()
        seg_model = nn.DataParallel(seg_model,
                                    device_ids=[int(_) for _ in gpus])
        seg_model.load_state_dict(ckpt['FPN_' + str(i)])

        seg_models.append(seg_model)

    make_video(model, env, seg_models, args)
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description='Preston\'s MRI-segmentation Arguments')
    parser.add_argument(
        '--mode',
        type=str,
        default="train",
        metavar='M',  # need to wrap the argument in double quotation
        help='Select either train or test (default: "train")')
    parser.add_argument('--checkpoint',
                        type=str,
                        default=None,
                        help="Path to checkpoint")
    parser.add_argument(
        '--num_files',
        type=int,
        default=4000,
        metavar='N',
        help='Number of files in training; split using tv_ratio (default: 4000)'
    ),
    parser.add_argument(
        '--tv_ratio',
        type=list,
        default=[0.8, 0.2],
        help=
        'Train-validate ratio in a list, must add up to 1, (default: [0.8, 0.2])'
    )
    parser.add_argument('--train_batch_size',
                        type=int,
                        default=4,
                        metavar='N',
                        help='input batch size for training (default: 4)')
    parser.add_argument('--valid_batch_size',
                        type=int,
                        default=4,
                        metavar='N',
                        help='input batch size for validating (default: 4)')
    parser.add_argument('--epochs',
                        type=int,
                        default=40,
                        metavar='N',
                        help='number of epochs to train (default: 40)')
    parser.add_argument('--lr',
                        type=float,
                        default=3e-3,
                        metavar='LR',
                        help='learning rate (default: 3e-3)')
    parser.add_argument('--sch_step',
                        type=int,
                        default=10,
                        metavar='sch_step',
                        help='Scheduler step size (default: 10)')
    parser.add_argument('--loss',
                        type=str,
                        default="dl",
                        metavar='LOSS',
                        help='Type of loss - dl or gdl (default: dl)')
    parser.add_argument('--beta',
                        type=float,
                        default=1,
                        metavar='B',
                        help='Beta in dice loss (default: 1)')
    parser.add_argument('--pw',
                        type=float,
                        default=1000,
                        metavar='PW',
                        help='Positive weight in the bce function')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.7,
                        metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument(
        '--no-cuda',
        action='store_true',
        default=
        False,  # if I specify "--no-cuda," I don't need an argument and this assumes that I am using my CPU. Otherwise, don't specify anything and use GPU by default
        help='disables CUDA training')
    parser.add_argument(
        '--just_show',
        type=int,
        default=None,
        help='set this flag when you want to skip the test metrics')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')

    start_time = time.time()
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    torch.manual_seed(args.seed)

    # Model Definition
    ENCODER = 'se_resnext50_32x4d'
    # ENCODER = 'vgg16'
    ENCODER_WEIGHTS = 'imagenet'  # None or 'imagenet' (if None, then weights are randomly initialized)
    CLASSES = ['tumor']  # only 1 class in this example
    ACTIVATION = 'sigmoid'
    # could be None for logits or 'softmax2d' for multicalss segmentation; sigmoid for binary classification
    DEVICE = "cuda" if use_cuda else "cpu"

    # create segmentation model
    model = smp.FPN(
        encoder_name=ENCODER,
        encoder_weights=ENCODER_WEIGHTS,
        classes=len(CLASSES),  # only 1 class in this case
        activation=ACTIVATION,
    )

    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER,
                                                         pretrained='imagenet')

    if args.loss == "dl":
        loss = smp.utils.losses.DiceLoss(beta=args.beta)
    elif args.loss == "gdl":
        loss = smp.utils.losses.GeneralizedDiceLoss()
    elif args.loss == "dl+bce":
        pw = torch.FloatTensor(
            [args.pw]
        )  # gonna use 1000 here because we of the pos_thresh we set in this pruned dataset
        pw = torch.reshape(pw, (1, 1, 1, 1))
        loss = smp.utils.losses.DiceLoss(
            beta=args.beta) + smp.utils.losses.BCEWithLogitsLoss(pos_weight=pw)
    elif args.loss == "dl+log(bce)":  # doesn't work yet
        pw = torch.FloatTensor([args.pw])
        pw = torch.reshape(pw, (1, 1, 1, 1))
        loss = smp.utils.losses.DiceLoss(beta=args.beta) + torch.log(
            smp.utils.losses.BCEWithLogitsLoss(pos_weight=pw))
    else:
        raise ValueError("Loss can only be dl or gdl for now")

    metrics = [
        smp.utils.metrics.IoU(threshold=0.5),
    ]

    data_dir = 'D:\\MRI Segmentation\\data'
    model_dir = r"C:\Users\prestonpan\PycharmProjects\Segmentation_example\runs"

    if args.mode == "train":
        comment = f'lr={args.lr}, loss={args.loss}, epochs={args.epochs}, ' \
                  f'num_files={args.num_files}, sch_step={args.sch_step}, train_batch_size={args.train_batch_size}, ' \
                  f'beta={args.beta}, pw={args.pw}'
        tb = SummaryWriter(comment=comment)
        print(comment
              )  # to verify that the hyperparameter values are set correctly

        dfTrain, dfVal = pp.prepare_csv(data_dir,
                                        args.tv_ratio,
                                        num_files=args.num_files,
                                        mode=args.mode)

        # Gonna modify the loader so that it looks more similar to the Dataset class used in smp's example
        train_dataset = MRIDataset(
            dfTrain,  # might just remove this kwarg later on
            classes=CLASSES,
            augmentation=pp.get_training_augmentation(),
            preprocessing=pp.get_preprocessing(preprocessing_fn))
        valid_dataset = MRIDataset(
            dfVal,
            classes=CLASSES,
            augmentation=pp.get_training_augmentation(),
            preprocessing=pp.get_preprocessing(preprocessing_fn))

        train_loader = DataLoader(train_dataset,
                                  batch_size=args.train_batch_size,
                                  shuffle=True,
                                  num_workers=12)
        valid_loader = DataLoader(valid_dataset,
                                  batch_size=args.valid_batch_size,
                                  shuffle=False,
                                  num_workers=4)

        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=0.90,
                              weight_decay=1e-6,
                              nesterov=True)
        # optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-6)  # this works very poorly

        start_epoch = 0
        max_score = 0

        scheduler = optim.lr_scheduler.StepLR(
            optimizer, step_size=args.sch_step, gamma=args.gamma
        )  # every step_size number of epoch, the lr is multiplied by args.gamma - reduces learning rate due to subgradient method

        # Loads the checkpoint and update some parameters
        if args.checkpoint is not None:
            root = tk.Tk()
            checkpoint_path = filedialog.askopenfilename(
                parent=root,
                initialdir=model_dir,
                title='Please select a model (.pth)')
            model, optimizer, scheduler, max_score, start_epoch = \
                pp.load_checkpoint(model, optimizer, scheduler, filename=checkpoint_path)

        train_epoch = smp.utils.train.TrainEpoch(
            model,
            loss=loss,
            metrics=metrics,
            optimizer=optimizer,
            device=DEVICE,
            verbose=True,
        )

        valid_epoch = smp.utils.train.ValidEpoch(
            model,
            loss=loss,
            metrics=metrics,
            device=DEVICE,
            verbose=True,
        )

        stagnant_epoch = 0
        for i in range(start_epoch, args.epochs):
            print('\nEpoch: {}'.format(i + 1))  # 1-index the epoch number
            train_logs = train_epoch.run(train_loader)
            valid_logs = valid_epoch.run(valid_loader)

            tb.add_scalar('Loss', train_logs[loss.__name__], i)
            tb.add_scalar('IOU', train_logs['iou_score'], i)

            scheduler.step()

            if max_score < valid_logs[
                    'iou_score']:  # we really care more about the validation metric because
                stagnant_epoch = 0
                # training metric can be overfit, which is unrepresentative
                max_score = valid_logs['iou_score']
                print('New best IOU: %.5f' % max_score)
                print('Saving checkpoint...')
                state = {
                    'model': model,
                    'epoch': args.epochs,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'iou_score': max_score
                }
                torch.save(state, './best MRI models/{}.pth'.format(comment))
            else:
                stagnant_epoch = stagnant_epoch + 1
                if stagnant_epoch > 8:
                    print(
                        'max iou remained stagnant for the past 8 epochs, returning early'
                    )
                    return

        total_time = time.time() - start_time
        print('Training took %.2f seconds' % total_time)

    elif args.mode == "test":

        dfTest = pp.prepare_csv(
            data_dir, args.tv_ratio, num_files=-1,
            mode=args.mode)  # number of testing images is fixed

        root = tk.Tk()
        model_path = filedialog.askopenfilename(
            parent=root,
            initialdir=model_dir,
            title='Please select a model (.pth)')
        root.destroy()

        # best_model = torch.load(model_path) # This line doesn't work anymore because the model is now saved in a dict
        best_model_dict = torch.load(model_path)
        best_model = best_model_dict['model']
        print('model loaded!')

        test_dataset = MRIDataset(
            dfTest,
            classes=CLASSES,
            preprocessing=pp.get_preprocessing(preprocessing_fn))

        # We have pre-processing here (to make sure the prediction works properly), but no need augmentation
        # because we're not training anymore
        test_dataloader = DataLoader(test_dataset)
        test_epoch = smp.utils.train.ValidEpoch(
            model=best_model,
            loss=loss,
            metrics=metrics,
            device=DEVICE,
        )

        if args.just_show is not None:
            num_to_display = args.just_show
        else:
            logs = test_epoch.run(test_dataloader)
            num_to_display = 5

        # # dataset for visualization without augmentation
        test_dataset_vis = MRIDataset(
            dfTest,
            classes=CLASSES)  # no preprocessing; this is the native image

        for i in range(num_to_display):
            n = np.random.choice(len(test_dataset_vis))

            image_vis = test_dataset_vis[n][0].astype(
                'int16')  # arranged in H * W * 3 (3 = RGB channels)
            image, gt_masks = test_dataset[n]
            # image has 3 channels (RGB) -> shape(image) = 3 * H * W, different from image_vis because of preprocessing
            # gt_masks has C classes -> shape(gt_mask) = C * H * W (C = Classes)

            x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(
                0)  # shape = 1 * H * W * 3 (3 = RGB channels)
            pr_masks = best_model.predict(
                x_tensor)  # shape = 1 * H * W * C (C = number of classes)
            pr_masks = (pr_masks.squeeze(0).cpu().numpy().round()
                        )  # shape = C * H * W  (squeeze the 0th dimensions)

            # Index through the classes and look at them one at a time
            for j in range(len(CLASSES)):
                gt_mask = gt_masks[j].squeeze()
                pr_mask = pr_masks[j].squeeze()

                pp.visualize_2(image=image_vis,
                               gt=gt_mask,
                               pr=pr_mask,
                               iou=1 - loss.forward(torch.from_numpy(gt_mask),
                                                    torch.from_numpy(pr_mask)))
Beispiel #4
0
def get_model(
    model_type: str = "Unet",
    encoder: str = "Resnet18",
    encoder_weights: str = "imagenet",
    activation: str = None,
    n_classes: int = 4,
    task: str = "segmentation",
    source: str = "pretrainedmodels",
    head: str = "simple",
):
    """
    Get model for training or inference.

    Returns loaded models, which is ready to be used.

    Args:
        model_type: segmentation model architecture
        encoder: encoder of the model
        encoder_weights: pre-trained weights to use
        activation: activation function for the output layer
        n_classes: number of classes in the output layer
        task: segmentation or classification
        source: source of model for classification
        head: simply change number of outputs or use better output head

    Returns:

    """
    if task == "segmentation":
        if model_type == "Unet":
            model = smp.Unet(
                # attention_type='scse',
                encoder_name=encoder,
                encoder_weights=encoder_weights,
                classes=n_classes,
                activation=activation,
            )

        elif model_type == "Linknet":
            model = smp.Linknet(
                encoder_name=encoder,
                encoder_weights=encoder_weights,
                classes=n_classes,
                activation=activation,
            )

        elif model_type == "FPN":
            model = smp.FPN(
                encoder_name=encoder,
                encoder_weights=encoder_weights,
                classes=n_classes,
                activation=activation,
            )

        elif model_type == "resnet34_fpn":
            model = resnet34_fpn(num_classes=n_classes, fpn_features=128)

        elif model_type == "effnetB4_fpn":
            model = effnetB4_fpn(num_classes=n_classes, fpn_features=128)

        else:
            model = None

    elif task == "classification":
        if source == "pretrainedmodels":
            model_fn = pretrainedmodels.__dict__[encoder]
            model = model_fn(num_classes=1000, pretrained=encoder_weights)
        elif source == "torchvision":
            model = torchvision.models.__dict__[encoder](
                pretrained=encoder_weights)

        if head == "simple":
            model.last_linear = nn.Linear(model.last_linear.in_features,
                                          n_classes)
        else:
            model = Net(net=model)

    return model
Beispiel #5
0
def main():
    parser = ArgumentParser()
    parser.add_argument('-d',
                        '--data_path',
                        dest='data_path',
                        type=str,
                        default=None,
                        help='path to the data')
    parser.add_argument('-e',
                        '--epochs',
                        dest='epochs',
                        default=20,
                        type=int,
                        help='number of epochs')
    parser.add_argument('-b',
                        '--batch_size',
                        dest='batch_size',
                        default=40,
                        type=int,
                        help='batch size')
    parser.add_argument('-s',
                        '--image_size',
                        dest='image_size',
                        default=256,
                        type=int,
                        help='input image size')
    parser.add_argument('-lr',
                        '--learning_rate',
                        dest='lr',
                        default=0.0001,
                        type=float,
                        help='learning rate')
    parser.add_argument('-wd',
                        '--weight_decay',
                        dest='weight_decay',
                        default=5e-4,
                        type=float,
                        help='weight decay')
    parser.add_argument('-lrs',
                        '--learning_rate_step',
                        dest='lr_step',
                        default=10,
                        type=int,
                        help='learning rate step')
    parser.add_argument('-lrg',
                        '--learning_rate_gamma',
                        dest='lr_gamma',
                        default=0.5,
                        type=float,
                        help='learning rate gamma')
    parser.add_argument(
        '-m',
        '--model',
        dest='model',
        default='fpn',
    )
    parser.add_argument('-w',
                        '--weight_bce',
                        default=0.5,
                        type=float,
                        help='weight BCE loss')
    parser.add_argument('-l',
                        '--load',
                        dest='load',
                        default=False,
                        help='load file model')
    parser.add_argument('-v',
                        '--val_split',
                        dest='val_split',
                        default=0.7,
                        help='train/val split')
    parser.add_argument('-o',
                        '--output_dir',
                        dest='output_dir',
                        default='./output',
                        help='dir to save log and models')
    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)
    logger = get_logger(os.path.join(args.output_dir, 'train.log'))
    logger.info('Start training with params:')
    for arg, value in sorted(vars(args).items()):
        logger.info("Argument %s: %r", arg, value)


#     net = UNet() # TODO: to use move novel arch or/and more lightweight blocks (mobilenet) to enlarge the batch_size
#     net = smp.FPN('mobilenet_v2', encoder_weights='imagenet', classes=2)
    net = smp.FPN('se_resnet50', encoder_weights='imagenet', classes=2)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if args.load:
        net.load_state_dict(torch.load(args.load))
    logger.info('Model type: {}'.format(net.__class__.__name__))

    net.to(device)

    optimizer = optim.Adam(net.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)
    criterion = lambda x, y: (args.weight_bce * nn.BCELoss()(x, y),
                              (1. - args.weight_bce) * dice_loss(x, y))
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step, gamma=args.lr_gamma) \
        if args.lr_step > 0 else None

    train_transforms = Compose([
        Crop(min_size=1 - 1 / 3., min_ratio=1.0, max_ratio=1.0, p=0.5),
        Flip(p=0.05),
        RandomRotate(),
        Pad(max_size=0.6, p=0.25),
        Resize(size=(args.image_size, args.image_size), keep_aspect=True),
        ScaleToZeroOne(),
    ])
    val_transforms = Compose([
        Resize(size=(args.image_size, args.image_size)),
        ScaleToZeroOne(),
    ])

    train_dataset = DetectionDataset(args.data_path,
                                     os.path.join(args.data_path,
                                                  'train_mask.json'),
                                     transforms=train_transforms)
    val_dataset = DetectionDataset(args.data_path,
                                   None,
                                   transforms=val_transforms)

    train_size = int(len(train_dataset) * args.val_split)
    val_dataset.image_names = train_dataset.image_names[train_size:]
    val_dataset.mask_names = train_dataset.mask_names[train_size:]
    train_dataset.image_names = train_dataset.image_names[:train_size]
    train_dataset.mask_names = train_dataset.mask_names[:train_size]
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=8,
                                  shuffle=True,
                                  drop_last=True)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=args.batch_size,
                                num_workers=4,
                                shuffle=False,
                                drop_last=False)
    logger.info('Number of batches of train/val=%d/%d', len(train_dataloader),
                len(val_dataloader))

    try:
        train(net,
              optimizer,
              criterion,
              scheduler,
              train_dataloader,
              val_dataloader,
              logger=logger,
              args=args,
              device=device)
    except KeyboardInterrupt:
        torch.save(
            net.state_dict(),
            os.path.join(args.output_dir, f'{args.model}_INTERRUPTED.pth'))
        logger.info('Saved interrupt')
        sys.exit(0)
def create_model_and_train():
    # create segmentation model with pretrained encoder
    model = smp.FPN(
        encoder_name=ENCODER,
        encoder_weights=ENCODER_WEIGHTS,
        classes=len(CarlaLanesDataset.CLASSES),
        activation=ACTIVATION,
        # encoder_depth = 4
    )

    train_dataset = CarlaLanesDataset(
        x_train_dir,
        y_train_dir,
        augmentation=get_training_augmentation(),
        preprocessing=get_preprocessing(preprocessing_fn),
        classes=CarlaLanesDataset.CLASSES,
    )

    valid_dataset = CarlaLanesDataset(
        x_valid_dir,
        y_valid_dir,
        augmentation=get_validation_augmentation(),
        preprocessing=get_preprocessing(preprocessing_fn),
        classes=CarlaLanesDataset.CLASSES,
    )

    bs_train = 2
    bs_valid = 2
    train_loader = DataLoader(train_dataset, batch_size=bs_train, shuffle=True)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=bs_valid,
                              shuffle=False)

    optimizer = torch.optim.Adam([
        dict(params=model.parameters(), lr=1e-4),
        # dict(params=model.parameters(), lr=1e-3)
    ])

    # create epoch runners
    # it is a simple loop of iterating over dataloader`s samples
    train_epoch = smp.utils.train.TrainEpoch(
        model,
        loss=loss,
        metrics=metrics,
        optimizer=optimizer,
        device=DEVICE,
        verbose=True,
    )

    valid_epoch = smp.utils.train.ValidEpoch(
        model,
        loss=loss,
        metrics=metrics,
        device=DEVICE,
        verbose=True,
    )

    # train model
    best_loss = 1e10

    for i in range(0, 5):

        print('\nEpoch: {}'.format(i))
        train_logs = train_epoch.run(train_loader)
        valid_logs = valid_epoch.run(valid_loader)

        # do something (save model, change lr, etc.)
        if best_loss > valid_logs[loss_string]:
            best_loss = valid_logs[loss_string]
            torch.save(model, './best_model_{}.pth'.format(loss_string))
            print('Model saved!')

        if i == 3:
            optimizer.param_groups[0]['lr'] = 1e-5
            print('Decrease decoder learning rate to 1e-5!')
Beispiel #7
0
                         collate_fn=collate_fn)

train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=4,
                          collate_fn=collate_fn,
                          drop_last=True)

val_loader = DataLoader(dataset=val_dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=4,
                        collate_fn=collate_fn)

model = smp.FPN(encoder_name="efficientnet-b3", encoder_weights="imagenet", in_channels=3, classes=12)
model = model.to(device)

train(num_epochs, model, train_loader, val_loader, val_every, device, 'test.pt')

checkpoint = torch.load('saved/test.pt', map_location=device)
model = model.to(device)
model.load_state_dict(checkpoint)
pseudo_labeling(num_epochs, model, train_loader, val_loader, test_loader, device, val_every, 'test_sudo.pt')

submission = pd.read_csv('./submission/sample_submission.csv', index_col=None)
checkpoint = torch.load('saved/test_sudo.pt', map_location=device)
model = model.to(device)
model.load_state_dict(checkpoint)

file_names, preds = test(model, test_loader, device)
Beispiel #8
0
############################################
# Seg_model = smp.FPN('efficientnet-b5',encoder_weights=None,classes=5).cuda()
#
# weights_name = ['f0/Heng-efficient_f0_1076.pth',
#                 'f1/Heng-efficient_f1_10719.pth',
#                 'f2/Heng-efficient_f2_1085.pth',
#                 'f3/Heng-efficient_f3_1086.pth',
#                 'f4/Heng-efficient_f4_10818.pth',]
# seg_path = '../kaggle-steel-0911/weights/'
#
# load_weights(Seg_model,os.path.join(seg_path,'Heng-efficient',weights_name[FOLD]))#resnet34_f4_10611

# Seg_model = Model([seg4])
# for net in Seg_model.models:
#     net.eval()

Seg_model = smp.FPN('resnet34',
                    encoder_weights=None,
                    classes=4,
                    activation='sigmoid').cuda()
load_weights(
    Seg_model,
    "../kaggle-256crop-4cls-seg/weights/res34FPN/f0/res34FPN_f0_10126.pth")

print('load weights done!!')

model_trainer = Trainer(Seg_model, Cls_model)
dice_grid = model_trainer.start()
print('grid search done!!')
np.save('grid_search.npy', dice_grid)
Beispiel #9
0
seg_probs = []
cloud_mask = 0

with torch.no_grad():
    for fold in range(K):
        print('Fold{}:'.format(fold))
        if MODEL == 'UNET':
            seg_model = smp.Unet(encoder_name=ENCODER,
                                 classes=4,
                                 encoder_weights=None,
                                 ED_drop=ED_drop)
        elif MODEL == 'FPN':
            seg_model = smp.FPN(encoder_name=ENCODER,
                                classes=4,
                                encoder_weights=None,
                                ED_drop=ED_drop,
                                dropout=dropout)
        else:
            seg_model = None
        seg_model.load_state_dict(
            torch.load(os.path.join(save_dir, 'model_{}.pth'.format(fold))))
        seg_model.cuda()
        seg_model.eval()

        preprocessing_fn = smp.encoders.get_preprocessing_fn(
            ENCODER, encoder_weights)

        # validate
        seg_probs_fold = 0
        for tt in range(4):
Beispiel #10
0
            f"({train_parameters['downsample_mask_factor']})"
        )

    if not train_parameters["width_crop_size"] / train_parameters["downsample_mask_factor"]:
        raise ValueError(
            f"Width crop size ({train_parameters['width_crop_size']}) "
            f"should be divisible by the downsample_mask_factor"
            f"({train_parameters['downsample_mask_factor']})"
        )

    final_upsampling = None
else:
    final_upsampling = 4

model = smp.FPN(
    encoder_type, encoder_weights="imagenet", classes=num_classes, activation=None, final_upsampling=final_upsampling
)

pad_factor = 64
imread_library = "cv2"  # can be cv2 or jpeg4py

optimizer = RAdam(
    [
        {"params": model.decoder.parameters(), "lr": train_parameters["lr"]},
        # decrease lr for encoder in order not to permute
        # pre-trained weights with large gradients on training start
        {"params": model.encoder.parameters(), "lr": train_parameters["lr"] / 100},
    ],
    weight_decay=1e-2,
)
Beispiel #11
0
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

    #segmodel = smp.FPN('efficientnet-b5',encoder_weights=None,classes=5).cuda()
    # segmodel = smp.Unet('resnet34',encoder_weights=None,classes=4).cuda()
    # seg_path = '../kaggle-256crop-4cls-seg/weights/resnet34/'
    # fold_mdoel = [
    #     'f0/resnet34_f0_10422.pth',
    #     'f1/resnet34_f1_10512.pth',
    #     'f2/resnet34_f2_1060.pth',
    #     'f3/resnet34_f3_1063.pth',
    #     'f4/resnet34_f4_10611.pth'
    # ]

    segmodel = smp.FPN('efficientnet-b5', encoder_weights=None,
                       classes=4).cuda()  #se_resnet50
    seg_path = '../kaggle-256crop-4cls-seg/weights/effb5FPN/'
    fold_mdoel = [
        'f0/effb5FPN_f0_101312.pth', 'f1/effb5FPN_f1_10148.pth',
        'f2/effb5FPN_f2_10148.pth', 'f3/effb5FPN_f3_10153.pth',
        'f4/effb5FPN_f4_10164.pth'
    ]

    load_weights(segmodel, seg_path + fold_mdoel[FOLD])  #<<<<<<<<<  set fold
    segmodel.eval()

    clsmodel = Net().cuda()
    #print(model)
    model_trainer = Trainer(segmodel, clsmodel)
    model_trainer.start()
    # n_channels=3 for RGB images
    # n_classes is the number of probabilities you want to get per pixel
    #   - For 1 class and background, use n_classes=1
    #   - For 2 classes, use n_classes=1
    #   - For N > 2 classes, use n_classes=N
    if config.MODEL_SELECTION == 'og_unet':
        net = UNet(n_channels=config.NUM_CHANNELS,
                   n_classes=config.NUM_CLASSES,
                   bilinear=True)
    elif config.MODEL_SELECTION == 'smp_unet':
        net = smp.Unet(config.BACKBONE,
                       encoder_weights=config.ENCODER_WEIGHTS,
                       classes=config.NUM_CLASSES)
    elif config.MODEL_SELECTION == 'smp_fpn':
        net = smp.FPN(config.BACKBONE,
                      encoder_weights=config.ENCODER_WEIGHTS,
                      classes=config.NUM_CLASSES)
    elif config.MODEL_SELECTION == 'pytorch_deeplab':
        net = DeepLabv3_plus(nInputChannels=config.NUM_CHANNELS,
                             n_classes=config.NUM_CLASSES,
                             os=16,
                             pretrained=True,
                             _print=False)
    elif config.MODEL_SELECTION == 'og_deeplab':
        net = Res_Deeplab(num_classes=config.NUM_CLASSES)
    else:
        raise NotImplementedError

    upsample = nn.Upsample(size=(config.CROP_H, config.CROP_W),
                           mode='bilinear',
                           align_corners=True)
Beispiel #13
0
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=4,
                                               collate_fn=collate_fn)

    val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=4,
                                             collate_fn=collate_fn)

    # model 불러오기
    # 출력 레이블 수 정의 (classes = 12)
    model = smp.FPN(encoder_name=encoder_name,
                    classes=12,
                    encoder_weights="noisy-student",
                    activation=None)
    model = model.to(device)
    wandb.watch(model, log_freq=100)

    def train(num_epochs, model, data_loader, val_loader, criterion, optimizer,
              scheduler, saved_dir, val_every, device):
        print(f'Start training..fold{fold+1}')
        best_mIoU = 0
        best_epoch = 0
        for epoch in tqdm(range(num_epochs)):
            model.train()
            for step, (images, masks, _) in enumerate(data_loader):
                images = torch.stack(images)  # (batch, channel, height, width)
                masks = torch.stack(
                    masks).long()  # (batch, channel, height, width)
Beispiel #14
0
def main():
    args = parse_args()

    torch.backends.cudnn.benchmark = True

    args.distributed = False

    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    args.world_size = 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    # print(args.world_size, args.local_rank, args.distributed)

    cfg.merge_from_file(args.cfg)

    cfg.DIR = os.path.join(
        cfg.DIR,
        args.cfg.split('/')[-1].split('.')[0] +
        datetime.now().strftime('-%Y-%m-%d-%a-%H:%M:%S:%f'))

    # Output directory
    # if not os.path.isdir(cfg.DIR):
    if args.local_rank == 0:
        os.makedirs(cfg.DIR, exist_ok=True)
        os.makedirs(os.path.join(cfg.DIR, 'weight'), exist_ok=True)
        os.makedirs(os.path.join(cfg.DIR, 'history'), exist_ok=True)
        shutil.copy(args.cfg, cfg.DIR)

    if os.path.exists(os.path.join(cfg.DIR, 'log.txt')):
        os.remove(os.path.join(cfg.DIR, 'log.txt'))
    logger = setup_logger(distributed_rank=args.local_rank,
                          filename=os.path.join(cfg.DIR, 'log.txt'))
    logger.info("Loaded configuration file {}".format(args.cfg))
    logger.info("Running with config:\n{}".format(cfg))

    if cfg.MODEL.arch == 'deeplab':
        model = DeepLab(
            num_classes=cfg.DATASET.num_class,
            backbone=cfg.MODEL.backbone,  # resnet101
            output_stride=cfg.MODEL.os,
            ibn_mode=cfg.MODEL.ibn_mode,
            freeze_bn=False)
    elif cfg.MODEL.arch == 'smp-deeplab':
        model = smp.DeepLabV3(encoder_name='resnet101', classes=7)
    elif cfg.MODEL.arch == 'FPN':
        model = smp.FPN(encoder_name='resnet101', classes=7)
    elif cfg.MODEL.arch == 'Unet':
        model = smp.Unet(encoder_name='resnet101', classes=7)

    if cfg.DATASET.val_channels[0] == 'rgbn':
        convert_model(model, 4)

    model = apex.parallel.convert_syncbn_model(model)
    model = model.cuda()

    model = amp.initialize(model, opt_level="O1")

    if args.distributed:
        model = DDP(model, delay_allreduce=True)

    if cfg.VAL.checkpoint != "":
        if args.local_rank == 0:
            logger.info("Loading weight from {}".format(cfg.VAL.checkpoint))

        weight = torch.load(
            cfg.VAL.checkpoint,
            map_location=lambda storage, loc: storage.cuda(args.local_rank))

        if not args.distributed:
            weight = {k[7:]: v for k, v in weight.items()}

        model.load_state_dict(weight)

    dataset_val = AgriValDataset(cfg.DATASET.root_dataset,
                                 cfg.DATASET.list_val,
                                 cfg.DATASET,
                                 channels=cfg.DATASET.val_channels[0])

    val_sampler = None

    if args.distributed:
        val_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset_val, num_replicas=args.world_size, rank=args.local_rank)

    loader_val = torch.utils.data.DataLoader(
        dataset_val,
        batch_size=cfg.VAL.batch_size_per_gpu,
        shuffle=False,  # we do not use this param
        num_workers=cfg.VAL.batch_size_per_gpu,
        drop_last=True,
        pin_memory=True,
        sampler=val_sampler)

    cfg.VAL.epoch_iters = len(loader_val)

    cfg.VAL.log_fmt = 'Mean IoU: {:.4f}\n'

    logger.info("World Size: {}".format(args.world_size))

    logger.info("VAL.epoch_iters: {}".format(cfg.VAL.epoch_iters))
    logger.info("VAL.sum_bs: {}".format(cfg.VAL.batch_size_per_gpu *
                                        args.world_size))

    os.makedirs(cfg.VAL.visualized_pred, exist_ok=True)
    os.makedirs(cfg.VAL.visualized_label, exist_ok=True)

    val(loader_val, model, args, logger)
Beispiel #15
0
        in_channels (int): bnumber of input channels
        num_classes (int): number of classes + 1 for background        
        activation (srt): output activation function, default is None
    """
    model = smp.Unet(
        encoder_name=encoder,
        encoder_weights='imagenet',
        classes=num_classes,
        in_channels=in_channels,
        activation=activation,
    )

    return model


model = smp.FPN('resnet34', in_channels=1)


def get_fpn(encoder: str = 'resnet50',
            in_channels: int = 4,
            num_classes: int = 1,
            activation: str = None):
    """
    Get FPN model from qubvel libruary
    create segmentation model with pretrained encoder
    Args: 
        encoder (str): encoder basenet 'resnext101_32x8d', 'resnet18', 'resnet50', 'resnet101'...
        in_channels (int): bnumber of input channels
        num_classes (int): number of classes + 1 for background        
        activation (srt): output activation function, default is None
    """
Beispiel #16
0
def get_model(config):
    """
    """
    arch = config.MODEL.ARCHITECTURE
    backbone = config.MODEL.BACKBONE
    encoder_weights = config.MODEL.ENCODER_PRETRAINED_FROM
    in_channels = config.MODEL.IN_CHANNELS
    n_classes = len(config.INPUT.CLASSES)
    activation = config.MODEL.ACTIVATION

    # unet specific
    decoder_attention_type = 'scse' if config.MODEL.UNET_ENABLE_DECODER_SCSE else None

    if arch == 'unet':
        model = smp.Unet(
            encoder_name=backbone,
            encoder_weights=encoder_weights,
            decoder_channels=config.MODEL.UNET_DECODER_CHANNELS,
            decoder_attention_type=decoder_attention_type,
            in_channels=in_channels,
            classes=n_classes,
            activation=activation
        )
    elif arch == 'fpn':
        model = smp.FPN(
            encoder_name=backbone,
            encoder_weights=encoder_weights,
            decoder_dropout=config.MODEL.FPN_DECODER_DROPOUT,
            in_channels=in_channels,
            classes=n_classes,
            activation=activation
        )
    elif arch == 'pan':
        model = smp.PAN(
            encoder_name=backbone,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=n_classes,
            activation=activation
        )
    elif arch == 'pspnet':
        model = smp.PSPNet(
            encoder_name=backbone,
            encoder_weights=encoder_weights,
            psp_dropout=config.MODEL.PSPNET_DROPOUT,
            in_channels=in_channels,
            classes=n_classes,
            activation=activation
        )
    elif arch == 'deeplabv3':
        model = smp.DeepLabV3(
            encoder_name=backbone,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=n_classes,
            activation=activation
        )
    elif arch == 'linknet':
        model = smp.Linknet(
            encoder_name=backbone,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=n_classes,
            activation=activation
        )
    else:
        raise ValueError()

    model = torch.nn.DataParallel(model)

    if config.MODEL.WEIGHT and config.MODEL.WEIGHT != 'none':
        # load weight from file
        model.load_state_dict(
            torch.load(
                config.MODEL.WEIGHT,
                map_location=torch.device('cpu')
            )
        )

    model = model.to(config.MODEL.DEVICE)
    return model
# > You can read more about them in [our blog post](https://github.com/catalyst-team/catalyst-info#catalyst-info-1-segmentation-models).
#
# But for now let's take the model from [segmentation_models.pytorch](https://github.com/qubvel/segmentation_models.pytorch) (SMP for short). The same segmentation architectures have been implemented in this repository, but there are many more pre-trained encoders.
#
# [![Segmentation Models logo](https://raw.githubusercontent.com/qubvel/segmentation_models.pytorch/master/pics/logo-small-w300.png)](https://github.com/qubvel/segmentation_models.pytorch)

# In[22]:

aux_params = dict(
    dropout=0.5,  # dropout ratio, default is None
    classes=6,  # define number of output labels
)

# We will use Feature Pyramid Network with pre-trained ResNeXt50 backbone
model = smp.FPN(encoder_name="resnext50_32x4d",
                classes=6,
                aux_params=aux_params)

# ### Model training
#
# We will optimize loss as the sum of IoU, Dice and BCE, specifically this function: $IoU + Dice + 0.8*BCE$.
#

# In[23]:

# we have multiple criterions
criterion = {
    "dice": DiceLoss(),
    "iou": IoULoss(),
    #"bce": nn.BCEWithLogitsLoss()
    "ce": nn.CrossEntropyLoss()
def cutHorizontal(x):
    return torch.cat(list(x[...,i*400:(i+1)*400] for i in range(4)), dim=0)

def to416(x):
    size = list(x.size())
    size[-1]=416
    new = torch.zeros(size)
    new[...,8:-8] = x
    return new


# In[9]:


model = smp.FPN(MODEL_NAME, encoder_weights="imagenet", classes=4, activation=None)


# ### Training and Validation

# In[10]:


class Trainer(object):
    '''This class takes care of training and validation of our model'''
    def __init__(self, model):
        self.num_workers = 6
        self.batch_size = {"train": BS, "val": BS}
        self.accumulation_steps = 32 // self.batch_size['train']
        self.lr = LR
        self.num_epochs = EPOCHS
Beispiel #19
0
def get_model(num_classes, model_name):
    if model_name == "UNet":
        print("using UNet")
        model = smp.Unet(encoder_name='resnet50',
                         classes=num_classes,
                         activation='softmax')
        if args.num_channels > 3:
            weight = model.encoder.conv1.weight.clone()
            model.encoder.conv1 = torch.nn.Conv2d(4,
                                                  64,
                                                  kernel_size=7,
                                                  stride=2,
                                                  padding=3,
                                                  bias=False)
            with torch.no_grad():
                print("using 4c")
                model.encoder.conv1.weight[:, :3] = weight
                model.encoder.conv1.weight[:,
                                           3] = model.encoder.conv1.weight[:,
                                                                           0]
        return model
    elif model_name == "PSPNet":
        print("using PSPNet")
        model = smp.PSPNet(encoder_name="resnet50",
                           classes=num_classes,
                           activation='softmax')
        if args.num_channels > 3:
            weight = model.encoder.conv1.weight.clone()
            model.encoder.conv1 = torch.nn.Conv2d(4,
                                                  64,
                                                  kernel_size=7,
                                                  stride=2,
                                                  padding=3,
                                                  bias=False)
            with torch.no_grad():
                print("using 4c")
                model.encoder.conv1.weight[:, :3] = weight
                model.encoder.conv1.weight[:,
                                           3] = model.encoder.conv1.weight[:,
                                                                           0]
        return model
    elif model_name == "FPN":
        print("using FPN")
        model = smp.FPN(encoder_name='resnet50', classes=num_classes)
        if args.num_channels > 3:
            weight = model.encoder.conv1.weight.clone()
            model.encoder.conv1 = torch.nn.Conv2d(4,
                                                  64,
                                                  kernel_size=7,
                                                  stride=2,
                                                  padding=3,
                                                  bias=False)
            with torch.no_grad():
                print("using 4c")
                model.encoder.conv1.weight[:, :3] = weight
                model.encoder.conv1.weight[:,
                                           3] = model.encoder.conv1.weight[:,
                                                                           0]
        return model
    elif model_name == "AlbuNet":
        print("using AlbuNet")
        model = AlbuNet(pretrained=True, num_classes=num_classes)
        return model
    elif model_name == "YpUnet":
        print("using YpUnet")
        model = UNet(pretrained=True, num_classes=num_classes)
        return model
    else:
        print("error in model")
        return None
Beispiel #20
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--encoder', type=str, default='efficientnet-b0')
    parser.add_argument('--model', type=str, default='unet')
    parser.add_argument('--loc', type=str)
    parser.add_argument('--data_folder', type=str, default='../input/')
    parser.add_argument('--batch_size', type=int, default=2)
    parser.add_argument('--optimize', type=bool, default=False)
    parser.add_argument('--tta_pre', type=bool, default=False)
    parser.add_argument('--tta_post', type=bool, default=False)
    parser.add_argument('--merge', type=str, default='mean')
    parser.add_argument('--min_size', type=int, default=10000)
    parser.add_argument('--thresh', type=float, default=0.5)
    parser.add_argument('--name', type=str)

    args = parser.parse_args()
    encoder = args.encoder
    model = args.model
    loc = args.loc
    data_folder = args.data_folder
    bs = args.batch_size
    optimize = args.optimize
    tta_pre = args.tta_pre
    tta_post = args.tta_post
    merge = args.merge
    min_size = args.min_size
    thresh = args.thresh
    name = args.name

    if model == 'unet':
        model = smp.Unet(encoder_name=encoder,
                         encoder_weights='imagenet',
                         classes=4,
                         activation=None)
    if model == 'fpn':
        model = smp.FPN(
            encoder_name=encoder,
            encoder_weights='imagenet',
            classes=4,
            activation=None,
        )
    if model == 'pspnet':
        model = smp.PSPNet(
            encoder_name=encoder,
            encoder_weights='imagenet',
            classes=4,
            activation=None,
        )
    if model == 'linknet':
        model = smp.Linknet(
            encoder_name=encoder,
            encoder_weights='imagenet',
            classes=4,
            activation=None,
        )

    preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder, 'imagenet')

    test_df = get_dataset(train=False)
    test_df = prepare_dataset(test_df)
    test_ids = test_df['Image_Label'].apply(
        lambda x: x.split('_')[0]).drop_duplicates().values
    test_dataset = CloudDataset(
        df=test_df,
        datatype='test',
        img_ids=test_ids,
        transforms=valid1(),
        preprocessing=get_preprocessing(preprocessing_fn))
    test_loader = DataLoader(test_dataset, batch_size=bs, shuffle=False)

    val_df = get_dataset(train=True)
    val_df = prepare_dataset(val_df)
    _, val_ids = get_train_test(val_df)
    valid_dataset = CloudDataset(
        df=val_df,
        datatype='train',
        img_ids=val_ids,
        transforms=valid1(),
        preprocessing=get_preprocessing(preprocessing_fn))
    valid_loader = DataLoader(valid_dataset, batch_size=bs, shuffle=False)

    model.load_state_dict(torch.load(loc)['model_state_dict'])

    class_params = {
        0: (thresh, min_size),
        1: (thresh, min_size),
        2: (thresh, min_size),
        3: (thresh, min_size)
    }

    if optimize:
        print("OPTIMIZING")
        print(tta_pre)
        if tta_pre:
            opt_model = tta.SegmentationTTAWrapper(
                model,
                tta.Compose([
                    tta.HorizontalFlip(),
                    tta.VerticalFlip(),
                    tta.Rotate90(angles=[0, 180])
                ]),
                merge_mode=merge)
        else:
            opt_model = model
        tta_runner = SupervisedRunner()
        print("INFERRING ON VALID")
        tta_runner.infer(
            model=opt_model,
            loaders={'valid': valid_loader},
            callbacks=[InferCallback()],
            verbose=True,
        )

        valid_masks = []
        probabilities = np.zeros((4 * len(valid_dataset), 350, 525))
        for i, (batch, output) in enumerate(
                tqdm(
                    zip(valid_dataset,
                        tta_runner.callbacks[0].predictions["logits"]))):
            _, mask = batch
            for m in mask:
                if m.shape != (350, 525):
                    m = cv2.resize(m,
                                   dsize=(525, 350),
                                   interpolation=cv2.INTER_LINEAR)
                valid_masks.append(m)

            for j, probability in enumerate(output):
                if probability.shape != (350, 525):
                    probability = cv2.resize(probability,
                                             dsize=(525, 350),
                                             interpolation=cv2.INTER_LINEAR)
                probabilities[(i * 4) + j, :, :] = probability

        print("RUNNING GRID SEARCH")
        for class_id in range(4):
            print(class_id)
            attempts = []
            for t in range(30, 70, 5):
                t /= 100
                for ms in [7500, 10000, 12500, 15000, 175000]:
                    masks = []
                    for i in range(class_id, len(probabilities), 4):
                        probability = probabilities[i]
                        predict, num_predict = post_process(
                            sigmoid(probability), t, ms)
                        masks.append(predict)

                    d = []
                    for i, j in zip(masks, valid_masks[class_id::4]):
                        if (i.sum() == 0) & (j.sum() == 0):
                            d.append(1)
                        else:
                            d.append(dice(i, j))

                    attempts.append((t, ms, np.mean(d)))

            attempts_df = pd.DataFrame(attempts,
                                       columns=['threshold', 'size', 'dice'])

            attempts_df = attempts_df.sort_values('dice', ascending=False)
            print(attempts_df.head())
            best_threshold = attempts_df['threshold'].values[0]
            best_size = attempts_df['size'].values[0]

            class_params[class_id] = (best_threshold, best_size)

        del opt_model
        del tta_runner
        del valid_masks
        del probabilities
    gc.collect()

    if tta_post:
        model = tta.SegmentationTTAWrapper(model,
                                           tta.Compose([
                                               tta.HorizontalFlip(),
                                               tta.VerticalFlip(),
                                               tta.Rotate90(angles=[0, 180])
                                           ]),
                                           merge_mode=merge)
    else:
        model = model
    print(tta_post)

    runner = SupervisedRunner()
    runner.infer(
        model=model,
        loaders={'test': test_loader},
        callbacks=[InferCallback()],
        verbose=True,
    )

    encoded_pixels = []
    image_id = 0

    for i, image in enumerate(tqdm(runner.callbacks[0].predictions['logits'])):
        for i, prob in enumerate(image):
            if prob.shape != (350, 525):
                prob = cv2.resize(prob,
                                  dsize=(525, 350),
                                  interpolation=cv2.INTER_LINEAR)
            predict, num_predict = post_process(sigmoid(prob),
                                                class_params[image_id % 4][0],
                                                class_params[image_id % 4][1])
            if num_predict == 0:
                encoded_pixels.append('')
            else:
                r = mask2rle(predict)
                encoded_pixels.append(r)
            image_id += 1

    test_df['EncodedPixels'] = encoded_pixels
    test_df.to_csv(name, columns=['Image_Label', 'EncodedPixels'], index=False)
def main():

    fold_path = args.fold_path
    fold_num = args.fold_num
    model_name = args.model_name
    train_csv = args.train_csv
    sub_csv = args.sub_csv
    encoder = args.encoder
    num_workers = args.num_workers
    batch_size = args.batch_size
    num_epochs = args.num_epochs
    learn_late = args.learn_late
    attention_type = args.attention_type

    train = pd.read_csv(train_csv)
    sub = pd.read_csv(sub_csv)

    train['label'] = train['Image_Label'].apply(lambda x: x.split('_')[-1])
    train['im_id'] = train['Image_Label'].apply(
        lambda x: x.replace('_' + x.split('_')[-1], ''))

    sub['label'] = sub['Image_Label'].apply(lambda x: x.split('_')[-1])
    sub['im_id'] = sub['Image_Label'].apply(
        lambda x: x.replace('_' + x.split('_')[-1], ''))

    train_fold = pd.read_csv(f'{fold_path}/train_file_fold_{fold_num}.csv')
    val_fold = pd.read_csv(f'{fold_path}/valid_file_fold_{fold_num}.csv')

    train_ids = np.array(train_fold.file_name)
    valid_ids = np.array(val_fold.file_name)

    encoder_weights = 'imagenet'
    attention_type = None if attention_type == 'None' else attention_type

    if model_name == 'Unet':
        model = smp.Unet(
            encoder_name=encoder,
            encoder_weights=encoder_weights,
            classes=4,
            activation='softmax',
            attention_type=attention_type,
        )
    if model_name == 'Linknet':
        model = smp.Linknet(
            encoder_name=encoder,
            encoder_weights=encoder_weights,
            classes=4,
            activation='softmax',
        )
    if model_name == 'FPN':
        model = smp.FPN(
            encoder_name=encoder,
            encoder_weights=encoder_weights,
            classes=4,
            activation='softmax',
        )
    if model_name == 'ORG':
        model = Linknet_resnet18_ASPP()

    preprocessing_fn = smp.encoders.get_preprocessing_fn(
        encoder, encoder_weights)

    train_dataset = CloudDataset(
        df=train,
        datatype='train',
        img_ids=train_ids,
        transforms=get_training_augmentation(),
        preprocessing=get_preprocessing(preprocessing_fn))

    valid_dataset = CloudDataset(
        df=train,
        datatype='valid',
        img_ids=valid_ids,
        transforms=get_validation_augmentation(),
        preprocessing=get_preprocessing(preprocessing_fn))

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        drop_last=True,
        pin_memory=True,
    )
    valid_loader = DataLoader(valid_dataset,
                              batch_size=batch_size,
                              shuffle=False,
                              num_workers=num_workers)

    loaders = {"train": train_loader, "valid": valid_loader}

    logdir = f"./log/logs_{model_name}_fold_{fold_num}_{encoder}/segmentation"

    #for batch_idx, (data, target) in enumerate(loaders['train']):
    #    print(batch_idx)

    print(logdir)

    if model_name == 'ORG':
        optimizer = NAdam([
            {
                'params': model.parameters(),
                'lr': learn_late
            },
        ])
    else:
        optimizer = NAdam([
            {
                'params': model.decoder.parameters(),
                'lr': learn_late
            },
            {
                'params': model.encoder.parameters(),
                'lr': learn_late
            },
        ])

    scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=0)
    criterion = smp.utils.losses.BCEDiceLoss()

    runner = SupervisedRunner()

    runner.train(model=model,
                 criterion=criterion,
                 optimizer=optimizer,
                 scheduler=scheduler,
                 loaders=loaders,
                 callbacks=[
                     DiceCallback(),
                     EarlyStoppingCallback(patience=5, min_delta=1e-7)
                 ],
                 logdir=logdir,
                 num_epochs=num_epochs,
                 verbose=1)
Beispiel #22
0
def generate_class_params(i_dont_know_how_to_return_values_without_map):

    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
    valid_dataset = CloudDataset(df=train, datatype='valid', img_ids=valid_ids, transforms=get_validation_augmentation(), preprocessing=get_preprocessing(preprocessing_fn))
    valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=0)

    model = smp.FPN(
        encoder_name=ENCODER,
        encoder_weights=ENCODER_WEIGHTS,
        classes=4,
        activation=ACTIVATION,
    )

    runner = SupervisedRunner()
    # Generate validation predictions
    encoded_pixels = []
    loaders = {"infer": valid_loader}
    runner.infer(
        model=model,
        loaders=loaders,
        callbacks=[
            CheckpointCallback(
                resume=f"{logdir}/checkpoints/best.pth"),
            InferCallback()
        ],
    )

    valid_masks = []
    probabilities = np.zeros((2220, 350, 525))
    for i, (batch, output) in enumerate(tqdm.tqdm(zip(
            valid_dataset, runner.callbacks[0].predictions["logits"]))):
        image, mask = batch
        for m in mask:
            if m.shape != (350, 525):
                m = cv2.resize(m, dsize=(525, 350), interpolation=cv2.INTER_LINEAR)
            valid_masks.append(m)

        for j, probability in enumerate(output):
            if probability.shape != (350, 525):
                probability = cv2.resize(probability, dsize=(525, 350), interpolation=cv2.INTER_LINEAR)
            probabilities[i * 4 + j, :, :] = probability


    class_params = {}
    for class_id in range(4):
        print(class_id)
        attempts = []
        for t in range(30, 100, 5):
            t /= 100
            for ms in [1200, 5000, 10000]:
                masks = []
                for i in range(class_id, len(probabilities), 4):
                    probability = probabilities[i]
                    predict, num_predict = post_process(sigmoid(probability), t, ms)
                    masks.append(predict)

                d = []
                for i, j in zip(masks, valid_masks[class_id::4]):
                    if (i.sum() == 0) & (j.sum() == 0):
                        d.append(1)
                    else:
                        d.append(dice(i, j))

                attempts.append((t, ms, np.mean(d)))

        attempts_df = pd.DataFrame(attempts, columns=['threshold', 'size', 'dice'])

        attempts_df = attempts_df.sort_values('dice', ascending=False)
        print(attempts_df.head())
        best_threshold = attempts_df['threshold'].values[0]
        best_size = attempts_df['size'].values[0]

        class_params[class_id] = (best_threshold, best_size)

    return class_params
Beispiel #23
0
#     visualize(image=image, mask=mask)

import torch
import numpy as np
import segmentation_models_pytorch as smp

ENCODER = 'efficientnet-b0'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = CarlaDataset.CLASSES
ACTIVATION = 'softmax2d'
DEVICE = 'cuda'

# create segmentation model with pretrained encoder
model = smp.FPN(
    encoder_name=ENCODER,
    encoder_weights=ENCODER_WEIGHTS,
    classes=len(CLASSES),
    activation=ACTIVATION,
)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

train_dataset = CarlaDataset(
    x_train_dir,
    y_train_dir,
    augmentation=get_training_augmentation(),
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CarlaDataset.CLASSES,
)

valid_dataset = CarlaDataset(
    x_valid_dir,
Beispiel #24
0
def validation(valid_ids, num_split, encoder, decoder):
    """
    模型验证,并选择后处理参数
    """
    train = "./data/Clouds_Classify/train.csv"

    # Data overview
    train = pd.read_csv(open(train))
    train.head()

    train['label'] = train['Image_Label'].apply(lambda x: x.split('_')[1])
    train['im_id'] = train['Image_Label'].apply(lambda x: x.split('_')[0])

    ENCODER = encoder
    ENCODER_WEIGHTS = 'imagenet'
    if decoder == 'unet':
        model = smp.Unet(
            encoder_name=ENCODER,
            encoder_weights=ENCODER_WEIGHTS,
            classes=4,
            activation=None,
        )
    else:
        model = smp.FPN(
            encoder_name=ENCODER,
            encoder_weights=ENCODER_WEIGHTS,
            classes=4,
            activation=None,
        )
    preprocessing_fn = smp.encoders.get_preprocessing_fn(
        ENCODER, ENCODER_WEIGHTS)

    num_workers = 4
    valid_bs = 32
    valid_dataset = CloudDataset(
        df=train,
        transforms=get_validation_augmentation(),
        datatype='valid',
        img_ids=valid_ids,
        preprocessing=get_preprocessing(preprocessing_fn))
    valid_loader = DataLoader(valid_dataset,
                              batch_size=valid_bs,
                              shuffle=False,
                              num_workers=num_workers)

    loaders = {"valid": valid_loader}
    logdir = "./logs/log_{}_{}/log_{}".format(encoder, decoder, num_split)

    valid_masks = []
    probabilities = np.zeros((len(valid_ids) * 4, 350, 525))

    ############### TTA预测 ####################
    use_TTA = True
    checkpoint_path = logdir + '/checkpoints/best.pth'
    runner_out = []
    model.load_state_dict(torch.load(checkpoint_path)['model_state_dict'])

    if use_TTA:
        transforms = tta.Compose([
            tta.HorizontalFlip(),
            tta.VerticalFlip(),
            tta.Scale(scales=[5 / 6, 1, 7 / 6]),
        ])
        tta_model = tta.SegmentationTTAWrapper(model,
                                               transforms,
                                               merge_mode='mean')
    else:
        tta_model = model

    tta_model = tta_model.cuda()
    tta_model.eval()

    with torch.no_grad():
        for i, data in enumerate(tqdm.tqdm(loaders['valid'])):
            img, _ = data
            img = img.cuda()
            batch_preds = tta_model(img).cpu().numpy()
            runner_out.extend(batch_preds)
    runner_out = np.array(runner_out)
    ######################END##########################

    for i, ((_, mask),
            output) in enumerate(tqdm.tqdm(zip(valid_dataset, runner_out))):
        for m in mask:
            if m.shape != (350, 525):
                m = cv2.resize(m,
                               dsize=(525, 350),
                               interpolation=cv2.INTER_LINEAR)
            valid_masks.append(m)

        for j, probability in enumerate(output):
            if probability.shape != (350, 525):
                probability = cv2.resize(probability,
                                         dsize=(525, 350),
                                         interpolation=cv2.INTER_LINEAR)
            probabilities[i * 4 + j, :, :] = probability

    # Find optimal values
    print('searching for optimal param...')
    params_0 = [[35, 76], [12000, 19001]]
    params_1 = [[35, 76], [12000, 19001]]
    params_2 = [[35, 76], [12000, 19001]]
    params_3 = [[35, 76], [8000, 15001]]
    param = [params_0, params_1, params_2, params_3]

    for class_id in range(4):
        par = param[class_id]
        attempts = []
        for t in range(par[0][0], par[0][1], 5):
            t /= 100
            for ms in range(par[1][0], par[1][1], 2000):
                masks = []
                print('==> searching [class_id:%d threshold:%.3f ms:%d]' %
                      (class_id, t, ms))
                for i in tqdm.tqdm(range(class_id, len(probabilities), 4)):
                    probability = probabilities[i]
                    predict, _ = post_process(sigmoid(probability), t, ms)
                    masks.append(predict)

                d = []
                for i, j in zip(masks, valid_masks[class_id::4]):
                    if (i.sum() == 0) & (j.sum() == 0):
                        d.append(1)
                    else:
                        d.append(dice(i, j))

                attempts.append((t, ms, np.mean(d)))

        attempts_df = pd.DataFrame(attempts,
                                   columns=['threshold', 'size', 'dice'])

        attempts_df = attempts_df.sort_values('dice', ascending=False)
        attempts_df.to_csv(
            './params/{}_{}_par/params_{}/tta_params_{}.csv'.format(
                encoder, decoder, num_split, class_id),
            columns=['threshold', 'size', 'dice'],
            index=False)
Beispiel #25
0
def train(args):
    logger = Logger(ckpt_path=args.ckpt_path, tsbd_path=args.vis_path)
    args.logger = logger

    train_transforms = image_train()
    test_transforms = image_test()

    models = []
    optimizers = []
    train_loaders = []
    test_loaders = []
    gpus = args.gpu_ids.split(',')
    for i in range(args.type):
        model = smp.FPN(args.encoder, in_channels=4, classes=3).cuda()
        model = nn.DataParallel(model, device_ids=[int(_) for _ in gpus])

        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=0.9,
                              weight_decay=0.0005,
                              nesterov=True)

        models.append(model)
        optimizers.append(optimizer)

        train_list = os.path.join(args.data_list_path,
                                  'label_' + str(i) + '_train.txt')
        test_list = os.path.join(args.data_list_path,
                                 'label_' + str(i) + '_test.txt')

        train_dset = ImageList(open(train_list).readlines(),
                               datadir=args.data_path,
                               transform=train_transforms)
        test_dset = ImageList(open(test_list).readlines(),
                              datadir=args.data_path,
                              transform=test_transforms)

        train_loader = DataLoader(train_dset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=4,
                                  drop_last=False)
        test_loader = DataLoader(test_dset,
                                 batch_size=args.test_batch,
                                 shuffle=False,
                                 num_workers=4,
                                 drop_last=False)

        train_loaders.append(train_loader)
        test_loaders.append(test_loader)

    total_epochs = args.total_epochs
    total_progress_bar = tqdm.tqdm(desc='Train iter', total=total_epochs)

    for epoch in range(total_epochs):
        total_progress_bar.update(1)
        ## train $type$ model
        for tp in range(args.type):
            train_loader = train_loaders[tp]
            test_loader = test_loaders[tp]
            model = models[tp]
            optimizer = optimizers[tp]

            model.train()

            temp_progress_bar = tqdm.tqdm(desc='Train iter for type ' +
                                          str(tp),
                                          total=len(train_loader))
            for it, (imgs, pixels, obj_types,
                     rewards) in enumerate(train_loader):
                temp_progress_bar.update(1)

                imgs = imgs.cuda()
                pixels = pixels.cuda()
                rewards = rewards.cuda()
                masks = model(imgs)

                ## TODO: better format
                loss = CE_pixel(masks, pixels, rewards)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                ## vis
                logger.add_scalar('loss_' + str(tp),
                                  loss.item() / args.batch_size)
                logger.step(1)

                logger.log('Loss {:d}: {:.3f}'.format(tp, loss.item()))

        ## test
        if epoch % args.test_interval == 1:
            accs = test(args, models, test_loaders)
            state = {}
            for i in range(args.type):
                state['FPN_' + str(i)] = models[i].state_dict()
                logger.add_scalar_print('acc_' + str(i), accs[i])

                logger.log('Acc {:d}: {:.3f}'.format(i, accs[i]))

            logger.save_ckpt_iter(state=state, iter=epoch)

    ## at last
    accs = test(args, models, test_loaders)
    state = {}
    for i in range(args.type):
        state['FPN_' + str(i)] = models[i].state_dict()
        logger.add_scalar_print('acc_' + str(i), accs[i])

        logger.log('Acc {:d}: {:.3f}'.format(i, accs[i]))

    logger.save_ckpt_iter(state=state, iter=total_epochs)
Beispiel #26
0
def testing(num_split, class_params, encoder, decoder):
    """
    测试推理
    """
    import gc
    torch.cuda.empty_cache()
    gc.collect()

    sub = "./data/Clouds_Classify/sample_submission.csv"
    sub = pd.read_csv(open(sub))
    sub.head()

    sub['label'] = sub['Image_Label'].apply(lambda x: x.split('_')[1])
    sub['im_id'] = sub['Image_Label'].apply(lambda x: x.split('_')[0])

    preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder, 'imagenet')
    if decoder == 'unet':
        model = smp.Unet(
            encoder_name=encoder,
            encoder_weights='imagenet',
            classes=4,
            activation=None,
        )
    else:
        model = smp.FPN(
            encoder_name=encoder,
            encoder_weights='imagenet',
            classes=4,
            activation=None,
        )
    test_ids = [id for id in os.listdir(test_imgs_folder)]

    test_dataset = CloudDataset(
        df=sub,
        transforms=get_validation_augmentation(),
        datatype='test',
        img_ids=test_ids,
        preprocessing=get_preprocessing(preprocessing_fn))
    test_loader = DataLoader(test_dataset,
                             batch_size=1,
                             shuffle=False,
                             num_workers=2)

    loaders = {"test": test_loader}
    logdir = "./logs/log_{}_{}/log_{}".format(encoder, decoder, num_split)

    encoded_pixels = []

    ###############使用pytorch TTA预测####################
    use_TTA = True
    checkpoint_path = logdir + '/checkpoints/best.pth'
    runner_out = []
    model.load_state_dict(torch.load(checkpoint_path)['model_state_dict'])
    #使用tta预测
    if use_TTA:
        transforms = tta.Compose([
            tta.HorizontalFlip(),
            tta.VerticalFlip(),
            tta.Scale(scales=[5 / 6, 1, 7 / 6]),
        ])
        tta_model = tta.SegmentationTTAWrapper(model,
                                               transforms,
                                               merge_mode='mean')
    else:
        tta_model = model

    tta_model = tta_model.cuda()
    tta_model.eval()

    with torch.no_grad():
        for i, data in enumerate(tqdm.tqdm(loaders['test'])):
            img, _ = data
            img = img.cuda()
            batch_preds = tta_model(img).cpu().numpy()
            runner_out.extend(batch_preds)
    runner_out = np.array(runner_out)

    for i, output in tqdm.tqdm(enumerate(runner_out)):
        for j, probability in enumerate(output):
            if probability.shape != (350, 525):
                probability = cv2.resize(probability,
                                         dsize=(525, 350),
                                         interpolation=cv2.INTER_LINEAR)
            logit = sigmoid(probability)
            predict, num_predict = post_process(logit, class_params[j][0],
                                                class_params[j][1])

            if num_predict == 0:
                encoded_pixels.append('')
            else:
                r = mask2rle(predict)
                encoded_pixels.append(r)

    sub['EncodedPixels'] = encoded_pixels
    sub.to_csv('./sub/{}_{}/tta_submission_{}.csv'.format(
        encoder, decoder, num_split),
               columns=['Image_Label', 'EncodedPixels'],
               index=False)
Beispiel #27
0
 def __init__(self):
     super().__init__()
     self.model = smp.FPN('dpn131',
                          encoder_weights='imagenet',
                          classes=4,
                          activation=None)
Beispiel #28
0
def training(train_ids, valid_ids, num_split, encoder, decoder):
    """
    模型训练
    """
    train = "./data/Clouds_Classify/train.csv"

    # Data overview
    train = pd.read_csv(open(train))
    train.head()

    train['label'] = train['Image_Label'].apply(lambda x: x.split('_')[1])
    train['im_id'] = train['Image_Label'].apply(lambda x: x.split('_')[0])

    ENCODER = encoder
    ENCODER_WEIGHTS = 'imagenet'

    if decoder == 'unet':
        model = smp.Unet(
            encoder_name=ENCODER,
            encoder_weights=ENCODER_WEIGHTS,
            classes=4,
            activation=None,
        )
    else:
        model = smp.FPN(
            encoder_name=ENCODER,
            encoder_weights=ENCODER_WEIGHTS,
            classes=4,
            activation=None,
        )
    preprocessing_fn = smp.encoders.get_preprocessing_fn(
        ENCODER, ENCODER_WEIGHTS)

    num_workers = 4
    bs = 12
    train_dataset = CloudDataset(
        df=train,
        transforms=get_training_augmentation(),
        datatype='train',
        img_ids=train_ids,
        preprocessing=get_preprocessing(preprocessing_fn))
    valid_dataset = CloudDataset(
        df=train,
        transforms=get_validation_augmentation(),
        datatype='valid',
        img_ids=valid_ids,
        preprocessing=get_preprocessing(preprocessing_fn))

    train_loader = DataLoader(train_dataset,
                              batch_size=bs,
                              shuffle=True,
                              num_workers=num_workers)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=bs,
                              shuffle=False,
                              num_workers=num_workers)

    loaders = {"train": train_loader, "valid": valid_loader}

    num_epochs = 50
    logdir = "./logs/log_{}_{}/log_{}".format(encoder, decoder, num_split)

    # model, criterion, optimizer
    optimizer = torch.optim.Adam([
        {
            'params': model.decoder.parameters(),
            'lr': 1e-2
        },
        {
            'params': model.encoder.parameters(),
            'lr': 1e-3
        },
    ])
    scheduler = ReduceLROnPlateau(optimizer, factor=0.35, patience=4)
    criterion = smp.utils.losses.BCEDiceLoss(eps=1.)
    runner = SupervisedRunner()

    runner.train(model=model,
                 criterion=criterion,
                 optimizer=optimizer,
                 scheduler=scheduler,
                 loaders=loaders,
                 callbacks=[DiceCallback()],
                 logdir=logdir,
                 num_epochs=num_epochs,
                 verbose=True)

    # Exploring predictions
    loaders = {"infer": valid_loader}
    runner.infer(
        model=model,
        loaders=loaders,
        callbacks=[
            CheckpointCallback(resume=f"{logdir}/checkpoints/best.pth"),
            InferCallback()
        ],
    )
    if not train_parameters["width_crop_size"] / train_parameters[
            "downsample_mask_factor"]:
        raise ValueError(
            f"Width crop size ({train_parameters['width_crop_size']}) "
            f"should be divisible by the downsample_mask_factor"
            f"({train_parameters['downsample_mask_factor']})")

    final_upsampling = None
else:
    final_upsampling = 4

model = smp.FPN(
    encoder_type,
    encoder_weights="imagenet",
    classes=num_classes,
    activation=None,
    final_upsampling=final_upsampling,
    dropout=0.5,
    decoder_merge_policy="cat",
)

pad_factor = 64
imread_library = "cv2"  # can be cv2 or jpeg4py

optimizer = RAdam(
    [
        {
            "params": model.decoder.parameters(),
            "lr": train_parameters["lr"]
        },
        # decrease lr for encoder in order not to permute
Beispiel #30
0
def generate_test_preds(ensemble_info):

    test_preds = np.zeros((len(sub), 350, 525), dtype=np.float32)
    num_models = len(ensemble_info)

    for model_info in ensemble_info:

        class_params = model_info['class_params']
        encoder = model_info['encoder']
        model_type = model_info['model_type']
        logdir = model_info['logdir']

        preprocessing_fn = smp.encoders.get_preprocessing_fn(
            encoder, ENCODER_WEIGHTS)

        model = None
        if model_type == 'unet':
            model = smp.Unet(
                encoder_name=encoder,
                encoder_weights=ENCODER_WEIGHTS,
                classes=4,
                activation=ACTIVATION,
            )
        elif model_type == 'fpn':
            model = smp.FPN(
                encoder_name=encoder,
                encoder_weights=ENCODER_WEIGHTS,
                classes=4,
                activation=ACTIVATION,
            )
        else:
            raise NotImplementedError("We only support FPN and UNet")

        runner = SupervisedRunner(model)

        # HACK: We are loading a few examples from our dummy loader so catalyst will properly load the weights
        # from our checkpoint
        dummy_dataset = CloudDataset(
            df=sub,
            datatype='test',
            img_ids=test_ids[:1],
            transforms=get_validation_augmentation(),
            preprocessing=get_preprocessing(preprocessing_fn))
        dummy_loader = DataLoader(dummy_dataset,
                                  batch_size=1,
                                  shuffle=False,
                                  num_workers=0)
        loaders = {"test": dummy_loader}
        runner.infer(
            model=model,
            loaders=loaders,
            callbacks=[
                CheckpointCallback(resume=f"{logdir}/checkpoints/best.pth"),
                InferCallback()
            ],
        )

        # Now we do real inference on the full dataset
        test_dataset = CloudDataset(
            df=sub,
            datatype='test',
            img_ids=test_ids,
            transforms=get_validation_augmentation(),
            preprocessing=get_preprocessing(preprocessing_fn))
        test_loader = DataLoader(test_dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=0)

        image_id = 0
        for batch_index, test_batch in enumerate(tqdm.tqdm(test_loader)):
            runner_out = runner.predict_batch(
                {"features":
                 test_batch[0].cuda()})['logits'].cpu().detach().numpy()

            # Applt TTA transforms
            v_flip = test_batch[0].flip(dims=(2, ))
            h_flip = test_batch[0].flip(dims=(3, ))
            v_flip_out = runner.predict_batch({"features": v_flip.cuda()
                                               })['logits'].cpu().detach()
            h_flip_out = runner.predict_batch({"features": h_flip.cuda()
                                               })['logits'].cpu().detach()
            # Undo transforms
            v_flip_out = v_flip_out.flip(dims=(2, )).numpy()
            h_flip_out = h_flip_out.flip(dims=(3, )).numpy()
            # Get average
            tta_avg_out = (v_flip_out + h_flip_out) / 2

            # Combine with original predictions
            beta = 0.4  # From fastai TTA
            runner_out = (beta) * runner_out + (1 - beta) * tta_avg_out

            for preds in runner_out:

                preds = preds.transpose((1, 2, 0))
                preds = cv2.resize(
                    preds,
                    (525, 350))  # height and width are backward in cv2...
                preds = preds.transpose((2, 0, 1))

                idx = batch_index * 4
                test_preds[idx + 0] += sigmoid(preds[0]) / num_models  # fish
                test_preds[idx + 1] += sigmoid(preds[1]) / num_models  # flower
                test_preds[idx + 2] += sigmoid(preds[2]) / num_models  # gravel
                test_preds[idx + 3] += sigmoid(preds[3]) / num_models  # sugar

    # Convert ensembled predictions to RLE predictions
    encoded_pixels = []
    for image_id, preds in enumerate(test_preds):

        predict, num_predict = post_process(preds,
                                            class_params[image_id % 4][0],
                                            class_params[image_id % 4][1])
        if num_predict == 0:
            encoded_pixels.append('')
        else:
            r = mask2rle(predict)
            encoded_pixels.append(r)

    print("Saving submission...")
    sub['EncodedPixels'] = encoded_pixels
    sub.to_csv('ensembled_submission.csv',
               columns=['Image_Label', 'EncodedPixels'],
               index=False)
    print("Saved.")