Пример #1
0
def build_model(configuration):
    model_list = ['UNet', 'LinkNet', 'PSPNet', 'FPN', 'PAN', 'Deeplab_v3', 'Deeplab_v3+']
    if configuration.Model.model_name.lower() == 'unet':
        return smp.Unet(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes,
            decoder_attention_type=None,
        )
    if configuration.Model.model_name.lower() == 'linknet':
        return smp.Linknet(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes
        )
    if configuration.Model.model_name.lower() == 'pspnet':
        return smp.PSPNet(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes
        )
    if configuration.Model.model_name.lower() == 'fpn':
        return smp.FPN(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes
        )
    if configuration.Model.model_name.lower() == 'pan':
        return smp.PAN(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes
        )
    if configuration.Model.model_name.lower() == 'deeplab_v3+':
        return smp.DeepLabV3Plus(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes
        )
    if configuration.Model.model_name.lower() == 'deeplab_v3':
        return smp.DeepLabV3(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes
        )
    raise KeyError(f'Model should be one of {model_list}')
Пример #2
0
def create_smp_model(arch, **kwargs):
    'Create segmentation_models_pytorch model'

    assert arch in ARCHITECTURES, f'Select one of {ARCHITECTURES}'

    if arch == "Unet": model = smp.Unet(**kwargs)
    elif arch == "UnetPlusPlus": model = smp.UnetPlusPlus(**kwargs)
    elif arch == "MAnet": model = smp.MAnet(**kwargs)
    elif arch == "FPN": model = smp.FPN(**kwargs)
    elif arch == "PAN": model = smp.PAN(**kwargs)
    elif arch == "PSPNet": model = smp.PSPNet(**kwargs)
    elif arch == "Linknet": model = smp.Linknet(**kwargs)
    elif arch == "DeepLabV3": model = smp.DeepLabV3(**kwargs)
    elif arch == "DeepLabV3Plus": model = smp.DeepLabV3Plus(**kwargs)
    else: raise NotImplementedError

    setattr(model, 'kwargs', kwargs)
    return model
Пример #3
0
    def __init__(self):
        self.backbone = 'resnet34'
        self.lr = 1e-4
        self.batch = 64
        self.crop_size = 384
        logger = logging.getLogger('train')

        # print settings
        logger.info('\nTRAINING SETTINGS')
        logger.info('###########################')
        logger.info(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
        logger.info('backbone:{}, lr:{}, batch:{}, crop size:{}'.format(self.backbone, self.lr, self.batch, self.crop_size))
        logger.info('###########################\n')

        # model define
        self.train_loader, self.test_loader = make_data_loader(batch=self.batch, crop_size=self.crop_size)
        model = smp.PAN(encoder_name=self.backbone, encoder_weights='imagenet', in_channels=3, classes=2)
        self.model = nn.DataParallel(model.cuda())
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        self.criterion = nn.CrossEntropyLoss()
        self.epoch = 0
Пример #4
0
    def __init__(self,
                 encoder,
                 encoder_weights,
                 classes,
                 activation,
                 learning_rate=1e-3,
                 **kwargs):
        super().__init__()
        self.save_hyperparameters()

        self.classes = classes

        if self.hparams.architecture == 'fpn':
            self.model = smp.FPN(
                encoder_name=encoder,
                encoder_weights=encoder_weights,
                classes=len(classes),
                activation=activation,
            )
        elif self.hparams.architecture == 'pan':
            self.model = smp.PAN(
                encoder_name=encoder,
                encoder_weights=encoder_weights,
                classes=len(classes),
                activation=activation,
            )
        elif self.hparams.architecture == 'pspnet':
            self.model = smp.PSPNet(
                encoder_name=encoder,
                encoder_weights=encoder_weights,
                classes=len(classes),
                activation=activation,
            )
        else:
            raise NameError('')

        self.loss = smp.utils.losses.DiceLoss()
Пример #5
0
            encoder_name=ENCODER, 
            encoder_weights=ENCODER_WEIGHTS, 
            classes=len(CLASSES), 
            activation=ACTIVATION,
        )
    elif MODEL == 'deeplabv3plus':
        model = smp.DeepLabV3Plus(
            encoder_name=ENCODER, 
            encoder_weights=ENCODER_WEIGHTS, 
            classes=len(CLASSES), 
            activation=ACTIVATION,
        )
    elif  MODEL == 'pannet':
        model = smp.PAN(
            encoder_name=ENCODER,
            encoder_weights=ENCODER_WEIGHTS,
            classes=len(CLASSES),
            activation=ACTIVATION,
        )
    elif  MODEL == 'fpn':
        model = smp.FPN(
            encoder_name=ENCODER,
            encoder_weights=ENCODER_WEIGHTS,
            classes=len(CLASSES),
            activation=ACTIVATION,
        )
    else:
        raise RuntimeError('Model name Error')

    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

    # ------------------------------------------------------
