示例#1
0
文件: builders.py 项目: taoari/XNAS
def sng_builder(category):
    if cfg.SNG.NAME == 'MDENAS':
        return CategoricalMDENAS(category, cfg.SNG.THETA_LR)
    elif cfg.SNG.NAME == 'DDPNAS':
        return CategoricalDDPNAS(category, cfg.SNG.PRUNING_STEP)
    elif cfg.SNG.NAME == 'SNG':
        return SNG(category)
    elif cfg.SNG.NAME == 'ASNG':
        return ASNG(category)
    elif cfg.SNG.NAME == 'dynamic_SNG':
        return Dynamic_SNG(category,
                           step=cfg.SNG.PRUNING_STEP,
                           pruning=cfg.SNG.PRUNING)
    elif cfg.SNG.NAME == 'dynamic_ASNG':
        return Dynamic_ASNG(category,
                            step=cfg.SNG.PRUNING_STEP,
                            pruning=cfg.SNG.PRUNING)
    elif cfg.SNG.NAME == 'MIGO':
        return MIGO(categories=category,
                    step=cfg.SNG.PRUNING_STEP,
                    pruning=cfg.SNG.PRUNING,
                    sample_with_prob=cfg.SNG.PROB_SAMPLING,
                    utility_function=cfg.SNG.UTILITY,
                    utility_function_hyper=cfg.SNG.UTILITY_FACTOR,
                    momentum=cfg.SNG.MOMENTUM,
                    gamma=cfg.SNG.GAMMA)
    else:
        raise NotImplementedError
示例#2
0
def get_optimizer(name,
                  category,
                  step=4,
                  gamma=0.9,
                  sample_with_prob=True,
                  utility_function='log',
                  utility_function_hyper=0.4):
    if name == 'DDPNAS':
        return CategoricalDDPNAS(category, 3)
    elif name == 'MDENAS':
        return CategoricalMDENAS(category, 0.01)
    elif name == 'SNG':
        return SNG(categories=category)
    elif name == 'ASNG':
        return ASNG(categories=category)
    elif name == 'dynamic_ASNG':
        return Dynamic_ASNG(categories=category,
                            step=step,
                            pruning=True,
                            sample_with_prob=sample_with_prob)
    elif name == 'dynamic_SNG':
        return Dynamic_SNG(categories=category,
                           step=step,
                           pruning=True,
                           sample_with_prob=sample_with_prob)
    elif name == 'MIGO':
        return MIGO(categories=category,
                    step=step,
                    lam=6,
                    pruning=True,
                    sample_with_prob=sample_with_prob,
                    utility_function='log',
                    utility_function_hyper=utility_function_hyper,
                    momentum=True,
                    gamma=gamma,
                    dynamic_sampling=False)
    elif name == 'GridSearch':
        return GridSearch(category)
    else:
        raise NotImplementedError