Esempio n. 1
0
def test_rnn_diff_supernet_forward(rnn_diff_super_net):
    from aw_nas.controller import DiffController

    time_steps = 5
    batch_size = 2
    _num_tokens = rnn_diff_super_net.num_tokens
    _num_hid = rnn_diff_super_net.num_hid
    _num_layers = rnn_diff_super_net._num_layers
    search_space = rnn_diff_super_net.search_space
    device = "cuda"
    controller = DiffController(search_space, device)
    rollout = controller.sample(1)[0]
    cand_net = rnn_diff_super_net.assemble_candidate(rollout)

    # init hiddens
    hiddens = rnn_diff_super_net.init_hidden(batch_size)

    data = _rnn_data(time_steps, batch_size, _num_tokens)

    logits, _, outs, next_hiddens = cand_net.forward_data(data[0],
                                                          mode="eval",
                                                          hiddens=hiddens)
    assert tuple(logits.shape) == (time_steps, batch_size, _num_tokens)
    assert tuple(outs.shape) == (time_steps, batch_size, _num_hid)
    assert len(next_hiddens) == _num_layers
    # the value is equal to the calculated results, the hidden is modified in-place
    assert (hiddens == next_hiddens).all()
Esempio n. 2
0
def test_rnn_diff_supernet_to_arch(rnn_diff_super_net):
    from aw_nas.controller import DiffController

    search_space = rnn_diff_super_net.search_space
    device = "cuda"
    controller = DiffController(search_space, device)
    rollout = controller.sample(1)[0]
    cand_net = rnn_diff_super_net.assemble_candidate(rollout)

    time_steps = 5
    batch_size = 2
    _num_tokens = rnn_diff_super_net.num_tokens
    data = _rnn_data(time_steps, batch_size, _num_tokens)

    hiddens = rnn_diff_super_net.init_hidden(batch_size)

    # default detach_arch=True, no grad w.r.t the controller param
    results = cand_net.forward_data(data[0], hiddens=hiddens)
    loss = _rnn_criterion(data[0], results, data[1].cuda())
    assert controller.cg_alphas[0].grad is None
    loss.backward()
    assert controller.cg_alphas[0].grad is None

    results = cand_net.forward_data(data[0],
                                    hiddens=hiddens,
                                    detach_arch=False)
    loss = _rnn_criterion(data[0], results, data[1].cuda())
    assert controller.cg_alphas[0].grad is None
    loss.backward()
    assert controller.cg_alphas[0].grad is not None
Esempio n. 3
0
def test_diff_supernet_to_arch(diff_super_net):
    from torch import nn
    from aw_nas.common import get_search_space
    from aw_nas.controller import DiffController

    search_space = get_search_space(cls="cnn")
    device = "cuda"
    controller = DiffController(search_space, device)
    rollout = controller.sample(1)[0]
    cand_net = diff_super_net.assemble_candidate(rollout)

    data = _cnn_data()  #pylint: disable=not-callable

    # default detach_arch=True, no grad w.r.t the controller param
    logits = cand_net.forward_data(data[0])
    loss = nn.CrossEntropyLoss()(logits, data[1].cuda())
    assert controller.cg_alphas[0].grad is None
    loss.backward()
    assert controller.cg_alphas[0].grad is None

    logits = cand_net.forward_data(data[0], detach_arch=False)
    loss = nn.CrossEntropyLoss()(logits, data[1].cuda())
    assert controller.cg_alphas[0].grad is None
    loss.backward()
    assert controller.cg_alphas[0].grad is not None
Esempio n. 4
0
def test_rnn_diff_supernet_forward(rnn_diff_super_net):
    if version.parse(torch.__version__).minor >= 7:
        pytest.xfail(
            "FIXME: We currently do not fix this bug yet. When using torch>=1.7.0, "
            "we encountered: Warning: Error detected in SplitBackward. "
            "RuntimeError: one of the variables needed for gradient computation "
            "has been modified by an inplace operation")
    from aw_nas.controller import DiffController

    time_steps = 5
    batch_size = 2
    _num_tokens = rnn_diff_super_net.num_tokens
    _num_hid = rnn_diff_super_net.num_hid
    _num_layers = rnn_diff_super_net._num_layers
    search_space = rnn_diff_super_net.search_space
    device = "cuda"
    controller = DiffController(search_space, device)
    rollout = controller.sample(1)[0]
    cand_net = rnn_diff_super_net.assemble_candidate(rollout)

    # init hiddens
    hiddens = rnn_diff_super_net.init_hidden(batch_size)

    data = _rnn_data(time_steps, batch_size, _num_tokens)

    logits, _, outs, next_hiddens = cand_net.forward_data(data[0],
                                                          mode="eval",
                                                          hiddens=hiddens)
    assert tuple(logits.shape) == (time_steps, batch_size, _num_tokens)
    assert tuple(outs.shape) == (time_steps, batch_size, _num_hid)
    assert len(next_hiddens) == _num_layers
    # the value is equal to the calculated results, the hidden is modified in-place
    assert (hiddens == next_hiddens).all()
Esempio n. 5
0
def test_diff_controller():
    from aw_nas.controller import DiffController

    search_space = get_search_space(cls="cnn")
    device = "cuda"
    controller = DiffController(search_space, device)

    assert controller.cg_alphas[0].shape == (
        14, len(search_space.shared_primitives))
    rollouts = controller.sample(3)
    assert isinstance(rollouts[0].genotype, search_space.genotype_type)