Пример #6
0
def get_model(config):
    """
    """
    arch = config.MODEL.ARCHITECTURE
    backbone = config.MODEL.BACKBONE
    encoder_weights = config.MODEL.ENCODER_PRETRAINED_FROM
    in_channels = config.MODEL.IN_CHANNELS
    n_classes = len(config.INPUT.CLASSES)
    activation = config.MODEL.ACTIVATION

    # unet specific
    decoder_attention_type = 'scse' if config.MODEL.UNET_ENABLE_DECODER_SCSE else None

    if arch == 'unet':
        model = smp.Unet(
            encoder_name=backbone,
            encoder_weights=encoder_weights,
            decoder_channels=config.MODEL.UNET_DECODER_CHANNELS,
            decoder_attention_type=decoder_attention_type,
            in_channels=in_channels,
            classes=n_classes,
            activation=activation
        )
    elif arch == 'fpn':
        model = smp.FPN(
            encoder_name=backbone,
            encoder_weights=encoder_weights,
            decoder_dropout=config.MODEL.FPN_DECODER_DROPOUT,
            in_channels=in_channels,
            classes=n_classes,
            activation=activation
        )
    elif arch == 'pan':
        model = smp.PAN(
            encoder_name=backbone,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=n_classes,
            activation=activation
        )
    elif arch == 'pspnet':
        model = smp.PSPNet(
            encoder_name=backbone,
            encoder_weights=encoder_weights,
            psp_dropout=config.MODEL.PSPNET_DROPOUT,
            in_channels=in_channels,
            classes=n_classes,
            activation=activation
        )
    elif arch == 'deeplabv3':
        model = smp.DeepLabV3(
            encoder_name=backbone,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=n_classes,
            activation=activation
        )
    elif arch == 'linknet':
        model = smp.Linknet(
            encoder_name=backbone,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=n_classes,
            activation=activation
        )
    else:
        raise ValueError()

    model = torch.nn.DataParallel(model)

    if config.MODEL.WEIGHT and config.MODEL.WEIGHT != 'none':
        # load weight from file
        model.load_state_dict(
            torch.load(
                config.MODEL.WEIGHT,
                map_location=torch.device('cpu')
            )
        )

    model = model.to(config.MODEL.DEVICE)
    return model
