Beispiel #1
0
class TestLongTensorEncoder(unittest.TestCase):
    def __init__(self, *args, **kwargs):
        super(TestLongTensorEncoder, self).__init__(*args, **kwargs)
        self.encoder = TreeEncoder(
            lambda x: torch.LongTensor([x[0], x[0] * 2, x[0] + 1]),
            lambda x: x[1])

    def test_encode_tree(self):
        tree = (42, [(17, []), (12, [(10, [(53, []), (15, [(8, [])])])]),
                     (28, [])])
        values, arities = self.encoder.encode(tree)
        self.assertEqual(values.size(), torch.Size((8, 3)))
        self.assertEqual(
            [list(value) for value in values],
            [[28, 56, 29], [8, 16, 9], [15, 30, 16], [53, 106, 54],
             [10, 20, 11], [12, 24, 13], [17, 34, 18], [42, 84, 43]])
        self.assertTrue(type(values) is torch.Tensor)
        self.assertTrue(type(arities) is torch.Tensor)

    def test_encode_batch(self):
        tree1 = (42, [(17, []), (12, [(10, [(53, []), (15, [(8, [])])])]),
                      (28, [])])
        tree2 = (12, [(17, [(8, [])]), (16, [(19, [(61, []), (10, [])])]),
                      (56, [])])
        values, arities = self.encoder.encode_batch([tree1, tree2])
        self.assertTrue(type(values) is torch.Tensor)
        self.assertTrue(type(arities) is torch.Tensor)
Beispiel #2
0
    def test_batch(self):
        net = TreeLSTM(3, 7, 2)
        encoder = TreeEncoder(lambda x: x[0], lambda x: x[1])
        tree1 = ((1, 2, 3), [((4, 5, 6), []), ((7, 8, 9), [])])
        tree2 = ((11, 12, 13), [((14, 15, 16), [((17, 18, 19), [])]),
                                ((20, 21, 22), [])])
        tree3 = ((21, 22, 23), [((24, 25, 26), [])])
        tree4 = ((31, 32, 33), [])

        inputs, arities = encoder.encode_batch([tree1, tree2, tree3, tree4])

        result = net.forward(inputs, arities)

        self.assertEqual(result.size(), torch.Size([4, 7]))
Beispiel #3
0
class TestTreeNetLinearSumUnit(unittest.TestCase):

    def __init__(self, *args, **kwargs):
        super(TestTreeNetLinearSumUnit, self).__init__(*args, **kwargs)

        self.encoder = TreeEncoder(lambda x: (x[0],), lambda x: x[1])

    def test_single(self):
        net = TreeNet(3, 4, LinearSumUnit(1, 3))
        tree = (12, [(1, []),
                     (14, [(17, []), (29, [])]),
                     (12, []),
                     (70, [])])

        inputs, arities = self.encoder.encode_batch([tree])
        outputs = net(inputs, arities)

        self.assertEqual(outputs.size(), torch.Size([1, 3]))

        target = torch.Tensor([[155, 12, 72]])
        loss = nn.functional.mse_loss(outputs, target)
        net.zero_grad()
        self.assertTrue(net.unit.fc1.bias.grad is None)
        loss.backward()
        self.assertTrue(net.unit.fc1.bias.grad is not None)
Beispiel #4
0
class TestTreeNetInterpreterUnit(unittest.TestCase):

    def __init__(self, *args, **kwargs):
        super(TestTreeNetInterpreterUnit, self).__init__(*args, **kwargs)

        def value(tree):
            out = torch.IntTensor(4)
            out.fill_(0)

            op = tree[0]
            if op == '+':
                out[0] = 1
            elif op == '*':
                out[1] = 1
            elif op == '-':
                out[2] = 1
            else:
                out[3] = op

            return out

        def children(tree):
            return tree[1]

        self.encoder = TreeEncoder(value, children)

    def test_batch(self):

        # 42  ==>  42
        tree1 = (42, [])

        # (3 * -4) + (10 - 9)  ==>  -11
        tree2 = ('+', [('*', [(3, []), ('-', [(4, [])])]),
                       ('-', [(10, []),
                              (9, [])])])

        # -(17 - (3 + 4)) * 2  ==>  -20
        tree3 = ('*', [('-', [('-', [(17, []),
                                     ('+', [(3, []),
                                            (4, [])])])]),
                       (2, [])])

        net = TreeNet(1, 2, InterpreterUnit())

        inputs, arities = self.encoder.encode_batch([tree1, tree2, tree3])
        outputs = net(inputs, arities)

        self.assertEqual(outputs.size(), torch.Size([3, 1]))

        self.assertEqual(outputs.data[0, 0], 42)
        self.assertEqual(outputs.data[1, 0], -11)
        self.assertEqual(outputs.data[2, 0], -20)
