Esempio n. 1
0
def test_context_last_conv_module():
    from aw_nas.ops import DilConv, SepConv, ResNetBlock
    from aw_nas.utils.common_utils import Context
    data = _cnn_data()
    dil_conv = DilConv(3, 16, 3, 1, 1, 1).to("cuda")
    context = Context(0, 1, use_stem=False)
    _, context = dil_conv.forward_one_step(context=context, inputs=data[0])
    assert context.last_conv_module is dil_conv.op[1]
    _, context = dil_conv.forward_one_step(context=context)
    assert context.last_conv_module is dil_conv.op[2]

    context = Context(0, 1, use_stem=False)
    sep_conv = SepConv(3, 16, 3, 1, 1).to("cuda")
    _, context = sep_conv.forward_one_step(context=context, inputs=data[0])
    assert context.last_conv_module is sep_conv.op[1]
    for expected_ind in [2, 5, 6]:
        _, context = sep_conv.forward_one_step(context=context)
        assert context.last_conv_module is sep_conv.op[expected_ind]

    context = Context(0, 1, use_stem=False)
    res_block = ResNetBlock(3, 3, 1, True).to("cuda")
    res_block.train()
    out_0 = res_block(data[0])
    for i, expected_mod in enumerate(
        [res_block.op_1.op[0], res_block.op_2.op[0], None, None]):
        state, context = res_block.forward_one_step(
            context=context, inputs=data[0] if i == 0 else None)
        assert context.last_conv_module is expected_mod
    assert context.is_end_of_op
    assert (state == out_0).all()

    context = Context(0, 1, use_stem=False)
    res_block_stride = ResNetBlock(3, 16, 2, True).to("cuda")
    out_0 = res_block_stride(data[0])
    for i, expected_mod in enumerate([
            res_block_stride.op_1.op[0], res_block_stride.op_2.op[0],
            res_block_stride.skip_op.op[0], None
    ]):
        state, context = res_block_stride.forward_one_step(
            context=context, inputs=data[0] if i == 0 else None)
        assert context.last_conv_module is expected_mod
    assert (state == out_0).all()
Esempio n. 2
0
 def __init__(self, C_in, C_out, kernel_size, stride, padding):
     super(ResSepConv, self).__init__()
     self.conv = SepConv(C_in, C_out, kernel_size, stride, padding)
     self.res = Identity() if stride == 1 else FactorizedReduce(
         C_in, C_out, stride)
Esempio n. 3
0
def test_use_params():
    from aw_nas.ops import SepConv
    from aw_nas.utils.torch_utils import use_params

    sep_conv_1 = SepConv(3, 10, 3, 1, 1, affine=True).cuda()
    sep_conv_2 = SepConv(3, 10, 3, 1, 1, affine=True).cuda()
    parameters_1 = dict(sep_conv_1.named_parameters())
    parameters_2 = dict(sep_conv_2.named_parameters())

    # random init params
    for n in parameters_1:
        parameters_1[n].data.random_()
        parameters_2[n].data.random_()
    for n in parameters_1:
        assert (parameters_1[n] - parameters_2[n]).abs().mean() > 1e-4
    batch_size = 2
    inputs = _cnn_data(batch_size=batch_size)[0]

    # use train mode, do not use bn running statistics
    sep_conv_1.train()
    sep_conv_2.train()
    conv1_res = sep_conv_1(inputs)
    conv2_res = sep_conv_2(inputs)
    assert (conv1_res != conv2_res).any()
    with use_params(sep_conv_1, parameters_2):
        conv1_useparams_res = sep_conv_1(inputs)
    assert (conv1_useparams_res == conv2_res).all()
    for n, new_param in sep_conv_1.named_parameters():
        assert (new_param == parameters_1[n]).all()