def test_parameter_dict(self): l = nn.Linear(5, 5) block = nn.Container(conv=nn.Conv2d(3, 3, 3, bias=False)) net = nn.Container( linear1=l, linear2=l, block=block, empty=None, ) param_dict = net.parameter_dict() self.assertEqual(len(param_dict), 5) self.assertIn('linear1.weight', param_dict) self.assertIn('linear1.bias', param_dict) self.assertIn('linear2.weight', param_dict) self.assertIn('linear2.bias', param_dict) self.assertIn('block.conv.weight', param_dict) self.assertNotIn('block.conv.bias', param_dict) self.assertFalse( any(map(lambda k: k.startswith('empty'), param_dict.keys()))) for k, v in param_dict.items(): param = net for component in k.split('.'): param = getattr(param, component) self.assertIs(v, param) l = nn.Linear(5, 5) param_dict = l.parameter_dict() self.assertEqual(len(param_dict), 2) self.assertIs(param_dict['weight'], l.weight) self.assertIs(param_dict['bias'], l.bias)
def test_load_parameter_dict(self): l = nn.Linear(5, 5) block = nn.Container(conv=nn.Conv2d(3, 3, 3, bias=False)) net = nn.Container( linear1=l, linear2=l, block=block, empty=None, ) param_dict = { 'linear1.weight': Variable(torch.ones(5, 5)), 'block.conv.bias': Variable(torch.range(1, 3)), } net.load_parameter_dict(param_dict) self.assertIs(net.linear1.weight, param_dict['linear1.weight']) self.assertIs(net.block.conv.bias, param_dict['block.conv.bias'])
def test_replicate_buffers(self): net = nn.Container() net.bn = nn.BatchNorm2d(10) net.cuda() replicas = dp.replicate(net, (0, 1)) for i, replica in enumerate(replicas): self.assertEqual(replica.bn.running_mean.get_device(), i, 'buffer on wrong device') self.assertEqual(replica.bn.running_var.get_device(), i, 'buffer on wrong device')
def test_type(self): l = nn.Linear(10, 20) net = nn.Container( l=l, l2=l, empty=None, ) net.float() self.assertIsInstance(l.weight.data, torch.FloatTensor) self.assertIsInstance(l.bias.data, torch.FloatTensor) net.double() self.assertIsInstance(l.weight.data, torch.DoubleTensor) self.assertIsInstance(l.bias.data, torch.DoubleTensor)
def test_add_module(self): l = nn.Linear(10, 20) net = nn.Container( l=l, l2=l, empty=None, ) self.assertEqual(net.l, l) self.assertEqual(net.l2, l) self.assertEqual(net.empty, None) net.add_module('l3', l) self.assertEqual(net.l3, l) self.assertRaises(KeyError, lambda: net.add_module('l', l)) self.assertRaises(TypeError, lambda: net.add_module('x', 'non-module'))