def test_batch(self): """Tests batching.""" btm = block_matrix.BlockTriangularMatrix(block_shape=(2, 2), block_rows=2, upper=False) output = btm(create_input(12, batch_size=2)) with self.test_session() as sess: result = sess.run(output) self._check_output_size(btm, result, batch_size=2) expected = np.array([[[0, 1, 0, 0], [2, 3, 0, 0], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 0, 0], [14, 15, 0, 0], [16, 17, 18, 19], [20, 21, 22, 23]]]) self.assertAllEqual(result, expected)
def test_upper(self): """Tests block upper-triangular matrix.""" btm = block_matrix.BlockTriangularMatrix( block_shape=(2, 3), block_rows=3, upper=True) self.assertEqual(btm.num_blocks, 6) self.assertEqual(btm.block_size, 6) self.assertEqual(btm.input_size, 36) output = btm(create_input(btm.input_size)) with self.test_session() as sess: result = sess.run(output) self._check_output_size(btm, result) expected = np.array([[[0, 1, 2, 3, 4, 5, 6, 7, 8], [9, 10, 11, 12, 13, 14, 15, 16, 17], [0, 0, 0, 18, 19, 20, 21, 22, 23], [0, 0, 0, 24, 25, 26, 27, 28, 29], [0, 0, 0, 0, 0, 0, 30, 31, 32], [0, 0, 0, 0, 0, 0, 33, 34, 35]]]) self.assertAllEqual(result, expected)
def test_lower_no_diagonal(self): """Tests block lower-triangular matrix without diagonal.""" btm = block_matrix.BlockTriangularMatrix( block_shape=(2, 3), block_rows=3, include_diagonal=False) self.assertEqual(btm.num_blocks, 3) self.assertEqual(btm.block_size, 6) self.assertEqual(btm.input_size, 18) output = btm(create_input(btm.input_size)) with self.test_session() as sess: result = sess.run(output) self._check_output_size(btm, result) expected = np.array([[[0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 1, 2, 0, 0, 0, 0, 0, 0], [3, 4, 5, 0, 0, 0, 0, 0, 0], [6, 7, 8, 9, 10, 11, 0, 0, 0], [12, 13, 14, 15, 16, 17, 0, 0, 0]]]) self.assertAllEqual(result, expected)