Esempio n. 1
0
def get_model(name, classification_head, model_weights_path=None):
    if name == 'unet34':
        return smp.Unet('resnet34', encoder_weights='imagenet')
    elif name == 'unet18':
        print('classification_head:', classification_head)
        if classification_head:
            aux_params = dict(
                pooling='max',  # one of 'avg', 'max'
                dropout=0.1,  # dropout ratio, default is None
                activation='sigmoid',  # activation function, default is None
                classes=1,  # define number of output labels
            )
            return smp.Unet('resnet18',
                            aux_params=aux_params,
                            encoder_weights=None,
                            encoder_depth=2,
                            decoder_channels=(256, 128))
        else:
            return smp.Unet('resnet18',
                            encoder_weights='imagenet',
                            encoder_depth=2,
                            decoder_channels=(256, 128))
    elif name == 'unet50':
        return smp.Unet('resnet50', encoder_weights='imagenet')
    elif name == 'unet101':
        return smp.Unet('resnet101', encoder_weights='imagenet')
    elif name == 'linknet34':
        return smp.Linknet('resnet34', encoder_weights='imagenet')
    elif name == 'linknet50':
        return smp.Linknet('resnet50', encoder_weights='imagenet')
    elif name == 'fpn34':
        return smp.FPN('resnet34', encoder_weights='imagenet')
    elif name == 'fpn50':
        return smp.FPN('resnet50', encoder_weights='imagenet')
    elif name == 'fpn101':
        return smp.FPN('resnet101', encoder_weights='imagenet')
    elif name == 'pspnet34':
        return smp.PSPNet('resnet34', encoder_weights='imagenet', classes=1)
    elif name == 'pspnet50':
        return smp.PSPNet('resnet50', encoder_weights='imagenet', classes=1)
    elif name == 'fpn50_season':
        from clearcut_research.pytorch import FPN_double_output
        return FPN_double_output('resnet50', encoder_weights='imagenet')
    elif name == 'fpn50_satellite':
        fpn_resnet50 = smp.FPN('resnet50', encoder_weights=None)
        fpn_resnet50.encoder = get_satellite_pretrained_resnet(
            model_weights_path)
        return fpn_resnet50
    elif name == 'fpn50_multiclass':
        return smp.FPN('resnet50',
                       encoder_weights='imagenet',
                       classes=3,
                       activation='softmax')
    else:
        raise ValueError("Unknown network")
def init(config):
    # ---- Model Initialization  ----
    if config["model"] == "UNet":
        model = smp.Unet(
            activation=None
        )  #UNet2D(n_channels=3, n_classes=1) # #UNet2D(n_channels=1, n_classes=1) #smp.Unet(activation=None)
    elif config["model"] == "PSPNet":
        model = smp.PSPNet(activation=None)
    elif config["model"] == "FPN":
        model = smp.FPN(activation=None)
    elif config["model"] == "Linknet":
        model = smp.Linknet(activation=None)
    else:
        raise Exception('Incorrect model name!')

    # ---- Loss Initialization  ----
    if config["mode"] == 'train':
        if config["loss"] == "DiceBCE":
            loss = LossBinaryDice(dice_weight=config["dice_weight"])
        elif config["loss"] == "FocalTversky":
            loss = FocalTverskyLoss()
        elif config["loss"] == "Focal":
            loss = FocalLoss()
        elif config["loss"] == "Tversky":
            loss = TverskyLoss()
        else:
            raise Exception('Incorrect loss name!')

        return model, loss
    else:
        return model
Esempio n. 3
0
def make_model(model_name='unet_resnet34',
               weights='imagenet',
               n_classes=2,
               input_channels=4):

    if model_name.split('_')[0] == 'unet':

        model = smp.Unet('_'.join(model_name.split('_')[1:]),
                         classes=n_classes,
                         activation=None,
                         encoder_weights=weights,
                         in_channels=input_channels)

    elif model_name.split('_')[0] == 'fpn':
        model = smp.FPN('_'.join(model_name.split('_')[1:]),
                        classes=n_classes,
                        activation=None,
                        encoder_weights=weights,
                        in_channels=input_channels)

    elif model_name.split('_')[0] == 'linknet':
        model = smp.Linknet('_'.join(model_name.split('_')[1:]),
                            classes=n_classes,
                            activation=None,
                            encoder_weights=weights,
                            in_channels=input_channels)
    else:
        raise ValueError('Model not implemented')

    return model
Esempio n. 4
0
    def create_model(self):
        kwargs = {
            'encoder_name': self.encoder_name,
            'encoder_weights': self.encoder_weights,
            'classes': self.num_classes
        }
        if self.model_architecture == 'Unet':
            model = smp.Unet(**kwargs)
        elif self.model_architecture == 'FPN':
            model = smp.FPN(**kwargs)
        elif self.model_architecture == 'Linknet':
            model = smp.Linknet(**kwargs)
        elif self.model_architecture == 'PSPNet':
            model = smp.Linknet(**kwargs)

        return model
 def Linknet(self, img_ch, output_ch):
     return smp.Linknet(encoder_name=self.encoder,
                        encoder_depth=self.en_depth,
                        encoder_weights=self.en_weights,
                        decoder_use_batchnorm=False,
                        in_channels=img_ch,
                        classes=output_ch,
                        activation=None,
                        aux_params=None)
Esempio n. 6
0
def resnet50_Linknet_noclassification(**kwargs):
    model = smp.Linknet('resnet50',
                        in_channels=in_channels,
                        classes=classes,
                        activation=activation,
                        **kwargs)
    print("Just segmentation Model args:")
    print("in_channels:%d,classes:%d,activation:%s" %
          (in_channels, classes, activation))
    print("kwargs", kwargs)
    return model
Esempio n. 7
0
def build_model(configuration):
    model_list = ['UNet', 'LinkNet', 'PSPNet', 'FPN', 'PAN', 'Deeplab_v3', 'Deeplab_v3+']
    if configuration.Model.model_name.lower() == 'unet':
        return smp.Unet(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes,
            decoder_attention_type=None,
        )
    if configuration.Model.model_name.lower() == 'linknet':
        return smp.Linknet(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes
        )
    if configuration.Model.model_name.lower() == 'pspnet':
        return smp.PSPNet(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes
        )
    if configuration.Model.model_name.lower() == 'fpn':
        return smp.FPN(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes
        )
    if configuration.Model.model_name.lower() == 'pan':
        return smp.PAN(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes
        )
    if configuration.Model.model_name.lower() == 'deeplab_v3+':
        return smp.DeepLabV3Plus(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes
        )
    if configuration.Model.model_name.lower() == 'deeplab_v3':
        return smp.DeepLabV3(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes
        )
    raise KeyError(f'Model should be one of {model_list}')
Esempio n. 8
0
def create_segmentation_models(encoder,
                               arch,
                               num_classes=4,
                               encoder_weights=None,
                               activation=None):
    '''
    segmentation_models_pytorch https://github.com/qubvel/segmentation_models.pytorch
    has following architectures: 
    - Unet
    - Linknet
    - FPN
    - PSPNet
    encoders: A lot! see the above github page.

    Deeplabv3+ https://github.com/jfzhang95/pytorch-deeplab-xception
    has for encoders:
    - resnet (resnet101)
    - mobilenet 
    - xception
    - drn
    '''
    if arch == "Unet":
        return smp.Unet(encoder,
                        encoder_weights=encoder_weights,
                        classes=num_classes,
                        activation=activation)
    elif arch == "Linknet":
        return smp.Linknet(encoder,
                           encoder_weights=encoder_weghts,
                           classes=num_classes,
                           activation=activation)
    elif arch == "FPN":
        return smp.FPN(encoder,
                       encoder_weights=encoder_weghts,
                       classes=num_classes,
                       activation=activation)
    elif arch == "PSPNet":
        return smp.PSPNet(encoder,
                          encoder_weights=encoder_weghts,
                          classes=num_classes,
                          activation=activation)
    elif arch == "deeplabv3plus":
        if deeplabv3plus_PATH in os.environ:
            sys.path.append(os.environ[deeplabv3plus_PATH])
            from modeling.deeplab import DeepLab
            return DeepLab(encoder, num_classes=4)
        else:
            raise ValueError('Set deeplabv3plus path by environment variable.')
    else:
        raise ValueError(
            'arch {} is not found, set the correct arch'.format(arch))
        sys.exit()
