Esempio n. 1
0
def get_basenet(basenet,
                backbone,
                encoder_weights,
                classes,
                decoder_channels,
                activation='sigmoid'):
    if basenet == 'fpn':
        return smp.FPN(backbone,
                       encoder_weights=encoder_weights,
                       classes=classes,
                       activation=activation)
    elif basenet == 'psp':
        return smp.PSPNet(backbone,
                          encoder_weights=encoder_weights,
                          classes=classes,
                          activation=activation)
    elif basenet == 'deeplabv3':
        return smp.DeepLabV3(backbone,
                             encoder_weights=encoder_weights,
                             classes=classes,
                             activation=activation)

    return smp.Unet(backbone,
                    encoder_weights=encoder_weights,
                    encoder_depth=len(decoder_channels),
                    classes=classes,
                    decoder_channels=decoder_channels,
                    activation=activation)
Esempio n. 2
0
 def __init__(
     self,
     num_points=100,
 ):
     super().__init__()
     self.num_points = num_points
     self.model = smp.DeepLabV3("efficientnet-b0", in_channels=3)
 def __init__(self, in_channels=2, classes=2, multi_output=False):
     super(reconstruction_deeplab, self).__init__()
     # se_resnet101
     # resnet101
     # arctan
     # self.model = smp.Unet('se_resnet101',in_channels=in_channels,classes=classes,activation='sigmoid',encoder_weights=None)
     self.model = smp.DeepLabV3('se_resnet50',
                                in_channels=in_channels,
                                classes=classes,
                                activation='arctan',
                                encoder_weights=None)
Esempio n. 4
0
def build_model(configuration):
    model_list = ['UNet', 'LinkNet', 'PSPNet', 'FPN', 'PAN', 'Deeplab_v3', 'Deeplab_v3+']
    if configuration.Model.model_name.lower() == 'unet':
        return smp.Unet(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes,
            decoder_attention_type=None,
        )
    if configuration.Model.model_name.lower() == 'linknet':
        return smp.Linknet(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes
        )
    if configuration.Model.model_name.lower() == 'pspnet':
        return smp.PSPNet(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes
        )
    if configuration.Model.model_name.lower() == 'fpn':
        return smp.FPN(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes
        )
    if configuration.Model.model_name.lower() == 'pan':
        return smp.PAN(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes
        )
    if configuration.Model.model_name.lower() == 'deeplab_v3+':
        return smp.DeepLabV3Plus(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes
        )
    if configuration.Model.model_name.lower() == 'deeplab_v3':
        return smp.DeepLabV3(
            encoder_name=configuration.Model.encoder,
            encoder_weights=configuration.Model.encoder_weights,
            activation=None,
            classes=configuration.DataSet.number_of_classes
        )
    raise KeyError(f'Model should be one of {model_list}')
Esempio n. 5
0
def create_smp_model(arch, **kwargs):
    'Create segmentation_models_pytorch model'

    assert arch in ARCHITECTURES, f'Select one of {ARCHITECTURES}'

    if arch == "Unet": model = smp.Unet(**kwargs)
    elif arch == "UnetPlusPlus": model = smp.UnetPlusPlus(**kwargs)
    elif arch == "MAnet": model = smp.MAnet(**kwargs)
    elif arch == "FPN": model = smp.FPN(**kwargs)
    elif arch == "PAN": model = smp.PAN(**kwargs)
    elif arch == "PSPNet": model = smp.PSPNet(**kwargs)
    elif arch == "Linknet": model = smp.Linknet(**kwargs)
    elif arch == "DeepLabV3": model = smp.DeepLabV3(**kwargs)
    elif arch == "DeepLabV3Plus": model = smp.DeepLabV3Plus(**kwargs)
    else: raise NotImplementedError

    setattr(model, 'kwargs', kwargs)
    return model
 def __init__(self,
              num_classes=12,
              encoder="resnext101_32x8d",
              pretrain_weight="imagenet",
              decoder="DeepLabV3Plus"):
     super(smpModel, self).__init__()
     if (decoder == "DeepLabV3Plus"):
         self.backbone = smp.DeepLabV3Plus(encoder_name=encoder,
                                           encoder_weights=pretrain_weight,
                                           in_channels=3,
                                           classes=num_classes)
     elif (decoder == "DeepLabV3"):
         self.backbone = smp.DeepLabV3(encoder_name=encoder,
                                       encoder_weights=pretrain_weight,
                                       in_channels=3,
                                       classes=num_classes)
     elif (decoder == "UnetPlusPlus"):
         self.backbone = smp.UnetPlusPlus(encoder_name=encoder,
                                          encoder_weights=pretrain_weight,
                                          in_channels=3,
                                          classes=num_classes)
