Example #1
0
    def testMlPerfPackedInputPadToMax(self):
        p = self._CreateMlPerfPackedInputParams()
        p.source_max_length = 300
        p.target_max_length = 300
        p.pad_to_max_seq_length = True
        with self.session(use_gpu=False) as sess:
            inp = input_generator.MlPerfInput(p)
            for _ in range(1):
                fetched = py_utils.NestedMap(
                    sess.run(inp.GetPreprocessedInputBatch()))

        self.checkPadShape(fetched.src.ids,
                           pad=0,
                           batch_size=4,
                           actual_max=240,
                           pad_length=300)

        self.checkPadShape(fetched.tgt.ids,
                           pad=0,
                           batch_size=4,
                           actual_max=240,
                           pad_length=300)

        self.checkPadShape(fetched.tgt.segment_ids,
                           pad=0,
                           batch_size=4,
                           actual_max=240,
                           pad_length=300)

        self.checkPadShape(fetched.tgt.segment_pos,
                           pad=0,
                           batch_size=4,
                           actual_max=240,
                           pad_length=300)
Example #2
0
    def testMlPerfPadToMax(self):
        p = self._CreateMlPerfInputParams()
        p.bucket_upper_bound = [20]
        p.bucket_batch_limit = [4]
        p.source_max_length = 30
        p.target_max_length = 30
        p.pad_to_max_seq_length = True

        with self.session(use_gpu=False) as sess:
            inp = input_generator.MlPerfInput(p)
            # Runs a few steps.
            for _ in range(10):
                fetched = py_utils.NestedMap(
                    sess.run(inp.GetPreprocessedInputBatch()))

        def Check(x, pad):
            # Check the shape: (batch, maxlen)
            self.assertEqual(x.shape, (4, 30))
            # Check the padding.
            self.assertAllEqual(x[:, 20:], np.full((4, 10), pad))

        Check(fetched.src.ids, 0)
        Check(fetched.src.paddings, 1)
        Check(fetched.tgt.ids, 0)
        Check(fetched.tgt.labels, 0)
        Check(fetched.tgt.weights, 0)
        Check(fetched.tgt.paddings, 1)
 def testMlPerf(self):
   p = self._CreateMlPerfInputParams()
   with self.session(use_gpu=False) as sess:
     inp = input_generator.MlPerfInput(p)
     # Runs a few steps.
     for _ in range(10):
       fetched = py_utils.NestedMap(sess.run(inp.GetPreprocessedInputBatch()))
       tf.logging.info(fetched)
Example #4
0
 def testMlPerfPackedInput(self):
     p = self._CreateMlPerfPackedInputParams()
     with self.session(use_gpu=False) as sess:
         inp = input_generator.MlPerfInput(p)
         for _ in range(1):
             fetched = py_utils.NestedMap(
                 sess.run(inp.GetPreprocessedInputBatch()))
             tf.logging.info(fetched.src.ids.shape)
             tf.logging.info(fetched.src.segment_ids.shape)
             tf.logging.info(fetched.src.segment_pos.shape)
             tf.logging.info(fetched.tgt.segment_ids.shape)
             tf.logging.info(fetched.tgt.segment_pos.shape)