Esempio n. 1
0
    def test_upsample(self):
        fractions = [(0, 1.3), (1, 0.5), (2, 0.8), (3, 9)]
        sampler = StratifiedSampler(fractions=fractions, method="upsample")
        tracker = Tracking("jobid", "guest", 9999, "abc", "123")
        sampler.set_tracker(tracker)
        sample_data, sample_ids = sampler.sample(self.table)
        new_data = list(sample_data.collect())
        count_label = [0 for i in range(4)]
        data_dict = dict(self.data)

        for id, inst in new_data:
            count_label[inst.label] += 1
            self.assertTrue(type(id).__name__ == 'int' and id >= 0 and id < len(sample_ids))
            real_id = sample_ids[id]
            self.assertTrue(inst.label == self.data[real_id][1].label and
                            inst.features == self.data[real_id][1].features)

        for i in range(4):
            self.assertTrue(np.abs(count_label[i] - 250 * fractions[i][1]) < 10)
        
        trans_sampler = StratifiedSampler(method="upsample")
        trans_sampler.set_tracker(tracker)
        trans_sample_data = trans_sampler.sample(self.table_trans, sample_ids)
        trans_data = (trans_sample_data.collect())
        trans_sample_ids = [id for (id, value) in trans_data]
        data_to_trans_dict = dict(self.data_to_trans)

        self.assertTrue(sorted(trans_sample_ids) == list(range(len(sample_ids))))
        for id, inst in trans_data:
            real_id = sample_ids[id]
            self.assertTrue(inst.features == self.data_to_trans_dict[real_id][1].features)
Esempio n. 2
0
    def test_downsample(self):
        fractions = [(0, 0.3), (1, 0.4), (2, 0.5), (3, 0.8)]
        sampler = StratifiedSampler(fractions=fractions, method="downsample")
        tracker = Tracking("jobid", "guest", 9999, "abc", "123")
        sampler.set_tracker(tracker)
        sample_data, sample_ids = sampler.sample(self.table)
        count_label = [0 for i in range(4)]
        new_data = list(sample_data.collect())
        data_dict = dict(self.data)
        self.assertTrue(set(sample_ids) & set(data_dict.keys()) == set(sample_ids))

        for id, inst in new_data:
            count_label[inst.label] += 1
            self.assertTrue(type(id).__name__ == 'int' and id >= 0 and id < 1000)
            self.assertTrue(inst.label == self.data[id][1].label and inst.features == self.data[id][1].features)

        for i in range(4):
            self.assertTrue(np.abs(count_label[i] - 250 * fractions[i][1]) < 10)

        trans_sampler = StratifiedSampler(method="downsample")
        trans_sampler.set_tracker(tracker)
        trans_sample_data = trans_sampler.sample(self.table_trans, sample_ids)
        trans_data = list(trans_sample_data.collect())
        trans_sample_ids = [id for (id, value) in trans_data]
        data_to_trans_dict = dict(self.data_to_trans)
     
        self.assertTrue(set(trans_sample_ids) == set(sample_ids))
        for id, inst in trans_data:
            self.assertTrue(inst.features == data_to_trans_dict.get(id).features)