class Benchmark2PeriodRadarsIMM2TestCase(Benchmark2PeriodRadarsTestEnv,
                                         unittest.TestCase):
    def setUp(self):
        # Radar & States generation
        self.setUp_radar_states()
        # Filter definition
        ## Classical models
        self.radar_filter_ca = MultiplePeriodRadarsFilterCA(dim_x=9,
                                                            dim_z=3,
                                                            q=100.,
                                                            radars=self.radars)
        self.radar_filter_cv = MultiplePeriodRadarsFilterCV(dim_x=9,
                                                            dim_z=3,
                                                            q=100.,
                                                            radars=self.radars)
        ## IMM with ca, cv and ct models
        filters = [self.radar_filter_cv, self.radar_filter_ca]
        mu = [0.5, 0.5]
        trans = np.array([[0.999, 0.001], [0.001, 0.999]])
        self.radar_filter = RadarIMM(filters=filters, mu=mu, M=trans)

        # Benchmark definition
        self.benchmark = Benchmark(radars=self.radars,
                                   radar_filter=self.radar_filter,
                                   states=self.states)

    def test_initialization_is_imm(self):
        self.assertTrue(self.benchmark.filter_is_imm)

    def test_process_filter_computes_probs(self):
        self.benchmark.gen_data_set()
        self.benchmark.process_filter(with_nees=True)
        self.assertEqual(np.shape(self.benchmark.probs),
                         (self.len_elements, 2))
class Benchmark1RadarIMM4TestCase(Benchmark1RadarTestEnv,unittest.TestCase):
    def setUp(self):
        # Radar & States generation
        self.setUp_radar_states()
        # Filter definition
        ## Classical models
        self.radar_filter_ca = RadarFilterCA(dim_x = 9, dim_z = 3, q = 100., radar = self.radar)
        self.radar_filter_cv = RadarFilterCV(dim_x = 9, dim_z = 3, q = 100., radar = self.radar)
        self.radar_filter_ct = RadarFilterCT(dim_x = 9, dim_z = 3, q = 100., radar = self.radar)
        self.radar_filter_ta = RadarFilterTA(dim_x = 9, dim_z = 3, q = 100., radar = self.radar)
        ## IMM with ca, cv and ct models
        filters = [self.radar_filter_cv, self.radar_filter_ca, self.radar_filter_ct, self.radar_filter_ta]
        mu = [0.25, 0.25, 0.25, 0.25]
        trans = np.array([[0.997, 0.001, 0.001, 0.001],
                          [0.050, 0.850, 0.050, 0.050],
                          [0.001, 0.001, 0.997, 0.001],
                          [0.001, 0.001, 0.001, 0.997]])
        self.radar_filter = RadarIMM(filters = filters, mu = mu, M = trans)

        # Benchmark definitions
        self.benchmark = Benchmark(radars = self.radar, radar_filter = self.radar_filter, states = self.states)

    def test_initialization_is_imm(self):
        self.assertTrue(self.benchmark.filter_is_imm)

    def test_process_filter_computes_probs(self):
        self.benchmark.gen_data_set()
        self.benchmark.process_filter(with_nees = True)
        self.assertEqual(np.shape(self.benchmark.probs), (100,4))