Exemplo n.º 1
0
def load_model() -> nn.Module:
    model = Unet(encoder_name="timm-efficientnet-b3", classes=1, encoder_weights=None)
    # TODO: While working on a fix, don't load for now :p
    state_dict = torch.load(MODEL_PATH)["state_dict"]
    state_dict = rename_layers(state_dict, {"model.": ""})
    model.load_state_dict(state_dict)
    return model
Exemplo n.º 2
0
def add_dropout(model: smp.Unet,
                decoder_channels: List[int] = (256, 128, 64, 32, 16),
                classes=1,
                activation=None):
    seg_head = SegmentationHead(
        in_channels=decoder_channels[-1],
        out_channels=classes,
        activation=activation,
        kernel_size=3,
    )
    model.add_module('segmentation_head', seg_head)
    model.initialize()
Exemplo n.º 3
0
 def __init__(self, cfg: Mapping = None):
     self.model = Unet(
         encoder_name=get_value(cfg, ["model", "encoder_name"],
                                "mobilenet_v2"),
         in_channels=get_value(cfg, ["model", "in_channels"], 1),
         encoder_weights=get_value(cfg, ["model", "encoder_weights"],
                                   "imagenet"),
         activation=get_value(cfg, ["model", "activation"], "identity"),
     )
Exemplo n.º 4
0
    def __init__(self, base_name, heads, pretrained, down_ratio, final_kernel,
                 last_level, head_conv, out_channel=0, no_down=False):
        super(DLASeg, self).__init__()
        assert down_ratio in [2, 4, 8, 16]
        self.first_level = int(np.log2(down_ratio))
        self.last_level = last_level
        # self.base = globals()[base_name](pretrained=pretrained)
        self.encoder = Unet("efficientnet-b5").encoder
        self.encoder.fc = torch.nn.Sequential(*[])
        # channels = self.base.channels
        channels = [3,48,40,64,176,512]
        scales = [2 ** i for i in range(len(channels[self.first_level:]))]
        self.dla_up = DLAUp(self.first_level, channels[self.first_level:], scales)

        if out_channel == 0:
            out_channel = channels[self.first_level]

        self.ida_up = IDAUp(out_channel, channels[self.first_level:self.last_level], 
                            [2 ** i for i in range(self.last_level - self.first_level)])
        
        self.heads = heads
        if no_down:
            self.up = nn.Sequential(*[
                # nn.UpsamplingBilinear2d(scale_factor=2),
                nn.Conv2d(64, 32, (3, 3), padding=1),
                nn.BatchNorm2d(32, momentum=BN_MOMENTUM),
                nn.ReLU(inplace=True),
                nn.UpsamplingBilinear2d(scale_factor=2),
                # nn.Conv2d(32, 32, (3, 3), padding=1),
                # nn.BatchNorm2d(32, momentum=BN_MOMENTUM),
                # nn.ReLU(inplace=True),
            ])
        else:
            self.up = None
        for head in self.heads:
            classes = self.heads[head]
            if head_conv > 0:
              fc = nn.Sequential(
                  nn.Conv2d(channels[self.first_level] , head_conv,
                    kernel_size=3, padding=1, bias=True),
                  nn.ReLU(inplace=True),
                  nn.Conv2d(head_conv, classes, 
                    kernel_size=final_kernel, stride=1, 
                    padding=final_kernel // 2, bias=True))
              if 'center' in head:
                fc[-1].bias.data.fill_(-1.19)
              else:
                fill_fc_weights(fc)
            else:
              fc = nn.Conv2d(channels[self.first_level], classes,
                  kernel_size=final_kernel, stride=1, 
                  padding=final_kernel // 2, bias=True)
              if 'center' in head:
                fc.bias.data.fill_(-1.19)
              else:
                fill_fc_weights(fc)
            self.__setattr__(head, fc)
Exemplo n.º 5
0
def get_model(config):

    if config.MODEL.NAME == 'hrnetv2':
        model = get_hrnetv2()
        print('model: hrnetv2')

    elif config.MODEL.NAME == 'resnet50_upernet':
        model = get_resnet50_upernet()
        print('model: resnet50_upernet')

    elif config.MODEL.NAME == 'resnet101_upernet':
        model = get_resnet101_upernet()
        print('model: resnet101_upernet')

    elif config.MODEL.NAME == 'acnet':
        model = ACNet(num_class=4, pretrained=True)
        print('model: acnet')

    elif config.MODEL.NAME == 'deeplabv3':
        model = get_deeplabv3()
        print('model: deeplabv3')

    elif config.MODEL.NAME == 'deeplab_xception':
        model = DeepLab(backbone='xception',
                        output_stride=16,
                        num_classes=4,
                        sync_bn=False,
                        freeze_bn=False)

    else:
        model_architecture = config.MODEL.ARCHITECTURE
        model_encoder = config.MODEL.ENCODER
        model_pretrained = config.MODEL.PRETRAINED

        if model_architecture == 'Unet':
            model = Unet(model_encoder,
                         encoder_weights=model_pretrained,
                         classes=4,
                         attention_type='scse')
        elif model_architecture == 'Linknet':
            model = Linknet(model_encoder,
                            encoder_weights=model_pretrained,
                            classes=4)
        elif model_architecture == 'FPN' or model_architecture == 'PSPNet':
            model = FPN(model_encoder,
                        encoder_weights=model_pretrained,
                        classes=4)

        print('architecture:', model_architecture, 'encoder:', model_encoder,
              'pretrained on:', model_pretrained)

    if config.PARALLEL:
        model = nn.DataParallel(model)

    print('[*] num parameters:', count_parameters(model))

    return model
Exemplo n.º 6
0
 def __init__(self, config, num_of_classes):
     super(Generator, self).__init__()
     self.config = config
     self.unet = Unet(encoder_name='resnet18',
                      encoder_weights='imagenet',
                      in_channels=3,
                      classes=num_of_classes,
                      activation='sigmoid',
                      decoder_use_batchnorm='inplace')
Exemplo n.º 7
0
def load_models(model_path, config, n_gpus=1, unet=True):
    """
    Loads segmentation models with a given architecture.
    """
    # Load models
    models = glob(model_path + '/*fold_*.pth')
    models.sort()

    # List the models
    model_list = []
    for fold in range(len(models)):
        if unet and n_gpus > 1:
            model = nn.DataParallel(
                Unet(config['model']['backbone'],
                     encoder_weights="imagenet",
                     activation='sigmoid'))
        elif unet:
            model = Unet(config['model']['backbone'],
                         encoder_weights="imagenet",
                         activation='sigmoid')
        elif n_gpus > 1:
            model = nn.DataParallel(EncoderDecoder(**config['model']))
        else:
            model = EncoderDecoder(**config['model'])
        model.load_state_dict(torch.load(models[fold]))
        model_list.append(model)

    return model_list
Exemplo n.º 8
0
def get_model(config):
    model_architecture = config.ARCHITECTURE
    model_encoder = config.ENCODER

    # activation은 eval 모드일 때 적용해 주는 거라 train 때에는 직접 sigmoid 쳐야한다.
    if model_architecture == 'Unet':
        model = Unet(model_encoder, encoder_weights='imagenet', classes=4, attention_type='scse')
    elif model_architecture == 'FPN':
        model = FPN(model_encoder, encoder_weights='imagenet', classes=4)

    print('architecture:', model_architecture, 'encoder:', model_encoder)

    return model
Exemplo n.º 9
0
def load_background_network(config_path, checkpoint_path):
    """Loads the background/tissue segmentation network."""

    # load config file
    with open(config_path, 'r') as yaml_file:
        config = yaml.load(yaml_file, Loader=yaml.FullLoader)

    num_classes = 1
    activation = 'sigmoid'

    aux_params = dict(
        pooling=config['pooling'],  # one of 'avg', 'max'
        dropout=config['dropout'],  # dropout ratio, default is None
        activation=activation,  # activation function, default is None
        classes=num_classes)  # define number of output labels

    # configure model
    models = {
        'unet':
        Unet(encoder_name=config['encoder_name'],
             encoder_weights=None,
             decoder_use_batchnorm=config['use_batchnorm'],
             classes=num_classes,
             activation=activation,
             aux_params=aux_params),
        'pspnet':
        PSPNet(encoder_name=config['encoder_name'],
               encoder_weights=None,
               psp_use_batchnorm=config['use_batchnorm'],
               classes=num_classes,
               activation=activation,
               aux_params=aux_params),
        'pan':
        PAN(encoder_name=config['encoder_name'],
            encoder_weights=None,
            classes=num_classes,
            activation=activation,
            aux_params=aux_params),
        'deeplabv3plus':
        DeepLabV3Plus(encoder_name=config['encoder_name'],
                      encoder_weights=None,
                      classes=num_classes,
                      activation=activation,
                      aux_params=aux_params)
    }

    network = models[config['architecture']]
    network.load_state_dict(torch.load(checkpoint_path))
    network.eval()

    return network
Exemplo n.º 10
0
    def __init__(self, hparams):
        """hparams must be a dict of {weight_decay, lr, num_classes}"""
        super().__init__()
        self.save_hyperparameters(hparams)

        # Create model from pre-trained DeepLabv3
        self.model = Unet(
            encoder_name="efficientnet-b4",
            encoder_weights="imagenet",
            in_channels=3,
            classes=self.hparams.num_classes,
        )
        self.model.requires_grad_(True)
        self.model.encoder.requires_grad_(False)

        # Loss function and metrics
        self.focal_tversky_loss = FocalTverskyMetric(
            self.hparams.num_classes,
            alpha=0.7,
            beta=0.3,
            gamma=4.0 / 3.0,
            ignore_index=self.hparams.get("ignore_index"),
        )
        self.accuracy_metric = Accuracy(
            ignore_index=self.hparams.get("ignore_index"))
        self.iou_metric = JaccardIndex(
            num_classes=self.hparams.num_classes,
            reduction="none",
            ignore_index=self.hparams.get("ignore_index"),
        )
        self.precision_metric = Precision(num_classes=self.num_classes,
                                          ignore_index=self.ignore_index,
                                          average='weighted',
                                          mdmc_average='global')
        self.recall_metric = Recall(num_classes=self.num_classes,
                                    ignore_index=self.ignore_index,
                                    average='weighted',
                                    mdmc_average='global')
Exemplo n.º 11
0
def Net():
    
    model = Unet(
        encoder_name='resnet152', 
        encoder_weights=None, 
        classes=4, 
        activation='sigmoid')
    # model = FPN(
    #     encoder_name='efficientnet-b5', 
    #     encoder_weights='imagenet', 
    #     classes=4, 
    #     activation='sigmoid')
    
    return model
Exemplo n.º 12
0
def initialize_artifacts_network(config):
    num_classes = config['num_classes']
    activation = 'sigmoid' if num_classes == 1 else 'softmax2d'

    aux_params = dict(
        pooling=config['pooling'],  # one of 'avg', 'max'
        dropout=config['dropout'],  # dropout ratio, default is None
        activation=activation,  # activation function, default is None
        classes=num_classes)  # define number of output labels

    models = {
        'unet':
        Unet(encoder_name=config['encoder_name'],
             encoder_weights=None,
             decoder_use_batchnorm=config['use_batchnorm'],
             classes=num_classes,
             activation=activation,
             aux_params=aux_params),
        'pspnet':
        PSPNet(encoder_name=config['encoder_name'],
               encoder_weights=None,
               psp_use_batchnorm=config['use_batchnorm'],
               classes=num_classes,
               activation=activation,
               aux_params=aux_params),
        'pan':
        PAN(encoder_name=config['encoder_name'],
            encoder_weights=None,
            classes=num_classes,
            activation=activation,
            aux_params=aux_params),
        'deeplabv3plus':
        DeepLabV3Plus(encoder_name=config['encoder_name'],
                      encoder_weights=None,
                      classes=num_classes,
                      activation=activation,
                      aux_params=aux_params)
    }

    # prepare network pretrained weights
    assert config['architecture'] in models.keys()
    network = models[config['architecture']]
    network = torch.nn.DataParallel(network, device_ids=[0])
    network.load_state_dict(torch.load(config['artifact_network']))
    network.to(torch.device('cuda'))
    network.eval()

    return network
Exemplo n.º 13
0
def unet(backbone, pretrained_weights=None, classes=1, activation='sigmoid'):
    device = torch.device("cuda")
    model = Unet(encoder_name=backbone,
                 encoder_weights=pretrained_weights,
                 classes=classes,
                 activation=activation)
    model.to(device)
    model.eval()  # 위치 확인해볼것

    return model
Exemplo n.º 14
0
class UnetEfficientnet(GeoTiffPredictionMixin, pl.LightningModule):
    def __init__(self, hparams):
        """hparams must be a dict of {weight_decay, lr, num_classes}"""
        super().__init__()
        self.save_hyperparameters(hparams)

        # Create model from pre-trained DeepLabv3
        self.model = Unet(
            encoder_name="efficientnet-b4",
            encoder_weights="imagenet",
            in_channels=3,
            classes=self.hparams.num_classes,
        )
        self.model.requires_grad_(True)
        self.model.encoder.requires_grad_(False)

        # Loss function and metrics
        self.focal_tversky_loss = FocalTverskyMetric(
            self.hparams.num_classes,
            alpha=0.7,
            beta=0.3,
            gamma=4.0 / 3.0,
            ignore_index=self.hparams.get("ignore_index"),
        )
        self.accuracy_metric = Accuracy(
            ignore_index=self.hparams.get("ignore_index"))
        self.iou_metric = JaccardIndex(
            num_classes=self.hparams.num_classes,
            reduction="none",
            ignore_index=self.hparams.get("ignore_index"),
        )
        self.precision_metric = Precision(num_classes=self.num_classes,
                                          ignore_index=self.ignore_index,
                                          average='weighted',
                                          mdmc_average='global')
        self.recall_metric = Recall(num_classes=self.num_classes,
                                    ignore_index=self.ignore_index,
                                    average='weighted',
                                    mdmc_average='global')

    @property
    def example_input_array(self) -> Any:
        return torch.rand(2, 3, 512, 512)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model.forward(x)

    def configure_optimizers(self):
        """Init optimizer and scheduler"""
        optimizer = torch.optim.SGD(
            filter(lambda p: p.requires_grad, self.model.parameters()),
            lr=self.hparams.lr,
            weight_decay=self.hparams.weight_decay,
        )
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.hparams.max_epochs)
        return [optimizer], [{"scheduler": lr_scheduler, "interval": "epoch"}]

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        probs = torch.softmax(logits, dim=1)
        loss = self.focal_tversky_loss(probs, y)

        preds = logits.argmax(dim=1)
        ious = self.iou_metric(preds, y)
        acc = self.accuracy_metric(preds, y)

        self.log("train_loss", loss, on_epoch=True, sync_dist=True)
        self.log("train_miou", ious.mean(), on_epoch=True, sync_dist=True)
        self.log("train_accuracy", acc, on_epoch=True, sync_dist=True)
        for c in range(len(ious)):
            self.log(f"train_c{c}_iou", ious[c], on_epoch=True, sync_dist=True)

        return loss

    def val_test_step(self, batch, batch_idx, phase="val"):
        x, y = batch
        logits = self.model(x)
        probs = torch.softmax(logits, dim=1)
        loss = self.focal_tversky_loss(probs, y)

        preds = logits.argmax(dim=1)
        ious = self.iou_metric(preds, y)
        miou = ious.mean()
        acc = self.accuracy_metric(preds, y)
        precision = self.precision_metric(preds, y)
        recall = self.recall_metric(preds, y)

        if phase == 'val':
            self.log(f"hp_metric", miou)
        self.log(f"{phase}_loss", loss, sync_dist=True)
        self.log(f"{phase}_miou", miou, sync_dist=True)
        self.log(f"{phase}_accuracy", acc, sync_dist=True)
        self.log(f"{phase}_precision", precision, sync_dist=True)
        self.log(f"{phase}_recall", recall, sync_dist=True)
        for c in range(len(ious)):
            self.log(f"{phase}_cls{c}_iou", ious[c], sync_dist=True)

        return loss

    def validation_step(self, batch, batch_idx):
        return self.val_test_step(batch, batch_idx, phase="val")

    def test_step(self, batch, batch_idx):
        return self.val_test_step(batch, batch_idx, phase="test")

    @staticmethod
    def ckpt2pt(ckpt_file, pt_path):
        checkpoint = torch.load(ckpt_file, map_location=torch.device("cpu"))
        torch.save(checkpoint["state_dict"], pt_path)

    # @classmethod
    # def from_presence_absence_weights(cls, pt_weights_file, hparams):
    #     self = cls(hparams)
    #     weights = torch.load(pt_weights_file)
    #
    #     # Remove trained weights for previous classifier output layers
    #     del weights["model.classifier.4.weight"]
    #     del weights["model.classifier.4.bias"]
    #     del weights["model.aux_classifier.4.weight"]
    #     del weights["model.aux_classifier.4.bias"]
    #
    #     self.load_state_dict(weights, strict=False)
    #     return self

    @staticmethod
    def add_argparse_args(parser):
        group = parser.add_argument_group("UnetEfficientnet")

        group.add_argument(
            "--num_classes",
            type=int,
            default=2,
            help="The number of image classes, including background.",
        )
        group.add_argument("--lr",
                           type=float,
                           default=0.001,
                           help="the learning rate")
        group.add_argument(
            "--weight_decay",
            type=float,
            default=1e-3,
            help="The weight decay factor for L2 regularization.",
        )
        group.add_argument("--ignore_index",
                           type=int,
                           help="Label of any class to ignore.")
        group.add_argument(
            "--aux_loss_factor",
            type=float,
            default=0.3,
            help=
            "The proportion of loss backpropagated to classifier built only on early layers.",
        )

        return parser
Exemplo n.º 15
0
def load_model_unet(_model_weights, is_inference=False):
    print("Using weights {}".format(_model_weights))
    if _model_weights == "imagenet":
        model = Unet(
            unet_encoder,
            encoder_weights="imagenet",
            classes=4,
            activation=None,
            attention_type=ATTENTION_TYPE,
        )
        if is_inference:
            model.eval()
        return model
    else:
        model = Unet(
            unet_encoder,
            encoder_weights=None,  # "imagenet",
            classes=4,
            activation=None,
            attention_type=ATTENTION_TYPE,
        )
        if is_inference:
            model.eval()
    if _model_weights is not None:
        device = torch.device("cuda")
        model.to(device)
        state = torch.load(
            _model_weights)  # , map_location=lambda storage, loc: storage)
        model.load_state_dict(state["state_dict"])
        optimizer_state = state["optimizer"]
        return model, optimizer_state
        # new_state_dict = OrderedDict()
        #
        # for k, v in state['state_dict'].items():
        #     if k in model.state_dict():
        #         new_state_dict[k] = v
        # model = model.load_state_dict(new_state_dict)
    return model
def main(cfg):
    """Runs main training procedure."""

    # fix random seeds for reproducibility
    seed_everything(seed=cfg['seed'])

    # neptune logging
    neptune.init(project_qualified_name=cfg['neptune_project_name'],
                 api_token=cfg['neptune_api_token'])

    neptune.create_experiment(name=cfg['neptune_experiment'], params=cfg)

    print('Preparing model and data...')
    print('Using SMP version:', smp.__version__)

    num_classes = 1 if len(cfg['classes']) == 1 else (len(cfg['classes']) + 1)
    activation = 'sigmoid' if num_classes == 1 else 'softmax2d'
    background = False if cfg['ignore_channels'] else True
    binary = True if num_classes == 1 else False
    softmax = False if num_classes == 1 else True
    sigmoid = True if num_classes == 1 else False

    aux_params = dict(
        pooling=cfg['pooling'],  # one of 'avg', 'max'
        dropout=cfg['dropout'],  # dropout ratio, default is None
        activation='sigmoid',  # activation function, default is None
        classes=num_classes)  # define number of output labels

    # configure model
    models = {
        'unet':
        Unet(encoder_name=cfg['encoder_name'],
             encoder_weights=cfg['encoder_weights'],
             decoder_use_batchnorm=cfg['use_batchnorm'],
             classes=num_classes,
             activation=activation,
             aux_params=aux_params),
        'pspnet':
        PSPNet(encoder_name=cfg['encoder_name'],
               encoder_weights=cfg['encoder_weights'],
               classes=num_classes,
               activation=activation,
               aux_params=aux_params),
        'pan':
        PAN(encoder_name=cfg['encoder_name'],
            encoder_weights=cfg['encoder_weights'],
            classes=num_classes,
            activation=activation,
            aux_params=aux_params),
        'deeplabv3plus':
        DeepLabV3Plus(encoder_name=cfg['encoder_name'],
                      encoder_weights=cfg['encoder_weights'],
                      classes=num_classes,
                      activation=activation,
                      aux_params=aux_params)
    }

    assert cfg['architecture'] in models.keys()
    model = models[cfg['architecture']]

    # configure loss
    losses = {
        'dice_loss':
        DiceLoss(include_background=background,
                 softmax=softmax,
                 sigmoid=sigmoid,
                 batch=cfg['combine']),
        'generalized_dice':
        GeneralizedDiceLoss(include_background=background,
                            softmax=softmax,
                            sigmoid=sigmoid,
                            batch=cfg['combine'])
    }

    assert cfg['loss'] in losses.keys()
    loss = losses[cfg['loss']]

    # configure optimizer
    optimizers = {
        'adam': Adam([dict(params=model.parameters(), lr=cfg['lr'])]),
        'adamw': AdamW([dict(params=model.parameters(), lr=cfg['lr'])]),
        'rmsprop': RMSprop([dict(params=model.parameters(), lr=cfg['lr'])])
    }

    assert cfg['optimizer'] in optimizers.keys()
    optimizer = optimizers[cfg['optimizer']]

    # configure metrics
    metrics = {
        'dice_score':
        DiceMetric(include_background=background, reduction='mean'),
        'dice_smp':
        Fscore(threshold=cfg['rounding'],
               ignore_channels=cfg['ignore_channels']),
        'iou_smp':
        IoU(threshold=cfg['rounding'], ignore_channels=cfg['ignore_channels']),
        'generalized_dice':
        GeneralizedDiceLoss(include_background=background,
                            softmax=softmax,
                            sigmoid=sigmoid,
                            batch=cfg['combine']),
        'dice_loss':
        DiceLoss(include_background=background,
                 softmax=softmax,
                 sigmoid=sigmoid,
                 batch=cfg['combine']),
        'cross_entropy':
        BCELoss(reduction='mean'),
        'accuracy':
        Accuracy(ignore_channels=cfg['ignore_channels'])
    }

    assert all(m['name'] in metrics.keys() for m in cfg['metrics'])
    metrics = [(metrics[m['name']], m['name'], m['type'])
               for m in cfg['metrics']]  # tuple of (metric, name, type)

    # TODO: Fix metric names

    # configure scheduler
    schedulers = {
        'steplr':
        StepLR(optimizer, step_size=cfg['step_size'], gamma=0.5),
        'cosine':
        CosineAnnealingLR(optimizer,
                          cfg['epochs'],
                          eta_min=cfg['eta_min'],
                          last_epoch=-1)
    }

    assert cfg['scheduler'] in schedulers.keys()
    scheduler = schedulers[cfg['scheduler']]

    # configure augmentations
    train_transform = load_train_transform(transform_type=cfg['transform'],
                                           patch_size=cfg['patch_size_train'])
    valid_transform = load_valid_transform(
        patch_size=cfg['patch_size_valid'])  # manually selected patch size

    train_dataset = ArtifactDataset(df_path=cfg['train_data'],
                                    classes=cfg['classes'],
                                    transform=train_transform,
                                    normalize=cfg['normalize'],
                                    ink_filters=cfg['ink_filters'])

    valid_dataset = ArtifactDataset(df_path=cfg['valid_data'],
                                    classes=cfg['classes'],
                                    transform=valid_transform,
                                    normalize=cfg['normalize'],
                                    ink_filters=cfg['ink_filters'])

    test_dataset = ArtifactDataset(df_path=cfg['test_data'],
                                   classes=cfg['classes'],
                                   transform=valid_transform,
                                   normalize=cfg['normalize'],
                                   ink_filters=cfg['ink_filters'])

    # load pre-sampled patch arrays
    train_image, train_mask = train_dataset[0]
    valid_image, valid_mask = valid_dataset[0]
    print('Shape of image patch', train_image.shape)
    print('Shape of mask patch', train_mask.shape)
    print('Train dataset shape:', len(train_dataset))
    print('Valid dataset shape:', len(valid_dataset))
    assert train_image.shape[1] == cfg[
        'patch_size_train'] and train_image.shape[2] == cfg['patch_size_train']
    assert valid_image.shape[1] == cfg[
        'patch_size_valid'] and valid_image.shape[2] == cfg['patch_size_valid']

    # save intermediate augmentations
    if cfg['eval_dir']:
        default_dataset = ArtifactDataset(df_path=cfg['train_data'],
                                          classes=cfg['classes'],
                                          transform=None,
                                          normalize=None,
                                          ink_filters=cfg['ink_filters'])

        transform_dataset = ArtifactDataset(df_path=cfg['train_data'],
                                            classes=cfg['classes'],
                                            transform=train_transform,
                                            normalize=None,
                                            ink_filters=cfg['ink_filters'])

        for idx in range(0, min(500, len(train_dataset)), 10):
            image_input, image_mask = default_dataset[idx]
            image_input = image_input.transpose((1, 2, 0)).astype(np.uint8)

            image_mask = image_mask.transpose(1, 2, 0)
            image_mask = np.argmax(
                image_mask, axis=2) if not binary else image_mask.squeeze()
            image_mask = image_mask.astype(np.uint8)

            image_transform, _ = transform_dataset[idx]
            image_transform = image_transform.transpose(
                (1, 2, 0)).astype(np.uint8)

            idx_str = str(idx).zfill(3)
            skimage.io.imsave(os.path.join(cfg['eval_dir'],
                                           f'{idx_str}a_image_input.png'),
                              image_input,
                              check_contrast=False)
            plt.imsave(os.path.join(cfg['eval_dir'],
                                    f'{idx_str}b_image_mask.png'),
                       image_mask,
                       vmin=0,
                       vmax=6,
                       cmap='Spectral')
            skimage.io.imsave(os.path.join(cfg['eval_dir'],
                                           f'{idx_str}c_image_transform.png'),
                              image_transform,
                              check_contrast=False)

        del transform_dataset

    # update process
    print('Starting training...')
    print('Available GPUs for training:', torch.cuda.device_count())

    # pytorch module wrapper
    class DataParallelModule(torch.nn.DataParallel):
        def __getattr__(self, name):
            try:
                return super().__getattr__(name)
            except AttributeError:
                return getattr(self.module, name)

    # data parallel training
    if torch.cuda.device_count() > 1:
        model = DataParallelModule(model)

    train_loader = DataLoader(train_dataset,
                              batch_size=cfg['batch_size'],
                              num_workers=cfg['workers'],
                              shuffle=True)

    valid_loader = DataLoader(valid_dataset,
                              batch_size=int(cfg['batch_size'] / 4),
                              num_workers=cfg['workers'],
                              shuffle=False)

    test_loader = DataLoader(test_dataset,
                             batch_size=int(cfg['batch_size'] / 4),
                             num_workers=cfg['workers'],
                             shuffle=False)

    trainer = Trainer(model=model,
                      device=cfg['device'],
                      save_checkpoints=cfg['save_checkpoints'],
                      checkpoint_dir=cfg['checkpoint_dir'],
                      checkpoint_name=cfg['checkpoint_name'])

    trainer.compile(optimizer=optimizer,
                    loss=loss,
                    metrics=metrics,
                    num_classes=num_classes)

    trainer.fit(train_loader,
                valid_loader,
                epochs=cfg['epochs'],
                scheduler=scheduler,
                verbose=cfg['verbose'],
                loss_weight=cfg['loss_weight'],
                test_loader=test_loader,
                binary=binary)

    # validation inference
    model.load_state_dict(
        torch.load(os.path.join(cfg['checkpoint_dir'],
                                cfg['checkpoint_name'])))
    model.to(cfg['device'])
    model.eval()

    # save best checkpoint to neptune
    neptune.log_artifact(
        os.path.join(cfg['checkpoint_dir'], cfg['checkpoint_name']))

    # setup directory to save plots
    if os.path.isdir(cfg['plot_dir_valid']):
        shutil.rmtree(cfg['plot_dir_valid'])
    os.makedirs(cfg['plot_dir_valid'], exist_ok=True)

    # valid dataset without transformations and normalization for image visualization
    valid_dataset_vis = ArtifactDataset(df_path=cfg['valid_data'],
                                        classes=cfg['classes'],
                                        ink_filters=cfg['ink_filters'])

    # keep track of valid masks
    valid_preds = []
    valid_masks = []

    if cfg['save_checkpoints']:
        print('Predicting valid patches...')
        for n in range(len(valid_dataset)):
            image_vis = valid_dataset_vis[n][0].astype('uint8')
            image_vis = image_vis.transpose(1, 2, 0)
            image, gt_mask = valid_dataset[n]
            gt_mask = gt_mask.transpose(1, 2, 0)
            gt_mask = np.argmax(gt_mask,
                                axis=2) if not binary else gt_mask.squeeze()
            gt_mask = gt_mask.astype(np.uint8)
            valid_masks.append(gt_mask)

            x_tensor = torch.from_numpy(image).to(cfg['device']).unsqueeze(0)
            pr_mask, _ = model.predict(x_tensor)
            pr_mask = pr_mask.squeeze(axis=0).cpu().numpy().round()
            pr_mask = pr_mask.transpose(1, 2, 0)
            pr_mask = np.argmax(pr_mask,
                                axis=2) if not binary else pr_mask.squeeze()
            pr_mask = pr_mask.astype(np.uint8)
            valid_preds.append(pr_mask)

            save_predictions(out_path=cfg['plot_dir_valid'],
                             index=n + 1,
                             image=image_vis,
                             ground_truth_mask=gt_mask,
                             predicted_mask=pr_mask)

    del train_dataset, valid_dataset
    del train_loader, valid_loader

    # calculate dice per class
    valid_masks = np.stack(valid_masks, axis=0)
    valid_masks = valid_masks.flatten()
    valid_preds = np.stack(valid_preds, axis=0)
    valid_preds = valid_preds.flatten()
    dice_score = f1_score(y_true=valid_masks, y_pred=valid_preds, average=None)
    neptune.log_text('valid_dice_class', str(dice_score))
    print('Valid dice score (class):', str(dice_score))

    if cfg['evaluate_test_set']:
        print('Predicting test patches...')

        # setup directory to save plots
        if os.path.isdir(cfg['plot_dir_test']):
            shutil.rmtree(cfg['plot_dir_test'])
        os.makedirs(cfg['plot_dir_test'], exist_ok=True)

        # test dataset without transformations and normalization for image visualization
        test_dataset_vis = ArtifactDataset(df_path=cfg['test_data'],
                                           classes=cfg['classes'],
                                           ink_filters=cfg['ink_filters'])

        # keep track of test masks
        test_masks = []
        test_preds = []

        for n in range(len(test_dataset)):
            image_vis = test_dataset_vis[n][0].astype('uint8')
            image_vis = image_vis.transpose(1, 2, 0)
            image, gt_mask = test_dataset[n]
            gt_mask = gt_mask.transpose(1, 2, 0)
            gt_mask = np.argmax(gt_mask,
                                axis=2) if not binary else gt_mask.squeeze()
            gt_mask = gt_mask.astype(np.uint8)
            test_masks.append(gt_mask)

            x_tensor = torch.from_numpy(image).to(cfg['device']).unsqueeze(0)
            pr_mask, _ = model.predict(x_tensor)
            pr_mask = pr_mask.squeeze(axis=0).cpu().numpy().round()
            pr_mask = pr_mask.transpose(1, 2, 0)
            pr_mask = np.argmax(pr_mask,
                                axis=2) if not binary else pr_mask.squeeze()
            pr_mask = pr_mask.astype(np.uint8)
            test_preds.append(pr_mask)

            save_predictions(out_path=cfg['plot_dir_test'],
                             index=n + 1,
                             image=image_vis,
                             ground_truth_mask=gt_mask,
                             predicted_mask=pr_mask)

            # calculate dice per class
            test_masks = np.stack(test_masks, axis=0)
            test_masks = test_masks.flatten()
            test_preds = np.stack(test_preds, axis=0)
            test_preds = test_preds.flatten()
            dice_score = f1_score(y_true=test_masks,
                                  y_pred=test_preds,
                                  average=None)
            neptune.log_text('test_dice_class', str({dice_score}))
            print('Test dice score (class):', str(dice_score))

    # end of training process
    print('Finished training!')
    val_batch_size=val_batch_size,
    pin_memory=True,
    train_sampler=train_sampler,
    limit_train_num_samples=100 if debug else None,
    limit_val_num_samples=100 if debug else None)

# accumulation_steps = 8

prepare_batch = prepare_batch_fp32

# Image denormalization function to plot predictions with images
img_denormalize = partial(denormalize, mean=mean, std=std)

#################### Model ####################

model = Unet(encoder_name='resnet50', classes=2)

#################### Solver ####################

num_epochs = 100

criterion = nn.CrossEntropyLoss(weight=torch.tensor([0.5, 1.0]))

lr = 0.05
weight_decay = 5e-4
momentum = 0.9
nesterov = True
optimizer = optim.SGD(model.parameters(),
                      lr=1.0,
                      momentum=momentum,
                      weight_decay=weight_decay,
Exemplo n.º 18
0
data_loader = get_inference_dataloader(
    test_dataset,
    transforms=transforms,
    batch_size=batch_size,
    num_workers=num_workers,
    pin_memory=True,
)

prepare_batch = inference_prepare_batch_f32

# Image denormalization function to plot predictions with images
img_denormalize = partial(denormalize, mean=mean, std=std)

#################### Model ####################

model = Unet(encoder_name='resnet50', classes=2)
run_uuid = "48bddd1fcabb41acaedce035e16247cb"
weights_filename = "best_model_1_val_miou_bg=0.7517839.pth"


def custom_weights_loading(model, model_weights_filepath):
    state_dict = torch.load(model_weights_filepath)

    if 'model' in state_dict:
        state_dict = state_dict['model']

    if not all([k.startswith("module.") for k in state_dict]):
        state_dict = {f"module.{k}": v for k, v in state_dict.items()}

    model.load_state_dict(state_dict)
    '/kaggle/input/severstalmodels/unet_mobilenet2.pth').cuda()
# unet_resnet34 = load('/kaggle/input/severstalmodels/unet_resnet34.pth').cuda()

import os
from segmentation_models_pytorch import Unet, FPN

ENCODER = 'resnet34'
ENCODER_WEIGHTS = 'imagenet'
DEVICE = 'cuda'

CLASSES = ['0', '1', '2', '3', '4']
ACTIVATION = 'softmax'

unet_resnet34 = Unet(
    encoder_name=ENCODER,
    encoder_weights=None,
    classes=4,
    activation='sigmoid',
)

state = torch.load("../input/bce-clf/unet_res34_525.pth")
unet_resnet34.load_state_dict(state['model_state_dict'])

unet_resnet34 = unet_resnet34.cuda()
unet_resnet34 = unet_resnet34.eval()

device = torch.device("cuda")
model_senet = Unet('se_resnext50_32x4d',
                   encoder_weights=None,
                   classes=4,
                   activation=None)
model_senet.to(device)
Exemplo n.º 20
0
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
df = pd.read_csv(sample_submission_path)
testset = DataLoader(
    TestDataset(test_data_folder, df, mean, std),
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True
)

# Initialize mode and load trained weights
ckpt_path = "../input/model_dump1/model.pth"
device = torch.device("cuda")
model = Unet("efficientnet-b5", encoder_weights=None, classes=4, activation=None)
model.to(device)
model.eval()
state = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
model.load_state_dict(state["state_dict"])

# start prediction
predictions = []
for i, batch in enumerate(tqdm(testset)):
    fnames, images = batch
    batch_preds = torch.sigmoid(model(images.to(device)))
    batch_preds = batch_preds.detach().cpu().numpy()
    for fname, preds in zip(fnames, batch_preds):
        for cls, pred in enumerate(preds):
            pred, num = post_process(pred, best_threshold[cls], min_size[cls])
            rle = mask2rle(pred)
Exemplo n.º 21
0
from collections import namedtuple

from iglovikov_helper_functions.dl.pytorch.utils import rename_layers
from segmentation_models_pytorch import Unet
from torch import nn
from torch.utils import model_zoo

model = namedtuple("model", ["url", "model"])

models = {
    "Unet_resnet34_2020-05-19":
    model(
        url=
        "https://github.com/ternaus/midv-500-models/releases/download/0.0.1/unet_resnet34_2020-05-19.zip",
        model=Unet(encoder_name="resnet34", classes=1, encoder_weights=None),
    )
}


def create_model(model_name: str) -> nn.Module:
    model = models[model_name].model
    state_dict = model_zoo.load_url(models[model_name].url,
                                    progress=True,
                                    map_location="cpu")["state_dict"]
    state_dict = rename_layers(state_dict, {"model.": ""})
    model.load_state_dict(state_dict)
    return model
Exemplo n.º 22
0
        columns=["Image_Label", "EncodedPixels"],
        index=False,
    )


