示例#1
0
    def test_image_resize(self):
        """Test the ETLDataReader.read_dataset_file method with different resizing options.
        """

        print("started: test_image_resize")

        size_in = [(12, 12), (12, 37), (-1, 12), (35, 0)]

        correct_out = [(12, 12, 1), (12, 37, 1), (63, 64, 1), (63, 64, 1)]

        reader = ETLDataReader(os.path.join(os.getcwd(), "etl_data_set"))
        imgs, labels = [], []

        for i in range(3):
            _imgs, _labels = reader.read_dataset_file(1,
                                                      ETLDataNames.ETL1,
                                                      [ETLCharacterGroups.all],
                                                      resize=size_in[i])
            imgs.append(_imgs)
            #labels.append(_labels)

        for i in range(3):
            #compare the byte representation
            self.assertEqual(imgs[i][0].shape, correct_out[i])

        print("finished: test_image_resize")
示例#2
0
    def test_read_dataset_file(self):
        """Test the ETLDataReader.read_dataset_file method.
        """

        print("started: test_read_dataset_file")

        reader = ETLDataReader(os.path.join(os.getcwd(), "etl_data_set"))

        imgs, labels = [], []

        for name in ETLDataNames:
            _imgs, _labels = reader.read_dataset_file(1, name,
                                                      [ETLCharacterGroups.all])
            labels.append(_labels)

        print(labels)
        correct_labels = [
            "0", "上", "0", "あ", "ア", "ア", "ア", "あ", "あ", "あ", "亜"
        ]
        for i in range(11):
            #compare the byte representation
            self.assertEqual(str.encode(labels[i][0]),
                             str.encode(correct_labels[i]))

        print("finished: test_read_dataset_file")
示例#3
0
    def test_read_dataset_whole_parallel(self):
        """Test the ETLDataReader.read_dataset_whole method in parallel mode.
        """

        print("started: test_read_dataset_whole_parallel")

        reader = ETLDataReader(os.path.join(os.getcwd(), "etl_data_set"))

        t_1_1 = time.perf_counter()
        _imgs_1, _labels_1 = reader.read_dataset_whole(
            [ETLCharacterGroups.all])
        t_1_2 = time.perf_counter()

        t_2_1 = time.perf_counter()
        _imgs_2, _labels_2 = reader.read_dataset_whole(
            [ETLCharacterGroups.all], mp.cpu_count())
        t_2_2 = time.perf_counter()

        time_1 = t_1_2 - t_1_1
        time_2 = t_2_2 - t_2_1

        print("running with 1 process in", time_1)
        print("running with", mp.cpu_count(), "processes in", time_2)
        print("absolute difference:", time_1 - time_2)
        print("speedup:", time_1 / time_2)
        print("efficiency:", time_2 / mp.cpu_count())

        self.assertEqual(len(_labels_2), len(_labels_1))

        print("finished: test_read_dataset_whole_parallel")
示例#4
0
    def test_read_dataset_part(self):
        """Test the ETLDataReader.read_dataset_part method.
        """

        print("started: test_read_dataset_part")

        reader = ETLDataReader(os.path.join(os.getcwd(), "etl_data_set"))

        _imgs, _labels = reader.read_dataset_part(ETLDataNames.ETL1,
                                                  [ETLCharacterGroups.all])

        self.assertEqual(len(_labels), 141251)

        print("finished: test_read_dataset_part")
