def create_datasets(self, db_name, dir_prefix, train_percent=0.6, validation_percent=0.2, test_percent=0.2): """ Splits into train, test and validation datasets and builds them. From the given tegaki database name. @precondition(train_percent + validation_percent + test_percent == 1.0) """ db_file = "unipen_db/" + db_name + ".chardb" charcol = CharacterCollection(db_file) num_chars = charcol.get_total_n_characters() print "total chars", num_chars chars = charcol.get_random_characters_gen(num_chars) train_size = int(num_chars * train_percent) validation_size = int(num_chars * validation_percent) if (train_percent + validation_percent + test_percent) == 1.0: # all the db is used test_size = num_chars - train_size - validation_size else: # only a fraction of the db is used test_size = int(num_chars * test_percent) print 'train set size:', train_size self._create_dataset(chars, train_size, dir_prefix + '_train_' + str(int(train_percent * 100)) + '.nc') print 'validation set size:', validation_size if validation_percent != 0.0: self._create_dataset(chars, validation_size, dir_prefix + '_validation_' + str(int(validation_percent * 100)) + '.nc') print 'test set size:', test_size if test_percent != 0.0: self._create_dataset(chars, test_size, dir_prefix + '_test_' + str(int(test_percent * 100)) + '.nc')
def testValidate(self): path = os.path.join(self.currdir, "data", "collection", "test.charcol") f = open(path) buf = f.read() f.close() invalid = \ """ <?xml version="1.0" encoding="UTF-8"?> <character> <utf8>防</utf8> <strokes> <stroke> </stroke> </strokes> </character> """ malformed = \ """ <?xml version="1.0" encoding="UTF-8"?> <character> """ try: self.assertTrue(CharacterCollection.validate(buf)) self.assertFalse(CharacterCollection.validate(invalid)) self.assertFalse(CharacterCollection.validate(malformed)) except NotImplementedError: sys.stderr.write("lxml missing!\n") pass
def setUp(self): self.currdir = os.path.dirname(os.path.abspath(__file__)) path = os.path.join(self.currdir, "data", "collection", "test.charcol") self.cc = CharacterCollection() self.cc.read(path) f = os.path.join(self.currdir, "data", "character.xml") self.c = Character() self.c.read(f)
class KVGXmlDictionaryReader(_XmlBase): def __init__(self): self._charcol = CharacterCollection() def get_character_collection(self): return self._charcol def _start_element(self, name, attrs): self._tag = name if self._first_tag: self._first_tag = False if self._tag != "kanjis": raise ValueError, "The very first tag should be <kanjis>" if self._tag == "kanji": self._writing = Writing() self._utf8 = attrs["midashi"].encode("UTF-8") if self._tag == "stroke": self._stroke = Stroke() if attrs.has_key("path"): self._stroke_svg = attrs["path"].encode("UTF-8") try: svg_parser = SVG_Parser(self._stroke_svg) svg_parser.parse() self._stroke.append_points(svg_parser.get_points()) except: sys.stderr.write("Something went wrong in this character: " + self._utf8 + "\n") else: print "Missing path in <stroke> element: " + self._utf8 def _end_element(self, name): if name == "kanji": char = Character() char.set_utf8(self._utf8) char.set_writing(self._writing) self._charcol.add_set(self._utf8) self._charcol.append_character(self._utf8, char) for s in ["_tag", "_stroke"]: if s in self.__dict__: del self.__dict__[s] if name == "stroke": self._writing.append_stroke(self._stroke) self._stroke = None self._tag = None def _char_data(self, data): if self._tag == "utf8": self._utf8 = data.encode("UTF-8") elif self._tag == "width": self._writing.set_width(int(data)) elif self._tag == "height": self._writing.set_height(int(data))
class KVGXmlDictionaryReader(_XmlBase): def __init__(self): self._charcol = CharacterCollection() def get_character_collection(self): return self._charcol def _start_element(self, name, attrs): self._tag = name if self._first_tag: self._first_tag = False if self._tag != "kanjivg": raise ValueError, "The very first tag should be <kanjivg>" if self._tag == "kanji": self._writing = Writing() self._utf8 = unichr(int(attrs["id"].split('_')[1], 16)).encode("UTF-8") if self._tag == "path": self._stroke = Stroke() if attrs.has_key("d"): self._stroke_svg = attrs["d"].encode("UTF-8") svg_parser = SVG_Parser(self._stroke_svg) svg_parser.parse() self._stroke.append_points(svg_parser.get_points()) else: sys.stderr.write("Missing data in <path> element: " + self._utf8 + "\n") def _end_element(self, name): if name == "kanji": char = Character() char.set_utf8(self._utf8) char.set_writing(self._writing) self._charcol.add_set(self._utf8) self._charcol.append_character(self._utf8, char) for s in ["_tag", "_stroke"]: if s in self.__dict__: del self.__dict__[s] if name == "path": self._writing.append_stroke(self._stroke) self._stroke = None self._tag = None def _char_data(self, data): if self._tag == "utf8": self._utf8 = data.encode("UTF-8") elif self._tag == "width": self._writing.set_width(int(data)) elif self._tag == "height": self._writing.set_height(int(data))
def changeDatabase(self): db_file = QtGui.QFileDialog.getOpenFileName(self, "Open database", QtCore.QDir.currentPath()) db_file = str(db_file) if db_file and os.path.splitext(db_file)[1] == '.chardb': charcol = CharacterCollection(db_file); print "chars in db:", charcol.get_total_n_characters() self.char_gen = charcol.get_random_characters_gen(charcol.get_total_n_characters()) self.random() else: self.char_gen = None
class KVGXmlDictionaryReader(_XmlBase): def __init__(self): self._charcol = CharacterCollection() def get_character_collection(self): return self._charcol def _start_element(self, name, attrs): self._tag = name if self._first_tag: self._first_tag = False if self._tag != "kanjivg": raise ValueError, "The very first tag should be <kanjivg>" if self._tag == "kanji": self._writing = Writing() self._utf8 = unichr(int(attrs["id"].split("_")[1], 16)).encode("UTF-8") if self._tag == "path": self._stroke = Stroke() if attrs.has_key("d"): self._stroke_svg = attrs["d"].encode("UTF-8") svg_parser = SVG_Parser(self._stroke_svg) svg_parser.parse() self._stroke.append_points(svg_parser.get_points()) else: sys.stderr.write("Missing data in <path> element: " + self._utf8 + "\n") def _end_element(self, name): if name == "kanji": char = Character() char.set_utf8(self._utf8) char.set_writing(self._writing) self._charcol.add_set(self._utf8) self._charcol.append_character(self._utf8, char) for s in ["_tag", "_stroke"]: if s in self.__dict__: del self.__dict__[s] if name == "path": self._writing.append_stroke(self._stroke) self._stroke = None self._tag = None def _char_data(self, data): if self._tag == "utf8": self._utf8 = data.encode("UTF-8") elif self._tag == "width": self._writing.set_width(int(data)) elif self._tag == "height": self._writing.set_height(int(data))
def _get_charcol(charcol_type, charcol_path): if charcol_type == TYPE_DIRECTORY: # charcol_path is actually a directory here return CharacterCollection.from_character_directory(charcol_path) elif charcol_type in (TYPE_CHARCOL, TYPE_CHARCOL_DB): return CharacterCollection(charcol_path) elif charcol_type == TYPE_TOMOE: return tomoe_dict_to_character_collection(charcol_path) elif charcol_type == TYPE_KUCHIBUE: return kuchibue_to_character_collection(charcol_path)
def testAddSame(self): path = os.path.join(self.currdir, "data", "collection", "test.charcol") charcol = CharacterCollection() charcol.read(path) charcol2 = CharacterCollection() charcol2.read(path) charcol3 = charcol.concatenate(charcol2, check_duplicate=True) self.assertEquals(charcol3.get_set_list(), ["一", "三", "二", "四"]) self.assertEquals(len(charcol3.get_characters("一")), 3) self.assertEquals(len(charcol3.get_characters("三")), 2) self.assertEquals(len(charcol3.get_characters("二")), 1) self.assertEquals(len(charcol3.get_characters("四")), 0)
def group_best(): best_1a = ['aga', 'apa', 'apb', 'app', 'art', 'ced', 'gmd', 'ibm', 'ipm', 'pri', 'syn', 'uqb', 'val'] best_1b = ['aga', 'apa', 'apb', 'app', 'art', 'ced', 'ibm', 'kai', 'lou', 'pri', 'syn', 'tos', 'upb', 'val', 'cea', 'ceb'] best_1c = ['aga', 'apa', 'apb', 'app', 'art', 'cea', 'ceb', 'ced', 'gmd', 'ibm', 'imp', 'kai', 'lav', 'lou', 'mot', 'pri', 'sie', 'syn', 'tos', 'val'] best_dbs = best_1c charcols = [] for db_name in best_dbs: db_file = 'unipen_db/' + db_name + '.chardb' charcol = CharacterCollection(db_file) # print db_name, charcol.get_total_n_characters() charcols.append(charcol) charcol_best = CharacterCollection() charcol_best.merge(charcols) print charcol_best.get_total_n_characters() charcol_best.save("unipen_db/best_1c.chardb") # unipen_to_sqlite() # group_unipen_db() # group_best()
def get_character_collection(self): charcol = CharacterCollection() # group characters with the same label into sets sets = {} for i in range(len(self._labels)): # Create Character writing = Writing() if self.height and self.width: writing.set_height(self.height) writing.set_width(self.width) for delin_range in self._delineations[i]: if delin_range.start_comp == (delin_range.end_comp - 1): stroke_points = self._strokes[delin_range.start_comp][delin_range.start_point:delin_range.end_point] writing.append_stroke(Stroke.from_list(stroke_points)) else: # add first stroke to writing start_stroke_points = self._strokes[delin_range.start_comp][delin_range.start_point:-1] if len(start_stroke_points) > 0: writing.append_stroke(Stroke.from_list(start_stroke_points)) # add last stroke to writing end_stroke_points = self._strokes[delin_range.end_comp - 1][0:delin_range.end_point] if len(end_stroke_points) > 0: writing.append_stroke(Stroke.from_list(end_stroke_points)) # add the remaining strokes to writing for stroke in self._strokes[delin_range.start_comp + 1:delin_range.end_comp - 1]: writing.append_stroke(stroke) character = Character() character.set_writing(writing) utf8 = self._labels[i] character.set_utf8(utf8) sets[utf8] = sets.get(utf8, []) + [character] charcol.add_sets(sets.keys()) for set_name, characters in sets.items(): charcol.append_characters(set_name, characters) return charcol
def get_aggregated_charcol(tuples, dbpath=None): """ Create a character collection out of other character collections, character directories, tomoe dictionaries or kuchibue databases. tuples: a list of tuples (TYPE, path list) """ # number of files for each character collection type n_files = [len(t[1]) for t in tuples] # we don't need to merge character collections if only one is provided # this can save a lot of time for large collections if sum(n_files) == 1 and dbpath is None: idx = n_files.index(1) return _get_charcol(tuples[idx][0], tuples[idx][1][0]) if dbpath is not None and dbpath.endswith(".chardb"): if os.path.exists(dbpath): print "%s exists already." % dbpath print "Continuing will modify it..." answer = raw_input("Continue anyway? (y/N)") if answer == "y": print "Overwrite to concatenate collections together " + \ "in a new database" print "Don't overwrite to append new characters or " + \ "filter (-i,-e,-m) existing database" answer = raw_input("Overwrite it? (y/N)") if answer == "y": os.unlink(dbpath) else: exit() charcol = CharacterCollection(dbpath) #charcol.WRITE_BACK = False #charcol.AUTO_COMMIT = True else: charcol = CharacterCollection() # in memory db charcols = [_get_charcol(typ, path) \ for typ, paths in tuples for path in paths] charcol.merge(charcols) return charcol
def testFromCharDirRecursive(self): directory = os.path.join(self.currdir, "data") charcol = CharacterCollection.from_character_directory(directory, check_duplicate=True) self.assertEquals(charcol.get_set_list(), ["防", "三", "一", "二"]) self.assertEquals(len(charcol.get_characters("一")), 3) self.assertEquals(len(charcol.get_characters("三")), 2) self.assertEquals(len(charcol.get_characters("二")), 1) self.assertEquals(len(charcol.get_characters("防")), 1)
def testFromCharDirRecursive(self): directory = os.path.join(self.currdir, "data") charcol = CharacterCollection.from_character_directory( directory, check_duplicate=True) self.assertEquals(sorted(charcol.get_set_list()), sorted(["yo", "防", "三", "一", "二"])) self.assertEquals(len(charcol.get_characters("一")), 3) self.assertEquals(len(charcol.get_characters("三")), 2) self.assertEquals(len(charcol.get_characters("二")), 1) self.assertEquals(len(charcol.get_characters("防")), 1)
def testAdd(self): path = os.path.join(self.currdir, "data", "collection", "test.charcol") charcol = CharacterCollection() charcol.read(path) path2 = os.path.join(self.currdir, "data", "collection", "test2.charcol") charcol2 = CharacterCollection() charcol2.read(path2) charcol3 = charcol + charcol2 self.assertEquals(charcol3.get_set_list(), ["一", "三", "二", "四", "a", "b", "c", "d"]) self.assertEquals(len(charcol3.get_characters("一")), 3) self.assertEquals(len(charcol3.get_characters("三")), 2) self.assertEquals(len(charcol3.get_characters("二")), 1) self.assertEquals(len(charcol3.get_characters("四")), 0) self.assertEquals(len(charcol3.get_characters("a")), 3) self.assertEquals(len(charcol3.get_characters("b")), 2) self.assertEquals(len(charcol3.get_characters("c")), 1) self.assertEquals(len(charcol3.get_characters("d")), 0)
def get_character_collection(self): charcol = CharacterCollection() assert(len(self._labels) == len(self._characters)) # group characters with the same label into sets sets = {} for i in range(len(self._characters)): utf8 = self._labels[i] self._characters[i].set_utf8(utf8) sets[utf8] = sets.get(utf8, []) + [self._characters[i]] charcol.add_sets(sets.keys()) for set_name, characters in sets.items(): charcol.append_characters(set_name, characters) return charcol
def __init__(self): self._charcol = CharacterCollection()
class CharacterCollectionTest(unittest.TestCase): def setUp(self): self.currdir = os.path.dirname(os.path.abspath(__file__)) path = os.path.join(self.currdir, "data", "collection", "test.charcol") self.cc = CharacterCollection() self.cc.read(path) f = os.path.join(self.currdir, "data", "character.xml") self.c = Character() self.c.read(f) def testValidate(self): path = os.path.join(self.currdir, "data", "collection", "test.charcol") f = open(path) buf = f.read() f.close() invalid = \ """ <?xml version="1.0" encoding="UTF-8"?> <character> <utf8>防</utf8> <strokes> <stroke> </stroke> </strokes> </character> """ malformed = \ """ <?xml version="1.0" encoding="UTF-8"?> <character> """ try: self.assertTrue(CharacterCollection.validate(buf)) self.assertFalse(CharacterCollection.validate(invalid)) self.assertFalse(CharacterCollection.validate(malformed)) except NotImplementedError: sys.stderr.write("lxml missing!\n") pass def _testReadXML(self, charcol): self.assertEquals(charcol.get_set_list(), ["一", "三", "二", "四"]) c = {} for k in [ "19968_1", "19968_2", "19968_3", "19977_1", "19977_2", "20108_1" ]: c[k] = Character() c[k].read( os.path.join(self.currdir, "data", "collection", k + ".xml")) self.assertEquals(charcol.get_characters("一"), [c["19968_1"], c["19968_2"], c["19968_3"]]) self.assertEquals(charcol.get_characters("三"), [c["19977_1"], c["19977_2"]]) self.assertEquals(charcol.get_characters("二"), [c["20108_1"]]) self.assertEquals(charcol.get_characters("四"), []) self.assertEquals(charcol.get_all_characters(), [ c["19968_1"], c["19968_2"], c["19968_3"], c["19977_1"], c["19977_2"], c["20108_1"] ]) def testReadXMLFile(self): self._testReadXML(self.cc) def testToXML(self): charcol2 = CharacterCollection() charcol2.read_string(self.cc.to_xml()) self.assertEquals(self.cc.get_set_list(), charcol2.get_set_list()) self.assertEquals(self.cc.get_all_characters(), charcol2.get_all_characters()) def testWriteGzipString(self): charcol2 = CharacterCollection() charcol2.read_string(self.cc.write_string(gzip=True), gzip=True) self.assertEquals(self.cc.get_set_list(), charcol2.get_set_list()) self.assertEquals(self.cc.get_all_characters(), charcol2.get_all_characters()) def testWriteBz2String(self): charcol2 = CharacterCollection() charcol2.read_string(self.cc.write_string(bz2=True), bz2=True) self.assertEquals(self.cc.get_set_list(), charcol2.get_set_list()) self.assertEquals(self.cc.get_all_characters(), charcol2.get_all_characters()) def testAddSame(self): path = os.path.join(self.currdir, "data", "collection", "test.charcol") charcol = CharacterCollection() charcol.read(path) charcol2 = CharacterCollection() charcol2.read(path) charcol3 = charcol.concatenate(charcol2, check_duplicate=True) self.assertEquals(charcol3.get_set_list(), ["一", "三", "二", "四"]) self.assertEquals(len(charcol3.get_characters("一")), 3) self.assertEquals(len(charcol3.get_characters("三")), 2) self.assertEquals(len(charcol3.get_characters("二")), 1) self.assertEquals(len(charcol3.get_characters("四")), 0) def testGetChars(self): all_ = self.cc.get_characters("一") self.assertEquals(self.cc.get_characters("一", limit=2), all_[0:2]) self.assertEquals(self.cc.get_characters("一", offset=2), all_[2:]) self.assertEquals(self.cc.get_characters("一", limit=1, offset=1), all_[1:2]) def testAdd(self): path = os.path.join(self.currdir, "data", "collection", "test.charcol") charcol = CharacterCollection() charcol.read(path) path2 = os.path.join(self.currdir, "data", "collection", "test2.charcol") charcol2 = CharacterCollection() charcol2.read(path2) charcol3 = charcol + charcol2 self.assertEquals(charcol3.get_set_list(), ["一", "三", "二", "四", "a", "b", "c", "d"]) self.assertEquals(len(charcol3.get_characters("一")), 3) self.assertEquals(len(charcol3.get_characters("三")), 2) self.assertEquals(len(charcol3.get_characters("二")), 1) self.assertEquals(len(charcol3.get_characters("四")), 0) self.assertEquals(len(charcol3.get_characters("a")), 3) self.assertEquals(len(charcol3.get_characters("b")), 2) self.assertEquals(len(charcol3.get_characters("c")), 1) self.assertEquals(len(charcol3.get_characters("d")), 0) def testFromCharDirRecursive(self): directory = os.path.join(self.currdir, "data") charcol = CharacterCollection.from_character_directory( directory, check_duplicate=True) self.assertEquals(sorted(charcol.get_set_list()), sorted(["yo", "防", "三", "一", "二"])) self.assertEquals(len(charcol.get_characters("一")), 3) self.assertEquals(len(charcol.get_characters("三")), 2) self.assertEquals(len(charcol.get_characters("二")), 1) self.assertEquals(len(charcol.get_characters("防")), 1) def testFromCharDirNotRecursive(self): directory = os.path.join(self.currdir, "data") charcol = CharacterCollection.from_character_directory( directory, recursive=False, check_duplicate=True) self.assertEquals(charcol.get_set_list(), ["防"]) self.assertEquals(len(charcol.get_characters("防")), 1) def testIncludeChars(self): self.cc.include_characters_from_text("一三") self.assertEquals(self.cc.get_set_list(), ["一", "三"]) def testExcludeChars(self): self.cc.exclude_characters_from_text("三") self.assertEquals(self.cc.get_set_list(), ["一", "二"]) def testProxy(self): char = self.cc.get_all_characters()[0] writing = char.get_writing() writing.normalize() strokes = writing.get_strokes(full=True) stroke = strokes[0] stroke.smooth() p = stroke[0] p.x = 10 char2 = self.cc.get_all_characters()[0] self.assertEquals(char, char2) def testNoProxy(self): self.cc.WRITE_BACK = False char = self.cc.get_all_characters()[0] writing = char.get_writing() writing.normalize() strokes = writing.get_strokes(full=True) stroke = strokes[0] stroke.smooth() p = stroke[0] p.x = 10 char2 = self.cc.get_all_characters()[0] self.assertNotEqual(char, char2) # manually update the object self.cc.update_character_object(char) char2 = self.cc.get_all_characters()[0] self.assertEquals(char, char2) def testAddSet(self): self.cc.add_set("toto") self.assertEquals(self.cc.get_set_list()[-1], "toto") def testRemoveSet(self): before = self.cc.get_set_list() self.cc.remove_set(before[-1]) after = self.cc.get_set_list() self.assertEquals(len(before) - 1, len(after)) self.assertEquals(before[0:-1], after) def testGetNSets(self): self.assertEquals(len(self.cc.get_set_list()), self.cc.get_n_sets()) self.assertEquals(4, self.cc.get_n_sets()) def testGetTotalNCharacters(self): self.assertEquals(len(self.cc.get_all_characters()), self.cc.get_total_n_characters()) self.assertEquals(6, self.cc.get_total_n_characters()) def testGetNCharacters(self): for set_name in self.cc.get_set_list(): self.assertEquals(len(self.cc.get_characters(set_name)), self.cc.get_n_characters(set_name)) self.assertEquals(self.cc.get_n_characters("一"), 3) self.assertEquals(self.cc.get_n_characters("三"), 2) self.assertEquals(self.cc.get_n_characters("二"), 1) def testSetCharacters(self): before = self.cc.get_characters("一")[0:2] self.cc.set_characters("一", before) after = self.cc.get_characters("一") self.assertEquals(before, after) def testAppendCharacter(self): len_before = len(self.cc.get_characters("一")) self.cc.append_character("一", self.c) len_after = len(self.cc.get_characters("一")) self.assertEquals(len_before + 1, len_after) def testInsertCharacter(self): before = self.cc.get_characters("一")[0] len_before = len(self.cc.get_characters("一")) self.cc.insert_character("一", 0, self.c) after = self.cc.get_characters("一")[0] self.assertNotEqual(before, after) len_after = len(self.cc.get_characters("一")) self.assertEqual(len_before + 1, len_after) def testReplaceCharacter(self): before = self.cc.get_characters("一")[0] len_before = len(self.cc.get_characters("一")) self.cc.replace_character("一", 0, self.c) after = self.cc.get_characters("一")[0] self.assertNotEqual(before, after) len_after = len(self.cc.get_characters("一")) self.assertEqual(len_before, len_after) def testRemoveCharacter(self): before = self.cc.get_characters("一")[0] len_before = len(self.cc.get_characters("一")) self.cc.remove_character("一", 0) after = self.cc.get_characters("一")[0] self.assertNotEqual(before, after) len_after = len(self.cc.get_characters("一")) self.assertEqual(len_before - 1, len_after) def testRemoveLastCharacter(self): before = self.cc.get_characters("一")[-1] len_before = len(self.cc.get_characters("一")) self.cc.remove_last_character("一") after = self.cc.get_characters("一")[-1] self.assertNotEqual(before, after) len_after = len(self.cc.get_characters("一")) self.assertEqual(len_before - 1, len_after) def testRemoveSamples(self): self.cc.remove_samples(keep_at_most=2) self.assertEquals(self.cc.get_n_characters("一"), 2) self.assertEquals(self.cc.get_n_characters("三"), 2) self.assertEquals(self.cc.get_n_characters("二"), 1) self.cc.remove_samples(keep_at_most=1) self.assertEquals(self.cc.get_n_characters("一"), 1) self.assertEquals(self.cc.get_n_characters("三"), 1) self.assertEquals(self.cc.get_n_characters("二"), 1) def testRemoveEmptySets(self): self.cc.remove_empty_sets() self.assertEquals(self.cc.get_set_list(), ["一", "三", "二"])
def testToXML(self): charcol2 = CharacterCollection() charcol2.read_string(self.cc.to_xml()) self.assertEquals(self.cc.get_set_list(), charcol2.get_set_list()) self.assertEquals(self.cc.get_all_characters(), charcol2.get_all_characters())
def testWriteBz2String(self): charcol2 = CharacterCollection() charcol2.read_string(self.cc.write_string(bz2=True), bz2=True) self.assertEquals(self.cc.get_set_list(), charcol2.get_set_list()) self.assertEquals(self.cc.get_all_characters(), charcol2.get_all_characters())
class TomoeXmlDictionaryReader(_XmlBase): def __init__(self): self._charcol = CharacterCollection() def get_character_collection(self): return self._charcol def _start_element(self, name, attrs): self._tag = name if self._first_tag: self._first_tag = False if self._tag != "dictionary": raise ValueError, "The very first tag should be <dictionary>" if self._tag == "character": self._writing = Writing() if self._tag == "stroke": self._stroke = Stroke() elif self._tag == "point": point = Point() for key in ("x", "y", "pressure", "xtilt", "ytilt", "timestamp"): if attrs.has_key(key): value = attrs[key].encode("UTF-8") if key in ("pressure", "xtilt", "ytilt"): value = float(value) else: value = int(float(value)) else: value = None setattr(point, key, value) self._stroke.append_point(point) def _end_element(self, name): if name == "character": char = Character() char.set_utf8(self._utf8) char.set_writing(self._writing) self._charcol.add_set(self._utf8) self._charcol.append_character(self._utf8, char) for s in ["_tag", "_stroke"]: if s in self.__dict__: del self.__dict__[s] if name == "stroke": self._writing.append_stroke(self._stroke) self._stroke = None self._tag = None def _char_data(self, data): if self._tag == "utf8": self._utf8 = data.encode("UTF-8") elif self._tag == "width": self._writing.set_width(int(data)) elif self._tag == "height": self._writing.set_height(int(data))
class CharacterCollectionTest(unittest.TestCase): def setUp(self): self.currdir = os.path.dirname(os.path.abspath(__file__)) path = os.path.join(self.currdir, "data", "collection", "test.charcol") self.cc = CharacterCollection() self.cc.read(path) f = os.path.join(self.currdir, "data", "character.xml") self.c = Character() self.c.read(f) def testValidate(self): path = os.path.join(self.currdir, "data", "collection", "test.charcol") f = open(path) buf = f.read() f.close() invalid = \ """ <?xml version="1.0" encoding="UTF-8"?> <character> <utf8>防</utf8> <strokes> <stroke> </stroke> </strokes> </character> """ malformed = \ """ <?xml version="1.0" encoding="UTF-8"?> <character> """ try: self.assertTrue(CharacterCollection.validate(buf)) self.assertFalse(CharacterCollection.validate(invalid)) self.assertFalse(CharacterCollection.validate(malformed)) except NotImplementedError: sys.stderr.write("lxml missing!\n") pass def _testReadXML(self, charcol): self.assertEquals(charcol.get_set_list(), ["一", "三", "二", "四"]) c = {} for k in ["19968_1", "19968_2", "19968_3", "19977_1", "19977_2", "20108_1"]: c[k] = Character() c[k].read(os.path.join(self.currdir, "data", "collection", k + ".xml")) self.assertEquals(charcol.get_characters("一"), [c["19968_1"], c["19968_2"], c["19968_3"]]) self.assertEquals(charcol.get_characters("三"), [c["19977_1"], c["19977_2"]]) self.assertEquals(charcol.get_characters("二"), [c["20108_1"]]) self.assertEquals(charcol.get_characters("四"), []) self.assertEquals(charcol.get_all_characters(), [c["19968_1"], c["19968_2"], c["19968_3"], c["19977_1"], c["19977_2"], c["20108_1"]]) def testReadXMLFile(self): self._testReadXML(self.cc) def testToXML(self): charcol2 = CharacterCollection() charcol2.read_string(self.cc.to_xml()) self.assertEquals(self.cc.get_set_list(), charcol2.get_set_list()) self.assertEquals(self.cc.get_all_characters(), charcol2.get_all_characters()) def testWriteGzipString(self): charcol2 = CharacterCollection() charcol2.read_string(self.cc.write_string(gzip=True), gzip=True) self.assertEquals(self.cc.get_set_list(), charcol2.get_set_list()) self.assertEquals(self.cc.get_all_characters(), charcol2.get_all_characters()) def testWriteBz2String(self): charcol2 = CharacterCollection() charcol2.read_string(self.cc.write_string(bz2=True), bz2=True) self.assertEquals(self.cc.get_set_list(), charcol2.get_set_list()) self.assertEquals(self.cc.get_all_characters(), charcol2.get_all_characters()) def testAddSame(self): path = os.path.join(self.currdir, "data", "collection", "test.charcol") charcol = CharacterCollection() charcol.read(path) charcol2 = CharacterCollection() charcol2.read(path) charcol3 = charcol.concatenate(charcol2, check_duplicate=True) self.assertEquals(charcol3.get_set_list(), ["一", "三", "二", "四"]) self.assertEquals(len(charcol3.get_characters("一")), 3) self.assertEquals(len(charcol3.get_characters("三")), 2) self.assertEquals(len(charcol3.get_characters("二")), 1) self.assertEquals(len(charcol3.get_characters("四")), 0) def testGetChars(self): all_ = self.cc.get_characters("一") self.assertEquals(self.cc.get_characters("一", limit=2), all_[0:2]) self.assertEquals(self.cc.get_characters("一", offset=2), all_[2:]) self.assertEquals(self.cc.get_characters("一", limit=1, offset=1), all_[1:2]) def testAdd(self): path = os.path.join(self.currdir, "data", "collection", "test.charcol") charcol = CharacterCollection() charcol.read(path) path2 = os.path.join(self.currdir, "data", "collection", "test2.charcol") charcol2 = CharacterCollection() charcol2.read(path2) charcol3 = charcol + charcol2 self.assertEquals(charcol3.get_set_list(), ["一", "三", "二", "四", "a", "b", "c", "d"]) self.assertEquals(len(charcol3.get_characters("一")), 3) self.assertEquals(len(charcol3.get_characters("三")), 2) self.assertEquals(len(charcol3.get_characters("二")), 1) self.assertEquals(len(charcol3.get_characters("四")), 0) self.assertEquals(len(charcol3.get_characters("a")), 3) self.assertEquals(len(charcol3.get_characters("b")), 2) self.assertEquals(len(charcol3.get_characters("c")), 1) self.assertEquals(len(charcol3.get_characters("d")), 0) def testFromCharDirRecursive(self): directory = os.path.join(self.currdir, "data") charcol = CharacterCollection.from_character_directory(directory, check_duplicate=True) self.assertEquals(sorted(charcol.get_set_list()), sorted(["yo", "防", "三", "一", "二"])) self.assertEquals(len(charcol.get_characters("一")), 3) self.assertEquals(len(charcol.get_characters("三")), 2) self.assertEquals(len(charcol.get_characters("二")), 1) self.assertEquals(len(charcol.get_characters("防")), 1) def testFromCharDirNotRecursive(self): directory = os.path.join(self.currdir, "data") charcol = CharacterCollection.from_character_directory(directory, recursive=False, check_duplicate=True) self.assertEquals(charcol.get_set_list(), ["防"]) self.assertEquals(len(charcol.get_characters("防")), 1) def testIncludeChars(self): self.cc.include_characters_from_text("一三") self.assertEquals(self.cc.get_set_list(), ["一", "三"]) def testExcludeChars(self): self.cc.exclude_characters_from_text("三") self.assertEquals(self.cc.get_set_list(), ["一", "二"]) def testProxy(self): char = self.cc.get_all_characters()[0] writing = char.get_writing() writing.normalize() strokes = writing.get_strokes(full=True) stroke = strokes[0] stroke.smooth() p = stroke[0] p.x = 10 char2 = self.cc.get_all_characters()[0] self.assertEquals(char, char2) def testNoProxy(self): self.cc.WRITE_BACK = False char = self.cc.get_all_characters()[0] writing = char.get_writing() writing.normalize() strokes = writing.get_strokes(full=True) stroke = strokes[0] stroke.smooth() p = stroke[0] p.x = 10 char2 = self.cc.get_all_characters()[0] self.assertNotEqual(char, char2) # manually update the object self.cc.update_character_object(char) char2 = self.cc.get_all_characters()[0] self.assertEquals(char, char2) def testAddSet(self): self.cc.add_set("toto") self.assertEquals(self.cc.get_set_list()[-1], "toto") def testRemoveSet(self): before = self.cc.get_set_list() self.cc.remove_set(before[-1]) after = self.cc.get_set_list() self.assertEquals(len(before)-1, len(after)) self.assertEquals(before[0:-1], after) def testGetNSets(self): self.assertEquals(len(self.cc.get_set_list()), self.cc.get_n_sets()) self.assertEquals(4, self.cc.get_n_sets()) def testGetTotalNCharacters(self): self.assertEquals(len(self.cc.get_all_characters()), self.cc.get_total_n_characters()) self.assertEquals(6, self.cc.get_total_n_characters()) def testGetNCharacters(self): for set_name in self.cc.get_set_list(): self.assertEquals(len(self.cc.get_characters(set_name)), self.cc.get_n_characters(set_name)) self.assertEquals(self.cc.get_n_characters("一"), 3) self.assertEquals(self.cc.get_n_characters("三"), 2) self.assertEquals(self.cc.get_n_characters("二"), 1) def testSetCharacters(self): before = self.cc.get_characters("一")[0:2] self.cc.set_characters("一", before) after = self.cc.get_characters("一") self.assertEquals(before, after) def testAppendCharacter(self): len_before = len(self.cc.get_characters("一")) self.cc.append_character("一", self.c) len_after = len(self.cc.get_characters("一")) self.assertEquals(len_before + 1, len_after) def testInsertCharacter(self): before = self.cc.get_characters("一")[0] len_before = len(self.cc.get_characters("一")) self.cc.insert_character("一", 0, self.c) after = self.cc.get_characters("一")[0] self.assertNotEqual(before, after) len_after = len(self.cc.get_characters("一")) self.assertEqual(len_before+1, len_after) def testReplaceCharacter(self): before = self.cc.get_characters("一")[0] len_before = len(self.cc.get_characters("一")) self.cc.replace_character("一", 0, self.c) after = self.cc.get_characters("一")[0] self.assertNotEqual(before, after) len_after = len(self.cc.get_characters("一")) self.assertEqual(len_before, len_after) def testRemoveCharacter(self): before = self.cc.get_characters("一")[0] len_before = len(self.cc.get_characters("一")) self.cc.remove_character("一", 0) after = self.cc.get_characters("一")[0] self.assertNotEqual(before, after) len_after = len(self.cc.get_characters("一")) self.assertEqual(len_before-1, len_after) def testRemoveLastCharacter(self): before = self.cc.get_characters("一")[-1] len_before = len(self.cc.get_characters("一")) self.cc.remove_last_character("一") after = self.cc.get_characters("一")[-1] self.assertNotEqual(before, after) len_after = len(self.cc.get_characters("一")) self.assertEqual(len_before-1, len_after) def testRemoveSamples(self): self.cc.remove_samples(keep_at_most=2) self.assertEquals(self.cc.get_n_characters("一"), 2) self.assertEquals(self.cc.get_n_characters("三"), 2) self.assertEquals(self.cc.get_n_characters("二"), 1) self.cc.remove_samples(keep_at_most=1) self.assertEquals(self.cc.get_n_characters("一"), 1) self.assertEquals(self.cc.get_n_characters("三"), 1) self.assertEquals(self.cc.get_n_characters("二"), 1) def testRemoveEmptySets(self): self.cc.remove_empty_sets() self.assertEquals(self.cc.get_set_list(), ["一", "三", "二"])
def testFromCharDirNotRecursive(self): directory = os.path.join(self.currdir, "data") charcol = CharacterCollection.from_character_directory( directory, recursive=False, check_duplicate=True) self.assertEquals(charcol.get_set_list(), ["防"]) self.assertEquals(len(charcol.get_characters("防")), 1)