def testWorksCorrectlyNoBatches(self): with self.test_session(): self.assertAllEqual( [[4., 8., 0., 0.], [1., 5., 9., 0.], [0., 2., 6., 10.], [0., 0., 3, 7.]], distribution_util.tridiag([1., 2., 3.], [4., 5., 6., 7.], [8., 9., 10.]).eval())
def testHandlesNone(self): with self.test_session(): self.assertAllClose( [[[4., 0., 0., 0.], [0., 5., 0., 0.], [0., 0., 6., 0.], [0., 0., 0, 7.]], [[0.7, 0.0, 0.0, 0.0], [0.0, 0.6, 0.0, 0.0], [0.0, 0.0, 0.5, 0.0], [0.0, 0.0, 0.0, 0.4]]], distribution_util.tridiag( diag=[[4., 5., 6., 7.], [0.7, 0.6, 0.5, 0.4]]).eval(), rtol=1e-5, atol=0.)
def testWorksCorrectlyNoBatches(self): with self.test_session(): self.assertAllEqual( [[4., 8., 0., 0.], [1., 5., 9., 0.], [0., 2., 6., 10.], [0., 0., 3, 7.]], distribution_util.tridiag( [1., 2., 3.], [4., 5., 6., 7.], [8., 9., 10.]).eval())
def testWorksCorrectlyBatches(self): with self.test_session(): self.assertAllClose([[[4., 8., 0., 0.], [1., 5., 9., 0.], [0., 2., 6., 10.], [0., 0., 3, 7.]], [[0.7, 0.1, 0.0, 0.0], [0.8, 0.6, 0.2, 0.0], [0.0, 0.9, 0.5, 0.3], [0.0, 0.0, 1.0, 0.4]]], distribution_util.tridiag( [[1., 2., 3.], [0.8, 0.9, 1.]], [[4., 5., 6., 7.], [0.7, 0.6, 0.5, 0.4]], [[8., 9., 10.], [0.1, 0.2, 0.3]]).eval(), rtol=1e-5, atol=0.)
def testHandlesNone(self): with self.test_session(): self.assertAllClose( [[[4., 0., 0., 0.], [0., 5., 0., 0.], [0., 0., 6., 0.], [0., 0., 0, 7.]], [[0.7, 0.0, 0.0, 0.0], [0.0, 0.6, 0.0, 0.0], [0.0, 0.0, 0.5, 0.0], [0.0, 0.0, 0.0, 0.4]]], distribution_util.tridiag( diag=[[4., 5., 6., 7.], [0.7, 0.6, 0.5, 0.4]]).eval(), rtol=1e-5, atol=0.)
def testWorksCorrectlyBatches(self): with self.test_session(): self.assertAllClose( [[[4., 8., 0., 0.], [1., 5., 9., 0.], [0., 2., 6., 10.], [0., 0., 3, 7.]], [[0.7, 0.1, 0.0, 0.0], [0.8, 0.6, 0.2, 0.0], [0.0, 0.9, 0.5, 0.3], [0.0, 0.0, 1.0, 0.4]]], distribution_util.tridiag( [[1., 2., 3.], [0.8, 0.9, 1.]], [[4., 5., 6., 7.], [0.7, 0.6, 0.5, 0.4]], [[8., 9., 10.], [0.1, 0.2, 0.3]]).eval(), rtol=1e-5, atol=0.)