Esempio n. 9
0
def get_model(name='fpn50', model_weights_path=None):
    if name == 'unet34':
        return smp.Unet('resnet34', encoder_weights='imagenet')
    elif name == 'unet50':
        return smp.Unet('resnet50', encoder_weights='imagenet')
    elif name == 'unet101':
        return smp.Unet('resnet101', encoder_weights='imagenet')
    elif name == 'linknet34':
        return smp.Linknet('resnet34', encoder_weights='imagenet')
    elif name == 'linknet50':
        return smp.Linknet('resnet50', encoder_weights='imagenet')
    elif name == 'fpn34':
        return smp.FPN('resnet34', encoder_weights='imagenet')
    elif name == 'fpn50':
        return smp.FPN('resnet50', encoder_weights='imagenet')
    elif name == 'fpn101':
        return smp.FPN('resnet101', encoder_weights='imagenet')
    elif name == 'pspnet34':
        return smp.PSPNet('resnet34', encoder_weights='imagenet', classes=1)
    elif name == 'pspnet50':
        return smp.PSPNet('resnet50', encoder_weights='imagenet', classes=1)
    elif name == 'fpn50_season':
        from clearcut_research.pytorch import FPN_double_output
        return FPN_double_output('resnet50', encoder_weights='imagenet')
    elif name == 'fpn50_satellite':
        fpn_resnet50 = smp.FPN('resnet50', encoder_weights=None)
        fpn_resnet50.encoder = get_satellite_pretrained_resnet(
            model_weights_path)
        return fpn_resnet50
    elif name == 'fpn50_multiclass':
        return smp.FPN('resnet50',
                       encoder_weights='imagenet',
                       classes=3,
                       activation='softmax')
    else:
        raise ValueError("Unknown network")
Esempio n. 10
0
def main():
    args = parseArgs()
    data = DataTest(rootPth=args.rootPth)
    dataLoader = DataLoader(data,
                            batch_size=args.batchSize,
                            shuffle=False,
                            pin_memory=False,
                            num_workers=args.numWorkers
                            )
    model = smp.Linknet(classes=1, encoder_name='se_resnext101_32x4d').to(device)
    model.load_state_dict(torch.load(args.modelPth))
    if not osp.exists(args.savePth):
        os.makedirs(args.savePth)
    inference(model, dataLoader, args)
    print('--Done--')
Esempio n. 11
0
def create_smp_model(arch, **kwargs):
    'Create segmentation_models_pytorch model'

    assert arch in ARCHITECTURES, f'Select one of {ARCHITECTURES}'

    if arch == "Unet": model = smp.Unet(**kwargs)
    elif arch == "UnetPlusPlus": model = smp.UnetPlusPlus(**kwargs)
    elif arch == "MAnet": model = smp.MAnet(**kwargs)
    elif arch == "FPN": model = smp.FPN(**kwargs)
    elif arch == "PAN": model = smp.PAN(**kwargs)
    elif arch == "PSPNet": model = smp.PSPNet(**kwargs)
    elif arch == "Linknet": model = smp.Linknet(**kwargs)
    elif arch == "DeepLabV3": model = smp.DeepLabV3(**kwargs)
    elif arch == "DeepLabV3Plus": model = smp.DeepLabV3Plus(**kwargs)
    else: raise NotImplementedError

    setattr(model, 'kwargs', kwargs)
    return model
Esempio n. 12
0
    def evaluate(self, architecture="FPN", encoder="resnet34",
                 encoder_weights="imagenet", activation="sigmoid",
                 resize_width=224, resize_height=224):
        """
        Args:

            architecture (str): One of the followings: "FPN", "UNET".
            encoder_weights (str): Any encoder supported by SMP.
            activation (str): Any SMP activation.
            resize_width (int): Preprocessing image resize width.
            resize_height (int): Preprocessing image resize height.
        """
        import segmentation_models_pytorch as smp

        from plantseg.inference import Preprocessing

        if architecture == "FPN":
            model = smp.FPN(
                encoder_name=encoder,
                encoder_weights=encoder_weights,
                classes=1,
                activation=activation)
        elif architecture == "UNet":
            model = smp.Unet(
                encoder_name=encoder,
                encoder_weights=encoder_weights,
                classes=1,
                activation=activation)
        elif architecture == "Linknet":
            model = smp.Linknet(
                encoder_name=encoder,
                encoder_weights=encoder_weights,
                classes=1,
                activation=activation)
        else:
            raise RuntimeError(f"Undefined architecture {architecture}")

        preproc_fun = smp.encoders.get_preprocessing_fn(
            encoder, encoder_weights)

        return model, Preprocessing(preproc_fun, resize_width, resize_height)
Esempio n. 13
0
def get_model(encoder='resnet18',
              type='unet',
              encoder_weights='imagenet',
              classes=4):
    # My own simple wrapper around smp
    if type == 'unet':
        model = smp.Unet(
            encoder_name=encoder,
            encoder_weights=encoder_weights,
            classes=classes,
            activation=None,
        )
    elif type == 'fpn':
        model = smp.FPN(
            encoder_name=encoder,
            encoder_weights=encoder_weights,
            classes=classes,
            activation=None,
        )
    elif type == 'pspnet':
        model = smp.PSPNet(
            encoder_name=encoder,
            encoder_weights=encoder_weights,
            classes=classes,
            activation=None,
        )
    elif type == 'linknet':
        model = smp.Linknet(
            encoder_name=encoder,
            encoder_weights=encoder_weights,
            classes=classes,
            activation=None,
        )
    else:
        raise "weird architecture"
    print(f"Training on {type} architecture with {encoder} encoder")
    preprocessing_fn = smp.encoders.get_preprocessing_fn(
        encoder, encoder_weights)
    return model, preprocessing_fn
Esempio n. 14
0
def get_model(
    model_type: str = "Unet",
    encoder: str = "Resnet18",
    encoder_weights: str = "imagenet",
    activation: str = None,
    n_classes: int = 4,
    task: str = "segmentation",
    source: str = "pretrainedmodels",
    head: str = "simple",
):
    """
    Get model for training or inference.

    Returns loaded models, which is ready to be used.

    Args:
        model_type: segmentation model architecture
        encoder: encoder of the model
        encoder_weights: pre-trained weights to use
        activation: activation function for the output layer
        n_classes: number of classes in the output layer
        task: segmentation or classification
        source: source of model for classification
        head: simply change number of outputs or use better output head

    Returns:

    """
    if task == "segmentation":
        if model_type == "Unet":
            model = smp.Unet(
                # attention_type='scse',
                encoder_name=encoder,
                encoder_weights=encoder_weights,
                classes=n_classes,
                activation=activation,
            )

        elif model_type == "Linknet":
            model = smp.Linknet(
                encoder_name=encoder,
                encoder_weights=encoder_weights,
                classes=n_classes,
                activation=activation,
            )

        elif model_type == "FPN":
            model = smp.FPN(
                encoder_name=encoder,
                encoder_weights=encoder_weights,
                classes=n_classes,
                activation=activation,
            )

        elif model_type == "resnet34_fpn":
            model = resnet34_fpn(num_classes=n_classes, fpn_features=128)

        elif model_type == "effnetB4_fpn":
            model = effnetB4_fpn(num_classes=n_classes, fpn_features=128)

        else:
            model = None

    elif task == "classification":
        if source == "pretrainedmodels":
            model_fn = pretrainedmodels.__dict__[encoder]
            model = model_fn(num_classes=1000, pretrained=encoder_weights)
        elif source == "torchvision":
            model = torchvision.models.__dict__[encoder](
                pretrained=encoder_weights)

        if head == "simple":
            model.last_linear = nn.Linear(model.last_linear.in_features,
                                          n_classes)
        else:
            model = Net(net=model)

    return model
