コード例 #1
0
ファイル: test_controller.py プロジェクト: zzzDavid/aw_nas
def test_population_controller_mutate():
    from aw_nas.controller import EvoController, ParetoEvoController

    device = "cuda"

    for ss, rollout_type in zip(("cnn", "ofa"), ("discrete", "ofa")):
        search_space = get_search_space(cls=ss)
        controller = EvoController(search_space,
                                   device,
                                   rollout_type=rollout_type,
                                   mode="train",
                                   population_size=1,
                                   parent_pool_size=1,
                                   mutate_kwargs={},
                                   avoid_mutate_repeat=True,
                                   avoid_mutate_repeat_worst_threshold=1,
                                   eval_sample_strategy="population",
                                   elimination_strategy="regularized")

        controller.mode = "train"
        rollouts = controller.sample(1)
        for rollout in rollouts:
            rollout.set_perf(np.random.random(), "reward")
        controller.step(rollouts)
        new_rollout = controller.sample(1)[0]
        assert str(rollouts[0].genotype) != str(new_rollout.genotype)

        controller = ParetoEvoController(search_space,
                                         device,
                                         rollout_type=rollout_type,
                                         mode="train",
                                         init_population_size=1,
                                         perf_names=["reward"],
                                         eval_sample_strategy="all")

        controller.mode = "train"
        rollouts = controller.sample(1)
        for rollout in rollouts:
            rollout.set_perf(np.random.random(), "reward")
        controller.step(rollouts)
        new_rollout = controller.sample(1)[0]
        assert str(rollouts[0].genotype) != str(new_rollout.genotype)
コード例 #2
0
ファイル: test_controller.py プロジェクト: zzzDavid/aw_nas
def test_population_controller_avoid_repeat():
    from aw_nas.controller import EvoController
    ss_cfgs = {
        "cell_layout": [0, 1, 0, 1, 0],
        "num_layers": 5,
        "shared_primitives": ["skip_connect", "sep_conv_3x3"],
        "num_init_nodes": 1,
        "num_steps": 3,
        "num_cell_groups": 2,
        "reduce_cell_groups": [1],
    }
    ss = get_search_space("cnn", **ss_cfgs)
    controller = EvoController(ss,
                               "cuda",
                               rollout_type="discrete",
                               avoid_mutate_repeat=True,
                               avoid_mutate_repeat_worst_threshold=3,
                               avoid_repeat_fallback="raise",
                               population_size=100,
                               parent_pool_size=1)
    rollout = controller.sample(n=1)[0]
    rollout.set_perf(1.0)
    controller.set_mode("train")
    # make it the highest reward
    rollouts = []
    for i in range(99):
        rollouts.append(ss.mutate(rollout).set_perf(0.3))
    controller.step([rollout] + rollouts)
    controller.population_size = len(controller.population)
    with pytest.raises(Exception):
        for _ in range(3):
            controller.sample(n=1)

    controller.avoid_repeat_fallback = "return"  # let's fallback to return, not rais
    for _ in range(3):
        controller.sample(n=1)