Esempio n. 1
0
    def test_logsumexp(self):
        mat = torch.FloatTensor([[-1, 0, 1, 10, 50], [-30, -20, 0, 20, 30],
                                 [10, 20, 30, 40, 50], [0, 0, 0, 0, 0]])
        result = lmu.logsumexp(mat, keep_mask=None, add_one=False, dim=1)
        correct_result = torch.logsumexp(mat, dim=1, keepdim=True)
        self.assertTrue(torch.equal(result, correct_result))

        result = lmu.logsumexp(mat, keep_mask=None, add_one=True, dim=1)
        correct_result = torch.logsumexp(torch.cat(
            [mat, torch.zeros(mat.size(0)).unsqueeze(1)], dim=1),
                                         dim=1,
                                         keepdim=True)
        self.assertTrue(torch.equal(result, correct_result))

        keep_mask = torch.FloatTensor([[1, 1, 0, 0, 0], [0, 1, 1, 1, 0],
                                       [0, 0, 0, 0, 0], [0, 1, 1, 0, 0]])
        result = lmu.logsumexp(mat, keep_mask=keep_mask, add_one=False, dim=1)

        row0 = torch.log(torch.sum(torch.exp(torch.FloatTensor(
            [-1, 0])))).unsqueeze(0)
        row1 = torch.log(torch.sum(torch.exp(torch.FloatTensor(
            [-20, 0, 20])))).unsqueeze(0)
        row2 = torch.FloatTensor([0.])
        row3 = torch.log(torch.sum(torch.exp(torch.FloatTensor(
            [0, 0])))).unsqueeze(0)
        correct_result = torch.stack([row0, row1, row2, row3], dim=0)
        self.assertTrue(torch.allclose(result, correct_result))
Esempio n. 2
0
    def test_logsumexp(self):
        for dtype in [torch.float16, torch.float32, torch.float64]:
            division_factor = 10 if dtype == torch.float16 else 1
            mat = torch.tensor([[-1, 0, 1, 10, 50],
                                    [-30, -20, 0, 20, 30],
                                    [10, 20, 30, 40, 50],
                                    [0,0,0,0,0]], dtype=dtype).to(self.device)
            mat /= division_factor
            result = lmu.logsumexp(mat, keep_mask=None, add_one=False, dim=1)
            correct_result = torch.logsumexp(mat, dim=1, keepdim=True)
            self.assertTrue(torch.equal(result, correct_result))

            result = lmu.logsumexp(mat, keep_mask=None, add_one=True, dim=1)
            correct_result = torch.logsumexp(torch.cat([mat, torch.zeros(mat.size(0),dtype=dtype).to(self.device).unsqueeze(1)], dim=1), dim=1, keepdim=True)
            self.assertTrue(torch.equal(result, correct_result))

            keep_mask = torch.tensor([[1, 1, 0, 0, 0],
                                            [0, 1, 1, 1, 0],
                                            [0, 0, 0, 0, 0],
                                            [0, 1, 1, 0, 0]], dtype=dtype).to(self.device)
            result = lmu.logsumexp(mat, keep_mask=keep_mask, add_one=False, dim=1)

            row0_input = torch.tensor([-1, 0], dtype=dtype).to(self.device) / division_factor
            row1_input = torch.tensor([-20, 0, 20], dtype=dtype).to(self.device) / division_factor
            row3_input = torch.tensor([0, 0], dtype=dtype).to(self.device) / division_factor

            row0 = torch.log(torch.sum(torch.exp(row0_input))).unsqueeze(0)
            row1 = torch.log(torch.sum(torch.exp(row1_input))).unsqueeze(0)
            row2 = torch.tensor([0.], dtype=dtype).to(self.device)
            row3 = torch.log(torch.sum(torch.exp(row3_input))).unsqueeze(0)
            correct_result = torch.stack([row0, row1, row2, row3], dim=0)
            rtol = 1e-2 if dtype == torch.float16 else 1e-5
            self.assertTrue(torch.allclose(result, correct_result, rtol=rtol))
    def test_logsumexp(self):
        for dtype in TEST_DTYPES:
            rtol = 1e-2 if dtype == torch.float16 else 1e-5
            mat = torch.tensor([[-1, 0, 1, 10, 50],
                                    [-300, -200, -100, -50, -20],
                                    [-300, -200, 0, 200, 300],
                                    [100, 200, 300, 400, 500],
                                    [0,0,0,0,0]], dtype=dtype, requires_grad=True).to(self.device)
            result = lmu.logsumexp(mat, keep_mask=None, add_one=False, dim=1)
            torch.mean(result).backward(retain_graph=True)
            correct_result = torch.logsumexp(mat, dim=1, keepdim=True)
            self.assertTrue(torch.allclose(result, correct_result, rtol=rtol))


            result = lmu.logsumexp(mat, keep_mask=None, add_one=True, dim=1)
            torch.mean(result).backward(retain_graph=True)
            correct_result = torch.logsumexp(torch.cat([mat, torch.zeros(mat.size(0),dtype=dtype).to(self.device).unsqueeze(1)], dim=1), dim=1, keepdim=True)
            self.assertTrue(torch.allclose(result, correct_result, rtol=rtol))

            keep_mask = torch.tensor([[1, 1, 0, 0, 0],
                                            [1, 1, 1, 1, 1],
                                            [0, 1, 1, 1, 0],
                                            [0, 0, 0, 0, 0],
                                            [0, 1, 1, 0, 0]], dtype=torch.bool).to(self.device)
            result = lmu.logsumexp(mat, keep_mask=keep_mask, add_one=False, dim=1)
            torch.mean(result).backward()

            row0_input = torch.tensor([-1, 0], dtype=dtype).to(self.device)
            row1_input = torch.tensor([-300, -200, -100, -50, -20], dtype=dtype).to(self.device)
            row2_input = torch.tensor([-200, 0, 200], dtype=dtype).to(self.device)
            row4_input = torch.tensor([0, 0], dtype=dtype).to(self.device)

            row0 = torch.logsumexp(row0_input, dim=0).unsqueeze(0)
            row1 = torch.logsumexp(row1_input, dim=0).unsqueeze(0)
            row2 = torch.logsumexp(row2_input, dim=0).unsqueeze(0)
            row3 = torch.tensor([0.], dtype=dtype).to(self.device)
            row4 = torch.logsumexp(row4_input, dim=0).unsqueeze(0)
            correct_result = torch.stack([row0, row1, row2, row3, row4], dim=0)
            self.assertTrue(torch.allclose(result, correct_result, rtol=rtol))