def _should_split(self, script_tokenized): token_script_ids = string_ops.unicode_script( ragged_string_ops.unicode_decode(script_tokenized.flat_values, "UTF-8"))[:, :1] token_script_ids_flat = token_script_ids.flat_values is_cjk = self._is_cjk(token_script_ids_flat) is_emoji = wordshape_ops.wordshape(script_tokenized.flat_values, wordshape_ops.WordShape.HAS_EMOJI) is_punct = wordshape_ops.wordshape( script_tokenized.flat_values, wordshape_ops.WordShape.IS_PUNCT_OR_SYMBOL) split_cond = is_cjk | is_emoji | is_punct return split_cond
def testValidScripts(self): inputs = [ ord("a"), 0x0411, # CYRILLIC CAPITAL LETTER BE 0x82b8, # CJK UNIFIED IDEOGRAPH-82B8 ord(",") ] with self.cached_session(): input_vector = constant_op.constant(inputs, dtypes.int32) outputs = string_ops.unicode_script(input_vector).eval() self.assertAllEqual( outputs, [ 25, # USCRIPT_LATIN (LATN) 8, # USCRIPT_CYRILLIC (CYRL) 17, # USCRIPT_HAN (HANI) 0 # USCRIPT_COMMON (ZYYY) ])
def benchmark_unicode_script(self): with session.Session(config=benchmark.benchmark_config()) as sess: chars = self._generateBenchmarkInput(1000000) script = string_ops.unicode_script(chars) self.run_op_benchmark(sess, script.op, min_iters=100)
def testInvalidScript(self): inputs = [-100, 0xffffff] with self.cached_session(): input_vector = constant_op.constant(inputs, dtypes.int32) outputs = string_ops.unicode_script(input_vector).eval() self.assertAllEqual(outputs, [-1, -1])