def test_constaninit(): """test ConstantInit class.""" model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2)) func = ConstantInit(val=1, bias=2, layer='Conv2d') func(model) assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.)) assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.)) assert not torch.equal(model[2].weight, torch.full(model[2].weight.shape, 1.)) assert not torch.equal(model[2].bias, torch.full(model[2].bias.shape, 2.)) func = ConstantInit(val=3, bias_prob=0.01, layer='Linear') func(model) res = bias_init_with_prob(0.01) assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.)) assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 3.)) assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.)) assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, res)) # test bias input type with pytest.raises(TypeError): func = ConstantInit(val=1, bias='1') # test bias_prob type with pytest.raises(TypeError): func = ConstantInit(val=1, bias_prob='1') # test layer input type with pytest.raises(TypeError): func = ConstantInit(val=1, layer=1)
def test_pretrainedinit(): """test PretrainedInit class.""" modelA = FooModule() constant_func = ConstantInit(val=1, bias=2, layer=['Conv2d', 'Linear']) modelA.apply(constant_func) modelB = FooModule() funcB = PretrainedInit(checkpoint='modelA.pth') modelC = nn.Linear(1, 2) funcC = PretrainedInit(checkpoint='modelA.pth', prefix='linear.') with TemporaryDirectory(): torch.save(modelA.state_dict(), 'modelA.pth') funcB(modelB) assert torch.equal(modelB.linear.weight, torch.full(modelB.linear.weight.shape, 1.)) assert torch.equal(modelB.linear.bias, torch.full(modelB.linear.bias.shape, 2.)) assert torch.equal(modelB.conv2d.weight, torch.full(modelB.conv2d.weight.shape, 1.)) assert torch.equal(modelB.conv2d.bias, torch.full(modelB.conv2d.bias.shape, 2.)) assert torch.equal(modelB.conv2d_2.weight, torch.full(modelB.conv2d_2.weight.shape, 1.)) assert torch.equal(modelB.conv2d_2.bias, torch.full(modelB.conv2d_2.bias.shape, 2.)) funcC(modelC) assert torch.equal(modelC.weight, torch.full(modelC.weight.shape, 1.)) assert torch.equal(modelC.bias, torch.full(modelC.bias.shape, 2.))
def test_xavierinit(): """test XavierInit class.""" model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2)) func = XavierInit(bias=0.1, layer='Conv2d') func(model) assert model[0].bias.allclose(torch.full_like(model[2].bias, 0.1)) assert not model[2].bias.allclose(torch.full_like(model[0].bias, 0.1)) constant_func = ConstantInit(val=0, bias=0) func = XavierInit(gain=100, bias_prob=0.01) model.apply(constant_func) assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.)) assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.)) assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.)) assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.)) res = bias_init_with_prob(0.01) func(model) assert not torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.)) assert not torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.)) assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, res)) assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, res)) # test bias input type with pytest.raises(TypeError): func = XavierInit(bias='0.1', layer='Conv2d') # test layer inpur type with pytest.raises(TypeError): func = XavierInit(bias=0.1, layer=1)
def test_kaiminginit(): """test KaimingInit class.""" model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2)) func = KaimingInit(bias=0.1, layer='Conv2d') func(model) assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.1)) assert not torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.1)) func = KaimingInit(a=100, bias=10, layer=['Conv2d', 'Linear']) constant_func = ConstantInit(val=0, bias=0, layer=['Conv2d', 'Linear']) model.apply(constant_func) assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.)) assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.)) assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.)) assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.)) func(model) assert not torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.)) assert not torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.)) assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.)) assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.)) # test layer key with base class name model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1)) func = KaimingInit(bias=0.1, layer='_ConvNd') func(model) assert torch.all(model[0].bias == 0.1) assert torch.all(model[2].bias == 0.1) func = KaimingInit(a=100, bias=10, layer='_ConvNd') constant_func = ConstantInit(val=0, bias=0, layer='_ConvNd') model.apply(constant_func) assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.)) assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.)) assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.)) assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.)) func(model) assert not torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.)) assert not torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.)) assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.)) assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))
def test_kaiminginit(): """test KaimingInit class.""" model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2)) func = KaimingInit(bias=0.1, layer='Conv2d') func(model) assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.1)) assert not torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.1)) func = KaimingInit(a=100, bias=10) constant_func = ConstantInit(val=0, bias=0) model.apply(constant_func) assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.)) assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.)) assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.)) assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.)) func(model) assert not torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.)) assert not torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.)) assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.)) assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))
def test_initialize(): model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2)) foonet = FooModule() # test layer key init_cfg = dict(type='Constant', layer=['Conv2d', 'Linear'], val=1, bias=2) initialize(model, init_cfg) assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.)) assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 1.)) assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.)) assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 2.)) assert init_cfg == dict(type='Constant', layer=['Conv2d', 'Linear'], val=1, bias=2) # test init_cfg with list type init_cfg = [ dict(type='Constant', layer='Conv2d', val=1, bias=2), dict(type='Constant', layer='Linear', val=3, bias=4) ] initialize(model, init_cfg) assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.)) assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 3.)) assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.)) assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 4.)) assert init_cfg == [ dict(type='Constant', layer='Conv2d', val=1, bias=2), dict(type='Constant', layer='Linear', val=3, bias=4) ] # test layer key and override key init_cfg = dict(type='Constant', val=1, bias=2, layer=['Conv2d', 'Linear'], override=dict(type='Constant', name='conv2d_2', val=3, bias=4)) initialize(foonet, init_cfg) assert torch.equal(foonet.linear.weight, torch.full(foonet.linear.weight.shape, 1.)) assert torch.equal(foonet.linear.bias, torch.full(foonet.linear.bias.shape, 2.)) assert torch.equal(foonet.conv2d.weight, torch.full(foonet.conv2d.weight.shape, 1.)) assert torch.equal(foonet.conv2d.bias, torch.full(foonet.conv2d.bias.shape, 2.)) assert torch.equal(foonet.conv2d_2.weight, torch.full(foonet.conv2d_2.weight.shape, 3.)) assert torch.equal(foonet.conv2d_2.bias, torch.full(foonet.conv2d_2.bias.shape, 4.)) assert init_cfg == dict(type='Constant', val=1, bias=2, layer=['Conv2d', 'Linear'], override=dict(type='Constant', name='conv2d_2', val=3, bias=4)) # test override key init_cfg = dict(type='Constant', val=5, bias=6, override=dict(name='conv2d_2')) initialize(foonet, init_cfg) assert not torch.equal(foonet.linear.weight, torch.full(foonet.linear.weight.shape, 5.)) assert not torch.equal(foonet.linear.bias, torch.full(foonet.linear.bias.shape, 6.)) assert not torch.equal(foonet.conv2d.weight, torch.full(foonet.conv2d.weight.shape, 5.)) assert not torch.equal(foonet.conv2d.bias, torch.full(foonet.conv2d.bias.shape, 6.)) assert torch.equal(foonet.conv2d_2.weight, torch.full(foonet.conv2d_2.weight.shape, 5.)) assert torch.equal(foonet.conv2d_2.bias, torch.full(foonet.conv2d_2.bias.shape, 6.)) assert init_cfg == dict(type='Constant', val=5, bias=6, override=dict(name='conv2d_2')) init_cfg = dict(type='Pretrained', checkpoint='modelA.pth', override=dict(type='Constant', name='conv2d_2', val=3, bias=4)) modelA = FooModule() constant_func = ConstantInit(val=1, bias=2, layer=['Conv2d', 'Linear']) modelA.apply(constant_func) with TemporaryDirectory(): torch.save(modelA.state_dict(), 'modelA.pth') initialize(foonet, init_cfg) assert torch.equal(foonet.linear.weight, torch.full(foonet.linear.weight.shape, 1.)) assert torch.equal(foonet.linear.bias, torch.full(foonet.linear.bias.shape, 2.)) assert torch.equal(foonet.conv2d.weight, torch.full(foonet.conv2d.weight.shape, 1.)) assert torch.equal(foonet.conv2d.bias, torch.full(foonet.conv2d.bias.shape, 2.)) assert torch.equal(foonet.conv2d_2.weight, torch.full(foonet.conv2d_2.weight.shape, 3.)) assert torch.equal(foonet.conv2d_2.bias, torch.full(foonet.conv2d_2.bias.shape, 4.)) assert init_cfg == dict(type='Pretrained', checkpoint='modelA.pth', override=dict(type='Constant', name='conv2d_2', val=3, bias=4)) # test init_cfg type with pytest.raises(TypeError): init_cfg = 'init_cfg' initialize(foonet, init_cfg) # test override value type with pytest.raises(TypeError): init_cfg = dict(type='Constant', val=1, bias=2, layer=['Conv2d', 'Linear'], override='conv') initialize(foonet, init_cfg) # test override name with pytest.raises(RuntimeError): init_cfg = dict(type='Constant', val=1, bias=2, layer=['Conv2d', 'Linear'], override=dict(type='Constant', name='conv2d_3', val=3, bias=4)) initialize(foonet, init_cfg) # test list override name with pytest.raises(RuntimeError): init_cfg = dict(type='Constant', val=1, bias=2, layer=['Conv2d', 'Linear'], override=[ dict(type='Constant', name='conv2d', val=3, bias=4), dict(type='Constant', name='conv2d_3', val=5, bias=6) ]) initialize(foonet, init_cfg) # test override with args except type key with pytest.raises(ValueError): init_cfg = dict(type='Constant', val=1, bias=2, override=dict(name='conv2d_2', val=3, bias=4)) initialize(foonet, init_cfg) # test override without name with pytest.raises(ValueError): init_cfg = dict(type='Constant', val=1, bias=2, override=dict(type='Constant', val=3, bias=4)) initialize(foonet, init_cfg)