def test_population_controller_mutate(): from aw_nas.controller import EvoController, ParetoEvoController device = "cuda" for ss, rollout_type in zip(("cnn", "ofa"), ("discrete", "ofa")): search_space = get_search_space(cls=ss) controller = EvoController(search_space, device, rollout_type=rollout_type, mode="train", population_size=1, parent_pool_size=1, mutate_kwargs={}, avoid_mutate_repeat=True, avoid_mutate_repeat_worst_threshold=1, eval_sample_strategy="population", elimination_strategy="regularized") controller.mode = "train" rollouts = controller.sample(1) for rollout in rollouts: rollout.set_perf(np.random.random(), "reward") controller.step(rollouts) new_rollout = controller.sample(1)[0] assert str(rollouts[0].genotype) != str(new_rollout.genotype) controller = ParetoEvoController(search_space, device, rollout_type=rollout_type, mode="train", init_population_size=1, perf_names=["reward"], eval_sample_strategy="all") controller.mode = "train" rollouts = controller.sample(1) for rollout in rollouts: rollout.set_perf(np.random.random(), "reward") controller.step(rollouts) new_rollout = controller.sample(1)[0] assert str(rollouts[0].genotype) != str(new_rollout.genotype)
def test_population_controller_avoid_repeat(): from aw_nas.controller import EvoController ss_cfgs = { "cell_layout": [0, 1, 0, 1, 0], "num_layers": 5, "shared_primitives": ["skip_connect", "sep_conv_3x3"], "num_init_nodes": 1, "num_steps": 3, "num_cell_groups": 2, "reduce_cell_groups": [1], } ss = get_search_space("cnn", **ss_cfgs) controller = EvoController(ss, "cuda", rollout_type="discrete", avoid_mutate_repeat=True, avoid_mutate_repeat_worst_threshold=3, avoid_repeat_fallback="raise", population_size=100, parent_pool_size=1) rollout = controller.sample(n=1)[0] rollout.set_perf(1.0) controller.set_mode("train") # make it the highest reward rollouts = [] for i in range(99): rollouts.append(ss.mutate(rollout).set_perf(0.3)) controller.step([rollout] + rollouts) controller.population_size = len(controller.population) with pytest.raises(Exception): for _ in range(3): controller.sample(n=1) controller.avoid_repeat_fallback = "return" # let's fallback to return, not rais for _ in range(3): controller.sample(n=1)