if __name__ == "__main__":
    model_name = sys.argv[1]
    test_data_path = sys.argv[2]
    class_params_path = sys.argv[3]
    output_path = sys.argv[3]

    df_test = pd.read_csv(os.path.join(test_data_path), "sample_submission.csv")
    test_ids = (
        df_test["Image_Label"].apply(lambda x: x.split("_")[0]).drop_duplicates().values
    )
    preprocess_fn = get_preprocessing_fn(model_name, "imagenet")
    test_dataset = CloudDataset(
        df=df_test,
        path=test_data_path,
        img_ids=test_ids,
        image_folder="test_images",
        transforms=get_transforms("valid"),
        preprocessing_fn=preprocess_fn,
    )
    test_loader = DataLoader(
        dataset=test_dataset, batch_size=32, shuffle=False, num_workers=4,
    )
    model = Unet(model_name, classes=4, activation=None)
    class_params = np.load(class_params_path).item()
    infer(model, test_loader, class_params, output_path)
    val_batch_size=val_batch_size,
    pin_memory=True,
    train_sampler=train_sampler,
    limit_train_num_samples=100 if debug else None,
    limit_val_num_samples=100 if debug else None)

# accumulation_steps = 8

prepare_batch = prepare_batch_fp32

# Image denormalization function to plot predictions with images
img_denormalize = partial(denormalize, mean=mean, std=std)