Esempio n. 6
0
def test_diff_controller_rollout_batch_size():
    import numpy as np
    from aw_nas.controller import DiffController

    search_space = get_search_space(cls="cnn")
    device = "cuda"
    controller = DiffController(search_space, device)

    rollout = controller.sample(1, batch_size=4)[0]
    assert rollout.sampled[0].shape == (14, 4,
                                        len(search_space.shared_primitives))
    assert rollout.logits[0].shape == (14, len(search_space.shared_primitives))
    print(rollout.genotype)
Esempio n. 7
0
def test_diff_supernet_forward(diff_super_net, controller_cfg):
    from aw_nas.common import get_search_space
    from aw_nas.controller import DiffController

    search_space = get_search_space(cls="cnn")
    device = "cuda"
    controller = DiffController(search_space, device, **controller_cfg)
    rollout = controller.sample(1)[0]
    cand_net = diff_super_net.assemble_candidate(rollout)

    data = _cnn_data()
    logits = cand_net.forward_data(data[0])
    assert tuple(logits.shape) == (2, 10)
Esempio n. 8
0
def test_diff_controller_force_uniform():
    import numpy as np
    from aw_nas.controller import DiffController

    search_space = get_search_space(cls="cnn")
    device = "cuda"
    controller = DiffController(search_space,
                                device,
                                force_uniform=True,
                                use_prob=True)

    rollouts = controller.sample(1)
    assert np.equal(rollouts[0].sampled[0].data, 1./len(search_space.shared_primitives) * \
                    np.ones((14, len(search_space.shared_primitives)))).all()
Esempio n. 9
0
def test_diff_supernet_data_parallel_forward_rolloutsize(diff_super_net):
    from aw_nas.common import get_search_space
    from aw_nas.controller import DiffController

    search_space = get_search_space(cls="cnn")
    device = "cuda"
    controller = DiffController(search_space, device)
    rollout = controller.sample(1, batch_size=9)[0]
    cand_net = diff_super_net.assemble_candidate(rollout)

    batch_size = 9
    data = _cnn_data(batch_size=batch_size)
    logits = cand_net.forward_data(data[0])
    assert tuple(logits.shape) == (batch_size, 10)
Esempio n. 10
0
def test_diff_controller_use_prob():
    from aw_nas import utils
    import numpy as np
    from aw_nas.controller import DiffController

    search_space = get_search_space(cls="cnn")
    device = "cuda"
    controller = DiffController(search_space, device, use_prob=True)

    assert controller.cg_alphas[0].shape == (
        14, len(search_space.shared_primitives))
    rollouts = controller.sample(3)
    assert np.abs((utils.get_numpy(rollouts[0].sampled[0]) - utils.softmax(rollouts[0].logits[0])))\
             .mean() < 1e-6
    assert isinstance(rollouts[0].genotype, search_space.genotype_type)
Esempio n. 11
0
def test_diff_controller_cellwise_num_steps():
    from aw_nas.controller import DiffController

    num_steps = [4, 6]
    num_cell_groups = len(num_steps)
    search_space = get_search_space(cls="cnn",
                                    num_cell_groups=num_cell_groups,
                                    num_steps=num_steps)
    device = "cuda"
    controller = DiffController(search_space, device)
    for i, num_step in enumerate(num_steps):
        assert controller.cg_alphas[i].shape[0] == \
            num_step * (num_step - 1) / 2 + search_space.num_init_nodes * num_step

    rollout = controller.sample(1)[0]
    assert isinstance(rollout.genotype, search_space.genotype_type)
    print(rollout.genotype)
Esempio n. 12
0
def test_diff_supernet_data_parallel_backward_rolloutsize(diff_super_net):
    from torch import nn
    from aw_nas.common import get_search_space
    from aw_nas.controller import DiffController

    search_space = get_search_space(cls="cnn")
    device = "cuda"
    controller = DiffController(search_space, device)
    rollout = controller.sample(1, batch_size=9)[0]
    cand_net = diff_super_net.assemble_candidate(rollout)

    batch_size = 9
    data = _cnn_data(batch_size=batch_size)

    logits = cand_net.forward_data(data[0], detach_arch=False)
    loss = nn.CrossEntropyLoss()(logits, data[1].cuda())
    assert controller.cg_alphas[0].grad is None
    loss.backward()
    assert controller.cg_alphas[0].grad is not None
Esempio n. 13
0
def test_diff_controller_cellwise_primitives():
    from aw_nas.controller import DiffController

    search_space = get_search_space(cls="cnn",
                                    num_cell_groups=2,
                                    cell_shared_primitives=[
                                        ("none", "avg_pool_3x3",
                                         "max_pool_3x3", "skip_connect"),
                                        ("skip_connect", "avg_pool_3x3",
                                         "dil_conv_3x3")
                                    ])
    device = "cuda"
    controller = DiffController(search_space, device)
    assert controller.cg_alphas[0].shape == (14, 4)
    assert controller.cg_alphas[1].shape == (14, 3)
    rollout = controller.sample(1)[0]
    assert isinstance(rollout.genotype, search_space.genotype_type)
    assert set([conn[0] for conn in rollout.genotype.normal_0]).issubset(
        ["none", "avg_pool_3x3", "max_pool_3x3", "skip_connect"])
    assert set([conn[0] for conn in rollout.genotype.reduce_1
                ]).issubset(["avg_pool_3x3", "dil_conv_3x3", "skip_connect"])
    print(rollout.genotype)