def setUp(self):
     """Called before each test."""
     self.svdmimo_object = SVDMimo()
class SVDMimoTestCase(unittest.TestCase):
    def setUp(self):
        """Called before each test."""
        self.svdmimo_object = SVDMimo()

    def test_encode(self):
        # xxxxxxxxxx test the case with Ntx=2 xxxxxxxxxxxxxxxxxxxxxxxxxxxxx
        Nt = 2
        Nr = 2
        data = np.r_[0:15*Nt]
        data_aux = data.reshape(Nt, -1)
        channel = randn_c(Nr, Nt)
        self.svdmimo_object.set_channel_matrix(channel)

        encoded_data = self.svdmimo_object.encode(data)

        _, _, V_H = np.linalg.svd(channel)
        W = V_H.conj().T / math.sqrt(Nt)
        expected_encoded_data = W.dot(data_aux)
        np.testing.assert_array_almost_equal(
            expected_encoded_data, encoded_data)
        # xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx

        # xxxxxxxxxx test the case with Ntx=4 xxxxxxxxxxxxxxxxxxxxxxxxxxxxx
        Nt = 4
        Nr = 4
        data = np.r_[0:15*Nt]
        data_aux = data.reshape(Nt, -1)
        channel = randn_c(Nr, Nt)
        self.svdmimo_object.set_channel_matrix(channel)

        encoded_data = self.svdmimo_object.encode(data)

        _, _, V_H = np.linalg.svd(channel)
        W = V_H.conj().T / math.sqrt(Nt)
        expected_encoded_data = W.dot(data_aux)
        np.testing.assert_array_almost_equal(
            expected_encoded_data, encoded_data)
        # xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx

        # xxxxx Test if an exception is raised for wrong size xxxxxxxxxxxxx
        # The exception is raised if the input array size is not a multiple
        # of the number of transmit antennas
        data2 = np.r_[0:15*Nt+1]
        with self.assertRaises(ValueError):
            self.svdmimo_object.encode(data2)
        # xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx

    def test_decode(self):
        # xxxxxxxxxx test the case with Ntx=2, NRx=2 xxxxxxxxxxxxxxxxxxxxxx
        Nt = 2
        Nr = 2
        data = np.r_[0:15*Nt]
        channel = randn_c(Nr, Nt)
        self.svdmimo_object.set_channel_matrix(channel)

        encoded_data = self.svdmimo_object.encode(data)
        received_data = channel.dot(encoded_data)
        decoded_data = self.svdmimo_object.decode(received_data)
        np.testing.assert_array_almost_equal(data, decoded_data)
        # xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx

    def test_calc_post_processing_SINRs(self):
        Nr = 3
        Nt = 3
        noise_var = 0.01
        channel = randn_c(Nr, Nt)
        self.svdmimo_object.set_channel_matrix(channel)

        W = self.svdmimo_object._calc_precoder(channel)
        G_H = self.svdmimo_object._calc_receive_filter(channel, noise_var)
        expected_sinrs = linear2dB(calc_SINRs(channel, W, G_H, noise_var))

        # Calculate the SINR using the function in the mimo module. Note
        # that we need to pass the channel, the precoder, the receive
        # filter and the noise variance.
        sinrs = calc_post_processing_SINRs(channel, W, G_H, noise_var)
        np.testing.assert_array_almost_equal(sinrs, expected_sinrs, 2)

        # Calculate the SINR using method in the MIMO class. Note that we
        # only need to pass the noise variance, since the mimo object knows
        # the channel and it can calculate the precoder and receive filter.
        sinrs_other = self.svdmimo_object.calc_linear_SINRs(noise_var)
        np.testing.assert_array_almost_equal(sinrs_other, expected_sinrs, 2)