def test_ConcatTable(self):
        input = [
            torch.randn(3, 4).float(),
            torch.randn(3, 4).float(), [torch.randn(3, 4).float()]
        ]
        _gradOutput = [
            torch.randn(3, 3, 4).float(),
            torch.randn(3, 3, 4).float(),
            torch.randn(3, 3, 4).float()
        ]
        gradOutput = [
            [_gradOutput[0][0], _gradOutput[1][0], [_gradOutput[2][0]]],
            [_gradOutput[0][1], _gradOutput[1][1], [_gradOutput[2][1]]],
            [_gradOutput[0][2], _gradOutput[1][2], [_gradOutput[2][2]]]
        ]
        module = nn.ConcatTable()
        module.add(nn.Identity())
        module.add(nn.Identity())
        module.add(nn.Identity())
        module.float()

        # Check that these don't raise errors
        module.__repr__()
        str(module)

        output = module.forward(input)
        output2 = [input, input, input]
        self.assertEqual(output2, output)
        gradInput = module.backward(input, gradOutput)
        gradInput2 = [
            _gradOutput[0].sum(0, keepdim=False),
            _gradOutput[1].sum(0, keepdim=False),
            [_gradOutput[2].sum(0, keepdim=False)]
        ]
        self.assertTrue(isinstance(gradInput, list))
        self.assertFalse(isinstance(gradInput[0], list))
        self.assertFalse(isinstance(gradInput[1], list))
        self.assertTrue(isinstance(gradInput[2], list))
        self.assertEqual(len(gradInput), 3)
        self.assertEqual(len(gradInput[2]), 1)
        for t1, t2 in zip(iter_tensors(gradInput), iter_tensors(gradInput2)):
            self.assertEqual(t1, t2)

        # test outputs for variable length inputs
        test = nn.ConcatTable()
        test.add(nn.Identity())
        test.add(nn.Identity())

        x = [torch.randn(5), torch.randn(5)]
        y = [torch.randn(5)]

        o1 = len(test.forward(x))
        go1 = len(test.backward(x, [x, x]))
        o2 = len(test.forward(y))
        go2 = len(test.backward(y, [y, y]))
        self.assertEqual(o1, 2)
        self.assertEqual(go1, 2)
        self.assertEqual(o2, 2)
        self.assertEqual(go2, 1)
Пример #2
0
    def build_net(self):
        # [1.0] first layer
        first_layer = nn.ConcatTable()

        # [1.1] feed forward neural net, produce v1, v2
        feedforward = nn.Sequential()
        feedforward.add(nn.Linear(self.input_size, self.hidden_layer_size))
        feedforward.add(nn.PReLU())

        # add hidden layers
        for i in range(self.hidden_layer_count - 1):
            feedforward.add(
                nn.Linear(self.hidden_layer_size, self.hidden_layer_size))
            feedforward.add(nn.PReLU())

        feedforward.add(nn.Linear(self.hidden_layer_size, self.output_size))

        # [1.2] right part, discard pot_size, produce r1, r2
        right_part = nn.Sequential()
        right_part.add(nn.Narrow(1, 0, self.output_size))

        first_layer.add(feedforward)
        first_layer.add(right_part)

        # [2.0] outer net force counterfactual values satisfy 0-sum property
        second_layer = nn.ConcatTable()

        # accept v1,v2; ignore r1, r2
        left_part2 = nn.Sequential()
        left_part2.add(nn.SelectTable(0))

        # accept, r1,r2, v1,v2; produce -0.5k=-0.5(r1v1 + r2v2)
        right_part2 = nn.Sequential()
        right_part2.add(nn.DotProduct())
        right_part2.add(nn.Unsqueeze(1))
        right_part2.add(nn.Replicate(self.output_size, 1))
        right_part2.add(nn.Squeeze(2))
        right_part2.add(nn.MulConstant(-0.5))

        second_layer.add(left_part2)
        second_layer.add(right_part2)

        final_mlp = nn.Sequential()
        final_mlp.add(first_layer)
        final_mlp.add(second_layer)
        # accept v1,v2 and -0.5k, product v1-0.5k, v2-0.5k
        final_mlp.add(nn.CAddTable())

        return final_mlp