def ValidateFrameworkAndMachineTypeGa(framework, machine_type): frameworks_enum = (versions_api.GetMessagesModule().GoogleCloudMlV1Version. FrameworkValueValuesEnum) if (framework != frameworks_enum.TENSORFLOW and not machine_type.startswith('ml')): raise InvalidArgumentCombinationError( 'Machine type {0} is currently only supported with tensorflow.'. format(machine_type))
def testValidateFrameworkAndMachineTypeGa(self): frameworks_enum = ( versions_api.GetMessagesModule().GoogleCloudMlV1Version .FrameworkValueValuesEnum) versions_util.ValidateFrameworkAndMachineTypeGa(frameworks_enum.XGBOOST, 'mls1-c1-m2') versions_util.ValidateFrameworkAndMachineTypeGa(frameworks_enum.TENSORFLOW, 'n1-standard-4') with self.assertRaises(versions_util.InvalidArgumentCombinationError): versions_util.ValidateFrameworkAndMachineTypeGa(frameworks_enum.XGBOOST, 'n1-standard-4')
def ParseAcceleratorFlag(accelerator): """Validates and returns a accelerator config message object.""" types = [c for c in _ACCELERATOR_TYPE_MAPPER.choices] if accelerator is None: return None raw_type = accelerator.get('type', None) if raw_type not in types: raise ArgumentError("""\ The type of the accelerator can only be one of the following: {}. """.format(', '.join(["'{}'".format(c) for c in types]))) accelerator_count = accelerator.get('count', 1) if accelerator_count <= 0: raise ArgumentError("""\ The count of the accelerator must be greater than 0. """) accelerator_msg = ( versions_api.GetMessagesModule().GoogleCloudMlV1AcceleratorConfig) accelerator_type = arg_utils.ChoiceToEnum( raw_type, accelerator_msg.TypeValueValuesEnum) return accelerator_msg(count=accelerator_count, type=accelerator_type)
def ParseAcceleratorFlag(accelerator): """Validates and returns a accelerator config message object.""" types = ('nvidia-tesla-k80', 'nvidia-tesla-p100', 'nvidia-tesla-v100', 'nvidia-tesla-p4') if accelerator is None: return None raw_type = accelerator.get('type', None) if raw_type not in types: raise ArgumentError("""\ The type of the accelerator can only be one of the following: 'nvidia-tesla-k80', 'nvidia-tesla-p100', 'nvidia-tesla-v100' and 'nvidia-tesla-p4'. """) accelerator_count = accelerator.get('count', 0) if accelerator_count <= 0: raise ArgumentError("""\ The count of the accelerator must be greater than 0. """) accelerator_msg = ( versions_api.GetMessagesModule().GoogleCloudMlV1AcceleratorConfig) accelerator_type = arg_utils.ChoiceToEnum( raw_type, accelerator_msg.TypeValueValuesEnum) return accelerator_msg(count=accelerator_count, type=accelerator_type)
'--allow-multiline-logs', action='store_true', help='Output multiline log messages as single records.') TASK_NAME = base.Argument( '--task-name', required=False, default=None, help='If set, display only the logs for this particular task.') _FRAMEWORK_CHOICES = { 'TENSORFLOW': 'tensorflow', 'SCIKIT_LEARN': 'scikit-learn', 'XGBOOST': 'xgboost' } FRAMEWORK_MAPPER = arg_utils.ChoiceEnumMapper( '--framework', (versions_api.GetMessagesModule().GoogleCloudMlV1Version. FrameworkValueValuesEnum), custom_mappings=_FRAMEWORK_CHOICES, help_str=('The ML framework used to train this version of the model. ' 'If not specified, defaults to `tensorflow`')) def AddPythonVersionFlag(parser, context): help_str = ( 'The version of Python used {context}. If not set, the default ' 'version is 2.7. Python 3.5 is available when `runtime_version` is ' 'set to 1.4 and above. Python 2.7 works with all supported runtime ' 'versions.').format(context=context) version = base.Argument('--python-version', help=help_str) version.AddToParser(parser)
help='Output multiline log messages as single records.') TASK_NAME = base.Argument( '--task-name', required=False, default=None, help='If set, display only the logs for this particular task.') _FRAMEWORK_CHOICES = { 'TENSORFLOW': 'tensorflow', 'SCIKIT_LEARN': 'scikit-learn', 'XGBOOST': 'xgboost' } FRAMEWORK_MAPPER = arg_utils.ChoiceEnumMapper( '--framework', (versions_api.GetMessagesModule(). GoogleCloudMlV1Version.FrameworkValueValuesEnum), custom_mappings=_FRAMEWORK_CHOICES, help_str=('ML framework used to train this version of the model. ' 'If not specified, defaults to \'tensorflow\'')) def AddKmsKeyFlag(parser, resource): permission_info = '{} must hold permission {}'.format( "The 'AI Platform Service Agent' service account", "'Cloud KMS CryptoKey Encrypter/Decrypter'") kms_resource_args.AddKmsKeyResourceArg( parser, resource, permission_info=permission_info) def AddPythonVersionFlag(parser, context): help_str = """\