def test_merge(self):
     file1 = os.path.join(self.tmpdir, "mergeA.bin")
     file2 = os.path.join(self.tmpdir, "mergeB.bin")
     dataset1 = BinaryDs(file1, features=14).open()
     dataset1.write(self.data_raw)
     dataset2 = BinaryDs(file2, features=14).open()
     dataset2.write(self.data_raw2)
     self.assertEqual(dataset1.get_examples_no(), 3)
     self.assertEqual(dataset2.get_examples_no(), 8)
     dataset1.merge(dataset2)
     self.assertEqual(dataset1.get_examples_no(), 11)
     self.assertEqual(dataset2.get_examples_no(), 0)
     self.assertEqual(dataset1.read(0, 11), self.data_raw + self.data_raw2)
     dataset1.close()
     dataset2.close()
 def test_split(self):
     file1 = os.path.join(self.tmpdir, "splitA.bin")
     file2 = os.path.join(self.tmpdir, "splitB.bin")
     dataset1 = BinaryDs(file1, features=14).open()
     dataset1.write(self.data_raw2)
     dataset2 = BinaryDs(file2, features=14).open()
     self.assertEqual(dataset1.get_examples_no(), 8)
     self.assertEqual(dataset2.get_examples_no(), 0)
     dataset1.split(dataset2, 0.5)
     self.assertEqual(dataset1.get_examples_no(), 4)
     self.assertEqual(dataset2.get_examples_no(), 4)
     self.assertEqual(dataset1.read(0, 4), self.data_raw2[:4])
     self.assertEqual(dataset2.read(0, 4), self.data_raw2[4:])
     dataset1.close()
     dataset2.close()
Ejemplo n.º 3
0
def count_categories(dataset: BinaryDs) -> List[int]:
    examples = dataset.get_examples_no()
    amount = 1000
    read_total = int(examples / amount)
    remainder = examples % amount
    categories = []
    for i in range(read_total):
        buffer = dataset.read(i * amount, amount)
        for val in buffer:
            category = val[0]
            while len(categories) <= category:
                categories.append(0)
            categories[category] += 1
    if remainder > 0:
        buffer = dataset.read(read_total * amount, remainder)
        for val in buffer:
            category = val[0]
            while len(categories) <= category:
                categories.append(0)
            categories[category] += 1
    assert len(categories) == dataset.get_categories()
    return categories