def _Run(args, version): """Update an existing Vertex AI endpoint.""" validation.ValidateDisplayName(args.display_name) endpoint_ref = args.CONCEPTS.endpoint.Parse() args.region = endpoint_ref.AsDict()['locationsId'] with endpoint_util.AiplatformEndpointOverrides(version, region=args.region): endpoints_client = client.EndpointsClient(version=version) def GetLabels(): return endpoints_client.Get(endpoint_ref).labels try: if version == constants.GA_VERSION: op = endpoints_client.Patch( endpoint_ref, labels_util.ProcessUpdateArgsLazy( args, endpoints_client.messages. GoogleCloudAiplatformV1Endpoint.LabelsValue, GetLabels), display_name=args.display_name, description=args.description, traffic_split=args.traffic_split, clear_traffic_split=args.clear_traffic_split) else: op = endpoints_client.PatchBeta( endpoint_ref, labels_util.ProcessUpdateArgsLazy( args, endpoints_client.messages. GoogleCloudAiplatformV1beta1Endpoint.LabelsValue, GetLabels), display_name=args.display_name, description=args.description, traffic_split=args.traffic_split, clear_traffic_split=args.clear_traffic_split) except errors.NoFieldsSpecifiedError: available_update_args = [ 'display_name', 'traffic_split', 'clear_traffic_split', 'update_labels', 'clear_labels', 'remove_labels', 'description' ] if not any(args.IsSpecified(arg) for arg in available_update_args): raise log.status.Print('No update to perform.') return None else: log.UpdatedResource(op.name, kind='Vertex AI endpoint') return op
def PatchAlpha(self, tensorboard_run_ref, args): """Update a Tensorboard run.""" tensorboard_run = self.messages.GoogleCloudAiplatformV1alpha1TensorboardRun( ) update_mask = [] def GetLabels(): return self.Get(tensorboard_run_ref).labels labels_update = labels_util.ProcessUpdateArgsLazy( args, self.messages.GoogleCloudAiplatformV1alpha1TensorboardRun. LabelsValue, GetLabels) if labels_update.needs_update: tensorboard_run.labels = labels_update.labels update_mask.append('labels') if args.display_name is not None: tensorboard_run.displayName = args.display_name update_mask.append('display_name') if args.description is not None: tensorboard_run.description = args.description update_mask.append('description') if not update_mask: raise errors.NoFieldsSpecifiedError('No updates requested.') request = self.messages.AiplatformProjectsLocationsTensorboardsExperimentsRunsPatchRequest( name=tensorboard_run_ref.RelativeName(), googleCloudAiplatformV1alpha1TensorboardRun=tensorboard_run, updateMask=','.join(update_mask)) return self._service.Patch(request)
def ParseUpdateLabels(client, job_ref, args): def GetLabels(): return client.Get(job_ref).labels return labels_util.ProcessUpdateArgsLazy(args, client.job_class.LabelsValue, GetLabels)
def Run(self, args): dns = apis.GetClientInstance('dns', 'v1beta2') messages = apis.GetMessagesModule('dns', 'v1beta2') zone_ref = util.GetRegistry('v1beta2').Parse( args.dns_zone, params={ 'project': properties.VALUES.core.project.GetOrFail, }, collection='dns.managedZones') dnssec_config = command_util.ParseDnssecConfigArgs(args, messages) zone_args = {'name': args.dns_zone} if dnssec_config is not None: zone_args['dnssecConfig'] = dnssec_config if args.description is not None: zone_args['description'] = args.description zone = messages.ManagedZone(**zone_args) def Get(): return dns.managedZones.Get( dns.MESSAGES_MODULE.DnsManagedZonesGetRequest( project=zone_ref.project, managedZone=zone_ref.managedZone)).labels labels_update = labels_util.ProcessUpdateArgsLazy( args, messages.ManagedZone.LabelsValue, Get) zone.labels = labels_update.GetOrNone() result = dns.managedZones.Patch( messages.DnsManagedZonesPatchRequest(managedZoneResource=zone, project=zone_ref.project, managedZone=args.dns_zone)) return result
def Run(self, args): """This is what gets called when the user runs this command. Args: args: an argparse namespace. All the arguments that were provided to this command invocation. Returns: A serialized object (dict) describing the results of the operation. Raises: An HttpException if there was a problem calling the API topics.Patch command. """ client = topics.TopicsClient() topic_ref = args.CONCEPTS.topic.Parse() labels_update = labels_util.ProcessUpdateArgsLazy( args, client.messages.Topic.LabelsValue, orig_labels_thunk=lambda: client.Get(topic_ref).labels) try: result = client.Patch(topic_ref, labels=labels_update.GetOrNone()) except topics.NoFieldsSpecifiedError: if not any( args.IsSpecified(arg) for arg in ('clear_labels', 'update_labels', 'remove_labels')): raise log.status.Print('No update to perform.') result = None else: log.UpdatedResource(topic_ref.RelativeName(), kind='topic') return result
def _Update(zones_client, args, private_visibility_config=None, forwarding_config=None, peering_config=None): """Helper function to perform the update.""" zone_ref = args.CONCEPTS.zone.Parse() dnssec_config = command_util.ParseDnssecConfigArgs(args, zones_client.messages) labels_update = labels_util.ProcessUpdateArgsLazy( args, zones_client.messages.ManagedZone.LabelsValue, lambda: zones_client.Get(zone_ref).labels) kwargs = {} if private_visibility_config: kwargs['private_visibility_config'] = private_visibility_config if forwarding_config: kwargs['forwarding_config'] = forwarding_config if peering_config: kwargs['peering_config'] = peering_config return zones_client.Patch(zone_ref, dnssec_config=dnssec_config, description=args.description, labels=labels_update.GetOrNone(), **kwargs)
def ParseUpdateLabels(models_client, args): def GetLabels(): return models_client.Get(args.model).labels return labels_util.ProcessUpdateArgsLazy( args, models_client.messages.GoogleCloudMlV1Model.LabelsValue, GetLabels)
def PatchBeta(self, index_endpoint_ref, args): """Update an index endpoint.""" index_endpoint = self.messages.GoogleCloudAiplatformV1beta1IndexEndpoint( ) update_mask = [] if args.display_name is not None: index_endpoint.displayName = args.display_name update_mask.append('display_name') if args.description is not None: index_endpoint.description = args.description update_mask.append('description') def GetLabels(): return self.Get(index_endpoint_ref).labels labels_update = labels_util.ProcessUpdateArgsLazy( args, self.messages.GoogleCloudAiplatformV1beta1IndexEndpoint. LabelsValue, GetLabels) if labels_update.needs_update: index_endpoint.labels = labels_update.labels update_mask.append('labels') if not update_mask: raise errors.NoFieldsSpecifiedError('No updates requested.') request = self.messages.AiplatformProjectsLocationsIndexEndpointsPatchRequest( name=index_endpoint_ref.RelativeName(), googleCloudAiplatformV1beta1IndexEndpoint=index_endpoint, updateMask=','.join(update_mask)) return self._service.Patch(request)
def Run(self, args): """This is what gets called when the user runs this command. Args: args: an argparse namespace. All the arguments that were provided to this command invocation. Returns: A serialized object (dict) describing the results of the operation. This description fits the Resource described in the ResourceRegistry under 'pubsub.projects.subscriptions'. Raises: An HttpException if there was a problem calling the API subscriptions.Patch command. """ client = subscriptions.SubscriptionsClient() subscription_ref = args.CONCEPTS.subscription.Parse() dead_letter_topic = getattr(args, 'dead_letter_topic', None) max_delivery_attempts = getattr(args, 'max_delivery_attempts', None) clear_dead_letter_policy = getattr(args, 'clear_dead_letter_policy', None) labels_update = labels_util.ProcessUpdateArgsLazy( args, client.messages.Subscription.LabelsValue, orig_labels_thunk=lambda: client.Get(subscription_ref).labels) no_expiration = False expiration_period = getattr(args, 'expiration_period', None) if expiration_period: if expiration_period == subscriptions.NEVER_EXPIRATION_PERIOD_VALUE: no_expiration = True expiration_period = None try: result = client.Patch( subscription_ref, ack_deadline=args.ack_deadline, push_config=util.ParsePushConfig(args), retain_acked_messages=args.retain_acked_messages, labels=labels_update.GetOrNone(), message_retention_duration=args.message_retention_duration, no_expiration=no_expiration, expiration_period=expiration_period, dead_letter_topic=dead_letter_topic, max_delivery_attempts=max_delivery_attempts, clear_dead_letter_policy=clear_dead_letter_policy) except subscriptions.NoFieldsSpecifiedError: if not any(args.IsSpecified(arg) for arg in ('clear_labels', 'update_labels', 'remove_labels')): raise log.status.Print('No update to perform.') result = None else: log.UpdatedResource(subscription_ref.RelativeName(), kind='subscription') return result
def _Update(zones_client, args): zone_ref = args.CONCEPTS.zone.Parse() dnssec_config = command_util.ParseDnssecConfigArgs(args, zones_client.messages) labels_update = labels_util.ProcessUpdateArgsLazy( args, zones_client.messages.ManagedZone.LabelsValue, lambda: zones_client.Get(zone_ref).labels) return zones_client.Patch(zone_ref, dnssec_config=dnssec_config, description=args.description, labels=labels_update.GetOrNone())
def Run(self, args): """This is what gets called when the user runs this command. Args: args: an argparse namespace. All the arguments that were provided to this command invocation. Returns: A serialized object (dict) describing the results of the operation. Raises: An HttpException if there was a problem calling the API topics.Patch command. """ client = topics.TopicsClient() topic_ref = args.CONCEPTS.topic.Parse() message_retention_duration = getattr(args, 'message_retention_duration', None) if message_retention_duration: message_retention_duration = util.FormatDuration( message_retention_duration) clear_message_retention_duration = getattr( args, 'clear_message_retention_duration', None) labels_update = labels_util.ProcessUpdateArgsLazy( args, client.messages.Topic.LabelsValue, orig_labels_thunk=lambda: client.Get(topic_ref).labels) result = None try: result = client.Patch(topic_ref, labels_update.GetOrNone(), _GetKmsKeyNameFromArgs(args), message_retention_duration, clear_message_retention_duration, args.recompute_message_storage_policy, args.message_storage_policy_allowed_regions) except topics.NoFieldsSpecifiedError: operations = [ 'clear_labels', 'update_labels', 'remove_labels', 'recompute_message_storage_policy', 'message_storage_policy_allowed_regions' ] if not any(args.IsSpecified(arg) for arg in operations): raise log.status.Print('No update to perform.') else: log.UpdatedResource(topic_ref.RelativeName(), kind='topic') return result
def ProcessUpdates(self, api_cofig, args): update_mask = [] labels_update = labels_util.ProcessUpdateArgsLazy( args, api_cofig.LabelsValue, lambda: api_cofig.labels) if labels_update.needs_update: api_cofig.labels = labels_update.labels update_mask.append('labels') if args.display_name: api_cofig.displayName = args.display_name update_mask.append('displayName') return api_cofig, ','.join(update_mask)
def testProcessUpdateArgsLazy(self, args_string, original_labels, expected_labels, needs_update): parser = argparse.ArgumentParser() labels_util.AddUpdateLabelsFlags(parser) args = parser.parse_args(args_string.split()) def _GetLabels(): if original_labels is None: self.fail('Should not call the orig_labels_thunk.') return self._MakeLabels(original_labels) result = labels_util.ProcessUpdateArgsLazy(args, self.labels_cls, _GetLabels) expected = self._MakeLabels(expected_labels) self.assertEqual(result._labels, expected) self.assertEqual(result.needs_update, needs_update)
def PatchBeta(self, endpoint_ref, args): """Update a endpoint.""" endpoint = self.messages.GoogleCloudAiplatformV1beta1Endpoint() update_mask = [] def GetLabels(): return self.Get(endpoint_ref).labels labels_update = labels_util.ProcessUpdateArgsLazy( args, self.messages.GoogleCloudAiplatformV1beta1Endpoint.LabelsValue, GetLabels) if labels_update.needs_update: endpoint.labels = labels_update.labels update_mask.append('labels') if args.display_name is not None: endpoint.displayName = args.display_name update_mask.append('display_name') if args.traffic_split is not None: additional_properties = [] for key, value in sorted(args.traffic_split.items()): additional_properties.append( endpoint.TrafficSplitValue().AdditionalProperty( key=key, value=value)) endpoint.trafficSplit = endpoint.TrafficSplitValue( additionalProperties=additional_properties) update_mask.append('traffic_split') if args.clear_traffic_split: endpoint.trafficSplit = None update_mask.append('traffic_split') if args.description is not None: endpoint.description = args.description update_mask.append('description') if not update_mask: raise errors.NoFieldsSpecifiedError('No updates requested.') req = self.messages.AiplatformProjectsLocationsEndpointsPatchRequest( name=endpoint_ref.RelativeName(), googleCloudAiplatformV1beta1Endpoint=endpoint, updateMask=','.join(update_mask)) return self.client.projects_locations_endpoints.Patch(req)
def Run(self, args): """This is what gets called when the user runs this command. Args: args: an argparse namespace. All the arguments that were provided to this command invocation. Returns: A serialized object (dict) describing the results of the operation. This description fits the Resource described in the ResourceRegistry under 'pubsub.projects.subscriptions'. Raises: An HttpException if there was a problem calling the API subscriptions.Patch command. """ client = subscriptions.SubscriptionsClient() subscription_ref = args.CONCEPTS.subscription.Parse() labels_update = labels_util.ProcessUpdateArgsLazy( args, client.messages.Subscription.LabelsValue, orig_labels_thunk=lambda: client.Get(subscription_ref).labels) try: result = client.Patch( subscription_ref, ack_deadline=args.ack_deadline, push_config=util.ParsePushConfig(args.push_endpoint), retain_acked_messages=args.retain_acked_messages, labels=labels_update.GetOrNone(), message_retention_duration=args.message_retention_duration) except subscriptions.NoFieldsSpecifiedError: if not any( args.IsSpecified(arg) for arg in ('clear_labels', 'update_labels', 'remove_labels')): raise log.status.Print('No update to perform.') result = None else: log.UpdatedResource(subscription_ref.RelativeName(), kind='subscription') return result
def ProcessUpdates(self, gateway, args): api_config_ref = args.CONCEPTS.api_config.Parse() update_mask = [] labels_update = labels_util.ProcessUpdateArgsLazy( args, gateway.LabelsValue, lambda: gateway.labels) if labels_update.needs_update: gateway.labels = labels_update.labels update_mask.append('labels') if api_config_ref: gateway.apiConfig = api_config_ref.RelativeName() update_mask.append('apiConfig') if args.display_name: gateway.displayName = args.display_name update_mask.append('displayName') return gateway, ','.join(update_mask)
def Patch(self, tensorboard_ref, args): """Update a Tensorboard.""" if self._version == constants.ALPHA_VERSION: tensorboard = self.messages.GoogleCloudAiplatformV1alpha1Tensorboard() labels_value = self.messages.GoogleCloudAiplatformV1alpha1Tensorboard.LabelsValue else: tensorboard = self.messages.GoogleCloudAiplatformV1beta1Tensorboard() labels_value = self.messages.GoogleCloudAiplatformV1beta1Tensorboard.LabelsValue update_mask = [] def GetLabels(): return self.Get(tensorboard_ref).labels labels_update = labels_util.ProcessUpdateArgsLazy(args, labels_value, GetLabels) if labels_update.needs_update: tensorboard.labels = labels_update.labels update_mask.append('labels') if args.display_name is not None: tensorboard.displayName = args.display_name update_mask.append('display_name') if args.description is not None: tensorboard.description = args.description update_mask.append('description') if not update_mask: raise errors.NoFieldsSpecifiedError('No updates requested.') if self._version == constants.ALPHA_VERSION: req = self.messages.AiplatformProjectsLocationsTensorboardsPatchRequest( name=tensorboard_ref.RelativeName(), googleCloudAiplatformV1alpha1Tensorboard=tensorboard, updateMask=','.join(update_mask)) else: req = self.messages.AiplatformProjectsLocationsTensorboardsPatchRequest( name=tensorboard_ref.RelativeName(), googleCloudAiplatformV1beta1Tensorboard=tensorboard, updateMask=','.join(update_mask)) return self._service.Patch(req)
def Patch(args): """Update an instance config.""" client = apis.GetClientInstance('spanner', 'v1') msgs = apis.GetMessagesModule('spanner', 'v1') ref = resources.REGISTRY.Parse( args.config, params={'projectsId': properties.VALUES.core.project.GetOrFail}, collection='spanner.projects.instanceConfigs') instance_config = msgs.InstanceConfig(name=ref.RelativeName()) update_mask = [] if args.display_name is not None: instance_config.displayName = args.display_name update_mask.append('display_name') if args.etag is not None: instance_config.etag = args.etag def GetLabels(): req = msgs.SpannerProjectsInstanceConfigsGetRequest( name=ref.RelativeName()) return client.projects_instanceConfigs.Get(req).labels labels_update = labels_util.ProcessUpdateArgsLazy( args, msgs.InstanceConfig.LabelsValue, GetLabels) if labels_update.needs_update: instance_config.labels = labels_update.labels update_mask.append('labels') if not update_mask: raise errors.NoFieldsSpecifiedError('No updates requested.') req = msgs.SpannerProjectsInstanceConfigsPatchRequest( name=ref.RelativeName(), updateInstanceConfigRequest=msgs.UpdateInstanceConfigRequest( instanceConfig=instance_config, updateMask=','.join(update_mask), validateOnly=args.validate_only)) return client.projects_instanceConfigs.Patch(req)
def ParseUpdateLabels(client, get_result, args): return labels_util.ProcessUpdateArgsLazy( args, client.version_class.LabelsValue, get_result.GetAttrThunk('labels'))
def Run(self, args): dataproc = dp.Dataproc(self.ReleaseTrack()) cluster_ref = util.ParseCluster(args.name, dataproc) cluster_config = dataproc.messages.ClusterConfig() changed_fields = [] has_changes = False if args.num_workers is not None: worker_config = dataproc.messages.InstanceGroupConfig( numInstances=args.num_workers) cluster_config.workerConfig = worker_config changed_fields.append('config.worker_config.num_instances') has_changes = True if args.num_preemptible_workers is not None: worker_config = dataproc.messages.InstanceGroupConfig( numInstances=args.num_preemptible_workers) cluster_config.secondaryWorkerConfig = worker_config changed_fields.append( 'config.secondary_worker_config.num_instances') has_changes = True if self.ReleaseTrack() == base.ReleaseTrack.BETA: if args.autoscaling_policy: cluster_config.autoscalingConfig = dataproc.messages.AutoscalingConfig( policyUri=args.CONCEPTS.autoscaling_policy.Parse( ).RelativeName()) changed_fields.append('config.autoscaling_config.policy_uri') has_changes = True elif args.autoscaling_policy == '' or args.disable_autoscaling: # pylint: disable=g-explicit-bool-comparison # Disabling autoscaling. Don't need to explicitly set # cluster_config.autoscaling_config to None. changed_fields.append('config.autoscaling_config.policy_uri') has_changes = True lifecycle_config = dataproc.messages.LifecycleConfig() changed_config = False if args.max_age is not None: lifecycle_config.autoDeleteTtl = str(args.max_age) + 's' changed_fields.append( 'config.lifecycle_config.auto_delete_ttl') changed_config = True if args.expiration_time is not None: lifecycle_config.autoDeleteTime = times.FormatDateTime( args.expiration_time) changed_fields.append( 'config.lifecycle_config.auto_delete_time') changed_config = True if args.max_idle is not None: lifecycle_config.idleDeleteTtl = str(args.max_idle) + 's' changed_fields.append( 'config.lifecycle_config.idle_delete_ttl') changed_config = True if args.no_max_age: lifecycle_config.autoDeleteTtl = None changed_fields.append( 'config.lifecycle_config.auto_delete_ttl') changed_config = True if args.no_max_idle: lifecycle_config.idleDeleteTtl = None changed_fields.append( 'config.lifecycle_config.idle_delete_ttl') changed_config = True if changed_config: cluster_config.lifecycleConfig = lifecycle_config has_changes = True # Put in a thunk so we only make this call if needed def _GetCurrentLabels(): # We need to fetch cluster first so we know what the labels look like. The # labels_util will fill out the proto for us with all the updates and # removals, but first we need to provide the current state of the labels get_cluster_request = ( dataproc.messages.DataprocProjectsRegionsClustersGetRequest( projectId=cluster_ref.projectId, region=cluster_ref.region, clusterName=cluster_ref.clusterName)) current_cluster = dataproc.client.projects_regions_clusters.Get( get_cluster_request) return current_cluster.labels labels_update = labels_util.ProcessUpdateArgsLazy( args, dataproc.messages.Cluster.LabelsValue, orig_labels_thunk=_GetCurrentLabels) if labels_update.needs_update: has_changes = True changed_fields.append('labels') labels = labels_update.GetOrNone() if not has_changes: raise exceptions.ArgumentError( 'Must specify at least one cluster parameter to update.') cluster = dataproc.messages.Cluster( config=cluster_config, clusterName=cluster_ref.clusterName, labels=labels, projectId=cluster_ref.projectId) request = dataproc.messages.DataprocProjectsRegionsClustersPatchRequest( clusterName=cluster_ref.clusterName, region=cluster_ref.region, projectId=cluster_ref.projectId, cluster=cluster, updateMask=','.join(changed_fields), requestId=util.GetUniqueId()) if args.graceful_decommission_timeout is not None: request.gracefulDecommissionTimeout = ( str(args.graceful_decommission_timeout) + 's') operation = dataproc.client.projects_regions_clusters.Patch(request) if args. async: log.status.write('Updating [{0}] with operation [{1}].'.format( cluster_ref, operation.name)) return util.WaitForOperation(dataproc, operation, message='Waiting for cluster update operation', timeout_s=args.timeout) request = dataproc.messages.DataprocProjectsRegionsClustersGetRequest( projectId=cluster_ref.projectId, region=cluster_ref.region, clusterName=cluster_ref.clusterName) cluster = dataproc.client.projects_regions_clusters.Get(request) log.UpdatedResource(cluster_ref) return cluster
def _Update(zones_client, args, private_visibility_config=None, forwarding_config=None, peering_config=None, reverse_lookup_config=None, cloud_logging_config=None, api_version='v1', cleared_fields=None): """Helper function to perform the update. Args: zones_client: the managed zones API client. args: the args provided by the user on the command line. private_visibility_config: zone visibility config. forwarding_config: zone forwarding config. peering_config: zone peering config. reverse_lookup_config: zone reverse lookup config. cloud_logging_config: Stackdriver logging config. api_version: the API version of this request. cleared_fields: the fields that should be included in the request JSON as their default value (fields that are their default value will be omitted otherwise). Returns: The update labels and PATCH call response. """ registry = util.GetRegistry(api_version) zone_ref = registry.Parse(args.zone, util.GetParamsForRegistry(api_version, args), collection='dns.managedZones') dnssec_config = command_util.ParseDnssecConfigArgs(args, zones_client.messages, api_version) labels_update = labels_util.ProcessUpdateArgsLazy( args, zones_client.messages.ManagedZone.LabelsValue, lambda: zones_client.Get(zone_ref).labels) update_results = [] if labels_update.GetOrNone(): update_results.append( zones_client.UpdateLabels(zone_ref, labels_update.GetOrNone())) kwargs = {} if private_visibility_config: kwargs['private_visibility_config'] = private_visibility_config if forwarding_config: kwargs['forwarding_config'] = forwarding_config if peering_config: kwargs['peering_config'] = peering_config if reverse_lookup_config: kwargs['reverse_lookup_config'] = reverse_lookup_config if cloud_logging_config: kwargs['cloud_logging_config'] = cloud_logging_config if dnssec_config or args.description or kwargs: update_results.append( zones_client.Patch(zone_ref, args.async_, dnssec_config=dnssec_config, description=args.description, labels=None, cleared_fields=cleared_fields, **kwargs)) return update_results
def Run(self, args): dataproc = dp.Dataproc(self.ReleaseTrack()) cluster_ref = args.CONCEPTS.cluster.Parse() cluster_config = dataproc.messages.ClusterConfig() changed_fields = [] has_changes = False if args.num_workers is not None: worker_config = dataproc.messages.InstanceGroupConfig( numInstances=args.num_workers) cluster_config.workerConfig = worker_config changed_fields.append('config.worker_config.num_instances') has_changes = True num_secondary_workers = _FirstNonNone(args.num_preemptible_workers, args.num_secondary_workers) if num_secondary_workers is not None: worker_config = dataproc.messages.InstanceGroupConfig( numInstances=num_secondary_workers) cluster_config.secondaryWorkerConfig = worker_config changed_fields.append( 'config.secondary_worker_config.num_instances') has_changes = True if args.autoscaling_policy: cluster_config.autoscalingConfig = dataproc.messages.AutoscalingConfig( policyUri=args.CONCEPTS.autoscaling_policy.Parse( ).RelativeName()) changed_fields.append('config.autoscaling_config.policy_uri') has_changes = True elif args.autoscaling_policy == '' or args.disable_autoscaling: # pylint: disable=g-explicit-bool-comparison # Disabling autoscaling. Don't need to explicitly set # cluster_config.autoscaling_config to None. changed_fields.append('config.autoscaling_config.policy_uri') has_changes = True lifecycle_config = dataproc.messages.LifecycleConfig() changed_config = False if args.max_age is not None: lifecycle_config.autoDeleteTtl = six.text_type(args.max_age) + 's' changed_fields.append('config.lifecycle_config.auto_delete_ttl') changed_config = True if args.expiration_time is not None: lifecycle_config.autoDeleteTime = times.FormatDateTime( args.expiration_time) changed_fields.append('config.lifecycle_config.auto_delete_time') changed_config = True if args.max_idle is not None: lifecycle_config.idleDeleteTtl = six.text_type(args.max_idle) + 's' changed_fields.append('config.lifecycle_config.idle_delete_ttl') changed_config = True if args.no_max_age: lifecycle_config.autoDeleteTtl = None changed_fields.append('config.lifecycle_config.auto_delete_ttl') changed_config = True if args.no_max_idle: lifecycle_config.idleDeleteTtl = None changed_fields.append('config.lifecycle_config.idle_delete_ttl') changed_config = True if changed_config: cluster_config.lifecycleConfig = lifecycle_config has_changes = True def _GetCurrentCluster(): # This is used for labels and auxiliary_node_pool_configs get_cluster_request = ( dataproc.messages.DataprocProjectsRegionsClustersGetRequest( projectId=cluster_ref.projectId, region=cluster_ref.region, clusterName=cluster_ref.clusterName)) current_cluster = dataproc.client.projects_regions_clusters.Get( get_cluster_request) return current_cluster # Put in a thunk so we only make this call if needed def _GetCurrentLabels(): # We need to fetch cluster first so we know what the labels look like. The # labels_util will fill out the proto for us with all the updates and # removals, but first we need to provide the current state of the labels current_cluster = _GetCurrentCluster() return current_cluster.labels labels_update = labels_util.ProcessUpdateArgsLazy( args, dataproc.messages.Cluster.LabelsValue, orig_labels_thunk=_GetCurrentLabels) if labels_update.needs_update: has_changes = True changed_fields.append('labels') labels = labels_update.GetOrNone() if args.driver_pool_size is not None: # Getting the node_pool_ids from the current node_pools and other attrs # that are not shared with the user # Driver pools can only be updated currently with NO other updates # We are relying on our frontend validation to prevent this until # the change is made to allow driver pools to be updated with other fields auxiliary_node_pools = _GetCurrentCluster( ).config.auxiliaryNodePoolConfigs # get the index of the current cluster's driver pool in the auxiliary # node pools list, index_driver_pools is also a list that should have a # length of 1 index_driver_pools = [ i for i, n in enumerate(auxiliary_node_pools) if dataproc.messages.NodePoolConfig. RolesValueListEntryValuesEnum.DRIVER in n.roles ] if len(index_driver_pools) > 1: raise exceptions.ArgumentError( 'At most one driver pool can be specified per cluster.') elif len(index_driver_pools) == 1: index = index_driver_pools[0] auxiliary_node_pools[ index].nodePoolConfig.numInstances = args.driver_pool_size else: # This case is only relevant for scaling from 0 -> N nodes # this will not be supported initially, but will be relying on our # front end validation to prevent or allow worker_config = dataproc.messages.InstanceGroupConfig( numInstances=args.driver_pool_size) node_config = dataproc.messages.NodePoolConfig( nodePoolConfig=worker_config, roles=[ dataproc.messages.NodePoolConfig. RolesValueListEntryValuesEnum.DRIVER ]) auxiliary_node_pools.append(node_config) cluster_config.auxiliaryNodePoolConfigs = auxiliary_node_pools changed_fields.append('config.auxiliary_node_pool_configs') has_changes = True if not has_changes: raise exceptions.ArgumentError( 'Must specify at least one cluster parameter to update.') cluster = dataproc.messages.Cluster( config=cluster_config, clusterName=cluster_ref.clusterName, labels=labels, projectId=cluster_ref.projectId) request = dataproc.messages.DataprocProjectsRegionsClustersPatchRequest( clusterName=cluster_ref.clusterName, region=cluster_ref.region, projectId=cluster_ref.projectId, cluster=cluster, updateMask=','.join(changed_fields), requestId=util.GetUniqueId()) if args.graceful_decommission_timeout is not None: request.gracefulDecommissionTimeout = ( six.text_type(args.graceful_decommission_timeout) + 's') operation = dataproc.client.projects_regions_clusters.Patch(request) if args.async_: log.status.write('Updating [{0}] with operation [{1}].'.format( cluster_ref, operation.name)) return util.WaitForOperation(dataproc, operation, message='Waiting for cluster update operation', timeout_s=args.timeout) request = dataproc.messages.DataprocProjectsRegionsClustersGetRequest( projectId=cluster_ref.projectId, region=cluster_ref.region, clusterName=cluster_ref.clusterName) cluster = dataproc.client.projects_regions_clusters.Get(request) log.UpdatedResource(cluster_ref) return cluster