def testSourceTargetValues(self): max_length = 50 p = self._CreatePunctuatorInputParams() with self.session(use_gpu=False) as sess: inp = input_generator.PunctuatorInput(p) tokenizer = inp.tokenizer_dict['default'] fetched = py_utils.NestedMap(sess.run(inp.GetPreprocessedInputBatch())) source_ids = fetched.src.ids tgt_ids = fetched.tgt.ids tgt_labels = fetched.tgt.labels expected_ref = ('His approach was inquisitive , a meeting of artful ' 'hesitation with fluid technique .') normalized_ref = expected_ref.lower().translate(None, string.punctuation) normalized_ref = ' '.join(normalized_ref.split()) _, expected_src_ids, _ = sess.run( tokenizer.StringsToIds( tf.convert_to_tensor([normalized_ref]), max_length=max_length)) expected_tgt_ids, expected_tgt_labels, _ = sess.run( tokenizer.StringsToIds( tf.convert_to_tensor([expected_ref]), max_length=max_length)) self.assertAllEqual(expected_src_ids[0], source_ids[0, :max_length]) self.assertAllEqual(expected_tgt_ids[0], tgt_ids[0, :max_length]) self.assertAllEqual(expected_tgt_labels[0], tgt_labels[0, :max_length])
def testSourceTargetValues(self): max_length = 50 p = self._CreatePunctuatorInputParams() with self.session(use_gpu=False): inp = input_generator.PunctuatorInput(p) fetched = py_utils.NestedMap( self.evaluate(inp.GetPreprocessedInputBatch())) source_ids = fetched.src.ids tgt_ids = fetched.tgt.ids tgt_labels = fetched.tgt.labels expected_ref = ( b'Elk calling -- a skill that hunters perfected long ago to lure ' b'game with the promise of a little romance -- is now its own sport .' ) normalized_ref = expected_ref.lower().translate( None, string.punctuation.encode('utf-8')) normalized_ref = b' '.join(normalized_ref.split()) _, expected_src_ids, _ = self.evaluate( inp.tokenizer.StringsToIds(tf.convert_to_tensor( [normalized_ref]), max_length=max_length)) expected_tgt_ids, expected_tgt_labels, _ = self.evaluate( inp.tokenizer.StringsToIds(tf.convert_to_tensor([expected_ref ]), max_length=max_length)) self.assertAllEqual(expected_src_ids[0], source_ids[0, :max_length]) self.assertAllEqual(expected_tgt_ids[0], tgt_ids[0, :max_length]) self.assertAllEqual(expected_tgt_labels[0], tgt_labels[0, :max_length])
def testBasic(self): p = self._CreatePunctuatorInputParams() with self.session(use_gpu=False) as sess: inp = input_generator.PunctuatorInput(p) # Runs a few steps. for _ in range(10): sess.run(inp.GetPreprocessedInputBatch())
def testSourceTargetValues(self): max_length = 50 p = self._CreatePunctuatorInputParams() with self.session(use_gpu=False) as sess: inp = input_generator.PunctuatorInput(p) tokenizer = inp.tokenizer_dict['default'] fetched = py_utils.NestedMap( sess.run(inp.GetPreprocessedInputBatch())) source_ids = fetched.src.ids tgt_ids = fetched.tgt.ids tgt_labels = fetched.tgt.labels expected_ref = 'The internet is sort-of-40 this year .' # "the internet is sortof40 this year" - lower-case, no dashes, no dot. normalized_ref = expected_ref.lower().translate( None, string.punctuation) expected_src_ids, _, _ = sess.run( tokenizer.StringsToIds([normalized_ref], max_length=max_length)) expected_tgt_ids, expected_tgt_labels, _ = sess.run( tokenizer.StringsToIds([expected_ref], max_length=max_length)) self.assertAllEqual(expected_src_ids[0], source_ids[0, :max_length]) self.assertAllEqual(expected_tgt_ids[0], tgt_ids[0, :max_length]) self.assertAllEqual(expected_tgt_labels[0], tgt_labels[0, :max_length])