def main():

    fold_path = args.fold_path
    fold_num = args.fold_num
    model_name = args.model_name
    train_csv = args.train_csv
    sub_csv = args.sub_csv
    encoder = args.encoder
    num_workers = args.num_workers
    batch_size = args.batch_size
    num_epochs = args.num_epochs
    learn_late = args.learn_late
    attention_type = args.attention_type

    train = pd.read_csv(train_csv)
    sub = pd.read_csv(sub_csv)

    train['label'] = train['Image_Label'].apply(lambda x: x.split('_')[-1])
    train['im_id'] = train['Image_Label'].apply(
        lambda x: x.replace('_' + x.split('_')[-1], ''))

    sub['label'] = sub['Image_Label'].apply(lambda x: x.split('_')[-1])
    sub['im_id'] = sub['Image_Label'].apply(
        lambda x: x.replace('_' + x.split('_')[-1], ''))

    train_fold = pd.read_csv(f'{fold_path}/train_file_fold_{fold_num}.csv')
    val_fold = pd.read_csv(f'{fold_path}/valid_file_fold_{fold_num}.csv')

    train_ids = np.array(train_fold.file_name)
    valid_ids = np.array(val_fold.file_name)

    encoder_weights = 'imagenet'
    attention_type = None if attention_type == 'None' else attention_type

    if model_name == 'Unet':
        model = smp.Unet(
            encoder_name=encoder,
            encoder_weights=encoder_weights,
            classes=4,
            activation='softmax',
            attention_type=attention_type,
        )
    if model_name == 'Linknet':
        model = smp.Linknet(
            encoder_name=encoder,
            encoder_weights=encoder_weights,
            classes=4,
            activation='softmax',
        )
    if model_name == 'FPN':
        model = smp.FPN(
            encoder_name=encoder,
            encoder_weights=encoder_weights,
            classes=4,
            activation='softmax',
        )
    if model_name == 'ORG':
        model = Linknet_resnet18_ASPP()

    preprocessing_fn = smp.encoders.get_preprocessing_fn(
        encoder, encoder_weights)

    train_dataset = CloudDataset(
        df=train,
        datatype='train',
        img_ids=train_ids,
        transforms=get_training_augmentation(),
        preprocessing=get_preprocessing(preprocessing_fn))

    valid_dataset = CloudDataset(
        df=train,
        datatype='valid',
        img_ids=valid_ids,
        transforms=get_validation_augmentation(),
        preprocessing=get_preprocessing(preprocessing_fn))

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        drop_last=True,
        pin_memory=True,
    )
    valid_loader = DataLoader(valid_dataset,
                              batch_size=batch_size,
                              shuffle=False,
                              num_workers=num_workers)

    loaders = {"train": train_loader, "valid": valid_loader}

    logdir = f"./log/logs_{model_name}_fold_{fold_num}_{encoder}/segmentation"

    #for batch_idx, (data, target) in enumerate(loaders['train']):
    #    print(batch_idx)

    print(logdir)

    if model_name == 'ORG':
        optimizer = NAdam([
            {
                'params': model.parameters(),
                'lr': learn_late
            },
        ])
    else:
        optimizer = NAdam([
            {
                'params': model.decoder.parameters(),
                'lr': learn_late
            },
            {
                'params': model.encoder.parameters(),
                'lr': learn_late
            },
        ])

    scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=0)
    criterion = smp.utils.losses.BCEDiceLoss()

    runner = SupervisedRunner()

    runner.train(model=model,
                 criterion=criterion,
                 optimizer=optimizer,
                 scheduler=scheduler,
                 loaders=loaders,
                 callbacks=[
                     DiceCallback(),
                     EarlyStoppingCallback(patience=5, min_delta=1e-7)
                 ],
                 logdir=logdir,
                 num_epochs=num_epochs,
                 verbose=1)
Esempio n. 16
0
cuda_id = 2
DEVICE = 'cuda'
NUM_EPOCH = 20
SAVE_PRE = 1
EVAL_PRE = 1
PRINT_PRE = 1
NUM_STEPS_STOP = 100000
SAVE_DIR = project_path + r'/../checkpoints/' + str(localtime)
if not os.path.exists(SAVE_DIR):
    os.makedirs(SAVE_DIR)

# 数据集
train_loader, valid_loader = create_valley_data_loader()

# 模型
net = smp.Linknet('resnet50', in_channels=1, classes=1).cuda(cuda_id)

# 损失函数
loss = torch.nn.CrossEntropyLoss().cuda(cuda_id)

# 验证指标
metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]

# 优化器
optimizer = torch.optim.Adam(params=net.parameters(), lr=0.0001)

# tensorboardX
writer = SummaryWriter(project_path + r'/../runs/' + str(localtime))  # 数据存放在这个文件夹
Esempio n. 17
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--encoder', type=str, default='efficientnet-b0')
    parser.add_argument('--model', type=str, default='unet')
    parser.add_argument('--loc', type=str)
    parser.add_argument('--data_folder', type=str, default='../input/')
    parser.add_argument('--batch_size', type=int, default=2)
    parser.add_argument('--optimize', type=bool, default=False)
    parser.add_argument('--tta_pre', type=bool, default=False)
    parser.add_argument('--tta_post', type=bool, default=False)
    parser.add_argument('--merge', type=str, default='mean')
    parser.add_argument('--min_size', type=int, default=10000)
    parser.add_argument('--thresh', type=float, default=0.5)
    parser.add_argument('--name', type=str)

    args = parser.parse_args()
    encoder = args.encoder
    model = args.model
    loc = args.loc
    data_folder = args.data_folder
    bs = args.batch_size
    optimize = args.optimize
    tta_pre = args.tta_pre
    tta_post = args.tta_post
    merge = args.merge
    min_size = args.min_size
    thresh = args.thresh
    name = args.name

    if model == 'unet':
        model = smp.Unet(encoder_name=encoder,
                         encoder_weights='imagenet',
                         classes=4,
                         activation=None)
    if model == 'fpn':
        model = smp.FPN(
            encoder_name=encoder,
            encoder_weights='imagenet',
            classes=4,
            activation=None,
        )
    if model == 'pspnet':
        model = smp.PSPNet(
            encoder_name=encoder,
            encoder_weights='imagenet',
            classes=4,
            activation=None,
        )
    if model == 'linknet':
        model = smp.Linknet(
            encoder_name=encoder,
            encoder_weights='imagenet',
            classes=4,
            activation=None,
        )

    preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder, 'imagenet')

    test_df = get_dataset(train=False)
    test_df = prepare_dataset(test_df)
    test_ids = test_df['Image_Label'].apply(
        lambda x: x.split('_')[0]).drop_duplicates().values
    test_dataset = CloudDataset(
        df=test_df,
        datatype='test',
        img_ids=test_ids,
        transforms=valid1(),
        preprocessing=get_preprocessing(preprocessing_fn))
    test_loader = DataLoader(test_dataset, batch_size=bs, shuffle=False)

    val_df = get_dataset(train=True)
    val_df = prepare_dataset(val_df)
    _, val_ids = get_train_test(val_df)
    valid_dataset = CloudDataset(
        df=val_df,
        datatype='train',
        img_ids=val_ids,
        transforms=valid1(),
        preprocessing=get_preprocessing(preprocessing_fn))
    valid_loader = DataLoader(valid_dataset, batch_size=bs, shuffle=False)

    model.load_state_dict(torch.load(loc)['model_state_dict'])

    class_params = {
        0: (thresh, min_size),
        1: (thresh, min_size),
        2: (thresh, min_size),
        3: (thresh, min_size)
    }

    if optimize:
        print("OPTIMIZING")
        print(tta_pre)
        if tta_pre:
            opt_model = tta.SegmentationTTAWrapper(
                model,
                tta.Compose([
                    tta.HorizontalFlip(),
                    tta.VerticalFlip(),
                    tta.Rotate90(angles=[0, 180])
                ]),
                merge_mode=merge)
        else:
            opt_model = model
        tta_runner = SupervisedRunner()
        print("INFERRING ON VALID")
        tta_runner.infer(
            model=opt_model,
            loaders={'valid': valid_loader},
            callbacks=[InferCallback()],
            verbose=True,
        )

        valid_masks = []
        probabilities = np.zeros((4 * len(valid_dataset), 350, 525))
        for i, (batch, output) in enumerate(
                tqdm(
                    zip(valid_dataset,
                        tta_runner.callbacks[0].predictions["logits"]))):
            _, mask = batch
            for m in mask:
                if m.shape != (350, 525):
                    m = cv2.resize(m,
                                   dsize=(525, 350),
                                   interpolation=cv2.INTER_LINEAR)
                valid_masks.append(m)

            for j, probability in enumerate(output):
                if probability.shape != (350, 525):
                    probability = cv2.resize(probability,
                                             dsize=(525, 350),
                                             interpolation=cv2.INTER_LINEAR)
                probabilities[(i * 4) + j, :, :] = probability

        print("RUNNING GRID SEARCH")
        for class_id in range(4):
            print(class_id)
            attempts = []
            for t in range(30, 70, 5):
                t /= 100
                for ms in [7500, 10000, 12500, 15000, 175000]:
                    masks = []
                    for i in range(class_id, len(probabilities), 4):
                        probability = probabilities[i]
                        predict, num_predict = post_process(
                            sigmoid(probability), t, ms)
                        masks.append(predict)

                    d = []
                    for i, j in zip(masks, valid_masks[class_id::4]):
                        if (i.sum() == 0) & (j.sum() == 0):
                            d.append(1)
                        else:
                            d.append(dice(i, j))

                    attempts.append((t, ms, np.mean(d)))

            attempts_df = pd.DataFrame(attempts,
                                       columns=['threshold', 'size', 'dice'])

            attempts_df = attempts_df.sort_values('dice', ascending=False)
            print(attempts_df.head())
            best_threshold = attempts_df['threshold'].values[0]
            best_size = attempts_df['size'].values[0]

            class_params[class_id] = (best_threshold, best_size)

        del opt_model
        del tta_runner
        del valid_masks
        del probabilities
    gc.collect()

    if tta_post:
        model = tta.SegmentationTTAWrapper(model,
                                           tta.Compose([
                                               tta.HorizontalFlip(),
                                               tta.VerticalFlip(),
                                               tta.Rotate90(angles=[0, 180])
                                           ]),
                                           merge_mode=merge)
    else:
        model = model
    print(tta_post)

    runner = SupervisedRunner()
    runner.infer(
        model=model,
        loaders={'test': test_loader},
        callbacks=[InferCallback()],
        verbose=True,
    )

    encoded_pixels = []
    image_id = 0

    for i, image in enumerate(tqdm(runner.callbacks[0].predictions['logits'])):
        for i, prob in enumerate(image):
            if prob.shape != (350, 525):
                prob = cv2.resize(prob,
                                  dsize=(525, 350),
                                  interpolation=cv2.INTER_LINEAR)
            predict, num_predict = post_process(sigmoid(prob),
                                                class_params[image_id % 4][0],
                                                class_params[image_id % 4][1])
            if num_predict == 0:
                encoded_pixels.append('')
            else:
                r = mask2rle(predict)
                encoded_pixels.append(r)
            image_id += 1

    test_df['EncodedPixels'] = encoded_pixels
    test_df.to_csv(name, columns=['Image_Label', 'EncodedPixels'], index=False)