Esempio n. 7
0
def load_model(arch):
    if arch == 'fcn8s':
        VGG_model = VGGNet(requires_grad=True, remove_fc=True)
        net = FCN8s(pretrained_net=VGG_model, n_class=2)
    elif arch == 'unet_resnet34':
        net = smp.Unet('resnet34', encoder_weights=None, classes=2)
    elif arch == 'deeplab':
        net = smp.DeepLabV3('resnet50', encoder_weights='imagenet', classes=2)
    elif arch == 'unet_resnet50_pre':
        net = smp.Unet('resnet50', encoder_weights='imagenet', classes=2)
    elif arch == 'unet_resnet101_pre':
        net = smp.Unet('resnet101', encoder_weights='imagenet', classes=2)
    elif arch == 'unet_resnet50':
        net = smp.Unet('resnet50', encoder_weights=None, classes=2)
    elif arch == 'unet_resnet101':
        net = smp.Unet('resnet101', encoder_weights=None, classes=2)
    elif arch == 'unet_vgg_pre':
        net = smp.Unet('vgg16_bn', encoder_weights='imagenet', classes=2)
    elif arch == 'unet_vgg':
        net = smp.Unet('vgg16_bn', encoder_weights=None, classes=2)
    return net
Esempio n. 8
0
def get_model(config):
    """
    """
    arch = config.MODEL.ARCHITECTURE
    backbone = config.MODEL.BACKBONE
    encoder_weights = config.MODEL.ENCODER_PRETRAINED_FROM
    in_channels = config.MODEL.IN_CHANNELS
    n_classes = len(config.INPUT.CLASSES)
    activation = config.MODEL.ACTIVATION

    # unet specific
    decoder_attention_type = 'scse' if config.MODEL.UNET_ENABLE_DECODER_SCSE else None

    if arch == 'unet':
        model = smp.Unet(
            encoder_name=backbone,
            encoder_weights=encoder_weights,
            decoder_channels=config.MODEL.UNET_DECODER_CHANNELS,
            decoder_attention_type=decoder_attention_type,
            in_channels=in_channels,
            classes=n_classes,
            activation=activation
        )
    elif arch == 'fpn':
        model = smp.FPN(
            encoder_name=backbone,
            encoder_weights=encoder_weights,
            decoder_dropout=config.MODEL.FPN_DECODER_DROPOUT,
            in_channels=in_channels,
            classes=n_classes,
            activation=activation
        )
    elif arch == 'pan':
        model = smp.PAN(
            encoder_name=backbone,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=n_classes,
            activation=activation
        )
    elif arch == 'pspnet':
        model = smp.PSPNet(
            encoder_name=backbone,
            encoder_weights=encoder_weights,
            psp_dropout=config.MODEL.PSPNET_DROPOUT,
            in_channels=in_channels,
            classes=n_classes,
            activation=activation
        )
    elif arch == 'deeplabv3':
        model = smp.DeepLabV3(
            encoder_name=backbone,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=n_classes,
            activation=activation
        )
    elif arch == 'linknet':
        model = smp.Linknet(
            encoder_name=backbone,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=n_classes,
            activation=activation
        )
    else:
        raise ValueError()

    model = torch.nn.DataParallel(model)

    if config.MODEL.WEIGHT and config.MODEL.WEIGHT != 'none':
        # load weight from file
        model.load_state_dict(
            torch.load(
                config.MODEL.WEIGHT,
                map_location=torch.device('cpu')
            )
        )

    model = model.to(config.MODEL.DEVICE)
    return model
