def test_rnn_diff_supernet_forward(rnn_diff_super_net): from aw_nas.controller import DiffController time_steps = 5 batch_size = 2 _num_tokens = rnn_diff_super_net.num_tokens _num_hid = rnn_diff_super_net.num_hid _num_layers = rnn_diff_super_net._num_layers search_space = rnn_diff_super_net.search_space device = "cuda" controller = DiffController(search_space, device) rollout = controller.sample(1)[0] cand_net = rnn_diff_super_net.assemble_candidate(rollout) # init hiddens hiddens = rnn_diff_super_net.init_hidden(batch_size) data = _rnn_data(time_steps, batch_size, _num_tokens) logits, _, outs, next_hiddens = cand_net.forward_data(data[0], mode="eval", hiddens=hiddens) assert tuple(logits.shape) == (time_steps, batch_size, _num_tokens) assert tuple(outs.shape) == (time_steps, batch_size, _num_hid) assert len(next_hiddens) == _num_layers # the value is equal to the calculated results, the hidden is modified in-place assert (hiddens == next_hiddens).all()
def test_rnn_diff_supernet_to_arch(rnn_diff_super_net): from aw_nas.controller import DiffController search_space = rnn_diff_super_net.search_space device = "cuda" controller = DiffController(search_space, device) rollout = controller.sample(1)[0] cand_net = rnn_diff_super_net.assemble_candidate(rollout) time_steps = 5 batch_size = 2 _num_tokens = rnn_diff_super_net.num_tokens data = _rnn_data(time_steps, batch_size, _num_tokens) hiddens = rnn_diff_super_net.init_hidden(batch_size) # default detach_arch=True, no grad w.r.t the controller param results = cand_net.forward_data(data[0], hiddens=hiddens) loss = _rnn_criterion(data[0], results, data[1].cuda()) assert controller.cg_alphas[0].grad is None loss.backward() assert controller.cg_alphas[0].grad is None results = cand_net.forward_data(data[0], hiddens=hiddens, detach_arch=False) loss = _rnn_criterion(data[0], results, data[1].cuda()) assert controller.cg_alphas[0].grad is None loss.backward() assert controller.cg_alphas[0].grad is not None
def test_diff_supernet_to_arch(diff_super_net): from torch import nn from aw_nas.common import get_search_space from aw_nas.controller import DiffController search_space = get_search_space(cls="cnn") device = "cuda" controller = DiffController(search_space, device) rollout = controller.sample(1)[0] cand_net = diff_super_net.assemble_candidate(rollout) data = _cnn_data() #pylint: disable=not-callable # default detach_arch=True, no grad w.r.t the controller param logits = cand_net.forward_data(data[0]) loss = nn.CrossEntropyLoss()(logits, data[1].cuda()) assert controller.cg_alphas[0].grad is None loss.backward() assert controller.cg_alphas[0].grad is None logits = cand_net.forward_data(data[0], detach_arch=False) loss = nn.CrossEntropyLoss()(logits, data[1].cuda()) assert controller.cg_alphas[0].grad is None loss.backward() assert controller.cg_alphas[0].grad is not None
def test_rnn_diff_supernet_forward(rnn_diff_super_net): if version.parse(torch.__version__).minor >= 7: pytest.xfail( "FIXME: We currently do not fix this bug yet. When using torch>=1.7.0, " "we encountered: Warning: Error detected in SplitBackward. " "RuntimeError: one of the variables needed for gradient computation " "has been modified by an inplace operation") from aw_nas.controller import DiffController time_steps = 5 batch_size = 2 _num_tokens = rnn_diff_super_net.num_tokens _num_hid = rnn_diff_super_net.num_hid _num_layers = rnn_diff_super_net._num_layers search_space = rnn_diff_super_net.search_space device = "cuda" controller = DiffController(search_space, device) rollout = controller.sample(1)[0] cand_net = rnn_diff_super_net.assemble_candidate(rollout) # init hiddens hiddens = rnn_diff_super_net.init_hidden(batch_size) data = _rnn_data(time_steps, batch_size, _num_tokens) logits, _, outs, next_hiddens = cand_net.forward_data(data[0], mode="eval", hiddens=hiddens) assert tuple(logits.shape) == (time_steps, batch_size, _num_tokens) assert tuple(outs.shape) == (time_steps, batch_size, _num_hid) assert len(next_hiddens) == _num_layers # the value is equal to the calculated results, the hidden is modified in-place assert (hiddens == next_hiddens).all()
def test_diff_controller(): from aw_nas.controller import DiffController search_space = get_search_space(cls="cnn") device = "cuda" controller = DiffController(search_space, device) assert controller.cg_alphas[0].shape == ( 14, len(search_space.shared_primitives)) rollouts = controller.sample(3) assert isinstance(rollouts[0].genotype, search_space.genotype_type)
def test_diff_controller_rollout_batch_size(): import numpy as np from aw_nas.controller import DiffController search_space = get_search_space(cls="cnn") device = "cuda" controller = DiffController(search_space, device) rollout = controller.sample(1, batch_size=4)[0] assert rollout.sampled[0].shape == (14, 4, len(search_space.shared_primitives)) assert rollout.logits[0].shape == (14, len(search_space.shared_primitives)) print(rollout.genotype)
def test_diff_supernet_forward(diff_super_net, controller_cfg): from aw_nas.common import get_search_space from aw_nas.controller import DiffController search_space = get_search_space(cls="cnn") device = "cuda" controller = DiffController(search_space, device, **controller_cfg) rollout = controller.sample(1)[0] cand_net = diff_super_net.assemble_candidate(rollout) data = _cnn_data() logits = cand_net.forward_data(data[0]) assert tuple(logits.shape) == (2, 10)
def test_diff_controller_force_uniform(): import numpy as np from aw_nas.controller import DiffController search_space = get_search_space(cls="cnn") device = "cuda" controller = DiffController(search_space, device, force_uniform=True, use_prob=True) rollouts = controller.sample(1) assert np.equal(rollouts[0].sampled[0].data, 1./len(search_space.shared_primitives) * \ np.ones((14, len(search_space.shared_primitives)))).all()
def test_diff_supernet_data_parallel_forward_rolloutsize(diff_super_net): from aw_nas.common import get_search_space from aw_nas.controller import DiffController search_space = get_search_space(cls="cnn") device = "cuda" controller = DiffController(search_space, device) rollout = controller.sample(1, batch_size=9)[0] cand_net = diff_super_net.assemble_candidate(rollout) batch_size = 9 data = _cnn_data(batch_size=batch_size) logits = cand_net.forward_data(data[0]) assert tuple(logits.shape) == (batch_size, 10)
def test_diff_controller_use_prob(): from aw_nas import utils import numpy as np from aw_nas.controller import DiffController search_space = get_search_space(cls="cnn") device = "cuda" controller = DiffController(search_space, device, use_prob=True) assert controller.cg_alphas[0].shape == ( 14, len(search_space.shared_primitives)) rollouts = controller.sample(3) assert np.abs((utils.get_numpy(rollouts[0].sampled[0]) - utils.softmax(rollouts[0].logits[0])))\ .mean() < 1e-6 assert isinstance(rollouts[0].genotype, search_space.genotype_type)
def test_diff_controller_cellwise_num_steps(): from aw_nas.controller import DiffController num_steps = [4, 6] num_cell_groups = len(num_steps) search_space = get_search_space(cls="cnn", num_cell_groups=num_cell_groups, num_steps=num_steps) device = "cuda" controller = DiffController(search_space, device) for i, num_step in enumerate(num_steps): assert controller.cg_alphas[i].shape[0] == \ num_step * (num_step - 1) / 2 + search_space.num_init_nodes * num_step rollout = controller.sample(1)[0] assert isinstance(rollout.genotype, search_space.genotype_type) print(rollout.genotype)
def test_diff_supernet_data_parallel_backward_rolloutsize(diff_super_net): from torch import nn from aw_nas.common import get_search_space from aw_nas.controller import DiffController search_space = get_search_space(cls="cnn") device = "cuda" controller = DiffController(search_space, device) rollout = controller.sample(1, batch_size=9)[0] cand_net = diff_super_net.assemble_candidate(rollout) batch_size = 9 data = _cnn_data(batch_size=batch_size) logits = cand_net.forward_data(data[0], detach_arch=False) loss = nn.CrossEntropyLoss()(logits, data[1].cuda()) assert controller.cg_alphas[0].grad is None loss.backward() assert controller.cg_alphas[0].grad is not None
def test_diff_controller_cellwise_primitives(): from aw_nas.controller import DiffController search_space = get_search_space(cls="cnn", num_cell_groups=2, cell_shared_primitives=[ ("none", "avg_pool_3x3", "max_pool_3x3", "skip_connect"), ("skip_connect", "avg_pool_3x3", "dil_conv_3x3") ]) device = "cuda" controller = DiffController(search_space, device) assert controller.cg_alphas[0].shape == (14, 4) assert controller.cg_alphas[1].shape == (14, 3) rollout = controller.sample(1)[0] assert isinstance(rollout.genotype, search_space.genotype_type) assert set([conn[0] for conn in rollout.genotype.normal_0]).issubset( ["none", "avg_pool_3x3", "max_pool_3x3", "skip_connect"]) assert set([conn[0] for conn in rollout.genotype.reduce_1 ]).issubset(["avg_pool_3x3", "dil_conv_3x3", "skip_connect"]) print(rollout.genotype)