def sng_builder(category): if cfg.SNG.NAME == 'MDENAS': return CategoricalMDENAS(category, cfg.SNG.THETA_LR) elif cfg.SNG.NAME == 'DDPNAS': return CategoricalDDPNAS(category, cfg.SNG.PRUNING_STEP) elif cfg.SNG.NAME == 'SNG': return SNG(category) elif cfg.SNG.NAME == 'ASNG': return ASNG(category) elif cfg.SNG.NAME == 'dynamic_SNG': return Dynamic_SNG(category, step=cfg.SNG.PRUNING_STEP, pruning=cfg.SNG.PRUNING) elif cfg.SNG.NAME == 'dynamic_ASNG': return Dynamic_ASNG(category, step=cfg.SNG.PRUNING_STEP, pruning=cfg.SNG.PRUNING) elif cfg.SNG.NAME == 'MIGO': return MIGO(categories=category, step=cfg.SNG.PRUNING_STEP, pruning=cfg.SNG.PRUNING, sample_with_prob=cfg.SNG.PROB_SAMPLING, utility_function=cfg.SNG.UTILITY, utility_function_hyper=cfg.SNG.UTILITY_FACTOR, momentum=cfg.SNG.MOMENTUM, gamma=cfg.SNG.GAMMA) else: raise NotImplementedError
def get_optimizer(name, category, step=4, gamma=0.9, sample_with_prob=True, utility_function='log', utility_function_hyper=0.4): if name == 'DDPNAS': return CategoricalDDPNAS(category, 3) elif name == 'MDENAS': return CategoricalMDENAS(category, 0.01) elif name == 'SNG': return SNG(categories=category) elif name == 'ASNG': return ASNG(categories=category) elif name == 'dynamic_ASNG': return Dynamic_ASNG(categories=category, step=step, pruning=True, sample_with_prob=sample_with_prob) elif name == 'dynamic_SNG': return Dynamic_SNG(categories=category, step=step, pruning=True, sample_with_prob=sample_with_prob) elif name == 'MIGO': return MIGO(categories=category, step=step, lam=6, pruning=True, sample_with_prob=sample_with_prob, utility_function='log', utility_function_hyper=utility_function_hyper, momentum=True, gamma=gamma, dynamic_sampling=False) elif name == 'GridSearch': return GridSearch(category) else: raise NotImplementedError