def define_resnet_flags(resnet_size_choices=None): """Add flags and validators for ResNet.""" flags_core.define_base() flags_core.define_performance(num_parallel_calls=False, tf_gpu_thread_mode=True, datasets_num_private_threads=True, datasets_num_parallel_batches=True) flags_core.define_image() flags_core.define_benchmark() flags.adopt_module_key_flags(flags_core) flags.DEFINE_enum( name='resnet_version', short_name='rv', default='1', enum_values=['1', '2'], help=flags_core.help_wrap( 'Version of ResNet. (1 or 2) See README.md for details.')) flags.DEFINE_bool( name='fine_tune', short_name='ft', default=False, help=flags_core.help_wrap( 'If True do not train any parameters except for the final layer.')) flags.DEFINE_string( name='pretrained_model_checkpoint_path', short_name='pmcp', default=None, help=flags_core.help_wrap( 'If not None initialize all the network except the final layer with ' 'these values')) flags.DEFINE_boolean( name='eval_only', default=False, help=flags_core.help_wrap('Skip training and only perform evaluation on ' 'the latest checkpoint.')) flags.DEFINE_boolean( name='image_bytes_as_serving_input', default=False, help=flags_core.help_wrap( 'If True exports savedmodel with serving signature that accepts ' 'JPEG image bytes instead of a fixed size [HxWxC] tensor that ' 'represents the image. The former is easier to use for serving at ' 'the expense of image resize/cropping being done as part of model ' 'inference. Note, this flag only applies to ImageNet and cannot ' 'be used for CIFAR.')) flags.DEFINE_boolean( name='turn_off_distribution_strategy', default=False, help=flags_core.help_wrap('Set to True to not use distribution ' 'strategies.')) lottery.add_flags(flags) choice_kwargs = dict( name='resnet_size', short_name='rs', default='50', help=flags_core.help_wrap('The size of the ResNet model to use.')) if resnet_size_choices is None: flags.DEFINE_string(**choice_kwargs) else: flags.DEFINE_enum(enum_values=resnet_size_choices, **choice_kwargs)
from tensorflow.contrib import summary from tensorflow.contrib.tpu.python.tpu import async_checkpoint from tensorflow.contrib.training.python.training import evaluation from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.estimator import estimator from lottery import lottery common_tpu_flags.define_common_tpu_flags() common_hparams_flags.define_common_hparams_flags() FLAGS = flags.FLAGS FAKE_DATA_DIR = 'gs://cloud-tpu-test-datasets/fake_imagenet' lottery.add_flags(flags) flags.DEFINE_string( 'hparams_file', default=None, help=('Set of model parameters to override the default mparams.')) flags.DEFINE_multi_string( 'hparams', default=None, help=('This is used to override only the model hyperparameters. It should ' 'not be used to override the other parameters like the tpu specific ' 'flags etc. For example, if experimenting with larger numbers of ' 'train_steps, a possible value is ' '--hparams=train_steps=28152.'))