Пример #7
0
def get_segmentation_model(
    arch: str,
    encoder_name: str,
    encoder_weights: Optional[str] = "imagenet",
    pretrained_checkpoint_path: Optional[str] = None,
    checkpoint_path: Optional[Union[str, List[str]]] = None,
    convert_bn: Optional[str] = None,
    convert_bottleneck: Tuple[int, int, int] = (0, 0, 0),
    **kwargs: Any,
) -> nn.Module:
    """
    Fetch segmentation model by its name
    :param arch:
    :param encoder_name:
    :param encoder_weights:
    :param checkpoint_path:
    :param pretrained_checkpoint_path:
    :param convert_bn:
    :param convert_bottleneck:
    :param kwargs:
    :return:
    """

    arch = arch.lower()
    if (encoder_name == "en_resnet34" or checkpoint_path is not None
            or pretrained_checkpoint_path is not None):
        encoder_weights = None

    if arch == "unet":
        model = smp.Unet(encoder_name=encoder_name,
                         encoder_weights=encoder_weights,
                         **kwargs)
    elif arch == "unetplusplus" or arch == "unet++":
        model = smp.UnetPlusPlus(encoder_name=encoder_name,
                                 encoder_weights=encoder_weights,
                                 **kwargs)
    elif arch == "linknet":
        model = smp.Linknet(encoder_name=encoder_name,
                            encoder_weights=encoder_weights,
                            **kwargs)
    elif arch == "pspnet":
        model = smp.PSPNet(encoder_name=encoder_name,
                           encoder_weights=encoder_weights,
                           **kwargs)
    elif arch == "pan":
        model = smp.PAN(encoder_name=encoder_name,
                        encoder_weights=encoder_weights,
                        **kwargs)
    elif arch == "deeplabv3":
        model = smp.DeepLabV3(encoder_name=encoder_name,
                              encoder_weights=encoder_weights,
                              **kwargs)
    elif arch == "deeplabv3plus" or arch == "deeplabv3+":
        model = smp.DeepLabV3Plus(encoder_name=encoder_name,
                                  encoder_weights=encoder_weights,
                                  **kwargs)
    elif arch == "manet":
        model = smp.MAnet(encoder_name=encoder_name,
                          encoder_weights=encoder_weights,
                          **kwargs)
    else:
        raise ValueError

    if pretrained_checkpoint_path is not None:
        print(f"Loading pretrained checkpoint {pretrained_checkpoint_path}")
        state_dict = torch.load(pretrained_checkpoint_path,
                                map_location=torch.device("cpu"))
        model.encoder.load_state_dict(state_dict)
        del state_dict

    # TODO fmap_size=16 hardcoded for input 256 (matters for positional encoding)
    botnet.convert_resnet(
        model.encoder,
        replacement=convert_bottleneck,
        fmap_size=16,
        position_encoding=None,
    )

    # TODO parametrize conversion
    print(f"Convert BN to {convert_bn}")
    if convert_bn == "instance":
        print("Converting BatchNorm2d to InstanceNorm2d")
        model = batch_norm2instance(model)
    elif convert_bn == "group":
        print("Converting BatchNorm2d to GroupNorm")
        model = batch_norm2group(model, channels_per_group=1)
    elif convert_bn == "bnet":
        print("Converting BatchNorm2d to BNet2d")
        model = batch_norm2bnet(model)
    elif convert_bn == "gnet":
        print("Converting BatchNorm2d to GNet2d")
        model = batch_norm2gnet(model, channels_per_group=1)
    elif not convert_bn:
        print("Do not convert BatchNorm2d")
    else:
        raise ValueError

    if checkpoint_path is not None:
        if not isinstance(checkpoint_path, list):
            checkpoint_path = [checkpoint_path]
        states = []
        for cp in checkpoint_path:
            # Load checkpoint
            print(f"\nLoading checkpoint {str(cp)}")
            state_dict = torch.load(
                cp, map_location=torch.device("cpu"))["model_state_dict"]
            states.append(state_dict)
        state_dict = average_weights(states)
        model.load_state_dict(state_dict)
        del state_dict

    return model
