def Create(self, neg_ref, neg_type, network_endpoint_type, default_port=None, network=None, subnet=None): """Creates a network endpoint group.""" network_uri = None if network: network_ref = self.resources.Parse(network, {'project': neg_ref.project}, collection='compute.networks') network_uri = network_ref.SelfLink() subnet_uri = None if subnet: region = api_utils.ZoneNameToRegionName(neg_ref.zone) subnet_ref = self.resources.Parse( subnet, {'project': neg_ref.project, 'region': region}, collection='compute.subnetworks') subnet_uri = subnet_ref.SelfLink() type_enum = self.messages.NetworkEndpointGroup.TypeValueValuesEnum endpoint_type_enum = (self.messages.NetworkEndpointGroup .NetworkEndpointTypeValueValuesEnum) network_endpoint_group = self.messages.NetworkEndpointGroup( name=neg_ref.Name(), type=arg_utils.ChoiceToEnum(neg_type, type_enum), networkEndpointType=arg_utils.ChoiceToEnum( network_endpoint_type, endpoint_type_enum), loadBalancer=self.messages.NetworkEndpointGroupLbNetworkEndpointGroup( defaultPort=default_port, network=network_uri, subnetwork=subnet_uri)) request = self.messages.ComputeNetworkEndpointGroupsInsertRequest( networkEndpointGroup=network_endpoint_group, project=neg_ref.project, zone=neg_ref.zone) return self.client.MakeRequests([(self._service, 'Insert', request)])[0]
def _MakeTextFindings(self, likelihood, info_types, count, include_quote=False, exclude_info_types=False): """Make list of test text DlpV2Findings for text inspect response.""" findings = [] count = count or 1000 for x in range(count): quote = 'finding {}'.format(x + 1) if include_quote else None infotype = self.msg.GooglePrivacyDlpV2InfoType( name=info_types[x % len(info_types)]) f = self.msg.GooglePrivacyDlpV2Finding( createTime='2018-01-01T00:00:{}0.000Z'.format(x), infoType=None if exclude_info_types else infotype, likelihood=arg_utils.ChoiceToEnum( likelihood, self.msg.GooglePrivacyDlpV2Finding. LikelihoodValueValuesEnum), location=self.msg.GooglePrivacyDlpV2Location( byteRange=self.msg.GooglePrivacyDlpV2Range(end=23, start=11), codepointRange=self.msg.GooglePrivacyDlpV2Range(end=23, start=11)), quote=quote) findings.append(f) return findings
def testMappingProperty(self): mapper = arg_utils.ChoiceEnumMapper( '--test_arg', self.test_enum, help_str='Auxilio aliis.') expected_mapping = { y: arg_utils.ChoiceToEnum( x, self.test_enum) for x, y in six.iteritems(self.string_mapping)} self.assertEqual(expected_mapping, mapper.choice_mappings)
def GetImageFromFile(path): """Builds a GooglePrivacyDlpV2ByteContentItem message from a path. Will attempt to set message.type from file extension (if present). Args: path: the path arg given to the command. Raises: ImageFileError: if the image path does not exist and does not have a valid extension. Returns: GooglePrivacyDlpV2ByteContentItem: an message containing image data for the API on the image to analyze. """ extension = os.path.splitext(path)[-1].lower() extension = extension or 'n_a' image_item = _GetMessageClass('GooglePrivacyDlpV2ByteContentItem') if os.path.isfile(path) and _ValidateExtension(extension): with io.open(path, 'rb') as content_file: enum_val = arg_utils.ChoiceToEnum( VALID_IMAGE_EXTENSIONS[extension], image_item.TypeValueValuesEnum) image = image_item(data=content_file.read(), type=enum_val) else: raise ImageFileError( 'The image path [{}] does not exist or has an invalid extension. ' 'Must be one of [jpg, jpeg, png, bmp or svg]. ' 'Please double-check your input and try again.'.format(path)) return image
def _MakeImageFindings(self, likelihood, info_types, count, include_quote=False, exclude_info_types=False): """Make list of test image DlpV2Findings for image inspect response.""" findings = [] count = count or 1000 for x in range(count): quote = 'finding {}'.format(x + 1) if include_quote else None infotype = self.msg.GooglePrivacyDlpV2InfoType( name=info_types[x % len(info_types)]) f = self.msg.GooglePrivacyDlpV2Finding( createTime='2018-01-01T00:00:{}0.000Z'.format(x), infoType=None if exclude_info_types else infotype, likelihood=arg_utils.ChoiceToEnum( likelihood, self.msg.GooglePrivacyDlpV2Finding. LikelihoodValueValuesEnum), location=self.msg.GooglePrivacyDlpV2Location(contentLocations=[ self.msg.GooglePrivacyDlpV2ContentLocation( imageLocation=self.msg.GooglePrivacyDlpV2ImageLocation( boundingBoxes=[ self.msg.GooglePrivacyDlpV2BoundingBox( height=46, left=150, top=179, width=122) ])) ]), quote=quote) findings.append(f) return findings
def MakeImageRedactRequest(self, file_type, info_types, min_likelihood, include_quote, remove_text=False, redact_color_string=None): """Create ImageRedactRequest message for testing.""" image_content_item = self.msg.GooglePrivacyDlpV2ByteContentItem( data=self.TEST_IMG_CONTENT, type=arg_utils.ChoiceToEnum( file_type, self.msg.GooglePrivacyDlpV2ByteContentItem. TypeValueValuesEnum)) inspect_config = self._GetInspectConfig(info_types, min_likelihood, None, include_quote, None) inspect_config.excludeInfoTypes = None image_redaction_config = self.msg.GooglePrivacyDlpV2ImageRedactionConfig( redactAllText=remove_text, redactionColor=self._MakeRedactColor(redact_color_string)) inner_request = self.msg.GooglePrivacyDlpV2RedactImageRequest( byteItem=image_content_item, inspectConfig=inspect_config, imageRedactionConfigs=[image_redaction_config]) inner_request.inspectConfig.limits = None return self.msg.DlpProjectsImageRedactRequest( googlePrivacyDlpV2RedactImageRequest=inner_request, parent='projects/' + self.Project())
def ParseAcceleratorFlag(accelerator, version): """Validates and returns an accelerator config message object.""" if accelerator is None: return None types = list(c for c in GetAcceleratorTypeMapper(version).choices) raw_type = accelerator.get('type', None) if raw_type not in types: raise errors.ArgumentError("""\ The type of the accelerator can only be one of the following: {}. """.format(', '.join(["'{}'".format(c) for c in types]))) accelerator_count = accelerator.get('count', 1) if accelerator_count <= 0: raise errors.ArgumentError("""\ The count of the accelerator must be greater than 0. """) if version == constants.ALPHA_VERSION: accelerator_msg = ( apis.GetMessagesModule(constants.AI_PLATFORM_API_NAME, constants.AI_PLATFORM_API_VERSION[version]) .GoogleCloudAiplatformV1alpha1MachineSpec) elif version == constants.BETA_VERSION: accelerator_msg = ( apis.GetMessagesModule(constants.AI_PLATFORM_API_NAME, constants.AI_PLATFORM_API_VERSION[version]) .GoogleCloudAiplatformV1beta1MachineSpec) else: accelerator_msg = ( apis.GetMessagesModule(constants.AI_PLATFORM_API_NAME, constants.AI_PLATFORM_API_VERSION[version]) .GoogleCloudAiplatformV1MachineSpec) accelerator_type = arg_utils.ChoiceToEnum( raw_type, accelerator_msg.AcceleratorTypeValueValuesEnum) return accelerator_msg( acceleratorCount=accelerator_count, acceleratorType=accelerator_type)
def _CreatePatchRollout(args, messages): """Creates a PatchRollout message from input arguments.""" if not any([ args.rollout_mode, args.rollout_disruption_budget, args.rollout_disruption_budget_percent ]): return None if args.rollout_mode and not (args.rollout_disruption_budget or args.rollout_disruption_budget_percent): raise exceptions.InvalidArgumentException( 'rollout-mode', '[rollout-disruption-budget] or [rollout-disruption-budget-percent] ' 'must also be specified.') if args.rollout_disruption_budget and not args.rollout_mode: raise exceptions.InvalidArgumentException( 'rollout-disruption-budget', '[rollout-mode] must also be specified.') if args.rollout_disruption_budget_percent and not args.rollout_mode: raise exceptions.InvalidArgumentException( 'rollout-disruption-budget-percent', '[rollout-mode] must also be specified.') rollout_modes = messages.PatchRollout.ModeValueValuesEnum return messages.PatchRollout( mode=arg_utils.ChoiceToEnum(args.rollout_mode, rollout_modes), disruptionBudget=messages.FixedOrPercent( fixed=int(args.rollout_disruption_budget) if args.rollout_disruption_budget else None, percent=int(args.rollout_disruption_budget_percent) if args.rollout_disruption_budget_percent else None))
def MakeGetAssetsHistoryHttpRequests(args, service, api_version=DEFAULT_API_VERSION): """Manually make the get assets history request.""" messages = GetMessages(api_version) encoding.AddCustomJsonFieldMapping( messages.CloudassetBatchGetAssetsHistoryRequest, 'readTimeWindow_startTime', 'readTimeWindow.startTime') encoding.AddCustomJsonFieldMapping( messages.CloudassetBatchGetAssetsHistoryRequest, 'readTimeWindow_endTime', 'readTimeWindow.endTime') content_type = arg_utils.ChoiceToEnum( args.content_type, messages.CloudassetBatchGetAssetsHistoryRequest. ContentTypeValueValuesEnum) parent = asset_utils.GetParentNameForGetHistory(args.organization, args.project) start_time = times.FormatDateTime(args.start_time) end_time = None if args.IsSpecified('end_time'): end_time = times.FormatDateTime(args.end_time) response = service.BatchGetAssetsHistory( messages.CloudassetBatchGetAssetsHistoryRequest( assetNames=args.asset_names, relationshipTypes=args.relationship_types, contentType=content_type, parent=parent, readTimeWindow_endTime=end_time, readTimeWindow_startTime=start_time, )) for asset in response.assets: yield asset
def SetExecutionConfig(messages, target, execution_configs): """Sets the executionConfigs field of cloud deploy resource message. Args: messages: module containing the definitions of messages for Cloud Deploy. target: googlecloudsdk.third_party.apis.clouddeploy.Target message. execution_configs: [googlecloudsdk.third_party.apis.clouddeploy.ExecutionConfig], list of ExecutionConfig messages. Raises: arg_parsers.ArgumentTypeError: if usage is not a valid enum. """ for config in execution_configs: execution_config_message = messages.ExecutionConfig() for field in config: # the value of usages field has enum, which needs special treatment. if field != 'usages': setattr(execution_config_message, field, config.get(field)) usages = config.get('usages') or [] for usage in usages: execution_config_message.usages.append( # converts a string literal in executionConfig.usages to an Enum. arg_utils.ChoiceToEnum( usage, messages.ExecutionConfig.UsagesValueListEntryValuesEnum, valid_choices=USAGE_CHOICES)) target.executionConfigs.append(execution_config_message)
def _AppendReplicas(msgs, add_replicas_arg, replica_info_list): """Appends each in add_replicas_arg to the given ReplicaInfo list.""" for replica in add_replicas_arg: replica_type = arg_utils.ChoiceToEnum( replica['type'], msgs.ReplicaInfo.TypeValueValuesEnum) replica_info_list.append( msgs.ReplicaInfo(location=replica['location'], type=replica_type))
def GetNotificationCategories(args, notification_category_enum_message): if not args.notification_categories: return [] return [ arg_utils.ChoiceToEnum(category_choice, notification_category_enum_message) for category_choice in args.notification_categories ]
def Run(self, args): holder = base_classes.ComputeApiHolder(self.ReleaseTrack()) client = holder.client # Import the virtual machine instance configuration specification. schema_path = self.GetSchemaPath(for_help=False) data = console_io.ReadFromFileOrStdin(args.source or '-', binary=False) instance = export_util.Import( message_type=client.messages.Instance, stream=data, schema_path=schema_path) # Confirm imported instance has base64 fingerprint. if not instance.fingerprint: raise exceptions.InvalidUserInputError( '"{}" is missing the instance\'s base64 fingerprint field.'.format( args.source)) # Retrieve specified instance reference. instance_ref = flags.INSTANCE_ARG.ResolveAsResource( args, holder.resources, scope_lister=compute_flags.GetDefaultScopeLister(client)) # Process update-constraint args. most_disruptive_allowed_action = arg_utils.ChoiceToEnum( args.most_disruptive_allowed_action, client.messages.ComputeInstancesUpdateRequest .MostDisruptiveAllowedActionValueValuesEnum) minimal_action = arg_utils.ChoiceToEnum( args.minimal_action, client.messages.ComputeInstancesUpdateRequest .MinimalActionValueValuesEnum) # Prepare and send the update request. request = client.messages.ComputeInstancesUpdateRequest( instance=instance.name, project=instance_ref.project, zone=instance_ref.zone, instanceResource=instance, minimalAction=minimal_action, mostDisruptiveAllowedAction=most_disruptive_allowed_action) if self._support_secure_tag and args.clear_secure_tag: request.clearSecureTag = True client.MakeRequests([(client.apitools_client.instances, 'Update', request)]) return
def Run(self, args): dataproc = dp.Dataproc(self.ReleaseTrack()) cluster_ref = args.CONCEPTS.cluster.Parse() request = None if args.tarball_access is not None: tarball_access = arg_utils.ChoiceToEnum( args.tarball_access, dataproc.messages.DiagnoseClusterRequest. TarballAccessValueValuesEnum) diagnose_request = dataproc.messages.DiagnoseClusterRequest( tarballAccess=tarball_access) request = dataproc.messages.DataprocProjectsRegionsClustersDiagnoseRequest( clusterName=cluster_ref.clusterName, region=cluster_ref.region, projectId=cluster_ref.projectId, diagnoseClusterRequest=diagnose_request) else: request = dataproc.messages.DataprocProjectsRegionsClustersDiagnoseRequest( clusterName=cluster_ref.clusterName, region=cluster_ref.region, projectId=cluster_ref.projectId) operation = dataproc.client.projects_regions_clusters.Diagnose(request) # TODO(b/36052522): Stream output during polling. operation = util.WaitForOperation( dataproc, operation, message='Waiting for cluster diagnose operation', timeout_s=args.timeout) if not operation.response: raise exceptions.OperationError('Operation is missing response') properties = encoding.MessageToDict(operation.response) output_uri = properties['outputUri'] if not output_uri: raise exceptions.OperationError('Response is missing outputUri') log.err.Print('Output from diagnostic:') log.err.Print('-----------------------------------------------') driver_log_stream = storage_helpers.StorageObjectSeriesStream( output_uri) # A single read might not read whole stream. Try a few times. read_retrier = retry.Retryer(max_retrials=4, jitter_ms=None) try: read_retrier.RetryOnResult( lambda: driver_log_stream.ReadIntoWritable(log.err), sleep_ms=100, should_retry_if=lambda *_: driver_log_stream.open) except retry.MaxRetrialsException: log.warning('Diagnostic finished successfully, ' 'but output did not finish streaming.') log.err.Print('-----------------------------------------------') return output_uri
def CreateNodeTemplate(node_template_ref, args, project, region, messages, resource_parser, enable_disk=False, enable_accelerator=False): """Creates a Node Template message from args.""" node_affinity_labels = None if args.node_affinity_labels: node_affinity_labels = _ParseNodeAffinityLabels( args.node_affinity_labels, messages) node_type_flexbility = None if args.IsSpecified('node_requirements'): node_type_flexbility = messages.NodeTemplateNodeTypeFlexibility( cpus=six.text_type(args.node_requirements.get('vCPU', 'any')), # local SSD is unique because the user may omit the local SSD constraint # entirely to include the possibility of node types with no local SSD. # "any" corresponds to "greater than zero". localSsd=args.node_requirements.get('localSSD', None), memory=args.node_requirements.get('memory', 'any')) node_template = messages.NodeTemplate( name=node_template_ref.Name(), description=args.description, nodeAffinityLabels=node_affinity_labels, nodeType=args.node_type, nodeTypeFlexibility=node_type_flexbility) if enable_disk: if args.IsSpecified('disk'): local_disk = messages.LocalDisk(diskCount=args.disk.get('count'), diskSizeGb=args.disk.get('size'), diskType=args.disk.get('type')) node_template.disks = [local_disk] if args.IsSpecified('cpu_overcommit_type'): overcommit_type = arg_utils.ChoiceToEnum( args.cpu_overcommit_type, messages.NodeTemplate.CpuOvercommitTypeValueValuesEnum) node_template.cpuOvercommitType = overcommit_type if enable_accelerator: node_template.accelerators = GetAccelerators(args, messages, resource_parser, project, region) server_binding_flag = flags.GetServerBindingMapperFlag(messages) server_binding = messages.ServerBinding( type=server_binding_flag.GetEnumForChoice(args.server_binding)) node_template.serverBinding = server_binding return node_template
def _GetWindowsUpdateSettings(args, messages): """Create WindowsUpdateSettings from input arguments.""" if args.windows_classifications or args.windows_excludes: enums = messages.WindowsUpdateSettings.ClassificationsValueListEntryValuesEnum classifications = [ arg_utils.ChoiceToEnum(c, enums) for c in args.windows_classifications ] if args.windows_classifications else [] return messages.WindowsUpdateSettings( classifications=classifications, excludes=args.windows_excludes if args.windows_excludes else []) else: return None
def Run(self, args): holder = base_classes.ComputeApiHolder(self.ReleaseTrack()) client = holder.client igm_ref = instance_groups_flags.CreateGroupReference( client, holder.resources, args) # Assert that Instance Group Manager exists. mig_utils.GetInstanceGroupManagerOrThrow(igm_ref, client) old_autoscaler = mig_utils.AutoscalerForMigByRef( client, holder.resources, igm_ref) if mig_utils.IsAutoscalerNew(old_autoscaler): raise NoMatchingAutoscalerFoundError( 'Instance group manager [{}] has no existing autoscaler; ' 'cannot update.'.format(igm_ref.Name())) autoscalers_client = autoscalers_api.GetClient(client, igm_ref) new_autoscaler = autoscalers_client.message_type( name=old_autoscaler.name, # PATCH needs this autoscalingPolicy=client.messages.AutoscalingPolicy()) if args.IsSpecified('mode'): mode = mig_utils.ParseModeString(args.mode, client.messages) new_autoscaler.autoscalingPolicy.mode = mode if args.IsSpecified('clear_scale_in_control'): new_autoscaler.autoscalingPolicy.scaleInControl = None else: new_autoscaler.autoscalingPolicy.scaleInControl = \ mig_utils.BuildScaleIn(args, client.messages) if self.clear_scale_down and args.IsSpecified( 'clear_scale_down_control'): new_autoscaler.autoscalingPolicy.scaleDownControl = None if args.IsSpecified('cpu_utilization_predictive_method'): cpu_predictive_enum = client.messages.AutoscalingPolicyCpuUtilization.PredictiveMethodValueValuesEnum new_autoscaler.autoscalingPolicy.cpuUtilization = client.messages.AutoscalingPolicyCpuUtilization( ) new_autoscaler.autoscalingPolicy.cpuUtilization.predictiveMethod = arg_utils.ChoiceToEnum( args.cpu_utilization_predictive_method, cpu_predictive_enum) scheduled = mig_utils.BuildSchedules(args, client.messages) if scheduled: new_autoscaler.autoscalingPolicy.scalingSchedules = scheduled if args.IsSpecified('min_num_replicas'): new_autoscaler.autoscalingPolicy.minNumReplicas = args.min_num_replicas if args.IsSpecified('max_num_replicas'): new_autoscaler.autoscalingPolicy.maxNumReplicas = args.max_num_replicas return self._SendPatchRequest(args, client, autoscalers_client, igm_ref, new_autoscaler)
def ParseReplacementMethod(method_type_str, messages): """Retrieves value of update policy type: substitute or recreate. Args: method_type_str: string containing update policy type. messages: module containing message classes. Returns: InstanceGroupManagerUpdatePolicy.TypeValueValuesEnum message enum value. """ return arg_utils.ChoiceToEnum(method_type_str, (messages.InstanceGroupManagerUpdatePolicy. ReplacementMethodValueValuesEnum))
def testChoiceToEnumErrors(self): # With valid choices specified, validate against those with self.assertRaisesRegex( arg_parsers.ArgumentTypeError, r'Invalid choice: badchoice. Valid choices are: \[a, b\].'): arg_utils.ChoiceToEnum('badchoice', fm.FakeMessage.FakeEnum, valid_choices=['a', 'b']) # With valid choices specified, and custom item type with self.assertRaisesRegex( arg_parsers.ArgumentTypeError, r'Invalid sproket: badchoice. Valid choices are: \[a, b\].'): arg_utils.ChoiceToEnum('badchoice', fm.FakeMessage.FakeEnum, item_type='sproket', valid_choices=['a', 'b']) # With no valid choices specified, validate against the enum with self.assertRaisesRegex( arg_parsers.ArgumentTypeError, r'Invalid choice: badchoice. Valid choices are: \[thing-one, ' r'thing-two\].'): arg_utils.ChoiceToEnum('badchoice', fm.FakeMessage.FakeEnum)
def GetLocationPolicyLocations(self, args, messages): locations = [] for zone, policy in args.location_policy.items(): zone_policy = arg_utils.ChoiceToEnum( policy, messages.LocationPolicyLocation.PreferenceValueValuesEnum) locations.append( messages.LocationPolicy.LocationsValue.AdditionalProperty( key='zones/{}'.format(zone), value=messages.LocationPolicyLocation( preference=zone_policy))) return messages.LocationPolicy.LocationsValue( additionalProperties=locations)
def ParseSecuritypolicy(securitypolicy, message): """Convert a string representation of a security policy to an enum representation. Args: securitypolicy: string representation of the security policy message: message module client Returns: an enum representation of the security policy """ return arg_utils.ChoiceToEnum( securitypolicy, message.CloudBuildMembershipConfig.SecurityPolicyValueValuesEnum)
def _SetRunOptionInRequest(run_option, run_schedule, request, messages): """Returns request with the run option set.""" if run_option == 'manual': arg_utils.SetFieldInMessage( request, 'googleCloudDatacatalogV1alpha3Crawler.config.adHocRun', messages.GoogleCloudDatacatalogV1alpha3AdhocRun()) elif run_option == 'scheduled': scheduled_run_option = arg_utils.ChoiceToEnum( run_schedule, (messages.GoogleCloudDatacatalogV1alpha3ScheduledRun. ScheduledRunOptionValueValuesEnum)) arg_utils.SetFieldInMessage( request, 'googleCloudDatacatalogV1alpha3Crawler.config.scheduledRun.scheduledRunOption', scheduled_run_option) return request
def _BuildCpuUtilization(args, messages, predictive=False): """Builds the CPU Utilization message given relevant arguments.""" flags_to_check = ['target_cpu_utilization', 'scale_based_on_cpu'] if predictive: flags_to_check.append('cpu_utilization_predictive_method') if instance_utils.IsAnySpecified(args, *flags_to_check): cpu_message = messages.AutoscalingPolicyCpuUtilization() if args.target_cpu_utilization: cpu_message.utilizationTarget = args.target_cpu_utilization if predictive and args.cpu_utilization_predictive_method: cpu_predictive_enum = messages.AutoscalingPolicyCpuUtilization.PredictiveMethodValueValuesEnum cpu_message.predictiveMethod = arg_utils.ChoiceToEnum( args.cpu_utilization_predictive_method, cpu_predictive_enum) return cpu_message return None
def IndexTypeToEnum(index_type): """Converts an Index Type String Literal to an Enum. Args: index_type: The index type e.g INDEX_TYPE_STRING. Returns: A IndexConfig.TypeValueValuesEnum mapped e.g TypeValueValuesEnum(INDEX_TYPE_INTEGER, 2) . Will return a Parser error if an incorrect value is provided. """ return arg_utils.ChoiceToEnum( index_type, GetMessages().IndexConfig.TypeValueValuesEnum, valid_choices=['INDEX_TYPE_STRING', 'INDEX_TYPE_INTEGER'])
def _CreateWindowsUpdateSettings(args, messages): """Creates a WindowsUpdateSettings message from input arguments.""" if not any([ args.windows_classifications, args.windows_excludes, args.windows_exclusive_patches ]): return None enums = messages.WindowsUpdateSettings.ClassificationsValueListEntryValuesEnum classifications = [ arg_utils.ChoiceToEnum(c, enums) for c in args.windows_classifications ] if args.windows_classifications else [] return messages.WindowsUpdateSettings( classifications=classifications, excludes=args.windows_excludes if args.windows_excludes else [], exclusivePatches=args.windows_exclusive_patches if args.windows_exclusive_patches else [], )
def GetLocationPolicy(self, args, messages): if not args.IsSpecified('location_policy') and ( not self._support_enable_target_shape or not args.IsSpecified('target_distribution_shape')): return None location_policy = messages.LocationPolicy() if args.IsSpecified('location_policy'): location_policy.locations = self.GetLocationPolicyLocations( args, messages) if (self._support_enable_target_shape and args.IsSpecified('target_distribution_shape')): location_policy.targetShape = arg_utils.ChoiceToEnum( args.target_distribution_shape, messages.LocationPolicy.TargetShapeValueValuesEnum) return location_policy
def _SkipReplicas(msgs, skip_replicas_arg, replica_info_list): """Skips each in skip_replicas_arg from the given ReplicaInfo list.""" for replica_to_skip in skip_replicas_arg: index_to_delete = None replica_type = arg_utils.ChoiceToEnum( replica_to_skip['type'], msgs.ReplicaInfo.TypeValueValuesEnum) for index, replica in enumerate(replica_info_list): # Only skip the first found matching replica. if (replica.location == replica_to_skip['location'] and replica.type == replica_type): index_to_delete = index pass if index_to_delete is None: raise MissingReplicaError(replica_to_skip['location'], replica_type) replica_info_list.pop(index_to_delete)
def Run(self, args): region_ref = args.CONCEPTS.region.Parse() region = region_ref.AsDict()['locationsId'] with endpoint_util.AiplatformEndpointOverrides( version=constants.BETA_VERSION, region=region): algorithm = arg_utils.ChoiceToEnum( args.algorithm, client.HpTuningJobsClient.GetAlgorithmEnum()) response = client.HpTuningJobsClient().Create( parent=region_ref.RelativeName(), config_path=args.config, display_name=args.display_name, max_trial_count=args.max_trial_count, parallel_trial_count=args.parallel_trial_count, algorithm=algorithm) log.status.Print( constants.HPTUNING_JOB_CREATION_DISPLAY_MESSAGE.format( id=hp_tuning_jobs_util.ParseJobName(response.name), state=response.state)) return response
def ModifyAlertPolicy(base_policy, messages, display_name=None, combiner=None, documentation_content=None, documentation_format=None, enabled=None, channels=None, field_masks=None): """Override and/or add fields from other flags to an Alert Policy.""" if field_masks is None: field_masks = [] if display_name is not None: field_masks.append('display_name') base_policy.displayName = display_name if ((documentation_content is not None or documentation_format is not None) and not base_policy.documentation): base_policy.documentation = messages.Documentation() if documentation_content is not None: field_masks.append('documentation.content') base_policy.documentation.content = documentation_content if documentation_format is not None: field_masks.append('documentation.mime_type') base_policy.documentation.mimeType = documentation_format if enabled is not None: field_masks.append('enabled') base_policy.enabled = enabled # None indicates no update and empty list indicates we want to explicitly set # an empty list. if channels is not None: field_masks.append('notification_channels') base_policy.notificationChannels = channels if combiner is not None: field_masks.append('combiner') combiner = arg_utils.ChoiceToEnum(combiner, base_policy.CombinerValueValuesEnum, item_type='combiner') base_policy.combiner = combiner
def ValidateAndAddPortSpecificationToHealthCheck(args, x_health_check): """Modifies the health check as needed and adds port spec to the check.""" enum_class = type(x_health_check).PortSpecificationValueValuesEnum if hasattr(args, 'port_specification') and args.port_specification: enum_value = arg_utils.ChoiceToEnum(args.port_specification, enum_class) if enum_value == enum_class.USE_FIXED_PORT: if args.IsSpecified('port_name'): _RaiseBadPortSpecificationError('--port-name', '--port-specification', 'USE_FIXED_PORT') if enum_value == enum_class.USE_NAMED_PORT: if args.IsSpecified('port'): _RaiseBadPortSpecificationError('--port', '--port-specification', 'USE_NAMED_PORT') # TODO(b/77489293): Stop overriding default values here. x_health_check.port = None if enum_value == enum_class.USE_SERVING_PORT: if args.IsSpecified('port_name'): _RaiseBadPortSpecificationError('--port-name', '--port-specification', 'USE_SERVING_PORT') if args.IsSpecified('port'): _RaiseBadPortSpecificationError('--port', '--port-specification', 'USE_SERVING_PORT') # TODO(b/77489293): Stop overriding default values here. x_health_check.port = None if hasattr(args, 'use_serving_port') and args.use_serving_port: _RaiseBadPortSpecificationError('--use-serving-port', '--port-specification', enum_value) x_health_check.portSpecification = enum_value elif hasattr(args, 'use_serving_port') and args.use_serving_port: if args.IsSpecified('port_name'): _RaiseBadPortSpecificationError('--port-name', '--use-serving-port', '--use-serving-port') if args.IsSpecified('port'): _RaiseBadPortSpecificationError('--port', '--use-serving-port', '--use-serving-port') x_health_check.portSpecification = enum_class.USE_SERVING_PORT x_health_check.port = None