def test_ConcatTable(self):
        input = [
            torch.randn(3, 4).float(),
            torch.randn(3, 4).float(), [torch.randn(3, 4).float()]
        ]
        _gradOutput = [
            torch.randn(3, 3, 4).float(),
            torch.randn(3, 3, 4).float(),
            torch.randn(3, 3, 4).float()
        ]
        gradOutput = [
            [_gradOutput[0][0], _gradOutput[1][0], [_gradOutput[2][0]]],
            [_gradOutput[0][1], _gradOutput[1][1], [_gradOutput[2][1]]],
            [_gradOutput[0][2], _gradOutput[1][2], [_gradOutput[2][2]]]
        ]
        module = nn.ConcatTable()
        module.add(nn.Identity())
        module.add(nn.Identity())
        module.add(nn.Identity())
        module.float()

        # Check that these don't raise errors
        module.__repr__()
        str(module)

        output = module.forward(input)
        output2 = [input, input, input]
        self.assertEqual(output2, output)
        gradInput = module.backward(input, gradOutput)
        gradInput2 = [
            _gradOutput[0].sum(0, keepdim=False),
            _gradOutput[1].sum(0, keepdim=False),
            [_gradOutput[2].sum(0, keepdim=False)]
        ]
        self.assertTrue(isinstance(gradInput, list))
        self.assertFalse(isinstance(gradInput[0], list))
        self.assertFalse(isinstance(gradInput[1], list))
        self.assertTrue(isinstance(gradInput[2], list))
        self.assertEqual(len(gradInput), 3)
        self.assertEqual(len(gradInput[2]), 1)
        for t1, t2 in zip(iter_tensors(gradInput), iter_tensors(gradInput2)):
            self.assertEqual(t1, t2)

        # test outputs for variable length inputs
        test = nn.ConcatTable()
        test.add(nn.Identity())
        test.add(nn.Identity())

        x = [torch.randn(5), torch.randn(5)]
        y = [torch.randn(5)]

        o1 = len(test.forward(x))
        go1 = len(test.backward(x, [x, x]))
        o2 = len(test.forward(y))
        go2 = len(test.backward(y, [y, y]))
        self.assertEqual(o1, 2)
        self.assertEqual(go1, 2)
        self.assertEqual(o2, 2)
        self.assertEqual(go2, 1)
예제 #2
0
    def test_ConcatTable(self):
        input = [
            torch.randn(3, 4).float(), torch.randn(3, 4).float(), [torch.randn(3, 4).float()]
        ]
        _gradOutput = [
            torch.randn(3, 3, 4).float(), torch.randn(3, 3, 4).float(), torch.randn(3, 3, 4).float()
        ]
        gradOutput = [
            [_gradOutput[0][0], _gradOutput[1][0], [_gradOutput[2][0]]],
            [_gradOutput[0][1], _gradOutput[1][1], [_gradOutput[2][1]]],
            [_gradOutput[0][2], _gradOutput[1][2], [_gradOutput[2][2]]]
        ]
        module = nn.ConcatTable()
        module.add(nn.Identity())
        module.add(nn.Identity())
        module.add(nn.Identity())
        module.float()

        # Check that these don't raise errors
        module.__repr__()
        str(module)

        output = module.forward(input)
        output2 = [input, input, input]
        self.assertEqual(output2, output)
        gradInput = module.backward(input, gradOutput)
        gradInput2 = [_gradOutput[0].sum(0, keepdim=False), _gradOutput[1].sum(
            0, keepdim=False), [_gradOutput[2].sum(0, keepdim=False)]]
        self.assertTrue(isinstance(gradInput, list))
        self.assertFalse(isinstance(gradInput[0], list))
        self.assertFalse(isinstance(gradInput[1], list))
        self.assertTrue(isinstance(gradInput[2], list))
        self.assertEqual(len(gradInput), 3)
        self.assertEqual(len(gradInput[2]), 1)
        for t1, t2 in zip(iter_tensors(gradInput), iter_tensors(gradInput2)):
            self.assertEqual(t1, t2)

        # test outputs for variable length inputs
        test = nn.ConcatTable()
        test.add(nn.Identity())
        test.add(nn.Identity())

        x = [torch.randn(5), torch.randn(5)]
        y = [torch.randn(5)]

        o1 = len(test.forward(x))
        go1 = len(test.backward(x, [x, x]))
        o2 = len(test.forward(y))
        go2 = len(test.backward(y, [y, y]))
        self.assertEqual(o1, 2)
        self.assertEqual(go1, 2)
        self.assertEqual(o2, 2)
        self.assertEqual(go2, 1)