Пример #8
0
    def __init__(self,
                 architecture="Unet",
                 encoder="resnet34",
                 depth=5,
                 in_channels=3,
                 classes=2,
                 activation="softmax"):
        super(SegmentationModels, self).__init__()
        self.architecture = architecture
        self.encoder = encoder
        self.depth = depth
        self.in_channels = in_channels
        self.classes = classes
        self.activation = activation

        # define model

        _ARCHITECTURES = [
            "Unet", "UnetPlusPlus", "Linknet", "MAnet", "FPN", "PSPNet", "PAN",
            "DeepLabV3", "DeepLabV3Plus"
        ]
        assert self.architecture in _ARCHITECTURES, "architecture=={0}, actual '{1}'".format(
            _ARCHITECTURES, self.architecture)

        if self.architecture == "Unet":
            self.model = smp.Unet(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
        elif self.architecture == "UnetPlusPlus":
            self.model = smp.UnetPlusPlus(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
        elif self.architecture == "MAnet":
            self.model = smp.MAnet(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
        elif self.architecture == "Linknet":
            self.model = smp.Linknet(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
        elif self.architecture == "FPN":
            self.model = smp.FPN(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
        elif self.architecture == "PSPNet":
            self.model = smp.PSPNet(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
        elif self.architecture == "PAN":
            self.model = smp.PAN(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
        elif self.architecture == "DeepLabV3":
            self.model = smp.DeepLabV3(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
        elif self.architecture == "DeepLabV3Plus":
            self.model = smp.DeepLabV3Plus(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
Пример #9
0
        classes=8,
        activation=args.activation,
    ),
    "deeplabv3plus":
    smp.DeepLabV3Plus(
        encoder_name=args.encoder,
        encoder_weights=args.weight,
        classes=8,
        activation=args.activation,
        aux_params=aux_params_dict,
    ),
    "pan":
    smp.PAN(
        encoder_name=args.encoder,
        encoder_weights=args.weight,
        classes=8,
        activation=args.activation,
        aux_params=aux_params_dict,
    ),
}


def save_checkpoint(state, filename):
    torch.save(state, filename)


def main():
    data_dir = Path(path_dict["train_data"])
    # images_dir = data_dir.joinpath("image")
    images_dir = Path(config["train"]["image_dir"])
    logger.info(f"Loading images from {images_dir}")
Пример #10
0
def main():
    global args

    args = parse_args()
    print(args)

    torch.backends.cudnn.benchmark = True
    if args.deterministic:
        set_seed(args.seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        torch.set_printoptions(precision=10)

    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    args.gpu = 0
    args.world_size = 1
    if args.distributed:
        args.gpu = args.local_rank
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    assert torch.backends.cudnn.enabled, 'Amp requires cudnn backend to be enabled.'

    # create model
    if args.cls:
        print('With classification')
    else:
        print('Without classification')

    if args.smodel == 'Unet':
        model = smp.Unet(encoder_name=args.encoder,
                         encoder_weights='imagenet'
                         if 'dpn92' not in args.encoder else 'imagenet+5k',
                         classes=args.n_classes,
                         in_channels=args.in_channels,
                         decoder_attention_type=args.attention_type,
                         activation=None)
    elif args.smodel == 'FPN':
        model = smp.FPN(encoder_name=args.encoder,
                        encoder_weights='imagenet'
                        if 'dpn92' not in args.encoder else 'imagenet+5k',
                        classes=args.n_classes,
                        in_channels=args.in_channels,
                        activation=None)
    elif args.smodel == 'PAN':
        model = smp.PAN(encoder_name=args.encoder,
                        encoder_weights='imagenet'
                        if 'dpn92' not in args.encoder else 'imagenet+5k',
                        classes=args.n_classes,
                        in_channels=args.in_channels,
                        activation=None)
    elif args.smodel == 'PSPNet':
        model = smp.PSPNet(encoder_name=args.encoder,
                           encoder_weights='imagenet'
                           if 'dpn92' not in args.encoder else 'imagenet+5k',
                           classes=args.n_classes,
                           in_channels=args.in_channels,
                           encoder_depth=3,
                           activation=None)
    else:
        raise

    if args.sync_bn:
        print('using apex synced BN')
        model = apex.parallel.convert_syncbn_model(model)

    model.cuda()

    # Scale learning rate based on global batch size
    print(f'lr={args.lr}, opt={args.opt}')
    if args.opt == 'adam':
        opt = apex.optimizers.FusedAdam(
            model.parameters(
            ),  # add_weight_decay(model, args.weight_decay, ('bn', )),
            lr=args.lr,
            weight_decay=args.weight_decay,
        )
    elif args.opt == 'sgd':
        opt = torch.optim.SGD(
            add_weight_decay(model, args.weight_decay,
                             ('bn', )),  # model.parameters(),
            args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay)
    else:
        raise

    # Initialize Amp. Amp accepts either values or strings for the optional override arguments,
    # for convenient interoperation with argparse.
    if args.fp16:
        model, opt = apex.amp.initialize(
            model,
            opt,
            opt_level=args.opt_level,
            keep_batchnorm_fp32=args.keep_batchnorm_fp32,
            loss_scale=args.loss_scale)

    # For distributed training, wrap the model with apex.parallel.DistributedDataParallel.
    # This must be done AFTER the call to amp.initialize.  If model = DDP(model) is called
    # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
    # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
    if args.distributed:
        # By default, apex.parallel.DistributedDataParallel overlaps communication with
        # computation in the backward pass.
        # model = DDP(model)
        # delay_allreduce delays all communication to the end of the backward pass.
        model = apex.parallel.DistributedDataParallel(model,
                                                      delay_allreduce=True)

    dice_loss = smp.utils.losses.DiceLoss(activation='sigmoid')
    bce_loss = nn.BCEWithLogitsLoss()
    if args.cls:

        def BCEBCE(logits, target):
            prediction_seg, prediction_cls = logits
            y_cls = (target.sum([2, 3]) > 0).float()

            return bce_loss(prediction_seg, target) + bce_loss(
                prediction_cls, y_cls)

        def symmetric_lovasz_fn(logits, target):
            prediction_seg, prediction_cls = logits
            y_cls = (target.sum([2, 3]) > 0).float()

            return symmetric_lovasz(prediction_seg, target) + bce_loss(
                prediction_cls, y_cls)
    else:
        if not args.use_softmax:

            def BCEBCE(logits, target):
                if args.loss == 'bce':
                    return bce_loss(logits, target)
                elif args.loss == 'dice':
                    return dice_loss(logits, target)
                elif args.loss == 'bce+dice':
                    return bce_loss(logits, target) + dice_loss(logits, target)

                return bce_loss(logits, target) + dice_loss(logits, target)

            symmetric_lovasz_fn = symmetric_lovasz
        else:

            def BCEBCE(logits, target):
                return nn.CrossEntropyLoss()(logits, target[:, 0].long())

            def symmetric_lovasz_fn(logits, target):
                return lovasz_softmax(torch.softmax(logits, dim=1),
                                      target[:, 0].long())

    criterion = BCEBCE

    history = {k: {k_: [] for k_ in ['train', 'dev']} for k in ['loss']}
    best_score = 0
    if not args.use_softmax:
        metrics = {
            'j55':
            JaccardMicro(n_classes=args.n_classes, thresh=0.55, w3m=args.w3m),
            'score':
            JaccardMicro(n_classes=args.n_classes, thresh=0.5, w3m=args.w3m),
            'j45':
            JaccardMicro(n_classes=args.n_classes, thresh=0.45, w3m=args.w3m),
            'd5':
            Dice(n_classes=args.n_classes, thresh=0.5, w3m=args.w3m),
        }
    else:
        metrics = {
            'score':
            JaccardMicro(n_classes=args.n_classes, thresh=None, w3m=args.w3m),
            'jaccard':
            Dice(n_classes=args.n_classes, thresh=None, w3m=args.w3m),
        }

    history.update({k: {v: [] for v in ['train', 'dev']} for k in metrics})

    base_name = f'{args.encoder}_b{args.batch_size}_{args.opt}_lr{args.lr}_w3m{int(args.w3m)}_f{args.fold}'
    work_dir = Path(args.work_dir) / base_name
    if args.local_rank == 0 and not work_dir.exists():
        work_dir.mkdir(parents=True)

    # Optionally load model from a checkpoint
    if args.load:

        def _load():
            path_to_load = Path(args.load)
            if path_to_load.is_file():
                print(f"=> loading model '{path_to_load}'")
                checkpoint = torch.load(
                    path_to_load,
                    map_location=lambda storage, loc: storage.cuda(args.gpu))
                model.load_state_dict(checkpoint['state_dict'])
                if args.fp16 and checkpoint['amp'] is not None:
                    apex.amp.load_state_dict(checkpoint['amp'])
                print(f"=> loaded model '{path_to_load}'")
            else:
                print(f"=> no model found at '{path_to_load}'")

        _load()

    # Optionally resume from a checkpoint
    if args.resume:
        # Use a local scope to avoid dangling references
        def _resume():
            nonlocal history, best_score
            path_to_resume = Path(args.resume)
            if path_to_resume.is_file():
                print(f"=> loading resume checkpoint '{path_to_resume}'")
                checkpoint = torch.load(
                    path_to_resume,
                    map_location=lambda storage, loc: storage.cuda(args.gpu))
                args.start_epoch = checkpoint['epoch'] + 1
                history = checkpoint['history']
                best_score = checkpoint['best_score']
                model.load_state_dict(checkpoint['state_dict'])
                opt.load_state_dict(checkpoint['opt_state_dict'])
                if args.fp16 and checkpoint['amp'] is not None:
                    apex.amp.load_state_dict(checkpoint['amp'])
                print(
                    f"=> resume from checkpoint '{path_to_resume}' (epoch {checkpoint['epoch']})"
                )
            else:
                print(f"=> no checkpoint found at '{args.resume}'")

        _resume()
    history.update({
        k: {v: []
            for v in ['train', 'dev']}
        for k in metrics if k not in history
    })

    scheduler = None
    if args.scheduler == 'cos':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            opt,
            T_max=args.T_max,
            eta_min=max(args.lr * 1e-2, 1e-6),
            last_epoch=args.start_epoch if args.resume else -1)

    path_to_data = Path(args.data)
    train_gps, dev_gps = get_data_groups(path_to_data / args.csv, args)

    train_ds = CloudsDS(train_gps,
                        root=path_to_data,
                        transform=train_transform,
                        w3m=args.w3m,
                        use_softmax=args.use_softmax)
    dev_ds = CloudsDS(dev_gps,
                      root=path_to_data,
                      transform=dev_transform,
                      w3m=args.w3m,
                      use_softmax=args.use_softmax)

    train_sampler = None
    dev_sampler = None
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_ds)
        dev_sampler = torch.utils.data.distributed.DistributedSampler(dev_ds)

    batch_size = args.batch_size
    num_workers = args.workers
    train_loader = torch.utils.data.DataLoader(train_ds,
                                               batch_size=batch_size,
                                               shuffle=train_sampler is None,
                                               sampler=train_sampler,
                                               num_workers=num_workers,
                                               collate_fn=collate_fn,
                                               pin_memory=True)

    dev_loader = torch.utils.data.DataLoader(
        dev_ds,
        batch_size=batch_size,  # 20, 27
        shuffle=False,
        sampler=dev_sampler,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=True)

    saver = lambda path: torch.save(
        {
            'epoch': epoch,
            'best_score': best_score,
            'history': history,
            'state_dict': model.state_dict(),
            'opt_state_dict': opt.state_dict(),
            'amp': apex.amp.state_dict() if args.fp16 else None,
            'args': args,
        }, path)

    teachers = None
    if args.teachers is not None:
        teachers = [
            torch.jit.load(str(p)).cuda().eval()
            for p in Path(args.teachers).rglob('*.pt')
        ]

        if args.distributed:
            for i in range(len(teachers)):
                teachers[i] = apex.parallel.DistributedDataParallel(
                    teachers[i], delay_allreduce=True)

        print(f'#teachers: {len(teachers)}')

    for epoch in range(args.start_epoch, args.epochs + 1):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        if epoch >= args.lovasz:
            criterion = symmetric_lovasz_fn

        for metric in metrics.values():
            metric.clean()
        loss = epoch_step(train_loader,
                          f'[ Training {epoch}/{args.epochs}.. ]',
                          model=model,
                          criterion=criterion,
                          metrics=metrics,
                          opt=opt,
                          batch_accum=args.batch_accum,
                          teachers=teachers)
        history['loss']['train'].append(loss)
        for k, metric in metrics.items():
            history[k]['train'].append(metric.evaluate())

        if not args.ft:
            with torch.no_grad():
                for metric in metrics.values():
                    metric.clean()
                loss = epoch_step(dev_loader,
                                  f'[ Validating {epoch}/{args.epochs}.. ]',
                                  model=model,
                                  criterion=criterion,
                                  metrics=metrics,
                                  opt=None)
                history['loss']['dev'].append(loss)
                for k, metric in metrics.items():
                    history[k]['dev'].append(metric.evaluate())
        else:
            history['loss']['dev'].append(loss)
            for k, metric in metrics.items():
                history[k]['dev'].append(metric.evaluate())

        if scheduler is not None:
            scheduler.step()

        if args.local_rank == 0:
            saver(work_dir / 'last.pth')
            if history['score']['dev'][-1] > best_score:
                best_score = history['score']['dev'][-1]
                saver(work_dir / 'best.pth')

            plot_hist(history, work_dir)
Пример #11
0
        encoder_name=checkpoint["encoder"],
        encoder_weights=checkpoint["encoder_weight"],
        classes=8,
        activation=checkpoint["activation"],
    ),
    "deeplabv3plus":
    smp.DeepLabV3Plus(
        encoder_name=checkpoint["encoder"],
        encoder_weights=checkpoint["encoder_weight"],
        classes=8,
        activation=checkpoint["activation"],
    ),
    "pan":
    smp.PAN(
        encoder_name=checkpoint["encoder"],
        encoder_weights=checkpoint["encoder_weight"],
        classes=8,
        activation=checkpoint["activation"],
    ),
}

device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = arch_dict[checkpoint["arch"]]
state_dict = OrderedDict()
for key, value in checkpoint["state_dict"].items():
    tmp = key[7:]
    state_dict[tmp] = value
model.load_state_dict(state_dict)
model.cuda()
model.eval()
preprocessing_fn = smp.encoders.get_preprocessing_fn(
    checkpoint["encoder"], checkpoint["encoder_weight"])
Пример #12
0
    def __init__(self, architecture='Unet', encoder='resnet34', depth=5, in_channels=3, classes=2, activation='softmax'):
        super(SegmentationModels, self).__init__()
        self.architecture = architecture
        self.encoder = encoder
        self.depth = depth
        self.in_channels = in_channels
        self.classes = classes
        self.activation = activation

        # define model
        _ARCHITECTURES = ['Unet', 'Linknet', 'FPN', 'PSPNet', 'PAN', 'DeepLabV3', 'DeepLabV3Plus']
        assert self.architecture in _ARCHITECTURES, 'architecture=={0}, actual \'{1}\''.format(_ARCHITECTURES, self.architecture)
        if self.architecture == 'Unet':
            self.model = smp.Unet(encoder_name=self.encoder,
                                  encoder_weights=None,
                                  encoder_depth=self.depth,
                                  in_channels=self.in_channels,
                                  classes=self.classes,
                                  activation=self.activation)
            self.pad_unit = 2 ** self.depth
        elif self.architecture == 'Linknet':
            self.model = smp.Linknet(encoder_name=self.encoder,
                                     encoder_weights=None,
                                     encoder_depth=self.depth,
                                     in_channels=self.in_channels,
                                     classes=self.classes,
                                     activation=self.activation)
            self.pad_unit = 2 ** self.depth
        elif self.architecture == 'FPN':
            self.model = smp.FPN(encoder_name=self.encoder,
                                 encoder_weights=None,
                                 encoder_depth=self.depth,
                                 in_channels=self.in_channels,
                                 classes=self.classes,
                                 activation=self.activation)
            self.pad_unit = 2 ** self.depth
        elif self.architecture == 'PSPNet':
            self.model = smp.PSPNet(encoder_name=self.encoder,
                                    encoder_weights=None,
                                    encoder_depth=self.depth,
                                    in_channels=self.in_channels,
                                    classes=self.classes,
                                    activation=self.activation)
            self.pad_unit = 2 ** self.depth
        elif self.architecture == 'PAN':
            self.model = smp.PAN(encoder_name=self.encoder,
                                 encoder_weights=None,
                                 encoder_depth=self.depth,
                                 in_channels=self.in_channels,
                                 classes=self.classes,
                                 activation=self.activation)
            self.pad_unit = 2 ** self.depth
        elif self.architecture == 'DeepLabV3':
            self.model = smp.DeepLabV3(encoder_name=self.encoder,
                                       encoder_weights=None,
                                       encoder_depth=self.depth,
                                       in_channels=self.in_channels,
                                       classes=self.classes,
                                       activation=self.activation)
            self.pad_unit = 2 ** self.depth
        elif self.architecture == 'DeepLabV3Plus':
            self.model = smp.DeepLabV3Plus(encoder_name=self.encoder,
                                       encoder_weights=None,
                                       encoder_depth=self.depth,
                                       in_channels=self.in_channels,
                                       classes=self.classes,
                                       activation=self.activation)
            self.pad_unit = 2 ** self.depth