#################### Model ####################

model = Unet(encoder_name='efficientnet-b3', classes=2)

#################### Solver ####################

num_epochs = 100

criterion = nn.CrossEntropyLoss(weight=torch.tensor([0.1, 2.0]))

lr = 0.05
weight_decay = 5e-4
momentum = 0.9
nesterov = True
optimizer = optim.SGD(model.parameters(),
                      lr=1.0,
                      momentum=momentum,
                      weight_decay=weight_decay,
Exemplo n.º 24
0
from collections import namedtuple
from torch import nn
from torch.utils import model_zoo
from iglovikov_helper_functions.dl.pytorch.utils import rename_layers

from segmentation_models_pytorch import Unet

model = namedtuple("model", ["url", "model"])

models = {
    "Unet_2020-10-30":
    model(
        url=
        "https://github.com/ternaus/cloths_segmentation/releases/download/0.0.1/weights.zip",
        model=Unet(encoder_name="timm-efficientnet-b3",
                   classes=1,
                   encoder_weights=None),
    )
}


def create_model(model_name: str) -> nn.Module:
    model = models[model_name].model
    state_dict = model_zoo.load_url(models[model_name].url,
                                    progress=True,
                                    map_location="cpu")["state_dict"]
    state_dict = rename_layers(state_dict, {"model.": ""})
    model.load_state_dict(state_dict)
    return model
Exemplo n.º 25
0
def main(cfg):
    """Runs main training procedure."""

    print('Starting training...')
    print('Current working directory is:', os.getcwd())

    # fix random seeds for reproducibility
    seed_everything(seed=cfg['seed'])

    # neptune logging
    neptune.init(project_qualified_name=cfg['neptune_project_name'],
                 api_token=cfg['neptune_api_token'])

    neptune.create_experiment(name=cfg['neptune_experiment'], params=cfg)

    num_classes = 1 if len(cfg['classes']) == 1 else (len(cfg['classes']) + 1)
    activation = 'sigmoid' if num_classes == 1 else 'softmax2d'
    background = False if cfg['ignore_channels'] else True

    aux_params = dict(
        pooling=cfg['pooling'],  # one of 'avg', 'max'
        dropout=cfg['dropout'],  # dropout ratio, default is None
        activation='sigmoid',  # activation function, default is None
        classes=num_classes)  # define number of output labels

    # configure model
    models = {
        'unet':
        Unet(encoder_name=cfg['encoder_name'],
             encoder_weights=cfg['encoder_weights'],
             decoder_use_batchnorm=cfg['use_batchnorm'],
             classes=num_classes,
             activation=activation,
             aux_params=aux_params),
        'unetplusplus':
        UnetPlusPlus(encoder_name=cfg['encoder_name'],
                     encoder_weights=cfg['encoder_weights'],
                     decoder_use_batchnorm=cfg['use_batchnorm'],
                     classes=num_classes,
                     activation=activation,
                     aux_params=aux_params),
        'deeplabv3plus':
        DeepLabV3Plus(encoder_name=cfg['encoder_name'],
                      encoder_weights=cfg['encoder_weights'],
                      classes=num_classes,
                      activation=activation,
                      aux_params=aux_params)
    }

    assert cfg['architecture'] in models.keys()
    model = models[cfg['architecture']]

    # configure loss
    losses = {
        'dice_loss':
        DiceLoss(include_background=background,
                 softmax=False,
                 batch=cfg['combine']),
        'generalized_dice':
        GeneralizedDiceLoss(include_background=background,
                            softmax=False,
                            batch=cfg['combine'])
    }

    assert cfg['loss'] in losses.keys()
    loss = losses[cfg['loss']]

    # configure optimizer
    optimizers = {
        'adam': Adam([dict(params=model.parameters(), lr=cfg['lr'])]),
        'adamw': AdamW([dict(params=model.parameters(), lr=cfg['lr'])]),
        'rmsprop': RMSprop([dict(params=model.parameters(), lr=cfg['lr'])])
    }

    assert cfg['optimizer'] in optimizers.keys()
    optimizer = optimizers[cfg['optimizer']]

    # configure metrics
    metrics = {
        'dice_score':
        DiceMetric(include_background=background, reduction='mean'),
        'dice_smp':
        Fscore(threshold=cfg['rounding'],
               ignore_channels=cfg['ignore_channels']),
        'iou_smp':
        IoU(threshold=cfg['rounding'], ignore_channels=cfg['ignore_channels']),
        'generalized_dice':
        GeneralizedDiceLoss(include_background=background,
                            softmax=False,
                            batch=cfg['combine']),
        'dice_loss':
        DiceLoss(include_background=background,
                 softmax=False,
                 batch=cfg['combine']),
        'cross_entropy':
        BCELoss(reduction='mean'),
        'accuracy':
        Accuracy(ignore_channels=cfg['ignore_channels'])
    }

    assert all(m['name'] in metrics.keys() for m in cfg['metrics'])
    metrics = [(metrics[m['name']], m['name'], m['type'])
               for m in cfg['metrics']]  # tuple of (metric, name, type)

    # configure scheduler
    schedulers = {
        'steplr':
        StepLR(optimizer, step_size=cfg['step_size'], gamma=0.5),
        'cosine':
        CosineAnnealingLR(optimizer,
                          cfg['epochs'],
                          eta_min=cfg['eta_min'],
                          last_epoch=-1)
    }

    assert cfg['scheduler'] in schedulers.keys()
    scheduler = schedulers[cfg['scheduler']]

    # configure augmentations
    train_transform = load_train_transform(transform_type=cfg['transform'],
                                           patch_size=cfg['patch_size'])
    valid_transform = load_valid_transform(patch_size=cfg['patch_size'])

    train_dataset = SegmentationDataset(df_path=cfg['train_data'],
                                        transform=train_transform,
                                        normalize=cfg['normalize'],
                                        tissuemix=cfg['tissuemix'],
                                        probability=cfg['probability'],
                                        blending=cfg['blending'],
                                        warping=cfg['warping'],
                                        color=cfg['color'])

    valid_dataset = SegmentationDataset(df_path=cfg['valid_data'],
                                        transform=valid_transform,
                                        normalize=cfg['normalize'])

    # save intermediate augmentations
    if cfg['eval_dir']:
        default_dataset = SegmentationDataset(df_path=cfg['train_data'],
                                              transform=None,
                                              normalize=None)

        transform_dataset = SegmentationDataset(df_path=cfg['train_data'],
                                                transform=None,
                                                normalize=None,
                                                tissuemix=cfg['tissuemix'],
                                                probability=cfg['probability'],
                                                blending=cfg['blending'],
                                                warping=cfg['warping'],
                                                color=cfg['color'])

        for idx in range(0, min(500, len(default_dataset)), 10):
            image_input, image_mask = default_dataset[idx]
            image_input = image_input.transpose((1, 2, 0))
            image_input = image_input.astype(np.uint8)

            image_mask = image_mask.transpose(
                1, 2, 0)  # Why do we need transpose here?
            image_mask = image_mask.astype(np.uint8)
            image_mask = image_mask.squeeze()
            image_mask = image_mask * 255

            image_transform, _ = transform_dataset[idx]
            image_transform = image_transform.transpose(
                (1, 2, 0)).astype(np.uint8)

            idx_str = str(idx).zfill(3)
            skimage.io.imsave(os.path.join(cfg['eval_dir'],
                                           f'{idx_str}a_image_input.png'),
                              image_input,
                              check_contrast=False)
            plt.imsave(os.path.join(cfg['eval_dir'],
                                    f'{idx_str}b_image_mask.png'),
                       image_mask,
                       vmin=0,
                       vmax=1)
            skimage.io.imsave(os.path.join(cfg['eval_dir'],
                                           f'{idx_str}c_image_transform.png'),
                              image_transform,
                              check_contrast=False)

        del transform_dataset

    train_loader = DataLoader(train_dataset,
                              batch_size=cfg['batch_size'],
                              num_workers=cfg['workers'],
                              shuffle=True)

    valid_loader = DataLoader(valid_dataset,
                              batch_size=cfg['batch_size'],
                              num_workers=cfg['workers'],
                              shuffle=False)

    trainer = Trainer(model=model,
                      device=cfg['device'],
                      save_checkpoints=cfg['save_checkpoints'],
                      checkpoint_dir=cfg['checkpoint_dir'],
                      checkpoint_name=cfg['checkpoint_name'])

    trainer.compile(optimizer=optimizer,
                    loss=loss,
                    metrics=metrics,
                    num_classes=num_classes)

    trainer.fit(train_loader,
                valid_loader,
                epochs=cfg['epochs'],
                scheduler=scheduler,
                verbose=cfg['verbose'],
                loss_weight=cfg['loss_weight'])

    # validation inference
    best_model = model
    best_model.load_state_dict(
        torch.load(os.path.join(cfg['checkpoint_dir'],
                                cfg['checkpoint_name'])))
    best_model.to(cfg['device'])
    best_model.eval()

    # setup directory to save plots
    if os.path.isdir(cfg['plot_dir']):
        # remove existing dir and content
        shutil.rmtree(cfg['plot_dir'])
    # create absolute destination
    os.makedirs(cfg['plot_dir'])

    # valid dataset without transformations and normalization for image visualization
    valid_dataset_vis = SegmentationDataset(df_path=cfg['valid_data'],
                                            transform=valid_transform,
                                            normalize=None)

    if cfg['save_checkpoints']:
        for n in range(len(valid_dataset)):
            image_vis = valid_dataset_vis[n][0].astype('uint8')
            image_vis = image_vis.transpose((1, 2, 0))

            image, gt_mask = valid_dataset[n]
            gt_mask = gt_mask.transpose((1, 2, 0))
            gt_mask = gt_mask.squeeze()

            x_tensor = torch.from_numpy(image).to(cfg['device']).unsqueeze(0)
            pr_mask, _ = best_model.predict(x_tensor)
            pr_mask = pr_mask.cpu().numpy().round()
            pr_mask = pr_mask.squeeze()

            save_predictions(out_path=cfg['plot_dir'],
                             index=n,
                             image=image_vis,
                             ground_truth_mask=gt_mask,
                             predicted_mask=pr_mask,
                             average='macro')
import torch
from segmentation_models_pytorch import Unet, Linknet, FPN, PSPNet


model = Unet('se_resnet152', encoder_weights='imagenet', classes=4, activation='sigmoid')
# model = Unet('resnext101_32x8d', encoder_weights='instagram', classes=4, activation='sigmoid')
# model = Unet('dpn92', encoder_weights='imagenet+5k', classes=4, activation='sigmoid')
print(model)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(count_parameters(model))

# weights_dict = {
#     'state_dict': model.state_dict(),
# }
# torch.save(weights_dict, 'check.pth')
Exemplo n.º 27
0
num_classes = 1 if len(cfg['classes']) == 1 else (len(cfg['classes']) + 1)
activation = 'sigmoid' if num_classes == 1 else 'softmax2d'
background = False if cfg['ignore_channels'] else True

aux_params = dict(
    pooling=cfg['pooling'],  # one of 'avg', 'max'
    dropout=cfg['dropout'],  # dropout ratio, default is None
    activation='sigmoid',  # activation function, default is None
    classes=num_classes)  # define number of output labels

# configure model
models = {
    'unet':
    Unet(encoder_name=cfg['encoder_name'],
         encoder_weights=cfg['encoder_weights'],
         decoder_use_batchnorm=cfg['use_batchnorm'],
         classes=num_classes,
         activation=activation,
         aux_params=aux_params),
    'unetplusplus':
    UnetPlusPlus(encoder_name=cfg['encoder_name'],
                 encoder_weights=cfg['encoder_weights'],
                 decoder_use_batchnorm=cfg['use_batchnorm'],
                 classes=num_classes,
                 activation=activation,
                 aux_params=aux_params),
    'deeplabv3plus':
    DeepLabV3Plus(encoder_name=cfg['encoder_name'],
                  encoder_weights=cfg['encoder_weights'],
                  classes=num_classes,
                  activation=activation,
                  aux_params=aux_params)