def testMlPerfPackedInputPadToMax(self): p = self._CreateMlPerfPackedInputParams() p.source_max_length = 300 p.target_max_length = 300 p.pad_to_max_seq_length = True with self.session(use_gpu=False) as sess: inp = input_generator.MlPerfInput(p) for _ in range(1): fetched = py_utils.NestedMap( sess.run(inp.GetPreprocessedInputBatch())) self.checkPadShape(fetched.src.ids, pad=0, batch_size=4, actual_max=240, pad_length=300) self.checkPadShape(fetched.tgt.ids, pad=0, batch_size=4, actual_max=240, pad_length=300) self.checkPadShape(fetched.tgt.segment_ids, pad=0, batch_size=4, actual_max=240, pad_length=300) self.checkPadShape(fetched.tgt.segment_pos, pad=0, batch_size=4, actual_max=240, pad_length=300)
def testMlPerfPadToMax(self): p = self._CreateMlPerfInputParams() p.bucket_upper_bound = [20] p.bucket_batch_limit = [4] p.source_max_length = 30 p.target_max_length = 30 p.pad_to_max_seq_length = True with self.session(use_gpu=False) as sess: inp = input_generator.MlPerfInput(p) # Runs a few steps. for _ in range(10): fetched = py_utils.NestedMap( sess.run(inp.GetPreprocessedInputBatch())) def Check(x, pad): # Check the shape: (batch, maxlen) self.assertEqual(x.shape, (4, 30)) # Check the padding. self.assertAllEqual(x[:, 20:], np.full((4, 10), pad)) Check(fetched.src.ids, 0) Check(fetched.src.paddings, 1) Check(fetched.tgt.ids, 0) Check(fetched.tgt.labels, 0) Check(fetched.tgt.weights, 0) Check(fetched.tgt.paddings, 1)
def testMlPerf(self): p = self._CreateMlPerfInputParams() with self.session(use_gpu=False) as sess: inp = input_generator.MlPerfInput(p) # Runs a few steps. for _ in range(10): fetched = py_utils.NestedMap(sess.run(inp.GetPreprocessedInputBatch())) tf.logging.info(fetched)
def testMlPerfPackedInput(self): p = self._CreateMlPerfPackedInputParams() with self.session(use_gpu=False) as sess: inp = input_generator.MlPerfInput(p) for _ in range(1): fetched = py_utils.NestedMap( sess.run(inp.GetPreprocessedInputBatch())) tf.logging.info(fetched.src.ids.shape) tf.logging.info(fetched.src.segment_ids.shape) tf.logging.info(fetched.src.segment_pos.shape) tf.logging.info(fetched.tgt.segment_ids.shape) tf.logging.info(fetched.tgt.segment_pos.shape)