def ParseX509Parameters(args, is_ca_command): """Parses the X509 parameters flags into an API X509Parameters. Args: args: The parsed argument values. is_ca_command: Whether the current command is on a CA. If so, certSign and crlSign key usages are added. Returns: An X509Parameters object. """ preset_profile_set = args.IsKnownAndSpecified('use_preset_profile') # TODO(b/183243757): Change to args.IsSpecified once --use-preset-profile flag # is registered. has_inline_values = any([ args.IsKnownAndSpecified(flag) for flag in [ 'key_usages', 'extended_key_usages', 'max_chain_length', 'is_ca_cert' ] ]) if preset_profile_set and has_inline_values: raise exceptions.InvalidArgumentException( '--use-preset-profile', '--use-preset-profile may not be specified if one or more of ' '--key-usages, --extended-key-usages or --max-chain-length are ' 'specified.') if preset_profile_set: return preset_profiles.GetPresetX509Parameters(args.use_preset_profile) base_key_usages = args.key_usages or [] is_ca = is_ca_command or (args.IsKnownAndSpecified('is_ca_cert') and args.is_ca_cert) if is_ca: # A CA should have these KeyUsages to be RFC 5280 compliant. base_key_usages.extend(['cert_sign', 'crl_sign']) key_usage_dict = {} for key_usage in base_key_usages: key_usage = text_utils.SnakeCaseToCamelCase(key_usage) key_usage_dict[key_usage] = True extended_key_usage_dict = {} for extended_key_usage in args.extended_key_usages or []: extended_key_usage = text_utils.SnakeCaseToCamelCase( extended_key_usage) extended_key_usage_dict[extended_key_usage] = True messages = privateca_base.GetMessagesModule('v1') return messages.X509Parameters( keyUsage=messages.KeyUsage( baseKeyUsage=messages_util.DictToMessageWithErrorCheck( key_usage_dict, messages.KeyUsageOptions), extendedKeyUsage=messages_util.DictToMessageWithErrorCheck( extended_key_usage_dict, messages.ExtendedKeyUsageOptions)), caOptions=messages.CaOptions( isCa=is_ca, # Don't include maxIssuerPathLength if it's None. maxIssuerPathLength=int(args.max_chain_length) if is_ca and args.max_chain_length is not None else None))
def ParseReusableConfig(args, required=False): """Parses the reusable config flags into an API ReusableConfigWrapper. Args: args: The parsed argument values. required: Whether a reusable config is required. Returns: A ReusableConfigWrapper object. """ resource = args.CONCEPTS.reusable_config.Parse() has_inline = args.IsSpecified('key_usages') or args.IsSpecified( 'extended_key_usages') or args.IsSpecified('max_chain_length') messages = privateca_base.GetMessagesModule() if resource and has_inline: raise exceptions.InvalidArgumentException( '--reusable-config', '--reusable-config may not be specified if one or more of ' '--key-usages, --extended-key-usages or --max-chain-length are ' 'specified.') if resource: return messages.ReusableConfigWrapper( reusableConfig=resource.RelativeName()) if not has_inline: if required: raise exceptions.InvalidArgumentException( '--reusable-config', 'Either --reusable-config or one or more of --key-usages, ' '--extended-key-usages and --max-chain-length must be specified.' ) return messages.ReusableConfigWrapper() key_usage_dict = {} for key_usage in args.key_usages or []: key_usage = text_utils.SnakeCaseToCamelCase(key_usage) key_usage_dict[key_usage] = True extended_key_usage_dict = {} for extended_key_usage in args.extended_key_usages or []: extended_key_usage = text_utils.SnakeCaseToCamelCase( extended_key_usage) extended_key_usage_dict[extended_key_usage] = True max_issuer_length = (int(args.max_chain_length) if args.IsSpecified('max_chain_length') else None) return messages.ReusableConfigWrapper( reusableConfigValues=messages. ReusableConfigValues(keyUsage=messages.KeyUsage( baseKeyUsage=messages_util.DictToMessageWithErrorCheck( key_usage_dict, messages.KeyUsageOptions), extendedKeyUsage=messages_util.DictToMessageWithErrorCheck( extended_key_usage_dict, messages.ExtendedKeyUsageOptions)), caOptions=messages.CaOptions( maxIssuerPathLength=max_issuer_length)))
def ParseReusableConfig(args): """Parses the reusable config flags into an API ReusableConfigWrapper. Args: args: The parsed argument values. Returns: A ReusableConfigWrapper object. """ resource = args.CONCEPTS.reusable_config.Parse() # If key_usages or extended_usages or max_chain_length or is_ca_cert are # provided OR nothing was provided, use inline values (with defaults). has_inline = args.IsSpecified('key_usages') or args.IsSpecified( 'extended_key_usages') or args.IsSpecified('max_chain_length') or ( 'is_ca_cert' in vars(args) and args.IsSpecified('is_ca_cert')) messages = privateca_base.GetMessagesModule() if resource and has_inline: raise exceptions.InvalidArgumentException( '--reusable-config', '--reusable-config may not be specified if one or more of ' '--key-usages, --extended-key-usages or --max-chain-length are ' 'specified.') if resource: return messages.ReusableConfigWrapper( reusableConfig=resource.RelativeName()) key_usage_dict = {} for key_usage in args.key_usages or []: key_usage = text_utils.SnakeCaseToCamelCase(key_usage) key_usage_dict[key_usage] = True extended_key_usage_dict = {} for extended_key_usage in args.extended_key_usages or []: extended_key_usage = text_utils.SnakeCaseToCamelCase( extended_key_usage) extended_key_usage_dict[extended_key_usage] = True if 'is_ca_cert' in vars(args): is_ca_val = args.is_ca_cert else: # For Reusable Configs in CA commands, the command is always creating a # CA certificate. is_ca_val = True return messages.ReusableConfigWrapper( reusableConfigValues=messages. ReusableConfigValues(keyUsage=messages.KeyUsage( baseKeyUsage=messages_util.DictToMessageWithErrorCheck( key_usage_dict, messages.KeyUsageOptions), extendedKeyUsage=messages_util.DictToMessageWithErrorCheck( extended_key_usage_dict, messages.ExtendedKeyUsageOptions)), caOptions=messages.CaOptions( isCa=is_ca_val, maxIssuerPathLength=int(args.max_chain_length) if is_ca_val else None)))
def Run(self, args): """Runs the command. Args: args: argparse.Namespace with command-line arguments. Returns: The policy resource. """ policy_resource_name = args.CONCEPTS.policy_resource_name.Parse() # Load the policy file into a Python dict. policy_obj = parsing.LoadResourceFile( # Avoid 'u' prefix in Python 2 when this file path gets embedded in # error messages. six.ensure_str(args.policy_file)) # Decode the dict into a PlatformPolicy message, allowing DecodeErrors to # bubble up to the user if they are raised. policy = messages_util.DictToMessageWithErrorCheck( policy_obj, # The API is only available in v1. apis.GetMessagesModule('v1').PlatformPolicy) return platform_policy.Client('v1').Create(policy_resource_name, policy)
def ParseSubject(args): """Parses a dictionary with subject attributes into a API Subject type. Args: args: The argparse namespace that contains the flag values. Returns: Subject: the Subject type represented in the api. """ subject_args = args.subject remap_args = { 'CN': 'commonName', 'C': 'countryCode', 'ST': 'province', 'L': 'locality', 'O': 'organization', 'OU': 'organizationalUnit' } mapped_args = {} for key, val in subject_args.items(): if key in remap_args: mapped_args[remap_args[key]] = val else: mapped_args[key] = val try: return messages_util.DictToMessageWithErrorCheck( mapped_args, privateca_base.GetMessagesModule('v1').Subject) except messages_util.DecodeError: raise exceptions.InvalidArgumentException( '--subject', 'Unrecognized subject attribute.')
def testTypeMismatch_Scalar(self): with self.assertRaisesRegexp( messages_util.ScalarTypeMismatchError, r'Expected type <(type|class).* for field updateTime, found 1' ): messages_util.DictToMessageWithErrorCheck({'updateTime': 1}, self.messages.Policy)
def Run(self, args): """Create or Update service from YAML.""" conn_context = connection_context.GetConnectionContext( args, self.ReleaseTrack()) with serverless_operations.Connect(conn_context) as client: new_service = service.Service( messages_util.DictToMessageWithErrorCheck( args.FILE, client.messages_module.Service), client.messages_module) # If managed, namespace must match project (or will default to project if # not specified). # If not managed, namespace simply must not conflict if specified in # multiple places (or will default to "default" if not specified). namespace = args.CONCEPTS.namespace.Parse().Name() # From flag or default if new_service.metadata.namespace is not None: if (args.IsSpecified('namespace') and namespace != new_service.metadata.namespace): raise exceptions.ConfigurationError( 'Namespace specified in file does not match passed flag.') namespace = new_service.metadata.namespace project = properties.VALUES.core.project.Get() if flags.IsManaged(args) and namespace != project: raise exceptions.ConfigurationError( 'Namespace must be [{}] for Cloud Run (fully managed).'.format( project)) new_service.metadata.namespace = namespace changes = [config_changes.ReplaceServiceChange(new_service)] service_ref = resources.REGISTRY.Parse( new_service.metadata.name, params={'namespacesId': namespace}, collection='run.namespaces.services') original_service = client.GetService(service_ref) pretty_print.Info(deploy.GetStartDeployMessage(conn_context, service_ref)) deployment_stages = stages.ServiceStages() header = ( 'Deploying...' if original_service else 'Deploying new service...') with progress_tracker.StagedProgressTracker( header, deployment_stages, failure_message='Deployment failed', suppress_output=args.async_) as tracker: client.ReleaseService( service_ref, changes, tracker, asyn=args.async_, allow_unauthenticated=None, for_replace=True) if args.async_: pretty_print.Success( 'Service [{{bold}}{serv}{{reset}}] is deploying ' 'asynchronously.'.format(serv=service_ref.servicesId)) else: pretty_print.Success(deploy.GetSuccessMessageForSynchronousDeploy( client, service_ref))
def ParseSubject(subject_args): """Parses a dictionary with subject attributes into a API Subject type and common name. Args: subject_args: A string->string dict with subject attributes and values. Returns: A tuple with (common_name, Subject) where common name is a string and Subject is the Subject type represented in the api. """ common_name = subject_args['CN'] remap_args = { 'C': 'countryCode', 'ST': 'province', 'L': 'locality', 'O': 'organization', 'OU': 'organizationalUnit' } mapped_args = {} for key, val in subject_args.items(): if key == 'CN': continue if key in remap_args: mapped_args[remap_args[key]] = val else: mapped_args[key] = val try: return common_name, messages_util.DictToMessageWithErrorCheck( mapped_args, privateca_base.GetMessagesModule().Subject) except messages_util.DecodeError: raise exceptions.InvalidArgumentException( '--subject', 'Unrecognized subject attribute.')
def UpdateThresholdRules(ref, args, req): """Add threshold rule to budget.""" messages = GetMessagesModule() client = apis.GetClientInstance('billingbudgets', 'v1beta1') budgets = client.billingAccounts_budgets get_request_type = messages.BillingbudgetsBillingAccountsBudgetsGetRequest get_request = get_request_type(name=six.text_type(ref.RelativeName())) old_threshold_rules = budgets.Get(get_request).thresholdRules if args.IsSpecified('clear_threshold_rules'): old_threshold_rules = [] req.googleCloudBillingBudgetsV1beta1UpdateBudgetRequest.budget.thresholdRules = old_threshold_rules if args.IsSpecified('add_threshold_rule'): added_threshold_rules = args.add_threshold_rule final_rules = AddRules(old_threshold_rules, added_threshold_rules) req.googleCloudBillingBudgetsV1beta1UpdateBudgetRequest.budget.thresholdRules = final_rules return req if args.IsSpecified('threshold_rules_from_file'): rules_from_file = yaml.load(args.threshold_rules_from_file) # create a mock budget with updated threshold rules budget = messages_util.DictToMessageWithErrorCheck( {'thresholdRules': rules_from_file}, messages.GoogleCloudBillingBudgetsV1beta1Budget) # update the request with the new threshold rules req.googleCloudBillingBudgetsV1beta1UpdateBudgetRequest.budget.thresholdRules = budget.thresholdRules req.googleCloudBillingBudgetsV1beta1UpdateBudgetRequest.updateMask += ',thresholdRules' return req
def WithServiceYaml(self, yaml_path): """Overrides settings with service.yaml and returns a new Settings object.""" yaml_dict = yaml.load_path(yaml_path) message = messages_util.DictToMessageWithErrorCheck( yaml_dict, RUN_MESSAGES_MODULE.Service) knative_service = k8s_service.Service(message, RUN_MESSAGES_MODULE) replacements = {} # Planned attributes in # http://doc/1ah6LB9we-FSEhcBZ7_4XQlnOPClTyyQW_O3Q5WNUuJc#bookmark=id.j3st2l8a3s19 try: [container] = knative_service.spec.template.spec.containers except ValueError: raise exceptions.Error( 'knative Service must have exactly one container.') for var in container.env: replacements.setdefault('env_vars', {})[var.name] = var.value service_account_name = knative_service.spec.template.spec.serviceAccountName if service_account_name: replacements['credential'] = ServiceAccountSetting( name=service_account_name) return self.replace(**replacements)
def testRepeatedField(self): with self.assertRaisesRegex(messages_util.DecodeError, r'\.admissionWhitelistPatterns\[0\]\.foo'): messages_util.DictToMessageWithErrorCheck( {'admissionWhitelistPatterns': [{ 'foo': 'bar' }]}, self.messages.Policy)
def testTypeMismatch_HeterogeneousRepeated(self): with self.assertRaisesRegexp( messages_util.DecodeError, r'\.admissionWhitelistPatterns\[0\]\.namePatterns'): messages_util.DictToMessageWithErrorCheck( {'admissionWhitelistPatterns': [{ 'namePatterns': ['a', 1] }]}, self.messages.Policy)
def CreateNetworkInterfaceMessages(resources, compute_client, network_interface_arg, project, location, scope, network_interface_json=None): """Create network interface messages. Args: resources: generates resource references. compute_client: creates resources. network_interface_arg: CLI argument specifying network interfaces. project: project of the instance that will own the generated network interfaces. location: Location of the instance that will own the new network interfaces. scope: Location type of the instance that will own the new network interfaces. network_interface_json: CLI argument value specifying network interfaces in a JSON string directly in the command or in a file. Returns: list, items are NetworkInterfaceMessages. """ result = [] if network_interface_arg: for interface in network_interface_arg: address = interface.get('address', None) no_address = 'no-address' in interface network_tier = interface.get('network-tier', None) result.append( CreateNetworkInterfaceMessage( resources=resources, compute_client=compute_client, network=interface.get('network', None), subnet=interface.get('subnet', None), private_network_ip=interface.get('private-network-ip', None), nic_type=interface.get('nic-type', None), no_address=no_address, address=address, project=project, location=location, scope=scope, alias_ip_ranges_string=interface.get('aliases', None), network_tier=network_tier)) elif network_interface_json is not None: network_interfaces = yaml.load(network_interface_json) if not network_interfaces: # Empty json. return result for interface in network_interfaces: if not interface: # Empty dicts. continue network_interface = messages_util.DictToMessageWithErrorCheck( interface, compute_client.messages.NetworkInterface) result.append(network_interface) return result
def Create(self, config_path, display_name, parent=None, max_trial_count=None, parallel_trial_count=None, algorithm=None, kms_key_name=None): """Creates a hyperparameter tuning job with given parameters. Args: config_path: str, the file path of the hyperparameter tuning job configuration. display_name: str, the display name of the created hyperparameter tuning job. parent: str, parent of the created hyperparameter tuning job. e.g. /projects/xxx/locations/xxx/ max_trial_count: int, the desired total number of Trials. The default value is 1. parallel_trial_count: int, the desired number of Trials to run in parallel. The default value is 1. algorithm: AlgorithmValueValuesEnum, the search algorithm specified for the Study. kms_key_name: A customer-managed encryption key to use for the hyperparameter tuning job. Returns: Created hyperparameter tuning job. """ job_spec = self.messages.GoogleCloudAiplatformV1beta1HyperparameterTuningJob( ) if config_path: data = yaml.load_path(config_path) if data: job_spec = messages_util.DictToMessageWithErrorCheck( data, self.messages. GoogleCloudAiplatformV1beta1HyperparameterTuningJob) job_spec.maxTrialCount = max_trial_count job_spec.parallelTrialCount = parallel_trial_count if display_name: job_spec.displayName = display_name if algorithm and job_spec.studySpec: job_spec.studySpec.algorithm = algorithm if kms_key_name is not None: job_spec.encryptionSpec = self.messages.GoogleCloudAiplatformV1beta1EncryptionSpec( kmsKeyName=kms_key_name) return self._service.Create( self.messages. AiplatformProjectsLocationsHyperparameterTuningJobsCreateRequest( parent=parent, googleCloudAiplatformV1beta1HyperparameterTuningJob=job_spec))
def Create(self, location_ref, args): """Creates a model deployment monitoring job.""" endpoint_ref = _ParseEndpoint(args.endpoint, location_ref) job_spec = self.messages.GoogleCloudAiplatformV1beta1ModelDeploymentMonitoringJob( ) if args.monitoring_config_from_file: data = yaml.load_path(args.monitoring_config_from_file) if data: job_spec = messages_util.DictToMessageWithErrorCheck( data, self.messages. GoogleCloudAiplatformV1beta1ModelDeploymentMonitoringJob) else: job_spec.modelDeploymentMonitoringObjectiveConfigs = self._ConstructObjectiveConfigForCreate( location_ref, endpoint_ref.RelativeName(), args.feature_thresholds, args.dataset, args.bigquery_uri, args.data_format, args.gcs_uris, args.target_field, args.training_sampling_rate) job_spec.endpoint = endpoint_ref.RelativeName() job_spec.displayName = args.display_name job_spec.modelMonitoringAlertConfig = self.messages.GoogleCloudAiplatformV1beta1ModelMonitoringAlertConfig( emailAlertConfig=self.messages. GoogleCloudAiplatformV1beta1ModelMonitoringAlertConfigEmailAlertConfig( userEmails=args.emails)) job_spec.loggingSamplingStrategy = self.messages.GoogleCloudAiplatformV1beta1SamplingStrategy( randomSampleConfig=self.messages. GoogleCloudAiplatformV1beta1SamplingStrategyRandomSampleConfig( sampleRate=args.prediction_sampling_rate)) job_spec.modelDeploymentMonitoringScheduleConfig = self.messages.GoogleCloudAiplatformV1beta1ModelDeploymentMonitoringScheduleConfig( monitorInterval='{}s'.format( six.text_type(3600 * int(args.monitoring_frequency)))) if args.predict_instance_schema: job_spec.predictInstanceSchemaUri = args.predict_instance_schema if args.analysis_instance_schema: job_spec.analysisInstanceSchemaUri = args.analysis_instance_schema if args.log_ttl: job_spec.logTtl = '{}s'.format( six.text_type(86400 * int(args.log_ttl))) if args.sample_predict_request: instance_json = model_monitoring_jobs_util.ReadInstanceFromArgs( args.sample_predict_request) job_spec.samplePredictInstance = encoding.PyValueToMessage( extra_types.JsonValue, instance_json) return self._service.Create( self.messages. AiplatformProjectsLocationsModelDeploymentMonitoringJobsCreateRequest( parent=location_ref.RelativeName(), googleCloudAiplatformV1beta1ModelDeploymentMonitoringJob= job_spec))
def testMultiple_SameMessage(self): with self.assertRaisesRegex( messages_util.DecodeError, r'\.defaultAdmissionRule\.\{evaluationMode,nonConformanceAction\}'): messages_util.DictToMessageWithErrorCheck( { 'defaultAdmissionRule': { 'evaluationMode': 'NOT_A_REAL_ENUM', 'nonConformanceAction': 'NOT_A_REAL_ENUM', } }, self.messages.Policy)
def Run(self, args): # The API is only available in v1. messages = apis.GetMessagesModule('v1') policy_ref = args.CONCEPTS.policy_resource_name.Parse().RelativeName() # Load the policy file into a Python dict. policy_obj = parsing.LoadResourceFile(args.policy_file) # Decode the dict into a PlatformPolicy message, allowing DecodeErrors to # bubble up to the user if they are raised. policy = messages_util.DictToMessageWithErrorCheck( policy_obj, messages.PlatformPolicy) return platform_policy.Client('v1').Update(policy_ref, policy)
def ParseIssuancePolicy(args): """Parses a CertificateAuthorityPolicy proto message from the args.""" if not args.IsSpecified('issuance_policy'): return None try: return messages_util.DictToMessageWithErrorCheck( args.issuance_policy, privateca_base.GetMessagesModule().CertificateAuthorityPolicy) except messages_util.DecodeError: raise exceptions.InvalidArgumentException( '--issuance-policy', 'Unrecognized field in the Issuance Policy.')
def _ReadExplanationMetadata(self, explanation_metadata_file): explanation_metadata = None if not explanation_metadata_file: raise gcloud_exceptions.BadArgumentException( '--explanation-metadata-file', 'Explanation metadata file must be specified.') # Yaml is a superset of json, so parse json file as yaml. data = yaml.load_path(explanation_metadata_file) if data: explanation_metadata = messages_util.DictToMessageWithErrorCheck( data, self.messages.GoogleCloudAiplatformV1beta1ExplanationMetadata) return explanation_metadata
def _ReadIndexMetadata(self, metadata_file): """Parse json metadata file.""" if not metadata_file: raise gcloud_exceptions.BadArgumentException( '--metadata-file', 'Index metadata file must be specified.') index_metadata = None # Yaml is a superset of json, so parse json file as yaml. data = yaml.load_path(metadata_file) if data: index_metadata = messages_util.DictToMessageWithErrorCheck( data, extra_types.JsonValue) return index_metadata
def testMap(self): with self.assertRaisesRegex( messages_util.DecodeError, r'\.clusterAdmissionRules\[us-east1-b.my-cluster-1\]\.evaluationMode'): messages_util.DictToMessageWithErrorCheck( { 'clusterAdmissionRules': { 'us-east1-b.my-cluster-1': { 'evaluationMode': 'NOT_A_REAL_ENUM' } } }, self.messages.Policy)
def ParseIssuancePolicy(args): """Parses a CertificateAuthorityPolicy proto message from the args.""" if not args.IsSpecified('issuance_policy'): return None try: return messages_util.DictToMessageWithErrorCheck( args.issuance_policy, privateca_base.GetMessagesModule().CertificateAuthorityPolicy) # TODO(b/77547931): Catch `AttributeError` until upstream library takes the # fix. except (messages_util.DecodeError, AttributeError): raise exceptions.InvalidArgumentException( '--issuance-policy', 'Unrecognized field in the Issuance Policy.')
def ParsePredefinedValues(args): """Parses an X509Parameters proto message from the predefined values file in args.""" if not args.IsSpecified('predefined_values_file'): return None try: return messages_util.DictToMessageWithErrorCheck( args.predefined_values_file, privateca_base.GetMessagesModule('v1').X509Parameters) # TODO(b/77547931): Catch `AttributeError` until upstream library takes the # fix. except (messages_util.DecodeError, AttributeError): raise exceptions.InvalidArgumentException( '--predefined-values-file', 'Unrecognized field in the X509Parameters file.')
def GetPresetX509Parameters(profile_name): """Parses the profile name string into the corresponding API X509Parameters. Args: profile_name: The preset profile name. Returns: An X509Parameters object. """ if profile_name not in _PRESET_PROFILES: raise exceptions.InvalidArgumentException( '--use-preset-profile', 'The preset profile that was specified does not exist.') messages = privateca_base.GetMessagesModule('v1') return messages_util.DictToMessageWithErrorCheck( _PRESET_PROFILES[profile_name], messages.X509Parameters)
def testMultiple_DifferentMessages(self): with self.assertRaisesRegexp( messages_util.DecodeError, r'\.clusterAdmissionRules\[cluster-[12]\]\.evaluationMode[\w\W]*' r'\.clusterAdmissionRules\[cluster-[12]\]\.evaluationMode'): messages_util.DictToMessageWithErrorCheck( { 'clusterAdmissionRules': { 'cluster-1': { 'evaluationMode': 'NOT_A_REAL_ENUM' }, 'cluster-2': { 'evaluationMode': 'NOT_A_REAL_ENUM' } } }, self.messages.Policy)
def ReadConfig(config_file, message_type): """Parses json config file. Args: config_file: file path of the config file. message_type: The protorpc Message type. Returns: A message of type "message_type". """ config = None # Yaml is a superset of json, so parse json file as yaml. data = yaml.load_path(config_file) if data: config = messages_util.DictToMessageWithErrorCheck(data, message_type) return config
def Run(self, args): """Create or Update service from YAML.""" conn_context = connection_context.GetConnectionContext(args) if conn_context.supports_one_platform: flags.VerifyOnePlatformFlags(args) else: flags.VerifyGKEFlags(args) with serverless_operations.Connect(conn_context) as client: message_dict = yaml.load_path(args.FILE) new_service = service.Service( messages_util.DictToMessageWithErrorCheck( message_dict, client.messages_module.Service), client.messages_module) changes = [config_changes.ReplaceServiceChange(new_service)] service_ref = resources.REGISTRY.Parse( new_service.metadata.name, params={'namespacesId': new_service.metadata.namespace}, collection='run.namespaces.services') original_service = client.GetService(service_ref) pretty_print.Info( deploy.GetStartDeployMessage(conn_context, service_ref)) deployment_stages = stages.ServiceStages() header = ('Deploying...' if original_service else 'Deploying new service...') with progress_tracker.StagedProgressTracker( header, deployment_stages, failure_message='Deployment failed', suppress_output=args. async) as tracker: client.ReleaseService(service_ref, changes, tracker, asyn=args. async, allow_unauthenticated=None, for_replace=True) if args. async: pretty_print.Success( 'Service [{{bold}}{serv}{{reset}}] is deploying ' 'asynchronously.'.format(serv=service_ref.servicesId)) else: pretty_print.Success( deploy.GetSuccessMessageForSynchronousDeploy( client, service_ref))
def GetMessageFromResponse(response, message_type): """Returns a message from the ResponseValue. Operations normally return a ResponseValue object in their response field that is somewhat difficult to use. This functions returns the corresponding message type to make it easier to parse the response. Args: response: The ResponseValue object that resulted from an Operation. message_type: The type of the message that should be returned Returns: An instance of message_type with the values from the response filled in. """ message_dict = encoding.MessageToDict(response) snake_cased_dict = text_utils.ToSnakeCaseDict(message_dict) return messages_util.DictToMessageWithErrorCheck(snake_cased_dict, message_type)
def GetMessageFromResponse(response, message_type): """Returns a message from the ResponseValue. Operations normally return a ResponseValue object in their response field that is somewhat difficult to use. This functions returns the corresponding message type to make it easier to parse the response. Args: response: The ResponseValue object that resulted from an Operation. message_type: The type of the message that should be returned Returns: An instance of message_type with the values from the response filled in. """ message_dict = encoding.MessageToDict(response) # '@type' is not needed and not present in messages. if '@type' in message_dict: del message_dict['@type'] return messages_util.DictToMessageWithErrorCheck(message_dict, message_type)
def ConstructCustomJobSpec(aiplatform_client, config_path=None, network=None, service_account=None, specs=None, **kwargs): """Constructs the spec of a custom job to be used in job creation request. Args: aiplatform_client: The AI Platform API client used. config_path: str, Local path of a YAML file which contains the worker pool network: user network to which the job should be peered with (overrides YAML file) service_account: A service account (email address string) to use for the job. specs: A dictionary of worker pool specifications, supposedly derived from the gcloud command flags. **kwargs: The keyword args to pass to construct the worker pool specs. Returns: A CustomJobSpec message instance for creating a custom job. """ job_spec_msg = aiplatform_client.GetMessage('CustomJobSpec') job_spec = job_spec_msg() if config_path: data = yaml.load_path(config_path) if data: job_spec = messages.DictToMessageWithErrorCheck(data, job_spec_msg) job_spec.network = network job_spec.serviceAccount = service_account if specs: job_spec.workerPoolSpecs = _ConstructWorkerPoolSpecs( aiplatform_client, specs, **kwargs) return job_spec