def test_ss(): from aw_nas.common import get_search_space, genotype_from_str ss = get_search_space("dense_rob", cell_layout=[0, 0, 1, 0, 0, 1, 0, 0], reduce_cell_groups=[1]) rollout = ss.random_sample() print(rollout.genotype) rollout_rec = ss.rollout_from_genotype(rollout.genotype) genotype_str = str(rollout.genotype) genotype_rec = genotype_from_str(genotype_str, ss) assert genotype_rec == rollout.genotype # test msrobnet-1560M and its search space ss = get_search_space( "dense_rob", num_cell_groups=8, num_init_nodes=2, cell_layout=[0, 1, 2, 3, 4, 5, 6, 7], reduce_cell_groups=[2, 5], num_steps=4, primitives=["none", "skip_connect", "sep_conv_3x3", "ResSepConv"], ) genotype_str = "DenseRobGenotype(normal_0='init_node~2+|skip_connect~0|sep_conv_3x3~1|+|none~0|skip_connect~1|skip_connect~2|+|none~0|sep_conv_3x3~1|ResSepConv~2|skip_connect~3|+|skip_connect~0|sep_conv_3x3~1|sep_conv_3x3~2|sep_conv_3x3~3|sep_conv_3x3~4|', normal_1='init_node~2+|ResSepConv~0|sep_conv_3x3~1|+|none~0|sep_conv_3x3~1|skip_connect~2|+|ResSepConv~0|ResSepConv~1|ResSepConv~2|none~3|+|none~0|skip_connect~1|sep_conv_3x3~2|sep_conv_3x3~3|skip_connect~4|', reduce_2='init_node~2+|ResSepConv~0|skip_connect~1|+|sep_conv_3x3~0|none~1|none~2|+|skip_connect~0|ResSepConv~1|sep_conv_3x3~2|none~3|+|ResSepConv~0|skip_connect~1|ResSepConv~2|none~3|skip_connect~4|', normal_3='init_node~2+|skip_connect~0|skip_connect~1|+|ResSepConv~0|none~1|ResSepConv~2|+|ResSepConv~0|none~1|ResSepConv~2|skip_connect~3|+|none~0|sep_conv_3x3~1|none~2|skip_connect~3|sep_conv_3x3~4|', normal_4='init_node~2+|ResSepConv~0|sep_conv_3x3~1|+|skip_connect~0|skip_connect~1|none~2|+|sep_conv_3x3~0|skip_connect~1|sep_conv_3x3~2|sep_conv_3x3~3|+|ResSepConv~0|ResSepConv~1|none~2|skip_connect~3|sep_conv_3x3~4|', reduce_5='init_node~2+|sep_conv_3x3~0|sep_conv_3x3~1|+|ResSepConv~0|sep_conv_3x3~1|ResSepConv~2|+|none~0|sep_conv_3x3~1|ResSepConv~2|sep_conv_3x3~3|+|ResSepConv~0|skip_connect~1|skip_connect~2|skip_connect~3|sep_conv_3x3~4|', normal_6='init_node~2+|none~0|ResSepConv~1|+|none~0|sep_conv_3x3~1|ResSepConv~2|+|skip_connect~0|ResSepConv~1|ResSepConv~2|none~3|+|ResSepConv~0|skip_connect~1|ResSepConv~2|none~3|sep_conv_3x3~4|', normal_7='init_node~2+|none~0|none~1|+|none~0|ResSepConv~1|none~2|+|sep_conv_3x3~0|skip_connect~1|ResSepConv~2|none~3|+|ResSepConv~0|skip_connect~1|skip_connect~2|none~3|ResSepConv~4|')" rec_genotype = genotype_from_str(genotype_str, ss) # test a genotype does not fit for this search space with pytest.raises(TypeError): genotype_str = "DenseRobGenotype(normal_0='init_node~1+|sep_conv_3x3~0|+|skip_connect~0|none~1|+|sep_conv_3x3~0|sep_conv_3x3~1|skip_connect~2|+|skip_connect~0|skip_connect~1|sep_conv_3x3~2|none~3|', reduce_1='init_node~1+|none~0|+|none~0|sep_conv_3x3~1|+|none~0|none~1|sep_conv_3x3~2|+|sep_conv_3x3~0|none~1|none~2|none~3|', normal_2='init_node~1+|sep_conv_3x3~0|+|sep_conv_3x3~0|skip_connect~1|+|skip_connect~0|none~1|skip_connect~2|+|skip_connect~0|sep_conv_3x3~1|sep_conv_3x3~2|skip_connect~3|', reduce_3='init_node~1+|skip_connect~0|+|skip_connect~0|skip_connect~1|+|none~0|sep_conv_3x3~1|sep_conv_3x3~2|+|skip_connect~0|skip_connect~1|sep_conv_3x3~2|skip_connect~3|', normal_4='init_node~1+|sep_conv_3x3~0|+|skip_connect~0|skip_connect~1|+|skip_connect~0|sep_conv_3x3~1|sep_conv_3x3~2|+|sep_conv_3x3~0|none~1|skip_connect~2|skip_connect~3|')" rec_genotype = genotype_from_str(genotype_str, ss)
def test_gen_model(case): from aw_nas.common import get_search_space from aw_nas.hardware.base import MixinProfilingSearchSpace from aw_nas.hardware.utils import Prim ss = get_search_space("ofa_mixin", **case["search_space_cfg"]) assert isinstance(ss, MixinProfilingSearchSpace) prof_prim_latencies = case["prof_prim_latencies"] if case["hwperfmodel_type"] == "regression": try: from sklearn import linear_model except ImportError as e: pytest.xfail( "Package 'scikit-learn' not found, this test case should fail") if case["hwperfmodel_type"] == "mlp": try: from sklearn import MLPRegressor except ImportError as e: pytest.xfail( "Package 'scikit-learn' not found, this test case should fail") hwobj_model = ss.parse_profiling_primitives(case["hwperfmodel_type"], case["hwperfmodel_cfg"]) hwobj_model.train(prof_prim_latencies) for prim, perf in hwobj_model._table.items(): assert isinstance(prim, Prim)
def test_hardware(case): from collections import namedtuple from aw_nas.objective.hardware import HardwareObjective from aw_nas.common import get_search_space latency = [ p["performances"]["latency"] for p in case["prof_nets"][0][0]["primitives"] ] ss = get_search_space("ofa_mixin", **case["search_space_cfg"]) if case["hardware_obj_type"] == "regression": try: from sklearn import linear_model except ImportError as e: pytest.xfail("Do not install scikit-learn, this should fail") obj = HardwareObjective(ss, case["prof_prims_cfg"], case["hardware_obj_type"], case["hardware_obj_cfg"]) obj.hwobj_models[0].train(case["prof_nets"]) rollout = ss.rollout_from_genotype(case["genotypes"]) C = namedtuple("cand_net", ["rollout"]) cand_net = C(rollout) perfs = obj.get_perfs(None, None, None, cand_net) assert 0 < perfs[0] < sum(latency)
def test_rob_weights_manager(): import re from aw_nas.common import get_search_space from aw_nas.weights_manager.base import BaseWeightsManager from aw_nas.final.base import FinalModel ss = get_search_space( "dense_rob", cell_layout=[0, 1, 2, 3, 4, 5], num_cell_groups=6, reduce_cell_groups=[1, 3], ) wm = BaseWeightsManager.get_class_("dense_rob_wm")(ss, "cuda") rollout = ss.random_sample() cand_net = wm.assemble_candidate(rollout) print("len parameters, all supernet params: ", len(list(cand_net.named_parameters()))) state_dict = cand_net.state_dict() print("partial statedict:", len(state_dict)) geno_str = str(rollout.genotype) model = FinalModel.get_class_("dense_rob_final_model")(ss, "cuda", geno_str) # remove `p_ops.<num>.` final_state_dict = { re.sub("p_ops\.\d+\.", "", key): value for key, value in state_dict.items() } model.load_state_dict(final_state_dict)
def test_mutate_and_evo(): from aw_nas.common import get_search_space search_space = get_search_space(cls="nasbench-201", load_nasbench=False) rollout = search_space.random_sample() mutated_rollout = search_space.mutate(rollout) print("before mutate: ", rollout) print("after mutate: ", mutated_rollout) from aw_nas.controller.evo import EvoController controller = EvoController(search_space, device="cpu", rollout_type="nasbench-201", mode="train", population_size=4, parent_pool_size=2) # random sample 4 rollouts = controller.sample(4) for rollout in rollouts: rollout.perf["reward"] = np.random.rand() controller.step(rollouts) new_rollouts = controller.sample(2) print(new_rollouts) for rollout in new_rollouts: rollout.perf["reward"] = np.random.rand() controller.step(new_rollouts) with controller.begin_mode("eval"): eval_rollouts = controller.sample(2) rewards = [r.perf["reward"] for r in rollouts + new_rollouts] print("all rollout rewards ever seen: ", rewards) print("eval sample (population): ", [r.perf["reward"] for r in eval_rollouts]) controller.eval_sample_strategy = "all" eval_rollouts = controller.sample(2) print("eval sample (all): ", [r.perf["reward"] for r in eval_rollouts])
def test_arch_comparator(case): from aw_nas.common import get_search_space from aw_nas.evaluator.arch_network import PointwiseComparator search_space = get_search_space(cls="cnn") device = "cuda" comparator = PointwiseComparator(search_space) comparator.to(device) batch_size = 4 archs_1 = [search_space.random_sample().arch for _ in range(batch_size)] archs_2 = [search_space.random_sample().arch for _ in range(batch_size)] # forward # true_scores = np.random.rand(batch_size * 2) true_scores = np.arange(0, 1.01, 1. / (2 * batch_size - 1)) scores = comparator.predict(archs_1 + archs_2) print("true scores:", true_scores) print("scores before {}:".format(case["method"]), scores) assert len(scores) == batch_size * 2 compare_res = comparator.compare(archs_1, archs_2) assert len(compare_res) == batch_size # update for _ in range(5): if case["method"] == "predict": comparator.update_predict(archs_1 + archs_2, true_scores) elif case["method"] == "compare": comparator.update_compare( archs_1, archs_2, true_scores[batch_size:] > true_scores[:batch_size]) elif case["method"] == "argsort": comparator.update_argsort([archs_1 + archs_2], [np.argsort(true_scores)[::-1]]) scores = comparator.predict(archs_1 + archs_2) print("scores after {}:".format(case["method"]), scores)
def test_pairwise_arch_comparator(case): from aw_nas.common import get_search_space from aw_nas.evaluator.arch_network import PairwiseComparator search_space = get_search_space(cls="cnn") device = "cuda" comparator = PairwiseComparator(search_space, **case) comparator.to(device) batch_size = 4 archs_1 = [search_space.random_sample().arch for _ in range(batch_size)] archs_2 = [search_space.random_sample().arch for _ in range(batch_size)] # forward before_inds = comparator.argsort_list(archs_1 + archs_2, batch_size=2) print(before_inds) before_res = comparator.compare(archs_1, archs_2) assert len(before_res) == batch_size print("[before] compare res: ", before_res) # update for _ in range(20): comparator.update_compare(archs_1, archs_2, [0, 1, 1, 1]) after_inds = comparator.argsort_list(archs_1 + archs_2, batch_size=2) print(after_inds) after_res = comparator.compare(archs_1, archs_2) assert len(after_res) == batch_size print("[after] compare res: ", after_res)
def test_population_controller_avoid_repeat(): from aw_nas.controller import EvoController ss_cfgs = { "cell_layout": [0, 1, 0, 1, 0], "num_layers": 5, "shared_primitives": ["skip_connect", "sep_conv_3x3"], "num_init_nodes": 1, "num_steps": 3, "num_cell_groups": 2, "reduce_cell_groups": [1], } ss = get_search_space("cnn", **ss_cfgs) controller = EvoController(ss, "cuda", rollout_type="discrete", avoid_mutate_repeat=True, avoid_mutate_repeat_worst_threshold=3, avoid_repeat_fallback="raise", population_size=100, parent_pool_size=1) rollout = controller.sample(n=1)[0] rollout.set_perf(1.0) controller.set_mode("train") # make it the highest reward rollouts = [] for i in range(99): rollouts.append(ss.mutate(rollout).set_perf(0.3)) controller.step([rollout] + rollouts) controller.population_size = len(controller.population) with pytest.raises(Exception): for _ in range(3): controller.sample(n=1) controller.avoid_repeat_fallback = "return" # let's fallback to return, not rais for _ in range(3): controller.sample(n=1)
def test_layer2_ss(tmp_path): from aw_nas.common import get_search_space, genotype_from_str, rollout_from_genotype_str ss = get_search_space("layer2", macro_search_space_cfg={ "num_cell_groups": 2, "cell_layout": [0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0], "width_choice": [0.25, 0.5, 0.75, 1.0], "reduce_cell_groups": [1] }, micro_search_space_cfg={ "num_cell_groups": 2, "num_steps": 4 }) rollout = ss.random_sample() print(rollout.genotype) rollout_rec = ss.rollout_from_genotype(rollout.genotype) assert rollout_rec == rollout # genotype from str genotype = genotype_from_str(str(rollout.genotype), ss) rollout_rec2 = rollout_from_genotype_str(str(rollout.genotype), ss) assert rollout_rec2 == rollout # plot path = os.path.join(str(tmp_path), "layer2") rollout.plot_arch(path, label="layer2 rollout example") print("Plot save to path: ", path)
def test_tfnas_macro_controller(): from aw_nas.common import get_search_space from aw_nas.btcs.layer2.controller import Layer2DiffController, MacroStagewiseDiffController, MicroDenseDiffController ss = get_search_space("layer2", macro_search_space_cfg={ "num_cell_groups": 2, "cell_layout": [0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0], "width_choice": [0.25, 0.5, 0.75, 1.], "reduce_cell_groups": [1] }, micro_search_space_cfg={ "num_cell_groups": 2, "num_steps": 4, }) device="cuda" macro_controller_cfg = { } micro_controller_cfg = { # "use_edge_normalization": True, } controller=Layer2DiffController(ss,"layer2", mode="eval", macro_controller_type="macro-sink-connect-diff", micro_controller_type="micro-dense-diff", macro_controller_cfg=macro_controller_cfg, micro_controller_cfg=micro_controller_cfg, ) rollouts = controller.sample(3) (rollouts[0].macro.arch[0]).sum().backward()
def test_layer2_final_model(genotype_str, tmp_path): from aw_nas.common import get_search_space, rollout_from_genotype_str from aw_nas.btcs.layer2.final_model import MacroSinkConnectFinalModel ss = get_search_space("layer2", macro_search_space_cfg={ "num_cell_groups": 2, "cell_layout": [0, 0, 1, 0, 0, 1, 0, 0], "width_choice": [0.25, 0.5, 0.75, 1.0], "reduce_cell_groups": [1] }, micro_search_space_cfg={ "num_cell_groups": 2, "num_steps": 4 }) if genotype_str is None: rollout = ss.random_sample() else: rollout = rollout_from_genotype_str(genotype_str, ss) final_model = MacroSinkConnectFinalModel( ss, "cuda", str(rollout.genotype), micro_model_type="micro-dense-model", micro_model_cfg={ "process_op_type": "nor_conv_1x1" }, init_channels=12, use_stem="conv_bn_3x3") data = _cnn_data(device="cuda", batch_size=2) logits = final_model(data[0]) assert logits.shape[-1] == 10
def test_derive_and_parse_derive(): import io import numpy as np from aw_nas.utils.common_utils import _parse_derive_file, _dump_with_perf from aw_nas.common import get_search_space output_f = io.StringIO() ss = get_search_space("cnn") rollouts = [ss.random_sample() for _ in range(6)] for rollout in rollouts[:4]: rollout.perf = { "reward": np.random.rand(), "other_perf": np.random.rand() } for i, rollout in enumerate(rollouts[:3]): _dump_with_perf(rollout, "str", output_f, index=i) for rollout in rollouts[3:]: _dump_with_perf(rollout, "str", output_f) input_f = io.StringIO(output_f.getvalue()) dct = _parse_derive_file(input_f) assert len(dct) == 4 # only 4 rollouts have performance information print(dct)
def test_diff_supernet_to_arch(diff_super_net): from torch import nn from aw_nas.common import get_search_space from aw_nas.controller import DiffController search_space = get_search_space(cls="cnn") device = "cuda" controller = DiffController(search_space, device) rollout = controller.sample(1)[0] cand_net = diff_super_net.assemble_candidate(rollout) data = _cnn_data() #pylint: disable=not-callable # default detach_arch=True, no grad w.r.t the controller param logits = cand_net.forward_data(data[0]) loss = nn.CrossEntropyLoss()(logits, data[1].cuda()) assert controller.cg_alphas[0].grad is None loss.backward() assert controller.cg_alphas[0].grad is None logits = cand_net.forward_data(data[0], detach_arch=False) loss = nn.CrossEntropyLoss()(logits, data[1].cuda()) assert controller.cg_alphas[0].grad is None loss.backward() assert controller.cg_alphas[0].grad is not None
def test_embedder(case): from aw_nas.evaluator.arch_network import ArchEmbedder from aw_nas.common import get_search_space nasbench_search_space = get_search_space("nasbench-101", load_nasbench=False) device = "cuda" embedder = ArchEmbedder.get_class_(case["embedder_type"])( nasbench_search_space, **case.get("embedder_cfg", {})) embedder.to(device) arch_1 = (np.array( [[0, 1, 0, 0, 1, 1, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 1], [0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0]], dtype=np.int8), [1, 2, 1, 1, 0]) arch_2 = (np.array( [[0, 1, 0, 1, 1, 0, 0], [0, 0, 1, 0, 0, 0, 1], [0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 1, 0, 1], [0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0]], dtype=np.int8), nasbench_search_space.op_to_idx([ 'input', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'maxpool3x3', 'conv3x3-bn-relu', 'none', 'output' ])) print(arch_1) print(arch_2) print(embedder.forward([arch_1, arch_2]))
def test_controller_network_cellwise_num_steps(case): import numpy as np from aw_nas.controller.rl_networks import BaseRLControllerNet from aw_nas.utils.exception import NasException search_space = get_search_space(cls="cnn", num_cell_groups=2, num_steps=[3, 6]) device = "cuda" cls = BaseRLControllerNet.get_class_(case["type"]) net0 = cls(search_space, device, cell_index=0) net1 = cls(search_space, device, cell_index=1) batch_size = 3 arch, log_probs, entropies, (hx, cx) = net0.sample(batch_size) assert len(hx) == net0.num_lstm_layers assert len(cx) == net0.num_lstm_layers assert hx[0].shape == (batch_size, net0.controller_hid) assert len(arch) == batch_size num_actions = len(arch[0][0]) + len(arch[0][1]) assert log_probs.shape == (batch_size, num_actions) assert entropies.shape == (batch_size, num_actions) assert len(arch[0][0]) == search_space.num_node_inputs * 3 batch_size = 3 arch, log_probs, entropies, (hx, cx) = net1.sample(batch_size) assert len(arch[0][0]) == search_space.num_node_inputs * 6 with pytest.raises(NasException): # cannot use shared network when search space have cellwise `num_steps` _ = cls(search_space, device, cell_index=None)
def test_dense_rollout(case): from aw_nas.common import get_search_space, genotype_from_str ss = get_search_space("cnn_dense", num_dense_blocks=len(case["arch"])) genotype = ss.genotype(case["arch"]) assert genotype.stem == case["stem"] for i, trans_c in enumerate(case["trans"]): assert getattr(genotype, "transition_{}".format(i)) == trans_c
def test_ss_plot(tmp_path): from aw_nas.common import get_search_space, plot_genotype ss_cfgs = { "cell_layout": [0, 1, 2, 3, 4, 5], "num_init_nodes": 2, "num_cell_groups": 6, "reduce_cell_groups": [1, 3], } ss = get_search_space("dense_rob", **ss_cfgs) rollout = ss.random_sample() path = os.path.join(str(tmp_path), "cell") rollout.plot_arch(path, label="robnas cell example") print(rollout.genotype) print("Plot save to path: ", path) rollout_2 = ss.random_sample() path_2 = os.path.join(str(tmp_path), "cell_2") plot_genotype(str(rollout_2.genotype), dest=path_2, cls="dense_rob", label="robnas cell example", **ss_cfgs) print(rollout_2.genotype) print("Plot save to path: ", path_2)
def test_ss(): from aw_nas.common import get_search_space ss = get_search_space("dense_rob", cell_layout=[0, 0, 1, 0, 0, 1, 0, 0], reduce_cell_groups=[1]) rollout = ss.random_sample() print(rollout.genotype) rollout_rec = ss.rollout_from_genotype(rollout.genotype)
def diff_super_net(request): cfg = getattr(request, "param", {}) scfg = cfg.pop("search_space_cfg", {}) from aw_nas.common import get_search_space from aw_nas.weights_manager import DiffSuperNet search_space = get_search_space(cls="cnn", **scfg) device = "cuda" net = DiffSuperNet(search_space, device, **cfg) return net
def ofa_super_net(request): cfg = getattr(request, "param", {}) scfg = cfg.pop("search_space_cfg", {}) from aw_nas.common import get_search_space from aw_nas.weights_manager import OFASupernet search_space = get_search_space(cls="ofa", **scfg) device = "cuda" net = OFASupernet(search_space, device, rollout_type="ofa", **cfg) return net
def morphism(request): cfg = getattr(request, "param", {}) scfg = cfg.pop("search_space_cfg", {}) from aw_nas.common import get_search_space from aw_nas.weights_manager import MorphismWeightsManager search_space = get_search_space(cls="cnn", **scfg) device = "cuda" net = MorphismWeightsManager(search_space, device, "mutation") return net
def rnn_diff_super_net(request): cfg = getattr(request, "param", {}) num_tokens = cfg.pop("num_tokens", 10) from aw_nas.common import get_search_space from aw_nas.weights_manager import RNNDiffSuperNet search_space = get_search_space(cls="rnn") device = "cuda" net = RNNDiffSuperNet(search_space, device, num_tokens, **cfg) return net
def nasbench_201(request): cfg = getattr(request, "param", {}) scfg = cfg.pop("search_space_cfg", {}) from aw_nas.common import get_search_space from aw_nas.btcs.nasbench_201 import NB201SharedNet search_space = get_search_space(cls="nasbench-201", **scfg) device = "cuda" net = NB201SharedNet(search_space, device, **cfg) return net
def init_from_dirs(cls, dirs, search_space=None, cfg_template_file=None): """ Init population from directories. Args: dirs: [directory paths] search_space: SearchSpace cfg_template_file: if not specified, default: "template.yaml" under `dirs[0]` Returns: Population There should be multiple meta-info (yaml) files named as "`<number>.yaml` under each directory, each of them specificy the meta information for a model in the population, with `<number>` represent its index. Note there should not be duplicate index, if there are duplicate index, rename or soft-link the files. In each meta-info file, the possible meta informations are: * genotype * train_config * checkpoint_path * (optional) confidence * perfs: a dict of performance name to performance value "template.yaml" under the first dir will be used as the template training config for new candidate model """ assert dirs, "No dirs specified!" if cfg_template_file is None: cfg_template_file = os.path.join(dirs[0], "template.yaml") with open(cfg_template_file, "r") as cfg_f: cfg_template = ConfigTemplate(yaml.safe_load(cfg_f)) _logger.getChild("population").info("Read the template config from %s", cfg_template_file) model_records = collections.OrderedDict() if search_space is None: # assume can parse search space from config template from aw_nas.common import get_search_space search_space = get_search_space(cfg_template["search_space_type"], **cfg_template["search_space_cfg"]) for _, dir_ in enumerate(dirs): meta_files = glob.glob(os.path.join(dir_, "*.yaml")) for fname in meta_files: if "template.yaml" in fname: # do not parse template.yaml continue index = int(os.path.basename(fname).rsplit(".", 1)[0]) expect( index not in model_records, "There are duplicate index: {}. rename or soft-link the files" .format(index)) model_records[index] = ModelRecord.init_from_file( fname, search_space) _logger.getChild("population").info( "Parsed %d directories, total %d model records loaded.", len(dirs), len(model_records)) return Population(search_space, model_records, cfg_template)
def random_derive(cfg, seed, out_file, n): with open(cfg, "r") as cfg_f: cfg_dct = yaml.safe_load(cfg_f) from aw_nas.common import get_search_space ss = get_search_space(cfg_dct["search_space_type"], **cfg_dct["search_space_cfg"]) with open(out_file, "w") as of: for i in range(n): rollout = ss.random_sample() yaml.safe_dump([str(rollout.genotype)], of) return out_file
def test_rollout_from_genotype_str(case): from aw_nas.common import get_search_space, rollout_from_genotype_str genotype_str = case.pop("genotype_str", None) ss = get_search_space(**case) if genotype_str: rec_rollout = rollout_from_genotype_str(genotype_str, ss) else: rollout = ss.random_sample() rec_rollout = rollout_from_genotype_str(str(rollout.genotype), ss) assert np.all(np.array(rec_rollout.arch) == np.array(rollout.arch))
def test_plot_arch(tmp_path): from aw_nas.common import get_search_space from aw_nas.btcs.nasbench_201 import NasBench201Rollout nasbench_ss = get_search_space("nasbench-201", load_nasbench=False) prefix = os.path.join(str(tmp_path), "nb201-cell") arch_1 = np.array([[0., 0., 0., 0.], [4., 0., 0., 0.], [2., 4., 0., 0.], [0., 0., 2., 0.]]) rollout = NasBench201Rollout(arch_1, search_space=nasbench_ss) print("genotype: ", rollout.genotype, "save to: ", prefix) rollout.plot_arch(prefix, label="test plot")
def test_diff_controller(): from aw_nas.controller import DiffController search_space = get_search_space(cls="cnn") device = "cuda" controller = DiffController(search_space, device) assert controller.cg_alphas[0].shape == ( 14, len(search_space.shared_primitives)) rollouts = controller.sample(3) assert isinstance(rollouts[0].genotype, search_space.genotype_type)
def test_dense_mutation_rollout(case): from aw_nas.common import get_search_space from aw_nas.rollout.mutation import ModelRecord from aw_nas.rollout.dense import DenseMutation, DenseMutationRollout ss_cfg = case.pop("search_space_cfg", {}) search_space = get_search_space("cnn_dense", **ss_cfg) par_genotype = search_space.genotype(case["arch"]) assert par_genotype.stem == case["stem"] for i, trans_c in enumerate(case["trans"]): assert getattr(par_genotype, "transition_{}".format(i)) == trans_c
def test_construct_final_densenet(): from aw_nas.common import get_search_space from aw_nas.final.dense import DenseGenotypeModel arch = [[4] * 6, [4] * 12, [4] * 24, [4] * 6] ss = get_search_space("cnn_dense", num_dense_blocks=len(arch)) genotype_str = str(ss.genotype(arch)) model = DenseGenotypeModel(ss, torch.device("cuda"), genotypes=genotype_str) data = _cnn_data() logits = model(data[0]) assert logits.shape[-1] == 10