예제 #1
0
def get_model_params(model_name, override_params):
    """ Get the block args and global params for a given model """
    if model_name.startswith('efficientnet'):
        w, d, s, p = efficientnet_params(model_name)
        # note: all models have drop connect rate = 0.2
        blocks_args, global_params = efficientnet(
            width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s)
    else:
        raise NotImplementedError('model name is not pre-defined: %s' % model_name)
    if override_params:
        # ValueError will be raised here if override_params has fields not included in global_params.
        global_params = global_params._replace(**override_params)
    return blocks_args, global_params
예제 #2
0
파일: beng_eff_net.py 프로젝트: dodler/kgl
    def __init__(self,
                 name='efficientnet-b0',
                 pretrained=True,
                 input_bn=True,
                 dropout=0.3,
                 head='V1'):
        super().__init__()
        self.input_bn = input_bn
        self.name = name
        if pretrained:
            self.net = EfficientNet.from_pretrained(model_name=name)
        else:
            self.net = EfficientNet.from_name(model_name=name)

        params = efficientnet_params(model_name=name)
        Conv2d = get_same_padding_conv2d(image_size=params[2])
        conv_stem_filts = {
            'efficientnet-b0': 32,
            'efficientnet-b4': 48,
            'efficientnet-b7': 64,
        }

        linear_size = {
            'efficientnet-b0': 1280,
            'efficientnet-b4': 1792,
            'efficientnet-b7': 2560
        }

        self.net._conv_stem = Conv2d(1,
                                     conv_stem_filts[name],
                                     kernel_size=(3, 3),
                                     stride=(2, 2),
                                     bias=False)

        self.cls1, self.cls2, self.cls3 = get_head(False,
                                                   head,
                                                   in_size=linear_size[name])

        if input_bn:
            self.bn_in = nn.BatchNorm2d(1)
예제 #3
0
 def get_image_size(cls, model_name):
     cls._check_model_name_is_valid(model_name)
     _, _, res, _ = efficientnet_params(model_name)
     return res
예제 #4
0
import torch