class TestHMM(unittest.TestCase):

    def setUp(self):
        """Sets the seed of the random number generator to produce the same tests.
        """
        np.random.seed(0)
        kwargs = {
            'width': 16,
            'length': 16,
            'rssi_range': 4,
            'n_beacons': 4,
            'init_pos': (0, 0)
        }
        self.model = HMM(**kwargs)
        self.T = 1 << 7  # Time steps
        actual_path, observations = self.model.generate_path(self.T)
        self.actual_path = actual_path
        self.observations = observations

    def test_predict(self):
        """Tests the method predict.
        """
        probs = self.model.predict(self.model.init_probs)
        err = np.sqrt(((probs / np.sum(probs) - predicted_probs / np.sum(predicted_probs)) ** 2).mean())
        self.assertAlmostEqual(0, err, delta=1e-9)

    def test_update(self):
        """Tests the method update.
        """
        probs = self.model.update(self.model.init_probs, self.observations[1])
        err = np.sqrt(((probs / np.sum(probs) - updated_probs / np.sum(updated_probs)) ** 2).mean())
        self.assertAlmostEqual(0, err, delta=1e-9)

    def test_monitor(self):
        """Tests the method monitor.
        """
        probs = self.model.monitor(self.T, self.observations)[-1]
        err = np.sqrt(((probs / np.sum(probs) - monitoring_probs / np.sum(monitoring_probs)) ** 2).mean())
        self.assertAlmostEqual(0, err, delta=1e-9)

    def test_postdict(self):
        """Tests the method postdict.
        """
        probs = self.model.postdict(self.model.init_probs)
        err = np.sqrt(((probs / np.sum(probs) - postdict_probs / np.sum(postdict_probs)) ** 2).mean())
        self.assertAlmostEqual(0, err, delta=1e-9)

    def test_backwards(self):
        """Tests the method backwards.
        """
        probs = self.model.backwards(self.T, self.observations)[0]
        err = np.sqrt(((probs / np.sum(probs) - backwards_probs / np.sum(backwards_probs)) ** 2).mean())
        self.assertAlmostEqual(0, err, delta=1e-9)

    def test_hindsight(self):
        """Tests the method hindsight.
        """
        probs = self.model.hindsight(self.T, self.observations)[self.T // 2]
        err = np.sqrt(((probs / np.sum(probs) - hindsight_probs / np.sum(hindsight_probs)) ** 2).mean())
        self.assertAlmostEqual(0, err, delta=1e-9)