Esempio n. 9
0
    def __init__(self, architecture='Unet', encoder='resnet34', depth=5, in_channels=3, classes=2, activation='softmax'):
        super(SegmentationModels, self).__init__()
        self.architecture = architecture
        self.encoder = encoder
        self.depth = depth
        self.in_channels = in_channels
        self.classes = classes
        self.activation = activation

        # define model
        _ARCHITECTURES = ['Unet', 'Linknet', 'FPN', 'PSPNet', 'PAN', 'DeepLabV3', 'DeepLabV3Plus']
        assert self.architecture in _ARCHITECTURES, 'architecture=={0}, actual \'{1}\''.format(_ARCHITECTURES, self.architecture)
        if self.architecture == 'Unet':
            self.model = smp.Unet(encoder_name=self.encoder,
                                  encoder_weights=None,
                                  encoder_depth=self.depth,
                                  in_channels=self.in_channels,
                                  classes=self.classes,
                                  activation=self.activation)
            self.pad_unit = 2 ** self.depth
        elif self.architecture == 'Linknet':
            self.model = smp.Linknet(encoder_name=self.encoder,
                                     encoder_weights=None,
                                     encoder_depth=self.depth,
                                     in_channels=self.in_channels,
                                     classes=self.classes,
                                     activation=self.activation)
            self.pad_unit = 2 ** self.depth
        elif self.architecture == 'FPN':
            self.model = smp.FPN(encoder_name=self.encoder,
                                 encoder_weights=None,
                                 encoder_depth=self.depth,
                                 in_channels=self.in_channels,
                                 classes=self.classes,
                                 activation=self.activation)
            self.pad_unit = 2 ** self.depth
        elif self.architecture == 'PSPNet':
            self.model = smp.PSPNet(encoder_name=self.encoder,
                                    encoder_weights=None,
                                    encoder_depth=self.depth,
                                    in_channels=self.in_channels,
                                    classes=self.classes,
                                    activation=self.activation)
            self.pad_unit = 2 ** self.depth
        elif self.architecture == 'PAN':
            self.model = smp.PAN(encoder_name=self.encoder,
                                 encoder_weights=None,
                                 encoder_depth=self.depth,
                                 in_channels=self.in_channels,
                                 classes=self.classes,
                                 activation=self.activation)
            self.pad_unit = 2 ** self.depth
        elif self.architecture == 'DeepLabV3':
            self.model = smp.DeepLabV3(encoder_name=self.encoder,
                                       encoder_weights=None,
                                       encoder_depth=self.depth,
                                       in_channels=self.in_channels,
                                       classes=self.classes,
                                       activation=self.activation)
            self.pad_unit = 2 ** self.depth
        elif self.architecture == 'DeepLabV3Plus':
            self.model = smp.DeepLabV3Plus(encoder_name=self.encoder,
                                       encoder_weights=None,
                                       encoder_depth=self.depth,
                                       in_channels=self.in_channels,
                                       classes=self.classes,
                                       activation=self.activation)
            self.pad_unit = 2 ** self.depth
Esempio n. 10
0
 def __init__(self, encoder, num_classes):
     super(DeepLabV3, self).__init__()
     self.model = smp.DeepLabV3(encoder, classes=num_classes if num_classes != 2 else 1, in_channels=4)
