def test_context_last_conv_module(): from aw_nas.ops import DilConv, SepConv, ResNetBlock from aw_nas.utils.common_utils import Context data = _cnn_data() dil_conv = DilConv(3, 16, 3, 1, 1, 1).to("cuda") context = Context(0, 1, use_stem=False) _, context = dil_conv.forward_one_step(context=context, inputs=data[0]) assert context.last_conv_module is dil_conv.op[1] _, context = dil_conv.forward_one_step(context=context) assert context.last_conv_module is dil_conv.op[2] context = Context(0, 1, use_stem=False) sep_conv = SepConv(3, 16, 3, 1, 1).to("cuda") _, context = sep_conv.forward_one_step(context=context, inputs=data[0]) assert context.last_conv_module is sep_conv.op[1] for expected_ind in [2, 5, 6]: _, context = sep_conv.forward_one_step(context=context) assert context.last_conv_module is sep_conv.op[expected_ind] context = Context(0, 1, use_stem=False) res_block = ResNetBlock(3, 3, 1, True).to("cuda") res_block.train() out_0 = res_block(data[0]) for i, expected_mod in enumerate( [res_block.op_1.op[0], res_block.op_2.op[0], None, None]): state, context = res_block.forward_one_step( context=context, inputs=data[0] if i == 0 else None) assert context.last_conv_module is expected_mod assert context.is_end_of_op assert (state == out_0).all() context = Context(0, 1, use_stem=False) res_block_stride = ResNetBlock(3, 16, 2, True).to("cuda") out_0 = res_block_stride(data[0]) for i, expected_mod in enumerate([ res_block_stride.op_1.op[0], res_block_stride.op_2.op[0], res_block_stride.skip_op.op[0], None ]): state, context = res_block_stride.forward_one_step( context=context, inputs=data[0] if i == 0 else None) assert context.last_conv_module is expected_mod assert (state == out_0).all()
def __init__(self, C_in, C_out, kernel_size, stride, padding): super(ResSepConv, self).__init__() self.conv = SepConv(C_in, C_out, kernel_size, stride, padding) self.res = Identity() if stride == 1 else FactorizedReduce( C_in, C_out, stride)
def test_use_params(): from aw_nas.ops import SepConv from aw_nas.utils.torch_utils import use_params sep_conv_1 = SepConv(3, 10, 3, 1, 1, affine=True).cuda() sep_conv_2 = SepConv(3, 10, 3, 1, 1, affine=True).cuda() parameters_1 = dict(sep_conv_1.named_parameters()) parameters_2 = dict(sep_conv_2.named_parameters()) # random init params for n in parameters_1: parameters_1[n].data.random_() parameters_2[n].data.random_() for n in parameters_1: assert (parameters_1[n] - parameters_2[n]).abs().mean() > 1e-4 batch_size = 2 inputs = _cnn_data(batch_size=batch_size)[0] # use train mode, do not use bn running statistics sep_conv_1.train() sep_conv_2.train() conv1_res = sep_conv_1(inputs) conv2_res = sep_conv_2(inputs) assert (conv1_res != conv2_res).any() with use_params(sep_conv_1, parameters_2): conv1_useparams_res = sep_conv_1(inputs) assert (conv1_useparams_res == conv2_res).all() for n, new_param in sep_conv_1.named_parameters(): assert (new_param == parameters_1[n]).all()