示例#1
0
def test_nasnet_fixwd():
    # minimum
    ss = searchspace.DARTS(width=16, num_cells=4)
    _test_searchspace_on_dataset(ss)

    # medium
    ss = searchspace.NASNet(width=16, num_cells=12)
    _test_searchspace_on_dataset(ss)
示例#2
0
def test_nasnet_corner_case():
    # The case is that output channel of reduce cell and normal cell are different
    # CellPreprocessor needs to know whether its predecessors are normal cell / reduction cell
    arch = {
        "width": 32,
        "depth": 8,
        "normal/op_2_0": "max_pool_7x7",
        "normal/op_2_1": "conv_1x1",
        "normal/op_3_0": "sep_conv_5x5",
        "normal/op_3_1": "max_pool_7x7",
        "normal/op_4_0": "sep_conv_5x5",
        "normal/op_4_1": "conv_1x1",
        "normal/op_5_0": "max_pool_3x3",
        "normal/op_5_1": "sep_conv_5x5",
        "normal/op_6_0": "max_pool_7x7",
        "normal/op_6_1": "sep_conv_5x5",
        "normal/input_2_0": 0,
        "normal/input_2_1": 0,
        "normal/input_3_0": 0,
        "normal/input_3_1": 1,
        "normal/input_4_0": 1,
        "normal/input_4_1": 2,
        "normal/input_5_0": 0,
        "normal/input_5_1": 1,
        "normal/input_6_0": 0,
        "normal/input_6_1": 2,
        "reduce/op_2_0": "dil_conv_3x3",
        "reduce/op_2_1": "max_pool_7x7",
        "reduce/op_3_0": "dil_conv_3x3",
        "reduce/op_3_1": "dil_conv_3x3",
        "reduce/op_4_0": "conv_7x1_1x7",
        "reduce/op_4_1": "conv_7x1_1x7",
        "reduce/op_5_0": "max_pool_3x3",
        "reduce/op_5_1": "conv_1x1",
        "reduce/op_6_0": "sep_conv_7x7",
        "reduce/op_6_1": "sep_conv_3x3",
        "reduce/input_2_0": 1,
        "reduce/input_2_1": 1,
        "reduce/input_3_0": 0,
        "reduce/input_3_1": 1,
        "reduce/input_4_0": 2,
        "reduce/input_4_1": 1,
        "reduce/input_5_0": 0,
        "reduce/input_5_1": 4,
        "reduce/input_6_0": 3,
        "reduce/input_6_1": 3,
    }

    _test_searchspace_on_dataset(searchspace.NASNet(), arch=arch)
示例#3
0
def _hub_factory(alias):
    if alias == 'nasbench101':
        return ss.NasBench101()
    if alias == 'nasbench201':
        return ss.NasBench201()

    if alias == 'mobilenetv3':
        return ss.MobileNetV3Space()

    if alias == 'mobilenetv3_small':
        return ss.MobileNetV3Space(width_multipliers=(0.75, 1, 1.5),
                                   expand_ratios=(4, 6))
    if alias == 'proxylessnas':
        return ss.ProxylessNAS()
    if alias == 'shufflenet':
        return ss.ShuffleNetSpace()
    if alias == 'autoformer':
        return ss.AutoformerSpace()

    if '_smalldepth' in alias:
        num_cells = (4, 8)
    elif '_depth' in alias:
        num_cells = (8, 12)
    else:
        num_cells = 8

    if '_width' in alias:
        width = (8, 16)
    else:
        width = 16

    if '_imagenet' in alias:
        dataset = 'imagenet'
    else:
        dataset = 'cifar'

    if alias.startswith('nasnet'):
        return ss.NASNet(width=width, num_cells=num_cells, dataset=dataset)
    if alias.startswith('enas'):
        return ss.ENAS(width=width, num_cells=num_cells, dataset=dataset)
    if alias.startswith('amoeba'):
        return ss.AmoebaNet(width=width, num_cells=num_cells, dataset=dataset)
    if alias.startswith('pnas'):
        return ss.PNAS(width=width, num_cells=num_cells, dataset=dataset)
    if alias.startswith('darts'):
        return ss.DARTS(width=width, num_cells=num_cells, dataset=dataset)

    raise ValueError(f'Unrecognized space: {alias}')
示例#4
0
def test_nasnet():
    _test_searchspace_on_dataset(searchspace.NASNet())
    _test_searchspace_on_dataset(searchspace.ENAS())
    _test_searchspace_on_dataset(searchspace.AmoebaNet())
    _test_searchspace_on_dataset(searchspace.PNAS())
    _test_searchspace_on_dataset(searchspace.DARTS())