def __init__(self, output_dir=None, default_flags=None, tpu=None): self.default_flags = default_flags or {} flag_methods = trainer.define_flags() super(NHNetBenchmark, self).__init__(output_dir=output_dir, default_flags=default_flags, flag_methods=flag_methods, tpu=tpu)
import os from absl import flags from absl.testing import parameterized import tensorflow as tf # pylint: disable=g-direct-tensorflow-import from tensorflow.python.distribute import combinations from tensorflow.python.distribute import strategy_combinations # pylint: enable=g-direct-tensorflow-import from official.nlp.nhnet import trainer from official.nlp.nhnet import utils FLAGS = flags.FLAGS trainer.define_flags() def all_strategy_combinations(): return combinations.combine( distribution=[ strategy_combinations.one_device_strategy, strategy_combinations.one_device_strategy_gpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.cloud_tpu_strategy, ], mode="eager", ) def get_trivial_data(config) -> tf.data.Dataset: