def test_ConcatTable(self): input = [ torch.randn(3, 4).float(), torch.randn(3, 4).float(), [torch.randn(3, 4).float()] ] _gradOutput = [ torch.randn(3, 3, 4).float(), torch.randn(3, 3, 4).float(), torch.randn(3, 3, 4).float() ] gradOutput = [ [_gradOutput[0][0], _gradOutput[1][0], [_gradOutput[2][0]]], [_gradOutput[0][1], _gradOutput[1][1], [_gradOutput[2][1]]], [_gradOutput[0][2], _gradOutput[1][2], [_gradOutput[2][2]]] ] module = nn.ConcatTable() module.add(nn.Identity()) module.add(nn.Identity()) module.add(nn.Identity()) module.float() # Check that these don't raise errors module.__repr__() str(module) output = module.forward(input) output2 = [input, input, input] self.assertEqual(output2, output) gradInput = module.backward(input, gradOutput) gradInput2 = [ _gradOutput[0].sum(0, keepdim=False), _gradOutput[1].sum(0, keepdim=False), [_gradOutput[2].sum(0, keepdim=False)] ] self.assertTrue(isinstance(gradInput, list)) self.assertFalse(isinstance(gradInput[0], list)) self.assertFalse(isinstance(gradInput[1], list)) self.assertTrue(isinstance(gradInput[2], list)) self.assertEqual(len(gradInput), 3) self.assertEqual(len(gradInput[2]), 1) for t1, t2 in zip(iter_tensors(gradInput), iter_tensors(gradInput2)): self.assertEqual(t1, t2) # test outputs for variable length inputs test = nn.ConcatTable() test.add(nn.Identity()) test.add(nn.Identity()) x = [torch.randn(5), torch.randn(5)] y = [torch.randn(5)] o1 = len(test.forward(x)) go1 = len(test.backward(x, [x, x])) o2 = len(test.forward(y)) go2 = len(test.backward(y, [y, y])) self.assertEqual(o1, 2) self.assertEqual(go1, 2) self.assertEqual(o2, 2) self.assertEqual(go2, 1)
def build_net(self): # [1.0] first layer first_layer = nn.ConcatTable() # [1.1] feed forward neural net, produce v1, v2 feedforward = nn.Sequential() feedforward.add(nn.Linear(self.input_size, self.hidden_layer_size)) feedforward.add(nn.PReLU()) # add hidden layers for i in range(self.hidden_layer_count - 1): feedforward.add( nn.Linear(self.hidden_layer_size, self.hidden_layer_size)) feedforward.add(nn.PReLU()) feedforward.add(nn.Linear(self.hidden_layer_size, self.output_size)) # [1.2] right part, discard pot_size, produce r1, r2 right_part = nn.Sequential() right_part.add(nn.Narrow(1, 0, self.output_size)) first_layer.add(feedforward) first_layer.add(right_part) # [2.0] outer net force counterfactual values satisfy 0-sum property second_layer = nn.ConcatTable() # accept v1,v2; ignore r1, r2 left_part2 = nn.Sequential() left_part2.add(nn.SelectTable(0)) # accept, r1,r2, v1,v2; produce -0.5k=-0.5(r1v1 + r2v2) right_part2 = nn.Sequential() right_part2.add(nn.DotProduct()) right_part2.add(nn.Unsqueeze(1)) right_part2.add(nn.Replicate(self.output_size, 1)) right_part2.add(nn.Squeeze(2)) right_part2.add(nn.MulConstant(-0.5)) second_layer.add(left_part2) second_layer.add(right_part2) final_mlp = nn.Sequential() final_mlp.add(first_layer) final_mlp.add(second_layer) # accept v1,v2 and -0.5k, product v1-0.5k, v2-0.5k final_mlp.add(nn.CAddTable()) return final_mlp