Esempio n. 18
0
checkpoint = torch.load(args.checkpoint)
arch_dict = {
    "unet":
    smp.Unet(
        encoder_name=checkpoint["encoder"],
        encoder_weights=checkpoint["encoder_weight"],
        classes=8,
        activation=checkpoint["activation"],
        decoder_attention_type="scse",
        decoder_use_batchnorm=True,
    ),
    "linknet":
    smp.Linknet(
        encoder_name=checkpoint["encoder"],
        encoder_weights=checkpoint["encoder_weight"],
        classes=8,
        activation=checkpoint["activation"],
    ),
    "fpn":
    smp.FPN(
        encoder_name=checkpoint["encoder"],
        encoder_weights=checkpoint["encoder_weight"],
        classes=8,
        activation=checkpoint["activation"],
    ),
    "pspnet":
    smp.PSPNet(
        encoder_name=checkpoint["encoder"],
        encoder_weights=checkpoint["encoder_weight"],
        classes=8,
        activation=checkpoint["activation"],
Esempio n. 19
0
        print("fold:  {}    ----------------------------------------".format(
            fold))
        best = 0
        trainloader, validloader = prepare_train_valid_dataloader(
            dir_df, [fold])

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        path = f'/mnt/result1/{CFG.data}/{CFG.encoder}-{CFG.base_model}-{CFG.criterion}-{CFG.data}-FOLD-{fold}-model.pth'
        state_dict = torch.load(path, map_location=torch.device('cpu'))
        if CFG.base_model == 'unet':
            model = smp.Unet(CFG.encoder, encoder_weights=None, classes=1)
        elif CFG.base_model == 'linknet':
            model = smp.Linknet(CFG.encoder,
                                encoder_weights='imagenet',
                                classes=1)

        model.load_state_dict(state_dict)
        del state_dict

        scaler = GradScaler()

        for epoch in range(CFG.epoch):
            if epoch < CFG.freeze_epoch:
                for p in model.encoder.parameters():
                    p.requires_grad = False

                if args.op == 'adam':
                    optimizer1 = Adam(filter(lambda p: p.requires_grad,
                                             model.parameters()),
Esempio n. 20
0
def get_model(config):
    """
    """
    arch = config.MODEL.ARCHITECTURE
    backbone = config.MODEL.BACKBONE
    encoder_weights = config.MODEL.ENCODER_PRETRAINED_FROM
    in_channels = config.MODEL.IN_CHANNELS
    n_classes = len(config.INPUT.CLASSES)
    activation = config.MODEL.ACTIVATION

    # unet specific
    decoder_attention_type = 'scse' if config.MODEL.UNET_ENABLE_DECODER_SCSE else None

    if arch == 'unet':
        model = smp.Unet(
            encoder_name=backbone,
            encoder_weights=encoder_weights,
            decoder_channels=config.MODEL.UNET_DECODER_CHANNELS,
            decoder_attention_type=decoder_attention_type,
            in_channels=in_channels,
            classes=n_classes,
            activation=activation
        )
    elif arch == 'fpn':
        model = smp.FPN(
            encoder_name=backbone,
            encoder_weights=encoder_weights,
            decoder_dropout=config.MODEL.FPN_DECODER_DROPOUT,
            in_channels=in_channels,
            classes=n_classes,
            activation=activation
        )
    elif arch == 'pan':
        model = smp.PAN(
            encoder_name=backbone,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=n_classes,
            activation=activation
        )
    elif arch == 'pspnet':
        model = smp.PSPNet(
            encoder_name=backbone,
            encoder_weights=encoder_weights,
            psp_dropout=config.MODEL.PSPNET_DROPOUT,
            in_channels=in_channels,
            classes=n_classes,
            activation=activation
        )
    elif arch == 'deeplabv3':
        model = smp.DeepLabV3(
            encoder_name=backbone,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=n_classes,
            activation=activation
        )
    elif arch == 'linknet':
        model = smp.Linknet(
            encoder_name=backbone,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=n_classes,
            activation=activation
        )
    else:
        raise ValueError()

    model = torch.nn.DataParallel(model)

    if config.MODEL.WEIGHT and config.MODEL.WEIGHT != 'none':
        # load weight from file
        model.load_state_dict(
            torch.load(
                config.MODEL.WEIGHT,
                map_location=torch.device('cpu')
            )
        )

    model = model.to(config.MODEL.DEVICE)
    return model
def main():
    fold_path = args.fold_path
    fold_num = args.fold_num
    model_name = args.model_name
    train_csv = args.train_csv
    sub_csv = args.sub_csv
    encoder = args.encoder
    num_workers = args.num_workers
    batch_size = args.batch_size
    log_path = args.log_path
    is_tta = args.is_tta
    test_batch_size = args.test_batch_size
    attention_type = args.attention_type
    print(log_path)

    train = pd.read_csv(train_csv)
    train['label'] = train['Image_Label'].apply(lambda x: x.split('_')[-1])
    train['im_id'] = train['Image_Label'].apply(lambda x: x.replace('_' + x.split('_')[-1], ''))

    val_fold = pd.read_csv(f'{fold_path}/valid_file_fold_{fold_num}.csv')
    valid_ids = np.array(val_fold.file_name)

    attention_type = None if attention_type == 'None' else attention_type

    encoder_weights = 'imagenet'

    if model_name == 'Unet':
        model = smp.Unet(
            encoder_name=encoder,
            encoder_weights=encoder_weights,
            classes=CLASS,
            activation='softmax',
            attention_type=attention_type,
        )
    if model_name == 'Linknet':
        model = smp.Linknet(
            encoder_name=encoder,
            encoder_weights=encoder_weights,
            classes=CLASS,
            activation='softmax',
        )
    if model_name == 'FPN':
        model = smp.FPN(
            encoder_name=encoder,
            encoder_weights=encoder_weights,
            classes=CLASS,
            activation='softmax',
        )
    if model_name == 'ORG':
        model = Linknet_resnet18_ASPP(
        )


    preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder, encoder_weights)

    valid_dataset = CloudDataset(df=train,
                                 datatype='valid',
                                 img_ids=valid_ids,
                                 transforms=get_validation_augmentation(),
                                 preprocessing=get_preprocessing(preprocessing_fn))

    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    loaders = {"infer": valid_loader}
    runner = SupervisedRunner()

    checkpoint = torch.load(f"{log_path}/checkpoints/best.pth")
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    transforms = tta.Compose(
        [
            tta.HorizontalFlip(),
            tta.VerticalFlip(),
        ]
    )
    model = tta.SegmentationTTAWrapper(model, transforms)
    runner.infer(
        model=model,
        loaders=loaders,
        callbacks=[InferCallback()],
    )
    callbacks_num = 0

    valid_masks = []
    probabilities = np.zeros((valid_dataset.__len__() * CLASS, IMG_SIZE[0], IMG_SIZE[1]))

    # ========
    # val predict
    #

    for batch in tqdm(valid_dataset):  # クラスごとの予測値
        _, mask = batch
        for m in mask:
            m = resize_img(m)
            valid_masks.append(m)

    for i, output in enumerate(tqdm(runner.callbacks[callbacks_num].predictions["logits"])):
        for j, probability in enumerate(output):
            probability = resize_img(probability)  # 各クラスごとにprobability(予測値)が取り出されている。jは0~3だと思う。
            probabilities[i * CLASS + j, :, :] = probability

    # ========
    # search best size and threshold
    #

    class_params = {}
    for class_id in range(CLASS):
        attempts = []
        for threshold in range(20, 90, 5):
            threshold /= 100
            for min_size in [10000, 15000, 20000]:
                masks = class_masks(class_id, probabilities, threshold, min_size)
                dices = class_dices(class_id, masks, valid_masks)
                attempts.append((threshold, min_size, np.mean(dices)))

        attempts_df = pd.DataFrame(attempts, columns=['threshold', 'size', 'dice'])
        attempts_df = attempts_df.sort_values('dice', ascending=False)

        print(attempts_df.head())

        best_threshold = attempts_df['threshold'].values[0]
        best_size = attempts_df['size'].values[0]

        class_params[class_id] = (best_threshold, best_size)

    # ========
    # gc
    #
    torch.cuda.empty_cache()
    gc.collect()

    # ========
    # predict
    #
    sub = pd.read_csv(sub_csv)
    sub['label'] = sub['Image_Label'].apply(lambda x: x.split('_')[-1])
    sub['im_id'] = sub['Image_Label'].apply(lambda x: x.replace('_' + x.split('_')[-1], ''))

    test_ids = sub['Image_Label'].apply(lambda x: x.split('_')[0]).drop_duplicates().values

    test_dataset = CloudDataset(df=sub,
                                datatype='test',
                                img_ids=test_ids,
                                transforms=get_validation_augmentation(),
                                preprocessing=get_preprocessing(preprocessing_fn))

    encoded_pixels = get_test_encoded_pixels(test_dataset, runner, class_params, test_batch_size)
    sub['EncodedPixels'] = encoded_pixels

    # ========
    # val dice
    #

    val_Image_Label = []
    for i, row in val_fold.iterrows():
        val_Image_Label.append(row.file_name + '_Fish')
        val_Image_Label.append(row.file_name + '_Flower')
        val_Image_Label.append(row.file_name + '_Gravel')
        val_Image_Label.append(row.file_name + '_Sugar')

    val_encoded_pixels = get_test_encoded_pixels(valid_dataset, runner, class_params, test_batch_size)
    val = pd.DataFrame(val_encoded_pixels, columns=['EncodedPixels'])
    val['Image_Label'] = val_Image_Label

    sub.to_csv(f'./sub/sub_{model_name}_fold_{fold_num}_{encoder}.csv', columns=['Image_Label', 'EncodedPixels'], index=False)
    val.to_csv(f'./val/val_{model_name}_fold_{fold_num}_{encoder}.csv', columns=['Image_Label', 'EncodedPixels'], index=False)
Esempio n. 22
0
def get_segmentation_model(
    arch: str,
    encoder_name: str,
    encoder_weights: Optional[str] = "imagenet",
    pretrained_checkpoint_path: Optional[str] = None,
    checkpoint_path: Optional[Union[str, List[str]]] = None,
    convert_bn: Optional[str] = None,
    convert_bottleneck: Tuple[int, int, int] = (0, 0, 0),
    **kwargs: Any,
) -> nn.Module:
    """
    Fetch segmentation model by its name
    :param arch:
    :param encoder_name:
    :param encoder_weights:
    :param checkpoint_path:
    :param pretrained_checkpoint_path:
    :param convert_bn:
    :param convert_bottleneck:
    :param kwargs:
    :return:
    """

    arch = arch.lower()
    if (encoder_name == "en_resnet34" or checkpoint_path is not None
            or pretrained_checkpoint_path is not None):
        encoder_weights = None

    if arch == "unet":
        model = smp.Unet(encoder_name=encoder_name,
                         encoder_weights=encoder_weights,
                         **kwargs)
    elif arch == "unetplusplus" or arch == "unet++":
        model = smp.UnetPlusPlus(encoder_name=encoder_name,
                                 encoder_weights=encoder_weights,
                                 **kwargs)
    elif arch == "linknet":
        model = smp.Linknet(encoder_name=encoder_name,
                            encoder_weights=encoder_weights,
                            **kwargs)
    elif arch == "pspnet":
        model = smp.PSPNet(encoder_name=encoder_name,
                           encoder_weights=encoder_weights,
                           **kwargs)
    elif arch == "pan":
        model = smp.PAN(encoder_name=encoder_name,
                        encoder_weights=encoder_weights,
                        **kwargs)
    elif arch == "deeplabv3":
        model = smp.DeepLabV3(encoder_name=encoder_name,
                              encoder_weights=encoder_weights,
                              **kwargs)
    elif arch == "deeplabv3plus" or arch == "deeplabv3+":
        model = smp.DeepLabV3Plus(encoder_name=encoder_name,
                                  encoder_weights=encoder_weights,
                                  **kwargs)
    elif arch == "manet":
        model = smp.MAnet(encoder_name=encoder_name,
                          encoder_weights=encoder_weights,
                          **kwargs)
    else:
        raise ValueError

    if pretrained_checkpoint_path is not None:
        print(f"Loading pretrained checkpoint {pretrained_checkpoint_path}")
        state_dict = torch.load(pretrained_checkpoint_path,
                                map_location=torch.device("cpu"))
        model.encoder.load_state_dict(state_dict)
        del state_dict

    # TODO fmap_size=16 hardcoded for input 256 (matters for positional encoding)
    botnet.convert_resnet(
        model.encoder,
        replacement=convert_bottleneck,
        fmap_size=16,
        position_encoding=None,
    )

    # TODO parametrize conversion
    print(f"Convert BN to {convert_bn}")
    if convert_bn == "instance":
        print("Converting BatchNorm2d to InstanceNorm2d")
        model = batch_norm2instance(model)
    elif convert_bn == "group":
        print("Converting BatchNorm2d to GroupNorm")
        model = batch_norm2group(model, channels_per_group=1)
    elif convert_bn == "bnet":
        print("Converting BatchNorm2d to BNet2d")
        model = batch_norm2bnet(model)
    elif convert_bn == "gnet":
        print("Converting BatchNorm2d to GNet2d")
        model = batch_norm2gnet(model, channels_per_group=1)
    elif not convert_bn:
        print("Do not convert BatchNorm2d")
    else:
        raise ValueError

    if checkpoint_path is not None:
        if not isinstance(checkpoint_path, list):
            checkpoint_path = [checkpoint_path]
        states = []
        for cp in checkpoint_path:
            # Load checkpoint
            print(f"\nLoading checkpoint {str(cp)}")
            state_dict = torch.load(
                cp, map_location=torch.device("cpu"))["model_state_dict"]
            states.append(state_dict)
        state_dict = average_weights(states)
        model.load_state_dict(state_dict)
        del state_dict

    return model
Esempio n. 23
0
    def __init__(self, architecture='Unet', encoder='resnet34', depth=5, in_channels=3, classes=2, activation='softmax'):
        super(SegmentationModels, self).__init__()
        self.architecture = architecture
        self.encoder = encoder
        self.depth = depth
        self.in_channels = in_channels
        self.classes = classes
        self.activation = activation

        # define model
        _ARCHITECTURES = ['Unet', 'Linknet', 'FPN', 'PSPNet', 'PAN', 'DeepLabV3', 'DeepLabV3Plus']
        assert self.architecture in _ARCHITECTURES, 'architecture=={0}, actual \'{1}\''.format(_ARCHITECTURES, self.architecture)
        if self.architecture == 'Unet':
            self.model = smp.Unet(encoder_name=self.encoder,
                                  encoder_weights=None,
                                  encoder_depth=self.depth,
                                  in_channels=self.in_channels,
                                  classes=self.classes,
                                  activation=self.activation)
            self.pad_unit = 2 ** self.depth
        elif self.architecture == 'Linknet':
            self.model = smp.Linknet(encoder_name=self.encoder,
                                     encoder_weights=None,
                                     encoder_depth=self.depth,
                                     in_channels=self.in_channels,
                                     classes=self.classes,
                                     activation=self.activation)
            self.pad_unit = 2 ** self.depth
        elif self.architecture == 'FPN':
            self.model = smp.FPN(encoder_name=self.encoder,
                                 encoder_weights=None,
                                 encoder_depth=self.depth,
                                 in_channels=self.in_channels,
                                 classes=self.classes,
                                 activation=self.activation)
            self.pad_unit = 2 ** self.depth
        elif self.architecture == 'PSPNet':
            self.model = smp.PSPNet(encoder_name=self.encoder,
                                    encoder_weights=None,
                                    encoder_depth=self.depth,
                                    in_channels=self.in_channels,
                                    classes=self.classes,
                                    activation=self.activation)
            self.pad_unit = 2 ** self.depth
        elif self.architecture == 'PAN':
            self.model = smp.PAN(encoder_name=self.encoder,
                                 encoder_weights=None,
                                 encoder_depth=self.depth,
                                 in_channels=self.in_channels,
                                 classes=self.classes,
                                 activation=self.activation)
            self.pad_unit = 2 ** self.depth
        elif self.architecture == 'DeepLabV3':
            self.model = smp.DeepLabV3(encoder_name=self.encoder,
                                       encoder_weights=None,
                                       encoder_depth=self.depth,
                                       in_channels=self.in_channels,
                                       classes=self.classes,
                                       activation=self.activation)
            self.pad_unit = 2 ** self.depth
        elif self.architecture == 'DeepLabV3Plus':
            self.model = smp.DeepLabV3Plus(encoder_name=self.encoder,
                                       encoder_weights=None,
                                       encoder_depth=self.depth,
                                       in_channels=self.in_channels,
                                       classes=self.classes,
                                       activation=self.activation)
            self.pad_unit = 2 ** self.depth
Esempio n. 24
0
test_dataset = SeverstalSteelData(img_dir=DATASET_PATH + '/train_images',
                                  split_csv=DATASET_PATH + '/steel_valid.csv',
                                  rle_csv=DATASET_PATH + '/train.csv',
                                  device=device)
test_dataloader = DataLoader(test_dataset,
                             batch_size=batch_size,
                             shuffle=False,
                             num_workers=0)

# print(train_dataset.__getitem__(2))

##===== Model Configuration ====================================================

model = smp.Linknet('se_resnext101_32x4d',
                    classes=4,
                    activation=None,
                    encoder_weights=None)
model = model.to(device)

TEST_MODEL(model, (3, 1600, 256))
##====== Optimizer Zone ========================================================

optimizer = torch.optim.Adam(model.parameters(),
                             lr=learning_rate,
                             weight_decay=1e-5)

criterionDBCE = metrics.DiceBCELoss()
criterionFTversky = metrics.FocalTverskyLoss()
criterionFocal = metrics.FocalLoss()

Esempio n. 25
0
def create_model(model_name,
                 encoder_name,
                 pretrained=False,
                 num_classes=6,
                 in_chans=3,
                 checkpoint_path='',
                 **kwargs):
    """Create a model
    Args:
        model_name (str): name of model to instantiate
        encoder_name (str): name of encoder to instantiate
        pretrained (bool): load pretrained ImageNet-1k weights if true
        num_classes (int): number of classes for final layer (default 6)
        in_chans (int): number of input channels / colors (default: 3)
        checkpoint_path (str): path of checkpoint to load after model is initialized
    Keyword Args:
        **: other kwargs are model specific
    """
    # I should probably rewrite it
    weights = None
    if pretrained:
        weights = 'imagenet'
        _logger.info('Using pre-trained imagenet weights')

    if model_name == 'unetplusplus':
        model = smp.UnetPlusPlus(encoder_name=encoder_name,
                                 encoder_weights=weights,
                                 classes=num_classes,
                                 in_channels=in_chans,
                                 **kwargs)
    elif model_name == 'unet':
        model = smp.Unet(encoder_name=encoder_name,
                         encoder_weights=weights,
                         classes=num_classes,
                         in_channels=in_chans,
                         **kwargs)
    elif model_name == 'fpn':
        model = smp.FPN(encoder_name=encoder_name,
                        encoder_weights=weights,
                        classes=num_classes,
                        in_channels=in_chans,
                        **kwargs)
    elif model_name == 'linknet':
        model = smp.Linknet(encoder_name=encoder_name,
                            encoder_weights=weights,
                            classes=num_classes,
                            in_channels=in_chans,
                            **kwargs)
    elif model_name == 'pspnet':
        model = smp.PSPNet(encoder_name=encoder_name,
                           encoder_weights=weights,
                           classes=num_classes,
                           in_channels=in_chans,
                           **kwargs)
    else:
        raise NotImplementedError()

    if checkpoint_path:
        load_checkpoint(model, checkpoint_path)

    return model
def main():
    with timer('load data'):
        df = pd.read_csv(FOLD_PATH)

    with timer('preprocessing'):
        train_df, val_df = df[df.fold_id != FOLD_ID], df[df.fold_id == FOLD_ID]

        train_augmentation = Compose([
            Flip(p=0.5),
            OneOf([
                #ElasticTransform(p=0.5, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03),
                GridDistortion(p=0.5),
                OpticalDistortion(p=0.5, distort_limit=2, shift_limit=0.5)
            ], p=0.5),
            #OneOf([
            #    ShiftScaleRotate(p=0.5),
            ##    RandomRotate90(p=0.5),
            #    Rotate(p=0.5)
            #], p=0.5),
            OneOf([
                Blur(blur_limit=8, p=0.5),
                MotionBlur(blur_limit=8,p=0.5),
                MedianBlur(blur_limit=8,p=0.5),
                GaussianBlur(blur_limit=8,p=0.5)
            ], p=0.5),
            OneOf([
                #CLAHE(clip_limit=4, tile_grid_size=(4, 4), p=0.5),
                RandomGamma(gamma_limit=(100,140), p=0.5),
                RandomBrightnessContrast(p=0.5),
                RandomBrightness(p=0.5),
                RandomContrast(p=0.5)
            ], p=0.5),
            OneOf([
                GaussNoise(p=0.5),
                Cutout(num_holes=10, max_h_size=10, max_w_size=20, p=0.5)
            ], p=0.5)
        ])
        train_augmentation = Compose([
            Flip(p=0.5)
        ])
        val_augmentation = None

        train_dataset = SeverDataset(train_df, IMG_DIR, IMG_SIZE, N_CLASSES, id_colname=ID_COLUMNS,
                                    transforms=train_augmentation)
        val_dataset = SeverDataset(val_df, IMG_DIR, IMG_SIZE, N_CLASSES, id_colname=ID_COLUMNS,
                                  transforms=val_augmentation)
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

        del train_df, val_df, df, train_dataset, val_dataset
        gc.collect()

    with timer('create model'):
        model = smp.Linknet('se_resnext101_32x4d', encoder_weights='imagenet', classes=N_CLASSES, encoder_se_module=True,
                         decoder_semodule=True, h_columns=False)
        model.load_state_dict(torch.load(model_path))
        model.to(device)

        #criterion = torch.nn.BCEWithLogitsLoss()
        criterion = FocalLovaszLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
        scheduler = CosineAnnealingLR(optimizer, T_max=CLR_CYCLE, eta_min=3e-5)
        #scheduler = GradualWarmupScheduler(optimizer, multiplier=1.1, total_epoch=CLR_CYCLE*2, after_scheduler=scheduler_cosine)

        model, optimizer = amp.initialize(model, optimizer, opt_level="O1", verbosity=0)

    with timer('train'):
        train_losses = []
        valid_losses = []

        best_model_loss = 999
        best_model_ep = 0
        checkpoint = 0

        for epoch in range(1, EPOCHS + 1):
            if epoch % (CLR_CYCLE * 2) == 0:
                if epoch != 0:
                    y_val = y_val.reshape(-1, N_CLASSES, IMG_SIZE[0], IMG_SIZE[1])
                    best_pred = best_pred.reshape(-1, N_CLASSES, IMG_SIZE[0], IMG_SIZE[1])
                    for i in range(N_CLASSES):
                        th, score, _, _ = search_threshold(y_val[:, i, :, :], best_pred[:, i, :, :])
                        LOGGER.info('Best loss: {} Best Dice: {} on epoch {} th {} class {}'.format(
                            round(best_model_loss, 5), round(score, 5), best_model_ep, th, i))
                checkpoint += 1
                best_model_loss = 999

            LOGGER.info("Starting {} epoch...".format(epoch))
            tr_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
            train_losses.append(tr_loss)
            LOGGER.info('Mean train loss: {}'.format(round(tr_loss, 5)))

            valid_loss, val_pred, y_val = validate(model, val_loader, criterion, device)
            valid_losses.append(valid_loss)
            LOGGER.info('Mean valid loss: {}'.format(round(valid_loss, 5)))

            scheduler.step()

            if valid_loss < best_model_loss:
                torch.save(model.state_dict(), '{}_fold{}_ckpt{}.pth'.format(EXP_ID, FOLD_ID, checkpoint))
                best_model_loss = valid_loss
                best_model_ep = epoch
                best_pred = val_pred

            del val_pred
            gc.collect()

    with timer('eval'):
        y_val = y_val.reshape(-1, N_CLASSES, IMG_SIZE[0], IMG_SIZE[1])
        best_pred = best_pred.reshape(-1, N_CLASSES, IMG_SIZE[0], IMG_SIZE[1])
        for i in range(N_CLASSES):
            th, score, _, _ = search_threshold(y_val[:, i, :, :], best_pred[:, i, :, :])
            LOGGER.info('Best loss: {} Best Dice: {} on epoch {} th {} class {}'.format(
                round(best_model_loss, 5), round(score, 5), best_model_ep, th, i))

    xs = list(range(1, len(train_losses) + 1))
    plt.plot(xs, train_losses, label='Train loss')
    plt.plot(xs, valid_losses, label='Val loss')
    plt.legend()
    plt.xticks(xs)
    plt.xlabel('Epochs')
    plt.savefig("loss.png")
Esempio n. 27
0
    def __init__(self,
                 architecture="Unet",
                 encoder="resnet34",
                 depth=5,
                 in_channels=3,
                 classes=2,
                 activation="softmax"):
        super(SegmentationModels, self).__init__()
        self.architecture = architecture
        self.encoder = encoder
        self.depth = depth
        self.in_channels = in_channels
        self.classes = classes
        self.activation = activation

        # define model

        _ARCHITECTURES = [
            "Unet", "UnetPlusPlus", "Linknet", "MAnet", "FPN", "PSPNet", "PAN",
            "DeepLabV3", "DeepLabV3Plus"
        ]
        assert self.architecture in _ARCHITECTURES, "architecture=={0}, actual '{1}'".format(
            _ARCHITECTURES, self.architecture)

        if self.architecture == "Unet":
            self.model = smp.Unet(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
        elif self.architecture == "UnetPlusPlus":
            self.model = smp.UnetPlusPlus(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
        elif self.architecture == "MAnet":
            self.model = smp.MAnet(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
        elif self.architecture == "Linknet":
            self.model = smp.Linknet(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
        elif self.architecture == "FPN":
            self.model = smp.FPN(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
        elif self.architecture == "PSPNet":
            self.model = smp.PSPNet(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
        elif self.architecture == "PAN":
            self.model = smp.PAN(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
        elif self.architecture == "DeepLabV3":
            self.model = smp.DeepLabV3(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
        elif self.architecture == "DeepLabV3Plus":
            self.model = smp.DeepLabV3Plus(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
Esempio n. 28
0
arch_dict = {
    "unet":
    smp.Unet(
        encoder_name=args.encoder,
        encoder_weights=args.weight,
        classes=8,
        activation=args.activation,
        decoder_attention_type="scse",
        decoder_use_batchnorm=True,
        aux_params=aux_params_dict,
    ),
    "linknet":
    smp.Linknet(
        encoder_name=args.encoder,
        encoder_weights=args.weight,
        classes=8,
        activation=args.activation,
    ),
    "fpn":
    smp.FPN(
        encoder_name=args.encoder,
        encoder_weights=args.weight,
        classes=8,
        activation=args.activation,
    ),
    "pspnet":
    smp.PSPNet(
        encoder_name=args.encoder,
        encoder_weights=args.weight,
        classes=8,
        activation=args.activation,
Esempio n. 29
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--encoder', type=str, default='efficientnet-b0')
    parser.add_argument('--model', type=str, default='unet')
    parser.add_argument('--pretrained', type=str, default='imagenet')
    parser.add_argument('--logdir', type=str, default='../logs/')
    parser.add_argument('--exp_name', type=str)
    parser.add_argument('--data_folder', type=str, default='../input/')
    parser.add_argument('--height', type=int, default=320)
    parser.add_argument('--width', type=int, default=640)
    parser.add_argument('--batch_size', type=int, default=2)
    parser.add_argument('--accumulate', type=int, default=8)
    parser.add_argument('--epochs', type=int, default=20)
    parser.add_argument('--enc_lr', type=float, default=1e-2)
    parser.add_argument('--dec_lr', type=float, default=1e-3)
    parser.add_argument('--optim', type=str, default="radam")
    parser.add_argument('--loss', type=str, default="bcedice")
    parser.add_argument('--schedule', type=str, default="rlop")
    parser.add_argument('--early_stopping', type=bool, default=True)

    args = parser.parse_args()

    encoder = args.encoder
    model = args.model
    pretrained = args.pretrained
    logdir = args.logdir
    name = args.exp_name
    data_folder = args.data_folder
    height = args.height
    width = args.width
    bs = args.batch_size
    accumulate = args.accumulate
    epochs = args.epochs
    enc_lr = args.enc_lr
    dec_lr = args.dec_lr
    optim = args.optim
    loss = args.loss
    schedule = args.schedule
    early_stopping = args.early_stopping

    if model == 'unet':
        model = smp.Unet(encoder_name=encoder,
                         encoder_weights=pretrained,
                         classes=4,
                         activation=None)
    if model == 'fpn':
        model = smp.FPN(
            encoder_name=encoder,
            encoder_weights=pretrained,
            classes=4,
            activation=None,
        )
    if model == 'pspnet':
        model = smp.PSPNet(
            encoder_name=encoder,
            encoder_weights=pretrained,
            classes=4,
            activation=None,
        )
    if model == 'linknet':
        model = smp.Linknet(
            encoder_name=encoder,
            encoder_weights=pretrained,
            classes=4,
            activation=None,
        )
    if model == 'aspp':
        print('aspp can only be used with resnet34')
        model = aspp(num_class=4)

    preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder, pretrained)
    log = os.path.join(logdir, name)

    ds = get_dataset(path=data_folder)
    prepared_ds = prepare_dataset(ds)

    train_set, valid_set = get_train_test(ds)

    train_ds = CloudDataset(df=prepared_ds,
                            datatype='train',
                            img_ids=train_set,
                            transforms=training1(h=height, w=width),
                            preprocessing=get_preprocessing(preprocessing_fn),
                            folder=data_folder)
    valid_ds = CloudDataset(df=prepared_ds,
                            datatype='train',
                            img_ids=valid_set,
                            transforms=valid1(h=height, w=width),
                            preprocessing=get_preprocessing(preprocessing_fn),
                            folder=data_folder)

    train_loader = DataLoader(train_ds,
                              batch_size=bs,
                              shuffle=True,
                              num_workers=multiprocessing.cpu_count())
    valid_loader = DataLoader(valid_ds,
                              batch_size=bs,
                              shuffle=False,
                              num_workers=multiprocessing.cpu_count())

    loaders = {
        'train': train_loader,
        'valid': valid_loader,
    }

    num_epochs = epochs

    if args.model != "aspp":
        if optim == "radam":
            optimizer = RAdam([
                {
                    'params': model.encoder.parameters(),
                    'lr': enc_lr
                },
                {
                    'params': model.decoder.parameters(),
                    'lr': dec_lr
                },
            ])
        if optim == "adam":
            optimizer = Adam([
                {
                    'params': model.encoder.parameters(),
                    'lr': enc_lr
                },
                {
                    'params': model.decoder.parameters(),
                    'lr': dec_lr
                },
            ])
        if optim == "adamw":
            optimizer = AdamW([
                {
                    'params': model.encoder.parameters(),
                    'lr': enc_lr
                },
                {
                    'params': model.decoder.parameters(),
                    'lr': dec_lr
                },
            ])
        if optim == "sgd":
            optimizer = SGD([
                {
                    'params': model.encoder.parameters(),
                    'lr': enc_lr
                },
                {
                    'params': model.decoder.parameters(),
                    'lr': dec_lr
                },
            ])
    elif args.model == 'aspp':
        if optim == "radam":
            optimizer = RAdam([
                {
                    'params': model.parameters(),
                    'lr': enc_lr
                },
            ])
        if optim == "adam":
            optimizer = Adam([
                {
                    'params': model.parameters(),
                    'lr': enc_lr
                },
            ])
        if optim == "adamw":
            optimizer = AdamW([
                {
                    'params': model.parameters(),
                    'lr': enc_lr
                },
            ])
        if optim == "sgd":
            optimizer = SGD([
                {
                    'params': model.parameters(),
                    'lr': enc_lr
                },
            ])

    scheduler = ReduceLROnPlateau(optimizer, factor=0.1, patience=5)
    if schedule == "rlop":
        scheduler = ReduceLROnPlateau(optimizer, factor=0.15, patience=3)
    if schedule == "noam":
        scheduler = NoamLR(optimizer, 10)

    if loss == "bcedice":
        criterion = smp.utils.losses.BCEDiceLoss(eps=1.)
    if loss == "dice":
        criterion = smp.utils.losses.DiceLoss(eps=1.)
    if loss == "bcejaccard":
        criterion = smp.utils.losses.BCEJaccardLoss(eps=1.)
    if loss == "jaccard":
        criterion == smp.utils.losses.JaccardLoss(eps=1.)
    if loss == 'bce':
        criterion = NewBCELoss()

    callbacks = [NewDiceCallback(), CriterionCallback()]

    callbacks.append(OptimizerCallback(accumulation_steps=accumulate))
    if early_stopping:
        callbacks.append(EarlyStoppingCallback(patience=5, min_delta=0.001))

    runner = SupervisedRunner()
    runner.train(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        loaders=loaders,
        callbacks=callbacks,
        logdir=log,
        num_epochs=num_epochs,
        verbose=True,
    )
Esempio n. 30
0
def main(args, logger):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    writer = SummaryWriter(logdir=args.subTensorboardDir)
    trainSet = Data(mode='train')
    trainLoader = DataLoader(trainSet,
                             batch_size=args.batchSizeTrain,
                             shuffle=True,
                             pin_memory=False,
                             drop_last=False,
                             num_workers=args.numWorkers)
    testSet = Data(mode='test')
    testLoader = DataLoader(testSet,
                            batch_size=args.batchSizeTest,
                            shuffle=False,
                            pin_memory=False,
                            num_workers=args.numWorkers)
    # net = smp.Unet(classes=2).to(device)
    net = smp.Linknet(classes=1,
                      activation='sigmoid',
                      encoder_name='se_resnext101_32x4d').to(device)
    # criterion = nn.CrossEntropyLoss().to(device)
    criterion = smploss.DiceLoss(eps=sys.float_info.min).to(device)
    # criterion = DiceLoss().to(device)
    optimizer = optim.SGD(net.parameters(), lr=.1, momentum=.9)
    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=args.scheduleStep,
                                          gamma=0.1)
    runningLoss = []
    st = stGloble = time.time()
    totalIter = len(trainLoader) * args.epoch
    iter = 0
    for epoch in range(args.epoch):
        if epoch != 0 and epoch % args.evalFrequency == 0:
            pass
        if epoch != 0 and epoch % args.saveFrequency == 0:
            modelName = osp.join(args.subModelDir, 'out_{}.pth'.format(epoch))
            state_dict = net.modules.state_dict() if hasattr(
                net, 'module') else net.state_dict()
            torch.save(state_dict, modelName)

        for img, mask in trainLoader:
            iter += 1
            img = img.to(device)
            mask = mask.to(device, dtype=torch.int64).unsqueeze(1)
            optimizer.zero_grad()
            outputs = net(img)
            # print(outputs.shape, mask.shape)
            # break
            loss = criterion(outputs, mask)
            loss.backward()
            optimizer.step()
            runningLoss.append(loss.item())
            if iter % args.msgFrequency == 0:
                # writer.add_images('img', img, iter)
                # writer.add_images('mask', mask.unsqueeze(1), iter)
                ed = time.time()
                spend = ed - st
                spendGloable = ed - stGloble
                st = ed
                eta = int((totalIter - iter) * (spendGloable / iter))
                spendGloable = str(datetime.timedelta(seconds=spendGloable))
                eta = str(datetime.timedelta(seconds=eta))
                avgLoss = np.mean(runningLoss)
                runningLoss = []
                lr = optimizer.param_groups[0]['lr']
                msg = '. '.join([
                    'epoch:{epoch}', 'iter/total_iter:{iter}/{totalIter}',
                    'lr:{lr:.5f}', 'loss:{loss:.4f}',
                    'spend/gloable_spend:{spend:.4f}/{gloable_spend}',
                    'eta:{eta}'
                ]).format(epoch=epoch,
                          loss=avgLoss,
                          iter=iter,
                          totalIter=totalIter,
                          spend=spend,
                          gloable_spend=spendGloable,
                          lr=lr,
                          eta=eta)

                logger.info(msg)
                writer.add_scalar('loss', avgLoss, iter)
                writer.add_scalar('lr', lr, iter)

        scheduler.step()

    outName = osp.join(args.subModelDir, 'final.pth')
    torch.save(net.cpu().state_dict(), outName)