Exemplo n.º 1
0
  def test_properties(self):
    """Tests properties of BlockDiagonalMatrix."""

    bdm = block_matrix.BlockDiagonalMatrix(block_shape=(3, 5), block_rows=7)
    self.assertEqual(bdm.num_blocks, 7)
    self.assertEqual(bdm.block_size, 15)
    self.assertEqual(bdm.input_size, 105)
    self.assertEqual(bdm.output_shape, (21, 35))
    self.assertEqual(bdm.block_shape, (3, 5))
Exemplo n.º 2
0
  def test_default(self):
    """Tests BlockDiagonalMatrix."""

    bdm = block_matrix.BlockDiagonalMatrix(block_shape=(2, 3), block_rows=3)
    self.assertEqual(bdm.num_blocks, 3)
    self.assertEqual(bdm.block_size, 6)
    self.assertEqual(bdm.input_size, 18)

    output = bdm(create_input(bdm.input_size))
    with self.test_session() as sess:
      result = sess.run(output)

    expected = np.array([[[0, 1, 2, 0, 0, 0, 0, 0, 0],
                          [3, 4, 5, 0, 0, 0, 0, 0, 0],
                          [0, 0, 0, 6, 7, 8, 0, 0, 0],
                          [0, 0, 0, 9, 10, 11, 0, 0, 0],
                          [0, 0, 0, 0, 0, 0, 12, 13, 14],
                          [0, 0, 0, 0, 0, 0, 15, 16, 17]]])
    self.assertAllEqual(result, expected)