示例#1
0
    def test_Concat(self):
        input = torch.randn(4, 2)
        num_modules = random.randint(2, 5)
        linears = [nn.Linear(2, 5) for i in range(num_modules)]

        m = nn.Concat(0)
        for l in linears:
            m.add(l)
            l.zeroGradParameters()
            l.weight.fill_(1)
            l.bias.fill_(0)

        # Check that these don't raise errors
        m.__repr__()
        str(m)

        output = m.forward(input)
        output2 = input.sum(1).expand(4, 5).repeat(num_modules, 1)
        self.assertEqual(output2, output)

        gradInput = m.backward(input, torch.ones(output2.size()))
        gradInput2 = torch.ones(4, 2).fill_(num_modules * 5)
        self.assertEqual(gradInput, gradInput2)

        gradWeight = input.sum(0).expand(5, 2)
        for l in linears:
            self.assertEqual(gradWeight, l.gradWeight)
示例#2
0
 def _build_net(self):
     return (nn.Sequential()
             .add(nn.Concat(0)
                  .add(nn.Linear(2, 5))
                  .add(nn.Linear(2, 5)))
             .add(nn.ReLU())
             .add(nn.Linear(10, 20)))