示例#5
0
    def test_read_dataset_selection(self):
        """Test the ETLDataReader.read_dataset_file method with filtering.
        """

        print("started: test_read_dataset_selection")

        reader = ETLDataReader(os.path.join(os.getcwd(), "etl_data_set"))

        # test all filter with mixed data set file
        _imgs, _labels = reader.read_dataset_file(1, ETLDataNames.ETL1,
                                                  [ETLCharacterGroups.number])
        self.assertEqual(len(_labels), 11530)
        self.assertEqual(len(_imgs), 11530)
        # test number filter
        _imgs, _labels = reader.read_dataset_file(1, ETLDataNames.ETL1,
                                                  [ETLCharacterGroups.all])
        print(len(_imgs), len(_labels))
        self.assertEqual(len(_labels), 11530)
        self.assertEqual(len(_imgs), 11530)
        # test number roman latter filter
        _imgs, _labels = reader.read_dataset_file(3, ETLDataNames.ETL1,
                                                  [ETLCharacterGroups.roman])
        self.assertEqual(len(_labels), 11558)
        self.assertEqual(len(_imgs), 11558)
        # test symbol filter
        _imgs, _labels = reader.read_dataset_file(6, ETLDataNames.ETL1,
                                                  [ETLCharacterGroups.symbols])
        self.assertEqual(len(_labels), 11554)
        self.assertEqual(len(_imgs), 11554)
        # test kanji filter
        _imgs, _labels = reader.read_dataset_file(1, ETLDataNames.ETL8G,
                                                  [ETLCharacterGroups.kanji])
        self.assertEqual(len(_labels), 4405)
        self.assertEqual(len(_imgs), 4405)
        # test hiragana filter
        _imgs, _labels = reader.read_dataset_file(
            1, ETLDataNames.ETL4, [ETLCharacterGroups.hiragana])
        self.assertEqual(len(_labels), 6120)
        self.assertEqual(len(_imgs), 6120)
        # test katakana filter
        _imgs, _labels = reader.read_dataset_file(
            1, ETLDataNames.ETL5, [ETLCharacterGroups.katakana])
        self.assertEqual(len(_labels), 10608)
        self.assertEqual(len(_imgs), 10608)
        # test *implicit* all filter with mixed data set file
        _imgs, _labels = reader.read_dataset_file(5, ETLDataNames.ETL1)
        self.assertEqual(len(_labels), 11545)
        self.assertEqual(len(_imgs), 11545)

        print("finished: test_read_dataset_selection")
示例#6
0
    def test_image_normalizing(self):
        """Test the ETLDataReader.read_dataset_file method with normalizing.
        """

        print("started: test_image_normalizing")

        reader = ETLDataReader(os.path.join(os.getcwd(), "etl_data_set"))

        _imgs, _labels = reader.read_dataset_file(1,
                                                  ETLDataNames.ETL1,
                                                  [ETLCharacterGroups.all],
                                                  normalize=True)

        self.assertTrue(_imgs[0].max() <= 1.0)

        print("finished: test_image_normalizing")
示例#7
0
def  load_the_whole_data_set_parallel(reader : ETLDataReader):
    """The third example of the README.

    Args:
        reader : ETLDataReader instance to load the data set part.
    """

    from etldr.etl_character_groups import ETLCharacterGroups

    include = [ETLCharacterGroups.roman, ETLCharacterGroups.symbols]

    imgs, labels = reader.read_dataset_whole(include, 16)
示例#8
0
def load_one_data_set_file(reader : ETLDataReader):
    """The first example of the README.

    Args:
        reader : ETLDataReader instance to load the data set part.
    """

    from etldr.etl_data_names import ETLDataNames
    from etldr.etl_character_groups import ETLCharacterGroups

    include = [ETLCharacterGroups.katakana, ETLCharacterGroups.number]

    imgs, labels = reader.read_dataset_file(2, ETLDataNames.ETL7, include)
示例#9
0
def load_one_data_set_part_parallel(reader : ETLDataReader):
    """The second example of the README.

    Args:
        reader : ETLDataReader instance to load the data set part.
    """

    from etldr.etl_data_names import ETLDataNames
    from etldr.etl_character_groups import ETLCharacterGroups

    include = [ETLCharacterGroups.kanji, ETLCharacterGroups.hiragana]

    imgs, labels = reader.read_dataset_part(ETLDataNames.ETL2, include, 16)
示例#10
0
    imgs, labels = reader.read_dataset_whole(include)


def  load_the_whole_data_set_parallel(reader : ETLDataReader):
    """The third example of the README.

    Args:
        reader : ETLDataReader instance to load the data set part.
    """

    from etldr.etl_character_groups import ETLCharacterGroups

    include = [ETLCharacterGroups.roman, ETLCharacterGroups.symbols]

    imgs, labels = reader.read_dataset_whole(include, 16)



if __name__ == "__main__":
    path_to_data_set = r"F:\data_sets\ETL_kanji"

    reader = ETLDataReader(path_to_data_set)

    # uncomment one of these examples
    #load_one_data_set_file(reader)
    #load_one_data_set_part(reader)
    #load_the_whole_data_set(reader)
    #load_one_data_set_part_parallel(reader)
    #load_the_whole_data_set_parallel(reader)