def initialize_common_flags(): """Define the common flags across models.""" common_hparams_flags.define_common_hparams_flags() common_tpu_flags.define_common_tpu_flags() # Parameters for MultiWorkerMirroredStrategy flags.DEFINE_string( 'worker_hosts', default=None, help='Comma-separated list of worker ip:port pairs for running ' 'multi-worker models with distribution strategy. The user would ' 'start the program on each host with identical value for this flag.') flags.DEFINE_integer( 'task_index', 0, 'If multi-worker training, the task_index of this worker.')
from hyperparameters import common_hparams_flags from hyperparameters import common_tpu_flags from hyperparameters import flags_to_params from hyperparameters import params_dict import imagenet_input import mnasnet_models import utils from configs import mnasnet_config from tensorflow.contrib.tpu.python.tpu import async_checkpoint from tensorflow.contrib.training.python.training import evaluation from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.estimator import estimator from tensorflow.python.keras import backend as K common_tpu_flags.define_common_tpu_flags() common_hparams_flags.define_common_hparams_flags() FLAGS = flags.FLAGS FAKE_DATA_DIR = 'gs://cloud-tpu-test-datasets/fake_imagenet' # Model specific flags flags.DEFINE_string( 'model_name', default=None, help=( 'The model name to select models among existing MnasNet configurations.' )) flags.DEFINE_enum(