Beispiel #5
0
    def __init__(self, *args, **kwargs):
        super(TestTreeNetInterpreterUnit, self).__init__(*args, **kwargs)

        def value(tree):
            out = torch.IntTensor(4)
            out.fill_(0)

            op = tree[0]
            if op == '+':
                out[0] = 1
            elif op == '*':
                out[1] = 1
            elif op == '-':
                out[2] = 1
            else:
                out[3] = op

            return out

        def children(tree):
            return tree[1]

        self.encoder = TreeEncoder(value, children)
Beispiel #6
0
class TestEquivalenceTreeNetBasic(unittest.TestCase):

    def __init__(self, *args, **kwargs):
        super(TestEquivalenceTreeNetBasic, self).__init__(*args, **kwargs)
        self.encoder = TreeEncoder(lambda x: x[0], lambda x: x[1])

    def test_equivalence(self):
        def t(elems):
            return torch.Tensor(elems)
        tree = (t([1, 2, 3]), [(t([4, 5, 6]), []),
                               (t([7, 8, 9]), [(t([10, 11, 12]), [])]),
                               (t([13, 14, 15]), [(t([16, 17, 18]), []),
                                                  (t([19, 20, 21]),
                                                   [(t([22, 23, 24]), [])])])])

        unit = LinearSumUnit(3, 5)

        net = TreeNet(5, 3, unit)
        inputs, arities = self.encoder.encode_batch([tree])
        res_net = net(inputs, arities)

        basic = BasicTreeNet(5, 3, unit)
        res_basic = basic(tree)


        # Test forward direction.
        self.assertEqual(res_net.size(), res_basic.size())
        for i in range(res_net.size(1)):
            self.assertAlmostEqual(res_net.data[0, i], res_basic.data[0, i])


        # Test backward direction.
        target = torch.randn(1, 5)

        # Getting gradients after using net.
        unit.zero_grad()
        loss_net = nn.functional.mse_loss(res_net, target)
        loss_net.backward()
        grads_net = [p.grad.data.clone() for p in unit.parameters()]

        # Getting gradients after using basic.
        unit.zero_grad()
        loss_basic = nn.functional.mse_loss(res_basic, target)
        loss_basic.backward()
        grads_basic = [p.grad.data.clone() for p in unit.parameters()]

        # Checking that gradients are within a small epsilon of each other.
        epsilon = 0.001
        for grad_net, grad_basic in zip(grads_net, grads_basic):
            self.assertTrue(((grad_net - grad_basic).abs() < epsilon).all())
Beispiel #7
0
class TestTupleEncoder(unittest.TestCase):
    def __init__(self, *args, **kwargs):
        super(TestTupleEncoder, self).__init__(*args, **kwargs)
        self.encoder = TreeEncoder(lambda x: (x[0], x[0] * 2, x[0] + 1),
                                   lambda x: x[1])

    def test_encode_tree(self):
        tree = (42, [(17, []), (12, [(10, [(53, []), (15, [(8, [])])])]),
                     (28, [])])
        values, arities = self.encoder.encode(tree)
        self.assertEqual(values.size(), torch.Size((8, 3)))
        self.assertEqual(
            [list(value) for value in values],
            [[28, 56, 29], [8, 16, 9], [15, 30, 16], [53, 106, 54],
             [10, 20, 11], [12, 24, 13], [17, 34, 18], [42, 84, 43]])
