GetMasterImageUri().AddToParser(parser) GetParameterServerMachineTypeConfig().AddToParser(parser) GetParameterServerAccelerator().AddToParser(parser) GetParameterServerImageUri().AddToParser(parser) GetWorkerMachineConfig().AddToParser(parser) GetWorkerAccelerator().AddToParser(parser) GetWorkerImageUri().AddToParser(parser) GetUseChiefInTfConfig().AddToParser(parser) if support_tpu_tf_version: GetTpuTfVersion().AddToParser(parser) # Custom Container Flags _ACCELERATOR_TYPE_MAPPER = arg_utils.ChoiceEnumMapper( 'generic-accelerator', jobs.GetMessagesModule( ).GoogleCloudMlV1AcceleratorConfig.TypeValueValuesEnum, help_str='Available types of accelerators.', include_filter=lambda x: x != 'ACCELERATOR_TYPE_UNSPECIFIED', required=False) _OP_ACCELERATOR_TYPE_MAPPER = arg_utils.ChoiceEnumMapper( 'generic-accelerator', jobs.GetMessagesModule( ).GoogleCloudMlV1AcceleratorConfig.TypeValueValuesEnum, help_str='Available types of accelerators.', include_filter=lambda x: x.startswith('NVIDIA'), required=False) _OP_AUTOSCALING_METRIC_NAME_MAPPER = arg_utils.ChoiceEnumMapper( 'autoscaling-metric-name', versions_api.GetMessagesModule(
_LOGS_URL = ('https://console.cloud.google.com/logs?' 'resource=ml.googleapis.com%2Fjob_id%2F{job_id}' '&project={project}') JOB_FORMAT = 'yaml(jobId,state,startTime.date(tz=LOCAL),endTime.date(tz=LOCAL))' # Check every 10 seconds if the job is complete (if we didn't fetch any logs the # last time) _CONTINUE_INTERVAL = 10 _TEXT_FILE_URL = ('https://www.tensorflow.org/guide/datasets' '#consuming_text_data') _TF_RECORD_URL = ('https://www.tensorflow.org/guide/datasets' '#consuming_tfrecord_data') _PREDICTION_DATA_FORMAT_MAPPER = arg_utils.ChoiceEnumMapper( '--data-format', jobs.GetMessagesModule( ).GoogleCloudMlV1PredictionInput.DataFormatValueValuesEnum, custom_mappings={ 'TEXT': ('text', ('Text files; see {}'.format(_TEXT_FILE_URL))), 'TF_RECORD': ('tf-record', 'TFRecord files; see {}'.format(_TF_RECORD_URL)), 'TF_RECORD_GZIP': ('tf-record-gzip', 'GZIP-compressed TFRecord files.') }, help_str='Data format of the input files.', required=True) _ACCELERATOR_MAP = arg_utils.ChoiceEnumMapper( '--accelerator-type', jobs.GetMessagesModule( ).GoogleCloudMlV1AcceleratorConfig.TypeValueValuesEnum, custom_mappings={ 'NVIDIA_TESLA_K80': ('nvidia-tesla-k80', 'NVIDIA Tesla K80 GPU'),
_LOGS_URL = ('https://console.cloud.google.com/logs?' 'resource=ml.googleapis.com%2Fjob_id%2F{job_id}' '&project={project}') JOB_FORMAT = 'yaml(jobId,state,startTime.date(tz=LOCAL),endTime.date(tz=LOCAL))' # Check every 10 seconds if the job is complete (if we didn't fetch any logs the # last time) _CONTINUE_INTERVAL = 10 _TF_RECORD_URL = ('https://www.tensorflow.org/versions/r0.12/how_tos/' 'reading_data/index.html#file-formats') _PREDICTION_DATA_FORMAT_MAPPER = arg_utils.ChoiceEnumMapper( '--data-format', jobs.GetMessagesModule( ).GoogleCloudMlV1PredictionInput.DataFormatValueValuesEnum, custom_mappings={ 'TEXT': ('text', ('Text files with instances separated ' 'by the new-line character.')), 'TF_RECORD': ('tf-record', 'TFRecord files; see {}'.format(_TF_RECORD_URL)), 'TF_RECORD_GZIP': ('tf-record-gzip', 'GZIP-compressed TFRecord files.') }, help_str='Data format of the input files.', required=True) def DataFormatFlagMap(): """Return the ChoiceEnumMapper for the --data-format flag.""" return _PREDICTION_DATA_FORMAT_MAPPER