def __call__(self):
        arch = self.configer.get('network', 'backbone')
        from lib.models.backbones.hrnet.hrnet_config import MODEL_CONFIGS

        if arch == 'hrnet18':
            arch_net = HighResolutionNet(MODEL_CONFIGS['hrnet18'],
                                         bn_type='inplace_abn',
                                         bn_momentum=0.1)
            arch_net = ModuleHelper.load_model(arch_net,
                                               pretrained=self.configer.get(
                                                   'network', 'pretrained'),
                                               all_match=False,
                                               network='hrnet')

        elif arch == 'hrnet32':
            arch_net = HighResolutionNet(MODEL_CONFIGS['hrnet32'],
                                         bn_type='inplace_abn',
                                         bn_momentum=0.1)
            arch_net = ModuleHelper.load_model(arch_net,
                                               pretrained=self.configer.get(
                                                   'network', 'pretrained'),
                                               all_match=False,
                                               network='hrnet')

        elif arch == 'hrnet48':
            arch_net = HighResolutionNet(MODEL_CONFIGS['hrnet48'],
                                         bn_type='inplace_abn',
                                         bn_momentum=0.1)
            arch_net = ModuleHelper.load_model(arch_net,
                                               pretrained=self.configer.get(
                                                   'network', 'pretrained'),
                                               all_match=False,
                                               network='hrnet')

        elif arch == 'hrnet64':
            arch_net = HighResolutionNet(MODEL_CONFIGS['hrnet64'],
                                         bn_type='inplace_abn',
                                         bn_momentum=0.1)
            arch_net = ModuleHelper.load_model(arch_net,
                                               pretrained=self.configer.get(
                                                   'network', 'pretrained'),
                                               all_match=False,
                                               network='hrnet')

        elif arch == 'hrnet2x20':
            arch_net = HighResolutionNext(MODEL_CONFIGS['hrnet2x20'],
                                          bn_type=self.configer.get(
                                              'network', 'bn_type'))
            arch_net = ModuleHelper.load_model(arch_net,
                                               pretrained=self.configer.get(
                                                   'network', 'pretrained'),
                                               all_match=False,
                                               network='hrnet')

        else:
            raise Exception('Architecture undefined!')

        return arch_net
Пример #2
0
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        import os
        from lib.models.tools.module_helper import ModuleHelper
        if os.path.exists(pretrained):
            ModuleHelper.load_model(model, pretrained, all_match=False)
        else:
            state_dict = load_state_dict_from_url(pretrained,
                                                  progress=progress)
            model.load_state_dict(state_dict, strict=False)
    return model
Пример #3
0
 def mobilenetv2(self):
     model = MobileNetV2()
     model = ModuleHelper.load_model(model,
                                     pretrained=self.configer.get(
                                         'network', 'pretrained'),
                                     all_match=False)
     return model
Пример #4
0
 def squeezenet(self):
     """Constructs a ResNet-18 model.
     Args:
         pretrained (bool): If True, returns a model pre-trained on Places
     """
     model = SqueezeNet()
     model = ModuleHelper.load_model(model, pretrained=self.configer.get('network', 'pretrained'), all_match=False)
     return model
 def deepbase_resnet101(self, pretrained=None, **kwargs):
     """Constructs a ResNet-101 model.
     Args:
         pretrained (bool): If True, returns a model pre-trained on Places
     """
     model = ResNet(Bottleneck, [3, 4, 23, 3],
                    deep_base=True,
                    norm_type=self.configer.get('network', 'norm_type'),
                    **kwargs)
     model = ModuleHelper.load_model(model, pretrained=pretrained)
     return model
 def resnet34(self, pretrained=None, **kwargs):
     """Constructs a ResNet-34 model.
     Args:
         pretrained (bool): If True, returns a model pre-trained on Places
     """
     model = ResNet(BasicBlock, [3, 4, 6, 3],
                    deep_base=False,
                    norm_type=self.configer.get('network', 'norm_type'),
                    **kwargs)
     model = ModuleHelper.load_model(model, pretrained=pretrained)
     return model
