Exemple #1
0
    def test_batch(self):
        """Tests batching."""

        btm = block_matrix.BlockTriangularMatrix(block_shape=(2, 2),
                                                 block_rows=2,
                                                 upper=False)
        output = btm(create_input(12, batch_size=2))
        with self.test_session() as sess:
            result = sess.run(output)

        self._check_output_size(btm, result, batch_size=2)

        expected = np.array([[[0, 1, 0, 0], [2, 3, 0, 0], [4, 5, 6, 7],
                              [8, 9, 10, 11]],
                             [[12, 13, 0, 0], [14, 15, 0, 0], [16, 17, 18, 19],
                              [20, 21, 22, 23]]])
        self.assertAllEqual(result, expected)
Exemple #2
0
  def test_upper(self):
    """Tests block upper-triangular matrix."""

    btm = block_matrix.BlockTriangularMatrix(
        block_shape=(2, 3), block_rows=3, upper=True)
    self.assertEqual(btm.num_blocks, 6)
    self.assertEqual(btm.block_size, 6)
    self.assertEqual(btm.input_size, 36)

    output = btm(create_input(btm.input_size))
    with self.test_session() as sess:
      result = sess.run(output)

    self._check_output_size(btm, result)

    expected = np.array([[[0, 1, 2, 3, 4, 5, 6, 7, 8],
                          [9, 10, 11, 12, 13, 14, 15, 16, 17],
                          [0, 0, 0, 18, 19, 20, 21, 22, 23],
                          [0, 0, 0, 24, 25, 26, 27, 28, 29],
                          [0, 0, 0, 0, 0, 0, 30, 31, 32],
                          [0, 0, 0, 0, 0, 0, 33, 34, 35]]])
    self.assertAllEqual(result, expected)
Exemple #3
0
  def test_lower_no_diagonal(self):
    """Tests block lower-triangular matrix without diagonal."""

    btm = block_matrix.BlockTriangularMatrix(
        block_shape=(2, 3), block_rows=3, include_diagonal=False)
    self.assertEqual(btm.num_blocks, 3)
    self.assertEqual(btm.block_size, 6)
    self.assertEqual(btm.input_size, 18)

    output = btm(create_input(btm.input_size))
    with self.test_session() as sess:
      result = sess.run(output)

    self._check_output_size(btm, result)

    expected = np.array([[[0, 0, 0, 0, 0, 0, 0, 0, 0],
                          [0, 0, 0, 0, 0, 0, 0, 0, 0],
                          [0, 1, 2, 0, 0, 0, 0, 0, 0],
                          [3, 4, 5, 0, 0, 0, 0, 0, 0],
                          [6, 7, 8, 9, 10, 11, 0, 0, 0],
                          [12, 13, 14, 15, 16, 17, 0, 0, 0]]])
    self.assertAllEqual(result, expected)