コード例 #1
0
def ValidateFrameworkAndMachineTypeGa(framework, machine_type):
    frameworks_enum = (versions_api.GetMessagesModule().GoogleCloudMlV1Version.
                       FrameworkValueValuesEnum)
    if (framework != frameworks_enum.TENSORFLOW
            and not machine_type.startswith('ml')):
        raise InvalidArgumentCombinationError(
            'Machine type {0} is currently only supported with tensorflow.'.
            format(machine_type))
コード例 #2
0
  def testValidateFrameworkAndMachineTypeGa(self):
    frameworks_enum = (
        versions_api.GetMessagesModule().GoogleCloudMlV1Version
        .FrameworkValueValuesEnum)

    versions_util.ValidateFrameworkAndMachineTypeGa(frameworks_enum.XGBOOST,
                                                    'mls1-c1-m2')
    versions_util.ValidateFrameworkAndMachineTypeGa(frameworks_enum.TENSORFLOW,
                                                    'n1-standard-4')
    with self.assertRaises(versions_util.InvalidArgumentCombinationError):
      versions_util.ValidateFrameworkAndMachineTypeGa(frameworks_enum.XGBOOST,
                                                      'n1-standard-4')
コード例 #3
0
ファイル: flags.py プロジェクト: novousernx/google-cloud-sdk
def ParseAcceleratorFlag(accelerator):
    """Validates and returns a accelerator config message object."""
    types = [c for c in _ACCELERATOR_TYPE_MAPPER.choices]
    if accelerator is None:
        return None
    raw_type = accelerator.get('type', None)
    if raw_type not in types:
        raise ArgumentError("""\
The type of the accelerator can only be one of the following: {}.
""".format(', '.join(["'{}'".format(c) for c in types])))
    accelerator_count = accelerator.get('count', 1)
    if accelerator_count <= 0:
        raise ArgumentError("""\
The count of the accelerator must be greater than 0.
""")
    accelerator_msg = (
        versions_api.GetMessagesModule().GoogleCloudMlV1AcceleratorConfig)
    accelerator_type = arg_utils.ChoiceToEnum(
        raw_type, accelerator_msg.TypeValueValuesEnum)
    return accelerator_msg(count=accelerator_count, type=accelerator_type)
コード例 #4
0
ファイル: flags.py プロジェクト: oarcia/cherrybit.io
def ParseAcceleratorFlag(accelerator):
    """Validates and returns a accelerator config message object."""
    types = ('nvidia-tesla-k80', 'nvidia-tesla-p100', 'nvidia-tesla-v100',
             'nvidia-tesla-p4')
    if accelerator is None:
        return None
    raw_type = accelerator.get('type', None)
    if raw_type not in types:
        raise ArgumentError("""\
The type of the accelerator can only be one of the following: 'nvidia-tesla-k80', 'nvidia-tesla-p100', 'nvidia-tesla-v100' and 'nvidia-tesla-p4'.
""")
    accelerator_count = accelerator.get('count', 0)
    if accelerator_count <= 0:
        raise ArgumentError("""\
The count of the accelerator must be greater than 0.
""")
    accelerator_msg = (
        versions_api.GetMessagesModule().GoogleCloudMlV1AcceleratorConfig)
    accelerator_type = arg_utils.ChoiceToEnum(
        raw_type, accelerator_msg.TypeValueValuesEnum)
    return accelerator_msg(count=accelerator_count, type=accelerator_type)
コード例 #5
0
ファイル: flags.py プロジェクト: barber223/AudioApp
    '--allow-multiline-logs',
    action='store_true',
    help='Output multiline log messages as single records.')
TASK_NAME = base.Argument(
    '--task-name',
    required=False,
    default=None,
    help='If set, display only the logs for this particular task.')

_FRAMEWORK_CHOICES = {
    'TENSORFLOW': 'tensorflow',
    'SCIKIT_LEARN': 'scikit-learn',
    'XGBOOST': 'xgboost'
}
FRAMEWORK_MAPPER = arg_utils.ChoiceEnumMapper(
    '--framework', (versions_api.GetMessagesModule().GoogleCloudMlV1Version.
                    FrameworkValueValuesEnum),
    custom_mappings=_FRAMEWORK_CHOICES,
    help_str=('The ML framework used to train this version of the model. '
              'If not specified, defaults to `tensorflow`'))


def AddPythonVersionFlag(parser, context):
    help_str = (
        'The version of Python used {context}. If not set, the default '
        'version is 2.7. Python 3.5 is available when `runtime_version` is '
        'set to 1.4 and above. Python 2.7 works with all supported runtime '
        'versions.').format(context=context)
    version = base.Argument('--python-version', help=help_str)
    version.AddToParser(parser)

コード例 #6
0
    help='Output multiline log messages as single records.')
TASK_NAME = base.Argument(
    '--task-name',
    required=False,
    default=None,
    help='If set, display only the logs for this particular task.')


_FRAMEWORK_CHOICES = {
    'TENSORFLOW': 'tensorflow',
    'SCIKIT_LEARN': 'scikit-learn',
    'XGBOOST': 'xgboost'
}
FRAMEWORK_MAPPER = arg_utils.ChoiceEnumMapper(
    '--framework',
    (versions_api.GetMessagesModule().
     GoogleCloudMlV1Version.FrameworkValueValuesEnum),
    custom_mappings=_FRAMEWORK_CHOICES,
    help_str=('ML framework used to train this version of the model. '
              'If not specified, defaults to \'tensorflow\''))


def AddKmsKeyFlag(parser, resource):
  permission_info = '{} must hold permission {}'.format(
      "The 'AI Platform Service Agent' service account",
      "'Cloud KMS CryptoKey Encrypter/Decrypter'")
  kms_resource_args.AddKmsKeyResourceArg(
      parser, resource, permission_info=permission_info)


def AddPythonVersionFlag(parser, context):
  help_str = """\