Пример #7
0
 def wide_resnet38(self, **kwargs):
     """Constructs a WideResNet-38 model.
     """
     model = WiderResNetA2([3, 3, 6, 3, 1, 1],
                           bn_type=self.configer.get('network', 'bn_type'),
                           **kwargs)
     model = ModuleHelper.load_model(model,
                                     pretrained=self.configer.get(
                                         'network', 'pretrained'),
                                     all_match=False,
                                     network="wide_resnet")
     return model
Пример #8
0
 def deepbase_resnet50(self, **kwargs):
     """Constructs a ResNet-50 model.
     Args:
         pretrained (bool): If True, returns a model pre-trained on Places
     """
     model = ResNet(Bottleneck, [3, 4, 6, 3],
                    deep_base=True,
                    bn_type=self.configer.get('network', 'bn_type'),
                    **kwargs)
     model = ModuleHelper.load_model(model,
                                     pretrained=self.configer.get(
                                         'network', 'pretrained'))
     return model
Пример #9
0
 def resnet18(self, **kwargs):
     """Constructs a ResNet-18 model.
     Args:
         pretrained (bool): If True, returns a model pre-trained on Places
     """
     model = ResNet(BasicBlock, [2, 2, 2, 2],
                    deep_base=False,
                    bn_type=self.configer.get('network', 'bn_type'),
                    **kwargs)
     model = ModuleHelper.load_model(model,
                                     pretrained=self.configer.get(
                                         'network', 'pretrained'))
     return model
Пример #10
0
 def deepbase_dcn_resnet101(self, **kwargs):
     """Constructs a ResNet-101 model.
     Args:
         pretrained (bool): If True, returns a model pre-trained on Places
     """
     model = DCNResNet(Bottleneck, [3, 4, 23, 3],
                       deep_base=True,
                       bn_type=self.configer.get('network', 'bn_type'),
                       **kwargs)
     model = ModuleHelper.load_model(model,
                                     all_match=False,
                                     pretrained=self.configer.get(
                                         'network', 'pretrained'),
                                     network="dcnet")
     return model
Пример #11
0
    def resnet152(self, **kwargs):
        """Constructs a ResNet-152 model.

        Args:
            pretrained (bool): If True, returns a model pre-trained on Places
        """
        model = ResNet(Bottleneck, [3, 8, 36, 3],
                       deep_base=False,
                       bn_type=self.configer.get('network', 'bn_type'),
                       **kwargs)
        model = ModuleHelper.load_model(model,
                                        all_match=False,
                                        pretrained=self.configer.get(
                                            'network', 'pretrained'),
                                        network="resnet152")
        return model
Пример #12
0
 def deepbase_resnest269(self, **kwargs):
     model = ResNeSt(Bottleneck, [3, 30, 48, 8],
                     radix=2,
                     groups=1,
                     bottleneck_width=64,
                     dilated=True,
                     dilation=4,
                     deep_stem=True,
                     stem_width=64,
                     avg_down=True,
                     avd=True,
                     avd_first=False,
                     bn_type=self.configer.get('network', 'bn_type'),
                     **kwargs)
     model = ModuleHelper.load_model(model,
                                     pretrained=self.configer.get(
                                         'network', 'pretrained'),
                                     all_match=False,
                                     network="resnest")
     return model
Пример #13
0
    def resnext101_32x48d(self, **kwargs):
        """Constructs a ResNeXt-101 32x48d model.

        Args:
            pretrained (bool): If True, returns a model pre-trained on ImageNet
            progress (bool): If True, displays a progress bar of the download to stderr
        """
        pretrained = False
        progress = False
        kwargs['groups'] = 32
        kwargs['width_per_group'] = 48
        model = ResNext('resnext101_32x48d',
                        Bottleneck, [3, 4, 23, 3],
                        pretrained,
                        progress,
                        bn_type=self.configer.get('network', 'bn_type'),
                        **kwargs)
        model = ModuleHelper.load_model(model,
                                        pretrained=self.configer.get(
                                            'network', 'pretrained'),
                                        all_match=False,
                                        network="resnext")
        return model
Пример #14
0
 def squeezenet_dilated8(self):
     model = DilatedSqueezeNet()
     model = ModuleHelper.load_model(model, pretrained=self.configer.get('network', 'pretrained'), all_match=False)
     return model