コード例 #1
0
def _strategy_factory(alias, space_type):
    # Some search space needs extra hooks
    extra_mutation_hooks = []
    nds_need_shape_alignment = '_smalldepth' in space_type
    if nds_need_shape_alignment:
        if alias in ['enas', 'random']:
            extra_mutation_hooks.append(NDSStagePathSampling.mutate)
        else:
            extra_mutation_hooks.append(NDSStageDifferentiable.mutate)

    if alias == 'darts':
        return stg.DARTS(mutation_hooks=extra_mutation_hooks)
    if alias == 'gumbel':
        return stg.GumbelDARTS(mutation_hooks=extra_mutation_hooks)
    if alias == 'proxyless':
        return stg.Proxyless()
    if alias == 'enas':
        return stg.ENAS(mutation_hooks=extra_mutation_hooks,
                        reward_metric_name='val_acc')
    if alias == 'random':
        return stg.RandomOneShot(mutation_hooks=extra_mutation_hooks)

    raise ValueError(f'Unrecognized strategy: {alias}')
コード例 #2
0
ファイル: test_oneshot.py プロジェクト: maxpark/nni
def test_darts():
    _test_strategy(strategy.DARTS())
コード例 #3
0
def test_darts_multi_gpu():
    _test_strategy(strategy.DARTS(), multi_gpu=True)