コード例 #1
0
 def test_predict_fp16(self):
     if context.num_gpus() >= 2:
         self.skipTest(
             'No need to test 2+ GPUs without a distribution strategy.')
     self._prepare_files_and_flags('--dtype=fp16')
     t = transformer_main.TransformerTask(FLAGS)
     t.predict()
コード例 #2
0
 def test_eval(self):
     if context.num_gpus() >= 2:
         self.skipTest(
             'No need to test 2+ GPUs without a distribution strategy.')
     if 'test_xla' in sys.argv[0]:
         self.skipTest('TODO(xla): Make this test faster under XLA.')
     self._prepare_files_and_flags()
     t = transformer_main.TransformerTask(FLAGS)
     t.eval()
コード例 #3
0
 def test_train_2_gpu(self):
     if context.num_gpus() < 2:
         self.skipTest(
             '{} GPUs are not available for this test. {} GPUs are available'
             .format(2, context.num_gpus()))
     FLAGS.distribution_strategy = 'mirrored'
     FLAGS.num_gpus = 2
     FLAGS.param_set = 'base'
     t = transformer_main.TransformerTask(FLAGS)
     t.train()
コード例 #4
0
 def test_train_static_batch(self):
     if context.num_gpus() >= 2:
         self.skipTest(
             'No need to test 2+ GPUs without a distribution strategy.')
     FLAGS.distribution_strategy = 'one_device'
     if tf.test.is_built_with_cuda():
         FLAGS.num_gpus = 1
     else:
         FLAGS.num_gpus = 0
     FLAGS.static_batch = True
     t = transformer_main.TransformerTask(FLAGS)
     t.train()
コード例 #5
0
 def test_train_no_dist_strat(self):
     if context.num_gpus() >= 2:
         self.skipTest(
             'No need to test 2+ GPUs without a distribution strategy.')
     t = transformer_main.TransformerTask(FLAGS)
     t.train()
コード例 #6
0
 def test_train_fp16(self):
     FLAGS.distribution_strategy = 'one_device'
     FLAGS.dtype = 'fp16'
     t = transformer_main.TransformerTask(FLAGS)
     t.train()
コード例 #7
0
 def test_train_1_gpu_with_dist_strat(self):
     FLAGS.distribution_strategy = 'one_device'
     t = transformer_main.TransformerTask(FLAGS)
     t.train()