Esempio n. 11
0
    def __init__(self,
                 architecture="Unet",
                 encoder="resnet34",
                 depth=5,
                 in_channels=3,
                 classes=2,
                 activation="softmax"):
        super(SegmentationModels, self).__init__()
        self.architecture = architecture
        self.encoder = encoder
        self.depth = depth
        self.in_channels = in_channels
        self.classes = classes
        self.activation = activation

        # define model

        _ARCHITECTURES = [
            "Unet", "UnetPlusPlus", "Linknet", "MAnet", "FPN", "PSPNet", "PAN",
            "DeepLabV3", "DeepLabV3Plus"
        ]
        assert self.architecture in _ARCHITECTURES, "architecture=={0}, actual '{1}'".format(
            _ARCHITECTURES, self.architecture)

        if self.architecture == "Unet":
            self.model = smp.Unet(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
        elif self.architecture == "UnetPlusPlus":
            self.model = smp.UnetPlusPlus(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
        elif self.architecture == "MAnet":
            self.model = smp.MAnet(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
        elif self.architecture == "Linknet":
            self.model = smp.Linknet(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
        elif self.architecture == "FPN":
            self.model = smp.FPN(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
        elif self.architecture == "PSPNet":
            self.model = smp.PSPNet(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
        elif self.architecture == "PAN":
            self.model = smp.PAN(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
        elif self.architecture == "DeepLabV3":
            self.model = smp.DeepLabV3(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
        elif self.architecture == "DeepLabV3Plus":
            self.model = smp.DeepLabV3Plus(
                encoder_name=self.encoder,
                encoder_weights=None,
                encoder_depth=self.depth,
                in_channels=self.in_channels,
                classes=self.classes,
                activation=self.activation,
            )
            self.pad_unit = 2**self.depth
Esempio n. 12
0
def smpdeeplab(config):
    return smp.DeepLabV3(config.encoder, encoder_weights='imagenet', in_channels=config.in_channels, classes=config.num_classes)
Esempio n. 13
0
     encoder_name=args.encoder,
     encoder_weights=args.weight,
     classes=8,
     activation=args.activation,
 ),
 "pspnet":
 smp.PSPNet(
     encoder_name=args.encoder,
     encoder_weights=args.weight,
     classes=8,
     activation=args.activation,
 ),
 "deeplabv3":
 smp.DeepLabV3(
     encoder_name=args.encoder,
     encoder_weights=args.weight,
     classes=8,
     activation=args.activation,
 ),
 "deeplabv3plus":
 smp.DeepLabV3Plus(
     encoder_name=args.encoder,
     encoder_weights=args.weight,
     classes=8,
     activation=args.activation,
     aux_params=aux_params_dict,
 ),
 "pan":
 smp.PAN(
     encoder_name=args.encoder,
     encoder_weights=args.weight,
     classes=8,
Esempio n. 14
0
     encoder_name=checkpoint["encoder"],
     encoder_weights=checkpoint["encoder_weight"],
     classes=8,
     activation=checkpoint["activation"],
 ),
 "pspnet":
 smp.PSPNet(
     encoder_name=checkpoint["encoder"],
     encoder_weights=checkpoint["encoder_weight"],
     classes=8,
     activation=checkpoint["activation"],
 ),
 "deeplabv3":
 smp.DeepLabV3(
     encoder_name=checkpoint["encoder"],
     encoder_weights=checkpoint["encoder_weight"],
     classes=8,
     activation=checkpoint["activation"],
 ),
 "deeplabv3plus":
 smp.DeepLabV3Plus(
     encoder_name=checkpoint["encoder"],
     encoder_weights=checkpoint["encoder_weight"],
     classes=8,
     activation=checkpoint["activation"],
 ),
 "pan":
 smp.PAN(
     encoder_name=checkpoint["encoder"],
     encoder_weights=checkpoint["encoder_weight"],
     classes=8,
     activation=checkpoint["activation"],
 def __init__(self, in_channels=3, classes=12):
     super(ResNextDeepLabV3AllTrain, self).__init__()
     self.backbone = smp.DeepLabV3(encoder_name="resnext101_32x8d",
                                   encoder_weights="imagenet",
                                   in_channels=in_channels,
                                   classes=classes)
Esempio n. 16
0
def main():
    args = parse_args()

    torch.backends.cudnn.benchmark = True

    args.distributed = False

    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    args.world_size = 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    # print(args.world_size, args.local_rank, args.distributed)

    cfg.merge_from_file(args.cfg)

    cfg.DIR = os.path.join(
        cfg.DIR,
        args.cfg.split('/')[-1].split('.')[0] +
        datetime.now().strftime('-%Y-%m-%d-%a-%H:%M:%S:%f'))

    # Output directory
    # if not os.path.isdir(cfg.DIR):
    if args.local_rank == 0:
        os.makedirs(cfg.DIR, exist_ok=True)
        os.makedirs(os.path.join(cfg.DIR, 'weight'), exist_ok=True)
        os.makedirs(os.path.join(cfg.DIR, 'history'), exist_ok=True)
        shutil.copy(args.cfg, cfg.DIR)

    if os.path.exists(os.path.join(cfg.DIR, 'log.txt')):
        os.remove(os.path.join(cfg.DIR, 'log.txt'))
    logger = setup_logger(distributed_rank=args.local_rank,
                          filename=os.path.join(cfg.DIR, 'log.txt'))
    logger.info("Loaded configuration file {}".format(args.cfg))
    logger.info("Running with config:\n{}".format(cfg))

    if cfg.MODEL.arch == 'deeplab':
        model = DeepLab(
            num_classes=cfg.DATASET.num_class,
            backbone=cfg.MODEL.backbone,  # resnet101
            output_stride=cfg.MODEL.os,
            ibn_mode=cfg.MODEL.ibn_mode,
            freeze_bn=False,
            num_low_level_feat=cfg.MODEL.num_low_level_feat)
    elif cfg.MODEL.arch == 'smp-deeplab':
        model = smp.DeepLabV3(encoder_name='resnet101', classes=7)
    elif cfg.MODEL.arch == 'FPN':
        model = smp.FPN(encoder_name='resnet101', classes=7)
    elif cfg.MODEL.arch == 'Unet':
        model = smp.Unet(encoder_name='resnet101', classes=7)

    convert_model(model, 4)
    from pytorch_model_summary import summary
    print(summary(model, torch.zeros((1, 4, 512, 512)), show_input=True))
    return
    model = apex.parallel.convert_syncbn_model(model)
    model = model.cuda()

    model = amp.initialize(model, opt_level="O1")

    if args.distributed:
        model = DDP(model, delay_allreduce=True)

    if cfg.TEST.checkpoint != "":
        if args.local_rank == 0:
            logger.info("Loading weight from {}".format(cfg.TEST.checkpoint))

        weight = torch.load(
            cfg.TEST.checkpoint,
            map_location=lambda storage, loc: storage.cuda(args.local_rank))

        if not args.distributed:
            weight = {k[7:]: v for k, v in weight.items()}

        model.load_state_dict(weight)

    dataset_test = AgriTestDataset(cfg.DATASET.root_dataset,
                                   cfg.DATASET.list_test, cfg.DATASET)

    test_sampler = None

    if args.distributed:
        test_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset_test, num_replicas=args.world_size, rank=args.local_rank)

    loader_test = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=cfg.TEST.batch_size_per_gpu,
        shuffle=False,  # we do not use this param
        drop_last=False,
        pin_memory=True,
        sampler=test_sampler)

    cfg.TEST.epoch_iters = len(loader_test)

    logger.info("World Size: {}".format(args.world_size))
    logger.info("TEST.epoch_iters: {}".format(cfg.TEST.epoch_iters))
    logger.info("TEST.sum_bs: {}".format(cfg.TEST.batch_size_per_gpu *
                                         args.world_size))

    test(loader_test, model, args, logger)
Esempio n. 17
0
def get_segmentation_model(
    arch: str,
    encoder_name: str,
    encoder_weights: Optional[str] = "imagenet",
    pretrained_checkpoint_path: Optional[str] = None,
    checkpoint_path: Optional[Union[str, List[str]]] = None,
    convert_bn: Optional[str] = None,
    convert_bottleneck: Tuple[int, int, int] = (0, 0, 0),
    **kwargs: Any,
) -> nn.Module:
    """
    Fetch segmentation model by its name
    :param arch:
    :param encoder_name:
    :param encoder_weights:
    :param checkpoint_path:
    :param pretrained_checkpoint_path:
    :param convert_bn:
    :param convert_bottleneck:
    :param kwargs:
    :return:
    """

    arch = arch.lower()
    if (encoder_name == "en_resnet34" or checkpoint_path is not None
            or pretrained_checkpoint_path is not None):
        encoder_weights = None

    if arch == "unet":
        model = smp.Unet(encoder_name=encoder_name,
                         encoder_weights=encoder_weights,
                         **kwargs)
    elif arch == "unetplusplus" or arch == "unet++":
        model = smp.UnetPlusPlus(encoder_name=encoder_name,
                                 encoder_weights=encoder_weights,
                                 **kwargs)
    elif arch == "linknet":
        model = smp.Linknet(encoder_name=encoder_name,
                            encoder_weights=encoder_weights,
                            **kwargs)
    elif arch == "pspnet":
        model = smp.PSPNet(encoder_name=encoder_name,
                           encoder_weights=encoder_weights,
                           **kwargs)
    elif arch == "pan":
        model = smp.PAN(encoder_name=encoder_name,
                        encoder_weights=encoder_weights,
                        **kwargs)
    elif arch == "deeplabv3":
        model = smp.DeepLabV3(encoder_name=encoder_name,
                              encoder_weights=encoder_weights,
                              **kwargs)
    elif arch == "deeplabv3plus" or arch == "deeplabv3+":
        model = smp.DeepLabV3Plus(encoder_name=encoder_name,
                                  encoder_weights=encoder_weights,
                                  **kwargs)
    elif arch == "manet":
        model = smp.MAnet(encoder_name=encoder_name,
                          encoder_weights=encoder_weights,
                          **kwargs)
    else:
        raise ValueError

    if pretrained_checkpoint_path is not None:
        print(f"Loading pretrained checkpoint {pretrained_checkpoint_path}")
        state_dict = torch.load(pretrained_checkpoint_path,
                                map_location=torch.device("cpu"))
        model.encoder.load_state_dict(state_dict)
        del state_dict

    # TODO fmap_size=16 hardcoded for input 256 (matters for positional encoding)
    botnet.convert_resnet(
        model.encoder,
        replacement=convert_bottleneck,
        fmap_size=16,
        position_encoding=None,
    )

    # TODO parametrize conversion
    print(f"Convert BN to {convert_bn}")
    if convert_bn == "instance":
        print("Converting BatchNorm2d to InstanceNorm2d")
        model = batch_norm2instance(model)
    elif convert_bn == "group":
        print("Converting BatchNorm2d to GroupNorm")
        model = batch_norm2group(model, channels_per_group=1)
    elif convert_bn == "bnet":
        print("Converting BatchNorm2d to BNet2d")
        model = batch_norm2bnet(model)
    elif convert_bn == "gnet":
        print("Converting BatchNorm2d to GNet2d")
        model = batch_norm2gnet(model, channels_per_group=1)
    elif not convert_bn:
        print("Do not convert BatchNorm2d")
    else:
        raise ValueError

    if checkpoint_path is not None:
        if not isinstance(checkpoint_path, list):
            checkpoint_path = [checkpoint_path]
        states = []
        for cp in checkpoint_path:
            # Load checkpoint
            print(f"\nLoading checkpoint {str(cp)}")
            state_dict = torch.load(
                cp, map_location=torch.device("cpu"))["model_state_dict"]
            states.append(state_dict)
        state_dict = average_weights(states)
        model.load_state_dict(state_dict)
        del state_dict

    return model
Esempio n. 18
0
def main():
    args = parse_args()

    torch.backends.cudnn.benchmark = True

    args.distributed = False

    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    args.world_size = 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    # print(args.world_size, args.local_rank, args.distributed)

    cfg.merge_from_file(args.cfg)

    cfg.DIR = os.path.join(cfg.DIR,
                           args.cfg.split('/')[-1].rstrip('.yaml') +
                           datetime.now().strftime('-%Y-%m-%d-%a-%H:%M:%S:%f'))

    # Output directory
    # if not os.path.isdir(cfg.DIR):
    if args.local_rank == 0:
        os.makedirs(cfg.DIR, exist_ok=True)
        os.makedirs(os.path.join(cfg.DIR, 'weight'), exist_ok=True)
        os.makedirs(os.path.join(cfg.DIR, 'history'), exist_ok=True)
        shutil.copy(args.cfg, cfg.DIR)

    if os.path.exists(os.path.join(cfg.DIR, 'log.txt')):
        os.remove(os.path.join(cfg.DIR, 'log.txt'))
    logger = setup_logger(distributed_rank=args.local_rank,
                          filename=os.path.join(cfg.DIR, 'log.txt'))
    logger.info("Loaded configuration file {}".format(args.cfg))
    logger.info("Running with config:\n{}".format(cfg))


    if cfg.MODEL.arch == 'deeplab':
        model = DeepLab(num_classes=cfg.DATASET.num_class,
                        backbone=cfg.MODEL.backbone,                  # resnet101
                        output_stride=cfg.MODEL.os,
                        ibn_mode=cfg.MODEL.ibn_mode,
                        freeze_bn=False,
                        num_low_level_feat=cfg.MODEL.num_low_level_feat,
                        interpolate_before_lastconv=cfg.MODEL.interpolate_before_lastconv)
    elif cfg.MODEL.arch == 'smp-deeplab':
        model = smp.DeepLabV3(encoder_name='resnet101', classes=7)
    elif cfg.MODEL.arch == 'FPN':
        model = smp.FPN(encoder_name='resnet101',classes=7)
    elif cfg.MODEL.arch == 'Unet':
        model = smp.Unet(encoder_name='resnet101',classes=7)
    elif cfg.MODEL.arch == 'Dinknet':
        if cfg.MODEL.backbone == 'resnet34':
            assert cfg.MODEL.ibn_mode == 'none'
            model = DinkNet34(num_classes=7)
            # weight = torch.load('pretrained/dinknet34.pth')
            # weight['finalconv3.weight'] = torch.Tensor(model.finalconv3.weight)
            # weight['finalconv3.bias'] = torch.Tensor(model.finalconv3.bias)
            # # weight.pop('finalconv3.weight')
            # # weight.pop('finalconv3.bias')
            # model.load_state_dict(weight)
        elif cfg.MODEL.backbone == 'resnet50':
            model = DinkNet50(num_classes=7, ibn_mode=cfg.MODEL.ibn_mode)
        elif cfg.MODEL.backbone == 'resnet101':
            model = DinkNet101(num_classes=7, ibn_mode=cfg.MODEL.ibn_mode)

    if cfg.DATASET.train_channels in ['rgbn', 'rgbr']:
        convert_model(model, 4)

    model = apex.parallel.convert_syncbn_model(model)
    model = model.cuda()

    loss_fn = ComposedLossWithLogits(dict(cfg.LOSS)).cuda()

    assert cfg.TRAIN.optim in ['SGD', 'Adam']

    if cfg.TRAIN.optim == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=cfg.TRAIN.lr,
                                    weight_decay=cfg.TRAIN.weight_decay,
                                    momentum=cfg.TRAIN.beta1)
    else:
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=cfg.TRAIN.lr)

    if cfg.MODEL.fp16:
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

    if args.distributed:
        model = DDP(model, delay_allreduce=True)

    if cfg.TRAIN.resume_checkpoint != "":
        if args.local_rank == 0:
            logger.info("Loading weight from {}".format(
                cfg.TRAIN.resume_checkpoint))

        weight = torch.load(cfg.TRAIN.resume_checkpoint,
                            map_location=lambda storage, loc: storage.cuda(args.local_rank))
        model.load_state_dict(weight)

    dataset_train = AgriTrainDataset(
        cfg.DATASET.root_dataset,
        cfg.DATASET.list_train,
        cfg.DATASET,
        channels=cfg.DATASET.train_channels)

    dataset_mixup = None

    if cfg.TRAIN.mixup_alpha > 0:
        dataset_mixup = AgriTrainDataset(
            cfg.DATASET.root_dataset,
            cfg.DATASET.list_train,
            cfg.DATASET,
            channels=cfg.DATASET.train_channels,
            reverse=True)

    dataset_vals = []

    for channels in cfg.DATASET.val_channels:
        dataset_vals.append(AgriValDataset(
            cfg.DATASET.root_dataset,
            cfg.DATASET.list_val,
            cfg.DATASET,
            channels=channels))

    # train_sampler, val_sampler = None, None

    # if args.distributed:
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        dataset_train,
        num_replicas=args.world_size,
        rank=args.local_rank
    )

    loader_train = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=cfg.TRAIN.batch_size_per_gpu,
        shuffle=False,  # we do not use this param
        num_workers=cfg.TRAIN.workers,
        drop_last=True,
        pin_memory=True,
        sampler=train_sampler
    )

    loader_mixup = None

    if cfg.TRAIN.mixup_alpha > 0:
        mixup_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset_mixup,
            num_replicas=args.world_size,
            rank=args.local_rank
        )

        loader_mixup = torch.utils.data.DataLoader(
            dataset_mixup,
            batch_size=cfg.TRAIN.batch_size_per_gpu,
            shuffle=False,  # we do not use this param
            num_workers=cfg.TRAIN.workers,
            drop_last=True,
            pin_memory=True,
            sampler=train_sampler
        )

    loader_vals = []

    for dataset_val in dataset_vals:
        val_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset_val,
            num_replicas=args.world_size,
            rank=args.local_rank
        )

        loader_vals.append(torch.utils.data.DataLoader(
            dataset_val,
            batch_size=cfg.VAL.batch_size_per_gpu,
            shuffle=False,  # we do not use this param
            num_workers=cfg.VAL.batch_size_per_gpu,
            drop_last=True,
            pin_memory=True,
            sampler=val_sampler
        ))

    cfg.TRAIN.epoch_iters = len(loader_train)
    cfg.VAL.epoch_iters = len(loader_vals[0])

    cfg.TRAIN.running_lr = cfg.TRAIN.lr
    # if cfg.TRAIN.lr_pow > 0:

    cfg.TRAIN.num_epoch = (cfg.TRAIN.iter_warmup + cfg.TRAIN.iter_static + cfg.TRAIN.iter_decay) \
                          // cfg.TRAIN.epoch_iters

    cfg.TRAIN.log_fmt = 'TRAIN >> Epoch: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, ' \
                        'lr: {:.6f}, Loss: {:.6f}'

    cfg.VAL.log_fmt = 'Mean IoU: {:.4f}\nMean Loss: {:.6f}'

    for name in loss_fn.names:
        cfg.TRAIN.log_fmt += ', {}_Loss: '.format(name) + '{:.6f}'
        cfg.VAL.log_fmt += '\nMean {} Loss: '.format(name) + '{:.6f}'

    # print(cfg.TRAIN.log_fmt)
    # print(cfg.VAL.log_fmt)

    logger.info("World Size: {}".format(args.world_size))
    logger.info("TRAIN.epoch_iters: {}".format(cfg.TRAIN.epoch_iters))
    logger.info("TRAIN.sum_bs: {}".format(cfg.TRAIN.batch_size_per_gpu *
                                          args.world_size))

    logger.info("VAL.epoch_iters: {}".format(cfg.VAL.epoch_iters))
    logger.info("VAL.sum_bs: {}".format(cfg.VAL.batch_size_per_gpu *
                                        args.world_size))

    logger.info("TRAIN.num_epoch: {}".format(cfg.TRAIN.num_epoch))

    history = init_history(cfg)

    for i in range(cfg.TRAIN.start_epoch,
                   cfg.TRAIN.start_epoch + cfg.TRAIN.num_epoch):
        # print(i, args.local_rank)
        train(i + 1, loader_train, loader_mixup, model, loss_fn, optimizer,
              history, args, logger)

        for loader_val in loader_vals:
            val(i + 1, loader_val, model, loss_fn,
                history, args, logger)

        if args.local_rank == 0:    
            checkpoint(model, history, cfg, i + 1, args, logger)
