def _strategy_factory(alias, space_type): # Some search space needs extra hooks extra_mutation_hooks = [] nds_need_shape_alignment = '_smalldepth' in space_type if nds_need_shape_alignment: if alias in ['enas', 'random']: extra_mutation_hooks.append(NDSStagePathSampling.mutate) else: extra_mutation_hooks.append(NDSStageDifferentiable.mutate) if alias == 'darts': return stg.DARTS(mutation_hooks=extra_mutation_hooks) if alias == 'gumbel': return stg.GumbelDARTS(mutation_hooks=extra_mutation_hooks) if alias == 'proxyless': return stg.Proxyless() if alias == 'enas': return stg.ENAS(mutation_hooks=extra_mutation_hooks, reward_metric_name='val_acc') if alias == 'random': return stg.RandomOneShot(mutation_hooks=extra_mutation_hooks) raise ValueError(f'Unrecognized strategy: {alias}')
def strategy_fn(base_model, evaluator): if isinstance(base_model, MultiHeadAttentionNet): return strategy.ENAS(reward_metric_name='val_mse') return strategy.ENAS(reward_metric_name='val_acc')