def test_Concat(self): input = torch.randn(4, 2) num_modules = random.randint(2, 5) linears = [nn.Linear(2, 5) for i in range(num_modules)] m = nn.Concat(0) for l in linears: m.add(l) l.zeroGradParameters() l.weight.fill_(1) l.bias.fill_(0) # Check that these don't raise errors m.__repr__() str(m) output = m.forward(input) output2 = input.sum(1).expand(4, 5).repeat(num_modules, 1) self.assertEqual(output2, output) gradInput = m.backward(input, torch.ones(output2.size())) gradInput2 = torch.ones(4, 2).fill_(num_modules * 5) self.assertEqual(gradInput, gradInput2) gradWeight = input.sum(0).expand(5, 2) for l in linears: self.assertEqual(gradWeight, l.gradWeight)
def _build_net(self): return (nn.Sequential() .add(nn.Concat(0) .add(nn.Linear(2, 5)) .add(nn.Linear(2, 5))) .add(nn.ReLU()) .add(nn.Linear(10, 20)))