Пример #1
0
def regnety(name, pretrained=False, nc=1000):
    """Constructs a RegNetY model."""
    is_valid = name in _REGNETY_URLS.keys() and name in _REGNETY_CFGS.keys()
    assert is_valid, "RegNetY-{} not found in the model zoo.".format(name)
    # Construct the model
    cfg = _REGNETY_CFGS[name]
    kwargs = {
        "stem_type": "simple_stem_in",
        "stem_w": 32,
        "block_type": "res_bottleneck_block",
        "ss": [2, 2, 2, 2],
        "bms": [1.0, 1.0, 1.0, 1.0],
        "se_r": 0.25,
        "nc": nc,
        "ds": cfg["ds"],
        "ws": cfg["ws"],
        "gws": [cfg["g"] for _ in range(4)],
    }
    model = AnyNet(**kwargs)
    # Download and load the weights
    if pretrained:
        url = os.path.join(_URL_PREFIX, _REGNETY_URLS[name])
        ws_path = cache_url(url, _DOWNLOAD_CACHE)
        checkpoint.load_checkpoint(ws_path, model)
    return model
Пример #2
0
def resnext(name, pretrained=False, nc=1000):
    """Constructs a ResNeXt model."""
    is_valid = name in _RESNEXT_URLS.keys() and name in _RESNEXT_CFGS.keys()
    assert is_valid, "ResNet-{} not found in the model zoo.".format(name)
    # Construct the model
    cfg = _RESNEXT_CFGS[name]
    kwargs = {
        "stem_type": "res_stem_in",
        "stem_w": 64,
        "block_type": "res_bottleneck_block",
        "ss": [1, 2, 2, 2],
        "bms": [0.5, 0.5, 0.5, 0.5],
        "se_r": None,
        "nc": nc,
        "ds": cfg["ds"],
        "ws": [256, 512, 1024, 2048],
        "gws": [4, 8, 16, 32],
    }
    model = AnyNet(**kwargs)
    # Download and load the weights
    if pretrained:
        url = os.path.join(_URL_PREFIX, _RESNEXT_URLS[name])
        ws_path = cache_url(url, _DOWNLOAD_CACHE)
        checkpoint.load_checkpoint(ws_path, model)
    return model
Пример #3
0
 def get_network(self, uid):
     netinfo = self.data[uid]
     config = netinfo['net']
     #print(config)
     if 'genotype' in config:
         #print('geno')
         gen = config['genotype']
         genotype = Genotype(normal=gen['normal'],
                             normal_concat=gen['normal_concat'],
                             reduce=gen['reduce'],
                             reduce_concat=gen['reduce_concat'])
         if '_in' in self.searchspace:
             network = NetworkImageNet(config['width'], 1, config['depth'],
                                       config['aux'], genotype)
         else:
             network = NetworkCIFAR(config['width'], 1, config['depth'],
                                    config['aux'], genotype)
         network.drop_path_prob = 0.
         #print(config)
         #print('genotype')
         L = config['depth']
     else:
         if 'bot_muls' in config and 'bms' not in config:
             config['bms'] = config['bot_muls']
             del config['bot_muls']
         if 'num_gs' in config and 'gws' not in config:
             config['gws'] = config['num_gs']
             del config['num_gs']
         config['nc'] = 1
         config['se_r'] = None
         config['stem_w'] = 12
         L = sum(config['ds'])
         if 'ResN' in self.searchspace:
             config['stem_type'] = 'res_stem_in'
         else:
             config['stem_type'] = 'simple_stem_in'
         #"res_stem_cifar": ResStemCifar,
         #"res_stem_in": ResStemIN,
         #"simple_stem_in": SimpleStemIN,
         if config['block_type'] == 'double_plain_block':
             config['block_type'] = 'vanilla_block'
         network = AnyNet(**config)
     return_feature_layer(network)
     return network
Пример #4
0
 def complexity(cx, params=None):
     """Computes model complexity (if you alter the model, make sure to update)."""
     params = RegNet.get_params() if not params else params
     return AnyNet.complexity(cx, params)
Пример #5
0
 def complexity(cx, **kwargs):
     """Computes model complexity. If you alter the model, make sure to update."""
     kwargs = RegNet.get_args() if not kwargs else kwargs
     return AnyNet.complexity(cx, **kwargs)