예제 #1
0
  def testSourceTargetValues(self):
    max_length = 50
    p = self._CreatePunctuatorInputParams()
    with self.session(use_gpu=False) as sess:
      inp = input_generator.PunctuatorInput(p)
      tokenizer = inp.tokenizer_dict['default']

      fetched = py_utils.NestedMap(sess.run(inp.GetPreprocessedInputBatch()))
      source_ids = fetched.src.ids
      tgt_ids = fetched.tgt.ids
      tgt_labels = fetched.tgt.labels

      expected_ref = ('His approach was inquisitive , a meeting of artful '
                      'hesitation with fluid technique .')

      normalized_ref = expected_ref.lower().translate(None, string.punctuation)
      normalized_ref = ' '.join(normalized_ref.split())
      _, expected_src_ids, _ = sess.run(
          tokenizer.StringsToIds(
              tf.convert_to_tensor([normalized_ref]), max_length=max_length))
      expected_tgt_ids, expected_tgt_labels, _ = sess.run(
          tokenizer.StringsToIds(
              tf.convert_to_tensor([expected_ref]), max_length=max_length))

      self.assertAllEqual(expected_src_ids[0], source_ids[0, :max_length])
      self.assertAllEqual(expected_tgt_ids[0], tgt_ids[0, :max_length])
      self.assertAllEqual(expected_tgt_labels[0], tgt_labels[0, :max_length])
예제 #2
0
    def testSourceTargetValues(self):
        max_length = 50
        p = self._CreatePunctuatorInputParams()
        with self.session(use_gpu=False):
            inp = input_generator.PunctuatorInput(p)

            fetched = py_utils.NestedMap(
                self.evaluate(inp.GetPreprocessedInputBatch()))
            source_ids = fetched.src.ids
            tgt_ids = fetched.tgt.ids
            tgt_labels = fetched.tgt.labels

            expected_ref = (
                b'Elk calling -- a skill that hunters perfected long ago to lure '
                b'game with the promise of a little romance -- is now its own sport .'
            )

            normalized_ref = expected_ref.lower().translate(
                None, string.punctuation.encode('utf-8'))
            normalized_ref = b' '.join(normalized_ref.split())
            _, expected_src_ids, _ = self.evaluate(
                inp.tokenizer.StringsToIds(tf.convert_to_tensor(
                    [normalized_ref]),
                                           max_length=max_length))
            expected_tgt_ids, expected_tgt_labels, _ = self.evaluate(
                inp.tokenizer.StringsToIds(tf.convert_to_tensor([expected_ref
                                                                 ]),
                                           max_length=max_length))

            self.assertAllEqual(expected_src_ids[0],
                                source_ids[0, :max_length])
            self.assertAllEqual(expected_tgt_ids[0], tgt_ids[0, :max_length])
            self.assertAllEqual(expected_tgt_labels[0],
                                tgt_labels[0, :max_length])
예제 #3
0
 def testBasic(self):
   p = self._CreatePunctuatorInputParams()
   with self.session(use_gpu=False) as sess:
     inp = input_generator.PunctuatorInput(p)
     # Runs a few steps.
     for _ in range(10):
       sess.run(inp.GetPreprocessedInputBatch())
예제 #4
0
    def testSourceTargetValues(self):
        max_length = 50
        p = self._CreatePunctuatorInputParams()
        with self.session(use_gpu=False) as sess:
            inp = input_generator.PunctuatorInput(p)
            tokenizer = inp.tokenizer_dict['default']

            fetched = py_utils.NestedMap(
                sess.run(inp.GetPreprocessedInputBatch()))
            source_ids = fetched.src.ids
            tgt_ids = fetched.tgt.ids
            tgt_labels = fetched.tgt.labels

            expected_ref = 'The internet is sort-of-40 this year .'

            # "the internet is sortof40 this year" - lower-case, no dashes, no dot.
            normalized_ref = expected_ref.lower().translate(
                None, string.punctuation)
            expected_src_ids, _, _ = sess.run(
                tokenizer.StringsToIds([normalized_ref],
                                       max_length=max_length))
            expected_tgt_ids, expected_tgt_labels, _ = sess.run(
                tokenizer.StringsToIds([expected_ref], max_length=max_length))

            self.assertAllEqual(expected_src_ids[0],
                                source_ids[0, :max_length])
            self.assertAllEqual(expected_tgt_ids[0], tgt_ids[0, :max_length])
            self.assertAllEqual(expected_tgt_labels[0],
                                tgt_labels[0, :max_length])