Exemplo n.º 1
0
    def test_parameter_dict(self):
        l = nn.Linear(5, 5)
        block = nn.Container(conv=nn.Conv2d(3, 3, 3, bias=False))
        net = nn.Container(
            linear1=l,
            linear2=l,
            block=block,
            empty=None,
        )
        param_dict = net.parameter_dict()
        self.assertEqual(len(param_dict), 5)
        self.assertIn('linear1.weight', param_dict)
        self.assertIn('linear1.bias', param_dict)
        self.assertIn('linear2.weight', param_dict)
        self.assertIn('linear2.bias', param_dict)
        self.assertIn('block.conv.weight', param_dict)
        self.assertNotIn('block.conv.bias', param_dict)
        self.assertFalse(
            any(map(lambda k: k.startswith('empty'), param_dict.keys())))
        for k, v in param_dict.items():
            param = net
            for component in k.split('.'):
                param = getattr(param, component)
            self.assertIs(v, param)

        l = nn.Linear(5, 5)
        param_dict = l.parameter_dict()
        self.assertEqual(len(param_dict), 2)
        self.assertIs(param_dict['weight'], l.weight)
        self.assertIs(param_dict['bias'], l.bias)
Exemplo n.º 2
0
 def test_load_parameter_dict(self):
     l = nn.Linear(5, 5)
     block = nn.Container(conv=nn.Conv2d(3, 3, 3, bias=False))
     net = nn.Container(
         linear1=l,
         linear2=l,
         block=block,
         empty=None,
     )
     param_dict = {
         'linear1.weight': Variable(torch.ones(5, 5)),
         'block.conv.bias': Variable(torch.range(1, 3)),
     }
     net.load_parameter_dict(param_dict)
     self.assertIs(net.linear1.weight, param_dict['linear1.weight'])
     self.assertIs(net.block.conv.bias, param_dict['block.conv.bias'])
Exemplo n.º 3
0
 def test_replicate_buffers(self):
     net = nn.Container()
     net.bn = nn.BatchNorm2d(10)
     net.cuda()
     replicas = dp.replicate(net, (0, 1))
     for i, replica in enumerate(replicas):
         self.assertEqual(replica.bn.running_mean.get_device(), i,
                          'buffer on wrong device')
         self.assertEqual(replica.bn.running_var.get_device(), i,
                          'buffer on wrong device')
Exemplo n.º 4
0
 def test_type(self):
     l = nn.Linear(10, 20)
     net = nn.Container(
         l=l,
         l2=l,
         empty=None,
     )
     net.float()
     self.assertIsInstance(l.weight.data, torch.FloatTensor)
     self.assertIsInstance(l.bias.data, torch.FloatTensor)
     net.double()
     self.assertIsInstance(l.weight.data, torch.DoubleTensor)
     self.assertIsInstance(l.bias.data, torch.DoubleTensor)
Exemplo n.º 5
0
 def test_add_module(self):
     l = nn.Linear(10, 20)
     net = nn.Container(
         l=l,
         l2=l,
         empty=None,
     )
     self.assertEqual(net.l, l)
     self.assertEqual(net.l2, l)
     self.assertEqual(net.empty, None)
     net.add_module('l3', l)
     self.assertEqual(net.l3, l)
     self.assertRaises(KeyError, lambda: net.add_module('l', l))
     self.assertRaises(TypeError, lambda: net.add_module('x', 'non-module'))