Esempio n. 19
0
def main():
    my_parser = argparse.ArgumentParser(description='Convert pytorch model to ONNX format.')
    
    my_parser.add_argument('Model',
                       metavar='model',
                       type=str,
                       help='The model to be converted.')
    
    my_parser.add_argument('Encoder',
                       metavar='encoder',
                       type=str,
                       help='The encoder for the model.')
    
    my_parser.add_argument('Encoder_Weights',
                       metavar='encoder_weights',
                       type=str,
                       help='The encoder weights for the model.')
    
    my_parser.add_argument('Classes',
                       metavar='classes',
                       type=str,
                       help='The number of classes the model can predict.')
        
    my_parser.add_argument('Model_Path',
                       metavar='model_path',
                       type=str,
                       help='The path to the saved pytorch model.')
      
    
    my_parser.add_argument('Output_Path',
                       metavar='output_path',
                       type=str,
                       help='The path where the onnx model will be saved.')
    
    args = my_parser.parse_args()
    
    model_name = args.Model
    ENCODER = args.Encoder
    ENCODER_WEIGHTS = args.Encoder_Weights
    CLASSES = int(args.Classes)
    ckpt_path = args.Model_Path
    output_path = args.Output_Path
    print(f"Model name :{model_name}")
    print(f"Model checkpoint path is :{ckpt_path}")
    print(f"Model output path is :{output_path}")
    
    ACTIVATION = None
  
    
    if model_name == 'Unet':
        # create segmentation model with pretrained encoder
        model = smp.Unet(
            encoder_name=ENCODER, 
            encoder_weights=ENCODER_WEIGHTS, 
            in_channels=3,
            classes=CLASSES, 
            activation=ACTIVATION,
            )
    elif model_name == 'FPN':
        # create segmentation model with pretrained encoder
        model = smp.FPN(
            encoder_name=ENCODER, 
            encoder_weights=ENCODER_WEIGHTS, 
            in_channels=3,
            classes=CLASSES, 
            activation=ACTIVATION,
            )
        
    elif model_name == 'DeepLab_v3':
        # create segmentation model with pretrained encoder
        model = smp.DeepLabV3(
            encoder_name=ENCODER, 
            encoder_weights=ENCODER_WEIGHTS, 
            in_channels=3,
            classes=CLASSES, 
            activation=ACTIVATION,
            )
    else:
        print("Unknown model. Please try other model.")
        exit()
        
    state = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
    new_state_dict = OrderedDict()
    for k, v in state["state_dict"].items():
        name = k[7:]
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)
    model.eval()

    input_var = torch.rand(1, 3, 128, 128)  # Use half of the original resolution.
    batch_size = 5
    # Export the model
    torch.onnx.export(model,                 # model being run
                  input_var,                 # model input (or a tuple for multiple inputs)
                  output_path, # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=11,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output']) # the model's output names
    
    print("Successfully converted the model to onnx format.")