def test_always_single_batch_shape(self): a = ltd.TermDynamic(torch.randn(1, 100, 2)) b = ltd.TermDynamic(torch.randn(1, 100, 2)) con = ltd.Always(ltd.EQ(a, b), 100) self.assertEqual(con.loss(0).shape, torch.Size([1])) self.assertEqual(con.satisfy(0).shape, torch.Size([1]))
def test_always_equal(self): a = ltd.TermDynamic(torch.ones(64, 10000, 7)) b = ltd.TermDynamic(torch.ones(64, 10000, 7)) con = ltd.Always(ltd.EQ(a, b), 10000) actual_loss = con.loss(0) self.assertEqual(actual_loss[0], 0.0) self.assertEqual(con.satisfy(0).all(), True)
def test_always_unequal_huge(self): a = ltd.TermDynamic(torch.ones(64, 10000, 7)) b = ltd.TermDynamic(torch.ones(64, 10000, 7) + 1000) con = ltd.Always(ltd.EQ(a, b), 10000) expected_loss = torch.ones(64) * np.sqrt(7000000) actual_loss = con.loss(0) self.assertAlmostEqual(actual_loss[0].item(), expected_loss[0].item(), places=1) self.assertEqual(con.satisfy(0).all(), False)
def test_always_unequal_stress(self): a = ltd.TermDynamic(torch.ones(64, 10000, 7)) b = ltd.TermDynamic(torch.ones(64, 10000, 7) + 1) con = ltd.Always(ltd.EQ(a, b), 10000) expected_loss = torch.ones(64) * np.sqrt(7) actual_loss = con.loss(0) # Note that the numerical errors start to mount up # In the end we can only get it to about 1 decimal place of accuracy self.assertAlmostEqual(actual_loss[0].item(), expected_loss[0].item(), places=4) self.assertEqual(con.satisfy(0).all(), False)