def test_nasnet_fixwd(): # minimum ss = searchspace.DARTS(width=16, num_cells=4) _test_searchspace_on_dataset(ss) # medium ss = searchspace.NASNet(width=16, num_cells=12) _test_searchspace_on_dataset(ss)
def _hub_factory(alias): if alias == 'nasbench101': return ss.NasBench101() if alias == 'nasbench201': return ss.NasBench201() if alias == 'mobilenetv3': return ss.MobileNetV3Space() if alias == 'mobilenetv3_small': return ss.MobileNetV3Space(width_multipliers=(0.75, 1, 1.5), expand_ratios=(4, 6)) if alias == 'proxylessnas': return ss.ProxylessNAS() if alias == 'shufflenet': return ss.ShuffleNetSpace() if alias == 'autoformer': return ss.AutoformerSpace() if '_smalldepth' in alias: num_cells = (4, 8) elif '_depth' in alias: num_cells = (8, 12) else: num_cells = 8 if '_width' in alias: width = (8, 16) else: width = 16 if '_imagenet' in alias: dataset = 'imagenet' else: dataset = 'cifar' if alias.startswith('nasnet'): return ss.NASNet(width=width, num_cells=num_cells, dataset=dataset) if alias.startswith('enas'): return ss.ENAS(width=width, num_cells=num_cells, dataset=dataset) if alias.startswith('amoeba'): return ss.AmoebaNet(width=width, num_cells=num_cells, dataset=dataset) if alias.startswith('pnas'): return ss.PNAS(width=width, num_cells=num_cells, dataset=dataset) if alias.startswith('darts'): return ss.DARTS(width=width, num_cells=num_cells, dataset=dataset) raise ValueError(f'Unrecognized space: {alias}')
def test_nasnet(): _test_searchspace_on_dataset(searchspace.NASNet()) _test_searchspace_on_dataset(searchspace.ENAS()) _test_searchspace_on_dataset(searchspace.AmoebaNet()) _test_searchspace_on_dataset(searchspace.PNAS()) _test_searchspace_on_dataset(searchspace.DARTS())