def test_logsumexp(self): mat = torch.FloatTensor([[-1, 0, 1, 10, 50], [-30, -20, 0, 20, 30], [10, 20, 30, 40, 50], [0, 0, 0, 0, 0]]) result = lmu.logsumexp(mat, keep_mask=None, add_one=False, dim=1) correct_result = torch.logsumexp(mat, dim=1, keepdim=True) self.assertTrue(torch.equal(result, correct_result)) result = lmu.logsumexp(mat, keep_mask=None, add_one=True, dim=1) correct_result = torch.logsumexp(torch.cat( [mat, torch.zeros(mat.size(0)).unsqueeze(1)], dim=1), dim=1, keepdim=True) self.assertTrue(torch.equal(result, correct_result)) keep_mask = torch.FloatTensor([[1, 1, 0, 0, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0], [0, 1, 1, 0, 0]]) result = lmu.logsumexp(mat, keep_mask=keep_mask, add_one=False, dim=1) row0 = torch.log(torch.sum(torch.exp(torch.FloatTensor( [-1, 0])))).unsqueeze(0) row1 = torch.log(torch.sum(torch.exp(torch.FloatTensor( [-20, 0, 20])))).unsqueeze(0) row2 = torch.FloatTensor([0.]) row3 = torch.log(torch.sum(torch.exp(torch.FloatTensor( [0, 0])))).unsqueeze(0) correct_result = torch.stack([row0, row1, row2, row3], dim=0) self.assertTrue(torch.allclose(result, correct_result))
def test_logsumexp(self): for dtype in [torch.float16, torch.float32, torch.float64]: division_factor = 10 if dtype == torch.float16 else 1 mat = torch.tensor([[-1, 0, 1, 10, 50], [-30, -20, 0, 20, 30], [10, 20, 30, 40, 50], [0,0,0,0,0]], dtype=dtype).to(self.device) mat /= division_factor result = lmu.logsumexp(mat, keep_mask=None, add_one=False, dim=1) correct_result = torch.logsumexp(mat, dim=1, keepdim=True) self.assertTrue(torch.equal(result, correct_result)) result = lmu.logsumexp(mat, keep_mask=None, add_one=True, dim=1) correct_result = torch.logsumexp(torch.cat([mat, torch.zeros(mat.size(0),dtype=dtype).to(self.device).unsqueeze(1)], dim=1), dim=1, keepdim=True) self.assertTrue(torch.equal(result, correct_result)) keep_mask = torch.tensor([[1, 1, 0, 0, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0], [0, 1, 1, 0, 0]], dtype=dtype).to(self.device) result = lmu.logsumexp(mat, keep_mask=keep_mask, add_one=False, dim=1) row0_input = torch.tensor([-1, 0], dtype=dtype).to(self.device) / division_factor row1_input = torch.tensor([-20, 0, 20], dtype=dtype).to(self.device) / division_factor row3_input = torch.tensor([0, 0], dtype=dtype).to(self.device) / division_factor row0 = torch.log(torch.sum(torch.exp(row0_input))).unsqueeze(0) row1 = torch.log(torch.sum(torch.exp(row1_input))).unsqueeze(0) row2 = torch.tensor([0.], dtype=dtype).to(self.device) row3 = torch.log(torch.sum(torch.exp(row3_input))).unsqueeze(0) correct_result = torch.stack([row0, row1, row2, row3], dim=0) rtol = 1e-2 if dtype == torch.float16 else 1e-5 self.assertTrue(torch.allclose(result, correct_result, rtol=rtol))
def test_logsumexp(self): for dtype in TEST_DTYPES: rtol = 1e-2 if dtype == torch.float16 else 1e-5 mat = torch.tensor([[-1, 0, 1, 10, 50], [-300, -200, -100, -50, -20], [-300, -200, 0, 200, 300], [100, 200, 300, 400, 500], [0,0,0,0,0]], dtype=dtype, requires_grad=True).to(self.device) result = lmu.logsumexp(mat, keep_mask=None, add_one=False, dim=1) torch.mean(result).backward(retain_graph=True) correct_result = torch.logsumexp(mat, dim=1, keepdim=True) self.assertTrue(torch.allclose(result, correct_result, rtol=rtol)) result = lmu.logsumexp(mat, keep_mask=None, add_one=True, dim=1) torch.mean(result).backward(retain_graph=True) correct_result = torch.logsumexp(torch.cat([mat, torch.zeros(mat.size(0),dtype=dtype).to(self.device).unsqueeze(1)], dim=1), dim=1, keepdim=True) self.assertTrue(torch.allclose(result, correct_result, rtol=rtol)) keep_mask = torch.tensor([[1, 1, 0, 0, 0], [1, 1, 1, 1, 1], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0], [0, 1, 1, 0, 0]], dtype=torch.bool).to(self.device) result = lmu.logsumexp(mat, keep_mask=keep_mask, add_one=False, dim=1) torch.mean(result).backward() row0_input = torch.tensor([-1, 0], dtype=dtype).to(self.device) row1_input = torch.tensor([-300, -200, -100, -50, -20], dtype=dtype).to(self.device) row2_input = torch.tensor([-200, 0, 200], dtype=dtype).to(self.device) row4_input = torch.tensor([0, 0], dtype=dtype).to(self.device) row0 = torch.logsumexp(row0_input, dim=0).unsqueeze(0) row1 = torch.logsumexp(row1_input, dim=0).unsqueeze(0) row2 = torch.logsumexp(row2_input, dim=0).unsqueeze(0) row3 = torch.tensor([0.], dtype=dtype).to(self.device) row4 = torch.logsumexp(row4_input, dim=0).unsqueeze(0) correct_result = torch.stack([row0, row1, row2, row3, row4], dim=0) self.assertTrue(torch.allclose(result, correct_result, rtol=rtol))