Esempio n. 1
0
def _hub_factory(alias):
    if alias == 'nasbench101':
        return ss.NasBench101()
    if alias == 'nasbench201':
        return ss.NasBench201()

    if alias == 'mobilenetv3':
        return ss.MobileNetV3Space()

    if alias == 'mobilenetv3_small':
        return ss.MobileNetV3Space(width_multipliers=(0.75, 1, 1.5),
                                   expand_ratios=(4, 6))
    if alias == 'proxylessnas':
        return ss.ProxylessNAS()
    if alias == 'shufflenet':
        return ss.ShuffleNetSpace()
    if alias == 'autoformer':
        return ss.AutoformerSpace()

    if '_smalldepth' in alias:
        num_cells = (4, 8)
    elif '_depth' in alias:
        num_cells = (8, 12)
    else:
        num_cells = 8

    if '_width' in alias:
        width = (8, 16)
    else:
        width = 16

    if '_imagenet' in alias:
        dataset = 'imagenet'
    else:
        dataset = 'cifar'

    if alias.startswith('nasnet'):
        return ss.NASNet(width=width, num_cells=num_cells, dataset=dataset)
    if alias.startswith('enas'):
        return ss.ENAS(width=width, num_cells=num_cells, dataset=dataset)
    if alias.startswith('amoeba'):
        return ss.AmoebaNet(width=width, num_cells=num_cells, dataset=dataset)
    if alias.startswith('pnas'):
        return ss.PNAS(width=width, num_cells=num_cells, dataset=dataset)
    if alias.startswith('darts'):
        return ss.DARTS(width=width, num_cells=num_cells, dataset=dataset)

    raise ValueError(f'Unrecognized space: {alias}')
Esempio n. 2
0
def test_nasnet_imagenet():
    ss = searchspace.ENAS(dataset='imagenet')
    _test_searchspace_on_dataset(ss, dataset='imagenet')

    ss = searchspace.PNAS(dataset='imagenet')
    _test_searchspace_on_dataset(ss, dataset='imagenet')
Esempio n. 3
0
def test_nasnet():
    _test_searchspace_on_dataset(searchspace.NASNet())
    _test_searchspace_on_dataset(searchspace.ENAS())
    _test_searchspace_on_dataset(searchspace.AmoebaNet())
    _test_searchspace_on_dataset(searchspace.PNAS())
    _test_searchspace_on_dataset(searchspace.DARTS())