예제 #1
0
def _strategy_factory(alias, space_type):
    # Some search space needs extra hooks
    extra_mutation_hooks = []
    nds_need_shape_alignment = '_smalldepth' in space_type
    if nds_need_shape_alignment:
        if alias in ['enas', 'random']:
            extra_mutation_hooks.append(NDSStagePathSampling.mutate)
        else:
            extra_mutation_hooks.append(NDSStageDifferentiable.mutate)

    if alias == 'darts':
        return stg.DARTS(mutation_hooks=extra_mutation_hooks)
    if alias == 'gumbel':
        return stg.GumbelDARTS(mutation_hooks=extra_mutation_hooks)
    if alias == 'proxyless':
        return stg.Proxyless()
    if alias == 'enas':
        return stg.ENAS(mutation_hooks=extra_mutation_hooks,
                        reward_metric_name='val_acc')
    if alias == 'random':
        return stg.RandomOneShot(mutation_hooks=extra_mutation_hooks)

    raise ValueError(f'Unrecognized strategy: {alias}')
예제 #2
0
파일: test_oneshot.py 프로젝트: maxpark/nni
 def strategy_fn(base_model, evaluator):
     if isinstance(base_model, MultiHeadAttentionNet):
         return strategy.ENAS(reward_metric_name='val_mse')
     return strategy.ENAS(reward_metric_name='val_acc')