Beispiel #8
0
class TestListEncoder(unittest.TestCase):
    def __init__(self, *args, **kwargs):
        super(TestListEncoder, self).__init__(*args, **kwargs)
        self.encoder = TreeEncoder(lambda x: [x[0], x[0] * 2], lambda x: x[1])

    def test_encode_single_node(self):
        tree = (42, [])
        values, arities = self.encoder.encode(tree)

        self.assertEqual(values.size(), torch.Size((1, 2)))
        self.assertEqual(values[0, 0], 42)
        self.assertEqual(values[0, 1], 84)

        self.assertEqual(arities.size(), torch.Size((1, )))
        self.assertEqual(arities[0], 0)

    def test_encode_tree(self):
        tree = (42, [(17, []), (12, [(10, [(53, []), (15, [(8, [])])])]),
                     (28, [])])
        values, arities = self.encoder.encode(tree)
        self.assertEqual(values.size(), torch.Size((8, 2)))
        self.assertEqual([list(value) for value in values],
                         [[28, 56], [8, 16], [15, 30], [53, 106], [10, 20],
                          [12, 24], [17, 34], [42, 84]])

        self.assertEqual(arities.size(), torch.Size((8, )))
        self.assertEqual(list(arities), [0, 0, 1, 0, 2, 1, 0, 3])

    def test_encode_batch(self):
        tree1 = (11, [(81, []), (9, [])])

        tree2 = (42, [(17, []), (12, [(10, [(53, []), (15, [(8, [])])])]),
                      (28, [])])

        tree3 = (18, [(19, []), (14, []), (11, [(79, [])]), (99, []),
                      (70, [])])

        values, arities = self.encoder.encode_batch([tree1, tree2, tree3])

        self.assertEqual(values.size(), torch.Size((8, 3, 2)))

        self.assertEqual([list(value) for value in values[:, 0, :]],
                         [[9, 18], [81, 162], [11, 22], [0, 0], [0, 0], [0, 0],
                          [0, 0], [0, 0]])

        self.assertEqual([list(value) for value in values[:, 1, :]],
                         [[28, 56], [8, 16], [15, 30], [53, 106], [10, 20],
                          [12, 24], [17, 34], [42, 84]])

        self.assertEqual([list(value) for value in values[:, 2, :]],
                         [[70, 140], [99, 198], [79, 158], [11, 22], [14, 28],
                          [19, 38], [18, 36], [0, 0]])

        self.assertEqual(arities.size(), torch.Size((8, 3)))

        self.assertEqual(list(arities[:, 0]), [0, 0, 2, -1, -1, -1, -1, -1])
        self.assertEqual(list(arities[:, 1]), [0, 0, 1, 0, 2, 1, 0, 3])
        self.assertEqual(list(arities[:, 2]), [0, 0, 0, 1, 0, 0, 5, -1])

    def test_encode_batch_batch_first(self):
        tree1 = (11, [(81, []), (9, [])])

        tree2 = (42, [(17, []), (12, [(10, [(53, []), (15, [(8, [])])])]),
                      (28, [])])

        tree3 = (18, [(19, []), (14, []), (11, [(79, [])]), (99, []),
                      (70, [])])

        values, arities = self.encoder.encode_batch([tree1, tree2, tree3],
                                                    batch_first=True)

        self.assertEqual(values.size(), torch.Size((3, 8, 2)))

        self.assertEqual([list(value) for value in values[0]],
                         [[9, 18], [81, 162], [11, 22], [0, 0], [0, 0], [0, 0],
                          [0, 0], [0, 0]])

        self.assertEqual([list(value) for value in values[1]],
                         [[28, 56], [8, 16], [15, 30], [53, 106], [10, 20],
                          [12, 24], [17, 34], [42, 84]])

        self.assertEqual([list(value) for value in values[2]],
                         [[70, 140], [99, 198], [79, 158], [11, 22], [14, 28],
                          [19, 38], [18, 36], [0, 0]])

        self.assertEqual(arities.size(), torch.Size((3, 8)))

        self.assertEqual(list(arities[0]), [0, 0, 2, -1, -1, -1, -1, -1])
        self.assertEqual(list(arities[1]), [0, 0, 1, 0, 2, 1, 0, 3])
        self.assertEqual(list(arities[2]), [0, 0, 0, 1, 0, 0, 5, -1])
Beispiel #9
0
 def __init__(self, *args, **kwargs):
     super(TestLongTensorEncoder, self).__init__(*args, **kwargs)
     self.encoder = TreeEncoder(
         lambda x: torch.LongTensor([x[0], x[0] * 2, x[0] + 1]),
         lambda x: x[1])
Beispiel #10
0
 def __init__(self, *args, **kwargs):
     super(TestTupleEncoder, self).__init__(*args, **kwargs)
     self.encoder = TreeEncoder(lambda x: (x[0], x[0] * 2, x[0] + 1),
                                lambda x: x[1])
Beispiel #11
0
 def __init__(self, *args, **kwargs):
     super(TestListEncoder, self).__init__(*args, **kwargs)
     self.encoder = TreeEncoder(lambda x: [x[0], x[0] * 2], lambda x: x[1])
Beispiel #12
0
    def __init__(self, *args, **kwargs):
        super(TestTreeNetLinearSumUnit, self).__init__(*args, **kwargs)

        self.encoder = TreeEncoder(lambda x: (x[0],), lambda x: x[1])
Beispiel #13
0
 def __init__(self, *args, **kwargs):
     super(TestEquivalenceTreeNetBasic, self).__init__(*args, **kwargs)
     self.encoder = TreeEncoder(lambda x: x[0], lambda x: x[1])