예제 #1
0
def initialize_common_flags():
    """Define the common flags across models."""
    common_hparams_flags.define_common_hparams_flags()
    common_tpu_flags.define_common_tpu_flags()
    # Parameters for MultiWorkerMirroredStrategy
    flags.DEFINE_string(
        'worker_hosts',
        default=None,
        help='Comma-separated list of worker ip:port pairs for running '
        'multi-worker models with distribution strategy.  The user would '
        'start the program on each host with identical value for this flag.')
    flags.DEFINE_integer(
        'task_index', 0,
        'If multi-worker training, the task_index of this worker.')
예제 #2
0
from hyperparameters import common_hparams_flags
from hyperparameters import common_tpu_flags
from hyperparameters import flags_to_params
from hyperparameters import params_dict
import imagenet_input
import mnasnet_models
import utils
from configs import mnasnet_config
from tensorflow.contrib.tpu.python.tpu import async_checkpoint
from tensorflow.contrib.training.python.training import evaluation
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.estimator import estimator
from tensorflow.python.keras import backend as K

common_tpu_flags.define_common_tpu_flags()
common_hparams_flags.define_common_hparams_flags()

FLAGS = flags.FLAGS

FAKE_DATA_DIR = 'gs://cloud-tpu-test-datasets/fake_imagenet'

# Model specific flags
flags.DEFINE_string(
    'model_name',
    default=None,
    help=(
        'The model name to select models among existing MnasNet configurations.'
    ))

flags.DEFINE_enum(