def _Run(args, version):
    """Update an existing Vertex AI endpoint."""
    validation.ValidateDisplayName(args.display_name)

    endpoint_ref = args.CONCEPTS.endpoint.Parse()
    args.region = endpoint_ref.AsDict()['locationsId']
    with endpoint_util.AiplatformEndpointOverrides(version,
                                                   region=args.region):
        endpoints_client = client.EndpointsClient(version=version)

        def GetLabels():
            return endpoints_client.Get(endpoint_ref).labels

        try:
            if version == constants.GA_VERSION:
                op = endpoints_client.Patch(
                    endpoint_ref,
                    labels_util.ProcessUpdateArgsLazy(
                        args, endpoints_client.messages.
                        GoogleCloudAiplatformV1Endpoint.LabelsValue,
                        GetLabels),
                    display_name=args.display_name,
                    description=args.description,
                    traffic_split=args.traffic_split,
                    clear_traffic_split=args.clear_traffic_split)
            else:
                op = endpoints_client.PatchBeta(
                    endpoint_ref,
                    labels_util.ProcessUpdateArgsLazy(
                        args, endpoints_client.messages.
                        GoogleCloudAiplatformV1beta1Endpoint.LabelsValue,
                        GetLabels),
                    display_name=args.display_name,
                    description=args.description,
                    traffic_split=args.traffic_split,
                    clear_traffic_split=args.clear_traffic_split)
        except errors.NoFieldsSpecifiedError:
            available_update_args = [
                'display_name', 'traffic_split', 'clear_traffic_split',
                'update_labels', 'clear_labels', 'remove_labels', 'description'
            ]
            if not any(args.IsSpecified(arg) for arg in available_update_args):
                raise
            log.status.Print('No update to perform.')
            return None
        else:
            log.UpdatedResource(op.name, kind='Vertex AI endpoint')
            return op
示例#2
0
    def PatchAlpha(self, tensorboard_run_ref, args):
        """Update a Tensorboard run."""
        tensorboard_run = self.messages.GoogleCloudAiplatformV1alpha1TensorboardRun(
        )
        update_mask = []

        def GetLabels():
            return self.Get(tensorboard_run_ref).labels

        labels_update = labels_util.ProcessUpdateArgsLazy(
            args, self.messages.GoogleCloudAiplatformV1alpha1TensorboardRun.
            LabelsValue, GetLabels)
        if labels_update.needs_update:
            tensorboard_run.labels = labels_update.labels
            update_mask.append('labels')

        if args.display_name is not None:
            tensorboard_run.displayName = args.display_name
            update_mask.append('display_name')

        if args.description is not None:
            tensorboard_run.description = args.description
            update_mask.append('description')

        if not update_mask:
            raise errors.NoFieldsSpecifiedError('No updates requested.')

        request = self.messages.AiplatformProjectsLocationsTensorboardsExperimentsRunsPatchRequest(
            name=tensorboard_run_ref.RelativeName(),
            googleCloudAiplatformV1alpha1TensorboardRun=tensorboard_run,
            updateMask=','.join(update_mask))
        return self._service.Patch(request)
def ParseUpdateLabels(client, job_ref, args):
    def GetLabels():
        return client.Get(job_ref).labels

    return labels_util.ProcessUpdateArgsLazy(args,
                                             client.job_class.LabelsValue,
                                             GetLabels)
示例#4
0
  def Run(self, args):
    dns = apis.GetClientInstance('dns', 'v1beta2')
    messages = apis.GetMessagesModule('dns', 'v1beta2')

    zone_ref = util.GetRegistry('v1beta2').Parse(
        args.dns_zone,
        params={
            'project': properties.VALUES.core.project.GetOrFail,
        },
        collection='dns.managedZones')

    dnssec_config = command_util.ParseDnssecConfigArgs(args, messages)
    zone_args = {'name': args.dns_zone}
    if dnssec_config is not None:
      zone_args['dnssecConfig'] = dnssec_config
    if args.description is not None:
      zone_args['description'] = args.description
    zone = messages.ManagedZone(**zone_args)

    def Get():
      return dns.managedZones.Get(
          dns.MESSAGES_MODULE.DnsManagedZonesGetRequest(
              project=zone_ref.project,
              managedZone=zone_ref.managedZone)).labels
    labels_update = labels_util.ProcessUpdateArgsLazy(
        args, messages.ManagedZone.LabelsValue, Get)
    zone.labels = labels_update.GetOrNone()

    result = dns.managedZones.Patch(
        messages.DnsManagedZonesPatchRequest(managedZoneResource=zone,
                                             project=zone_ref.project,
                                             managedZone=args.dns_zone))
    return result
示例#5
0
    def Run(self, args):
        """This is what gets called when the user runs this command.

    Args:
      args: an argparse namespace. All the arguments that were provided to this
        command invocation.

    Returns:
      A serialized object (dict) describing the results of the operation.

    Raises:
      An HttpException if there was a problem calling the
      API topics.Patch command.
    """
        client = topics.TopicsClient()
        topic_ref = args.CONCEPTS.topic.Parse()

        labels_update = labels_util.ProcessUpdateArgsLazy(
            args,
            client.messages.Topic.LabelsValue,
            orig_labels_thunk=lambda: client.Get(topic_ref).labels)
        try:
            result = client.Patch(topic_ref, labels=labels_update.GetOrNone())
        except topics.NoFieldsSpecifiedError:
            if not any(
                    args.IsSpecified(arg)
                    for arg in ('clear_labels', 'update_labels',
                                'remove_labels')):
                raise
            log.status.Print('No update to perform.')
            result = None
        else:
            log.UpdatedResource(topic_ref.RelativeName(), kind='topic')
        return result
示例#6
0
def _Update(zones_client,
            args,
            private_visibility_config=None,
            forwarding_config=None,
            peering_config=None):
    """Helper function to perform the update."""
    zone_ref = args.CONCEPTS.zone.Parse()

    dnssec_config = command_util.ParseDnssecConfigArgs(args,
                                                       zones_client.messages)
    labels_update = labels_util.ProcessUpdateArgsLazy(
        args, zones_client.messages.ManagedZone.LabelsValue,
        lambda: zones_client.Get(zone_ref).labels)

    kwargs = {}
    if private_visibility_config:
        kwargs['private_visibility_config'] = private_visibility_config
    if forwarding_config:
        kwargs['forwarding_config'] = forwarding_config
    if peering_config:
        kwargs['peering_config'] = peering_config
    return zones_client.Patch(zone_ref,
                              dnssec_config=dnssec_config,
                              description=args.description,
                              labels=labels_update.GetOrNone(),
                              **kwargs)
def ParseUpdateLabels(models_client, args):
    def GetLabels():
        return models_client.Get(args.model).labels

    return labels_util.ProcessUpdateArgsLazy(
        args, models_client.messages.GoogleCloudMlV1Model.LabelsValue,
        GetLabels)
    def PatchBeta(self, index_endpoint_ref, args):
        """Update an index endpoint."""
        index_endpoint = self.messages.GoogleCloudAiplatformV1beta1IndexEndpoint(
        )
        update_mask = []

        if args.display_name is not None:
            index_endpoint.displayName = args.display_name
            update_mask.append('display_name')

        if args.description is not None:
            index_endpoint.description = args.description
            update_mask.append('description')

        def GetLabels():
            return self.Get(index_endpoint_ref).labels

        labels_update = labels_util.ProcessUpdateArgsLazy(
            args, self.messages.GoogleCloudAiplatformV1beta1IndexEndpoint.
            LabelsValue, GetLabels)
        if labels_update.needs_update:
            index_endpoint.labels = labels_update.labels
            update_mask.append('labels')

        if not update_mask:
            raise errors.NoFieldsSpecifiedError('No updates requested.')

        request = self.messages.AiplatformProjectsLocationsIndexEndpointsPatchRequest(
            name=index_endpoint_ref.RelativeName(),
            googleCloudAiplatformV1beta1IndexEndpoint=index_endpoint,
            updateMask=','.join(update_mask))
        return self._service.Patch(request)
示例#9
0
  def Run(self, args):
    """This is what gets called when the user runs this command.

    Args:
      args: an argparse namespace. All the arguments that were provided to this
        command invocation.

    Returns:
      A serialized object (dict) describing the results of the operation. This
      description fits the Resource described in the ResourceRegistry under
      'pubsub.projects.subscriptions'.

    Raises:
      An HttpException if there was a problem calling the
      API subscriptions.Patch command.
    """
    client = subscriptions.SubscriptionsClient()
    subscription_ref = args.CONCEPTS.subscription.Parse()
    dead_letter_topic = getattr(args, 'dead_letter_topic', None)
    max_delivery_attempts = getattr(args, 'max_delivery_attempts', None)
    clear_dead_letter_policy = getattr(args, 'clear_dead_letter_policy', None)

    labels_update = labels_util.ProcessUpdateArgsLazy(
        args, client.messages.Subscription.LabelsValue,
        orig_labels_thunk=lambda: client.Get(subscription_ref).labels)

    no_expiration = False
    expiration_period = getattr(args, 'expiration_period', None)
    if expiration_period:
      if expiration_period == subscriptions.NEVER_EXPIRATION_PERIOD_VALUE:
        no_expiration = True
        expiration_period = None

    try:
      result = client.Patch(
          subscription_ref,
          ack_deadline=args.ack_deadline,
          push_config=util.ParsePushConfig(args),
          retain_acked_messages=args.retain_acked_messages,
          labels=labels_update.GetOrNone(),
          message_retention_duration=args.message_retention_duration,
          no_expiration=no_expiration,
          expiration_period=expiration_period,
          dead_letter_topic=dead_letter_topic,
          max_delivery_attempts=max_delivery_attempts,
          clear_dead_letter_policy=clear_dead_letter_policy)
    except subscriptions.NoFieldsSpecifiedError:
      if not any(args.IsSpecified(arg) for arg in ('clear_labels',
                                                   'update_labels',
                                                   'remove_labels')):
        raise
      log.status.Print('No update to perform.')
      result = None
    else:
      log.UpdatedResource(subscription_ref.RelativeName(), kind='subscription')
    return result
示例#10
0
def _Update(zones_client, args):
    zone_ref = args.CONCEPTS.zone.Parse()

    dnssec_config = command_util.ParseDnssecConfigArgs(args,
                                                       zones_client.messages)
    labels_update = labels_util.ProcessUpdateArgsLazy(
        args, zones_client.messages.ManagedZone.LabelsValue,
        lambda: zones_client.Get(zone_ref).labels)

    return zones_client.Patch(zone_ref,
                              dnssec_config=dnssec_config,
                              description=args.description,
                              labels=labels_update.GetOrNone())
示例#11
0
    def Run(self, args):
        """This is what gets called when the user runs this command.

    Args:
      args: an argparse namespace. All the arguments that were provided to this
        command invocation.

    Returns:
      A serialized object (dict) describing the results of the operation.

    Raises:
      An HttpException if there was a problem calling the
      API topics.Patch command.
    """
        client = topics.TopicsClient()
        topic_ref = args.CONCEPTS.topic.Parse()

        message_retention_duration = getattr(args,
                                             'message_retention_duration',
                                             None)
        if message_retention_duration:
            message_retention_duration = util.FormatDuration(
                message_retention_duration)
        clear_message_retention_duration = getattr(
            args, 'clear_message_retention_duration', None)

        labels_update = labels_util.ProcessUpdateArgsLazy(
            args,
            client.messages.Topic.LabelsValue,
            orig_labels_thunk=lambda: client.Get(topic_ref).labels)

        result = None
        try:
            result = client.Patch(topic_ref, labels_update.GetOrNone(),
                                  _GetKmsKeyNameFromArgs(args),
                                  message_retention_duration,
                                  clear_message_retention_duration,
                                  args.recompute_message_storage_policy,
                                  args.message_storage_policy_allowed_regions)
        except topics.NoFieldsSpecifiedError:
            operations = [
                'clear_labels', 'update_labels', 'remove_labels',
                'recompute_message_storage_policy',
                'message_storage_policy_allowed_regions'
            ]
            if not any(args.IsSpecified(arg) for arg in operations):
                raise
            log.status.Print('No update to perform.')
        else:
            log.UpdatedResource(topic_ref.RelativeName(), kind='topic')
        return result
示例#12
0
    def ProcessUpdates(self, api_cofig, args):
        update_mask = []

        labels_update = labels_util.ProcessUpdateArgsLazy(
            args, api_cofig.LabelsValue, lambda: api_cofig.labels)
        if labels_update.needs_update:
            api_cofig.labels = labels_update.labels
            update_mask.append('labels')

        if args.display_name:
            api_cofig.displayName = args.display_name
            update_mask.append('displayName')

        return api_cofig, ','.join(update_mask)
  def testProcessUpdateArgsLazy(self, args_string, original_labels,
                                expected_labels, needs_update):
    parser = argparse.ArgumentParser()
    labels_util.AddUpdateLabelsFlags(parser)
    args = parser.parse_args(args_string.split())
    def _GetLabels():
      if original_labels is None:
        self.fail('Should not call the orig_labels_thunk.')
      return self._MakeLabels(original_labels)

    result = labels_util.ProcessUpdateArgsLazy(args, self.labels_cls,
                                               _GetLabels)
    expected = self._MakeLabels(expected_labels)
    self.assertEqual(result._labels, expected)
    self.assertEqual(result.needs_update, needs_update)
示例#14
0
    def PatchBeta(self, endpoint_ref, args):
        """Update a endpoint."""
        endpoint = self.messages.GoogleCloudAiplatformV1beta1Endpoint()
        update_mask = []

        def GetLabels():
            return self.Get(endpoint_ref).labels

        labels_update = labels_util.ProcessUpdateArgsLazy(
            args,
            self.messages.GoogleCloudAiplatformV1beta1Endpoint.LabelsValue,
            GetLabels)
        if labels_update.needs_update:
            endpoint.labels = labels_update.labels
            update_mask.append('labels')

        if args.display_name is not None:
            endpoint.displayName = args.display_name
            update_mask.append('display_name')

        if args.traffic_split is not None:
            additional_properties = []
            for key, value in sorted(args.traffic_split.items()):
                additional_properties.append(
                    endpoint.TrafficSplitValue().AdditionalProperty(
                        key=key, value=value))
            endpoint.trafficSplit = endpoint.TrafficSplitValue(
                additionalProperties=additional_properties)
            update_mask.append('traffic_split')

        if args.clear_traffic_split:
            endpoint.trafficSplit = None
            update_mask.append('traffic_split')

        if args.description is not None:
            endpoint.description = args.description
            update_mask.append('description')

        if not update_mask:
            raise errors.NoFieldsSpecifiedError('No updates requested.')

        req = self.messages.AiplatformProjectsLocationsEndpointsPatchRequest(
            name=endpoint_ref.RelativeName(),
            googleCloudAiplatformV1beta1Endpoint=endpoint,
            updateMask=','.join(update_mask))
        return self.client.projects_locations_endpoints.Patch(req)
示例#15
0
    def Run(self, args):
        """This is what gets called when the user runs this command.

    Args:
      args: an argparse namespace. All the arguments that were provided to this
        command invocation.

    Returns:
      A serialized object (dict) describing the results of the operation.
      This description fits the Resource described in the ResourceRegistry under
      'pubsub.projects.subscriptions'.

    Raises:
      An HttpException if there was a problem calling the
      API subscriptions.Patch command.
    """
        client = subscriptions.SubscriptionsClient()
        subscription_ref = args.CONCEPTS.subscription.Parse()

        labels_update = labels_util.ProcessUpdateArgsLazy(
            args,
            client.messages.Subscription.LabelsValue,
            orig_labels_thunk=lambda: client.Get(subscription_ref).labels)
        try:
            result = client.Patch(
                subscription_ref,
                ack_deadline=args.ack_deadline,
                push_config=util.ParsePushConfig(args.push_endpoint),
                retain_acked_messages=args.retain_acked_messages,
                labels=labels_update.GetOrNone(),
                message_retention_duration=args.message_retention_duration)
        except subscriptions.NoFieldsSpecifiedError:
            if not any(
                    args.IsSpecified(arg)
                    for arg in ('clear_labels', 'update_labels',
                                'remove_labels')):
                raise
            log.status.Print('No update to perform.')
            result = None
        else:
            log.UpdatedResource(subscription_ref.RelativeName(),
                                kind='subscription')
        return result
示例#16
0
    def ProcessUpdates(self, gateway, args):
        api_config_ref = args.CONCEPTS.api_config.Parse()
        update_mask = []

        labels_update = labels_util.ProcessUpdateArgsLazy(
            args, gateway.LabelsValue, lambda: gateway.labels)
        if labels_update.needs_update:
            gateway.labels = labels_update.labels
            update_mask.append('labels')

        if api_config_ref:
            gateway.apiConfig = api_config_ref.RelativeName()
            update_mask.append('apiConfig')

        if args.display_name:
            gateway.displayName = args.display_name
            update_mask.append('displayName')

        return gateway, ','.join(update_mask)
示例#17
0
  def Patch(self, tensorboard_ref, args):
    """Update a Tensorboard."""
    if self._version == constants.ALPHA_VERSION:
      tensorboard = self.messages.GoogleCloudAiplatformV1alpha1Tensorboard()
      labels_value = self.messages.GoogleCloudAiplatformV1alpha1Tensorboard.LabelsValue
    else:
      tensorboard = self.messages.GoogleCloudAiplatformV1beta1Tensorboard()
      labels_value = self.messages.GoogleCloudAiplatformV1beta1Tensorboard.LabelsValue

    update_mask = []

    def GetLabels():
      return self.Get(tensorboard_ref).labels

    labels_update = labels_util.ProcessUpdateArgsLazy(args, labels_value,
                                                      GetLabels)

    if labels_update.needs_update:
      tensorboard.labels = labels_update.labels
      update_mask.append('labels')

    if args.display_name is not None:
      tensorboard.displayName = args.display_name
      update_mask.append('display_name')

    if args.description is not None:
      tensorboard.description = args.description
      update_mask.append('description')

    if not update_mask:
      raise errors.NoFieldsSpecifiedError('No updates requested.')

    if self._version == constants.ALPHA_VERSION:
      req = self.messages.AiplatformProjectsLocationsTensorboardsPatchRequest(
          name=tensorboard_ref.RelativeName(),
          googleCloudAiplatformV1alpha1Tensorboard=tensorboard,
          updateMask=','.join(update_mask))
    else:
      req = self.messages.AiplatformProjectsLocationsTensorboardsPatchRequest(
          name=tensorboard_ref.RelativeName(),
          googleCloudAiplatformV1beta1Tensorboard=tensorboard,
          updateMask=','.join(update_mask))
    return self._service.Patch(req)
示例#18
0
def Patch(args):
    """Update an instance config."""
    client = apis.GetClientInstance('spanner', 'v1')
    msgs = apis.GetMessagesModule('spanner', 'v1')
    ref = resources.REGISTRY.Parse(
        args.config,
        params={'projectsId': properties.VALUES.core.project.GetOrFail},
        collection='spanner.projects.instanceConfigs')
    instance_config = msgs.InstanceConfig(name=ref.RelativeName())

    update_mask = []

    if args.display_name is not None:
        instance_config.displayName = args.display_name
        update_mask.append('display_name')

    if args.etag is not None:
        instance_config.etag = args.etag

    def GetLabels():
        req = msgs.SpannerProjectsInstanceConfigsGetRequest(
            name=ref.RelativeName())
        return client.projects_instanceConfigs.Get(req).labels

    labels_update = labels_util.ProcessUpdateArgsLazy(
        args, msgs.InstanceConfig.LabelsValue, GetLabels)
    if labels_update.needs_update:
        instance_config.labels = labels_update.labels
        update_mask.append('labels')

    if not update_mask:
        raise errors.NoFieldsSpecifiedError('No updates requested.')

    req = msgs.SpannerProjectsInstanceConfigsPatchRequest(
        name=ref.RelativeName(),
        updateInstanceConfigRequest=msgs.UpdateInstanceConfigRequest(
            instanceConfig=instance_config,
            updateMask=','.join(update_mask),
            validateOnly=args.validate_only))
    return client.projects_instanceConfigs.Patch(req)
示例#19
0
def ParseUpdateLabels(client, get_result, args):
  return labels_util.ProcessUpdateArgsLazy(
      args, client.version_class.LabelsValue, get_result.GetAttrThunk('labels'))
示例#20
0
    def Run(self, args):
        dataproc = dp.Dataproc(self.ReleaseTrack())

        cluster_ref = util.ParseCluster(args.name, dataproc)

        cluster_config = dataproc.messages.ClusterConfig()
        changed_fields = []

        has_changes = False

        if args.num_workers is not None:
            worker_config = dataproc.messages.InstanceGroupConfig(
                numInstances=args.num_workers)
            cluster_config.workerConfig = worker_config
            changed_fields.append('config.worker_config.num_instances')
            has_changes = True

        if args.num_preemptible_workers is not None:
            worker_config = dataproc.messages.InstanceGroupConfig(
                numInstances=args.num_preemptible_workers)
            cluster_config.secondaryWorkerConfig = worker_config
            changed_fields.append(
                'config.secondary_worker_config.num_instances')
            has_changes = True

        if self.ReleaseTrack() == base.ReleaseTrack.BETA:
            if args.autoscaling_policy:
                cluster_config.autoscalingConfig = dataproc.messages.AutoscalingConfig(
                    policyUri=args.CONCEPTS.autoscaling_policy.Parse(
                    ).RelativeName())
                changed_fields.append('config.autoscaling_config.policy_uri')
                has_changes = True
            elif args.autoscaling_policy == '' or args.disable_autoscaling:  # pylint: disable=g-explicit-bool-comparison
                # Disabling autoscaling. Don't need to explicitly set
                # cluster_config.autoscaling_config to None.
                changed_fields.append('config.autoscaling_config.policy_uri')
                has_changes = True

            lifecycle_config = dataproc.messages.LifecycleConfig()
            changed_config = False
            if args.max_age is not None:
                lifecycle_config.autoDeleteTtl = str(args.max_age) + 's'
                changed_fields.append(
                    'config.lifecycle_config.auto_delete_ttl')
                changed_config = True
            if args.expiration_time is not None:
                lifecycle_config.autoDeleteTime = times.FormatDateTime(
                    args.expiration_time)
                changed_fields.append(
                    'config.lifecycle_config.auto_delete_time')
                changed_config = True
            if args.max_idle is not None:
                lifecycle_config.idleDeleteTtl = str(args.max_idle) + 's'
                changed_fields.append(
                    'config.lifecycle_config.idle_delete_ttl')
                changed_config = True
            if args.no_max_age:
                lifecycle_config.autoDeleteTtl = None
                changed_fields.append(
                    'config.lifecycle_config.auto_delete_ttl')
                changed_config = True
            if args.no_max_idle:
                lifecycle_config.idleDeleteTtl = None
                changed_fields.append(
                    'config.lifecycle_config.idle_delete_ttl')
                changed_config = True
            if changed_config:
                cluster_config.lifecycleConfig = lifecycle_config
                has_changes = True

        # Put in a thunk so we only make this call if needed
        def _GetCurrentLabels():
            # We need to fetch cluster first so we know what the labels look like. The
            # labels_util will fill out the proto for us with all the updates and
            # removals, but first we need to provide the current state of the labels
            get_cluster_request = (
                dataproc.messages.DataprocProjectsRegionsClustersGetRequest(
                    projectId=cluster_ref.projectId,
                    region=cluster_ref.region,
                    clusterName=cluster_ref.clusterName))
            current_cluster = dataproc.client.projects_regions_clusters.Get(
                get_cluster_request)
            return current_cluster.labels

        labels_update = labels_util.ProcessUpdateArgsLazy(
            args,
            dataproc.messages.Cluster.LabelsValue,
            orig_labels_thunk=_GetCurrentLabels)
        if labels_update.needs_update:
            has_changes = True
            changed_fields.append('labels')
        labels = labels_update.GetOrNone()

        if not has_changes:
            raise exceptions.ArgumentError(
                'Must specify at least one cluster parameter to update.')

        cluster = dataproc.messages.Cluster(
            config=cluster_config,
            clusterName=cluster_ref.clusterName,
            labels=labels,
            projectId=cluster_ref.projectId)

        request = dataproc.messages.DataprocProjectsRegionsClustersPatchRequest(
            clusterName=cluster_ref.clusterName,
            region=cluster_ref.region,
            projectId=cluster_ref.projectId,
            cluster=cluster,
            updateMask=','.join(changed_fields),
            requestId=util.GetUniqueId())

        if args.graceful_decommission_timeout is not None:
            request.gracefulDecommissionTimeout = (
                str(args.graceful_decommission_timeout) + 's')

        operation = dataproc.client.projects_regions_clusters.Patch(request)

        if args. async:
            log.status.write('Updating [{0}] with operation [{1}].'.format(
                cluster_ref, operation.name))
            return

        util.WaitForOperation(dataproc,
                              operation,
                              message='Waiting for cluster update operation',
                              timeout_s=args.timeout)

        request = dataproc.messages.DataprocProjectsRegionsClustersGetRequest(
            projectId=cluster_ref.projectId,
            region=cluster_ref.region,
            clusterName=cluster_ref.clusterName)
        cluster = dataproc.client.projects_regions_clusters.Get(request)
        log.UpdatedResource(cluster_ref)
        return cluster
def _Update(zones_client,
            args,
            private_visibility_config=None,
            forwarding_config=None,
            peering_config=None,
            reverse_lookup_config=None,
            cloud_logging_config=None,
            api_version='v1',
            cleared_fields=None):
    """Helper function to perform the update.

  Args:
    zones_client: the managed zones API client.
    args: the args provided by the user on the command line.
    private_visibility_config: zone visibility config.
    forwarding_config: zone forwarding config.
    peering_config: zone peering config.
    reverse_lookup_config: zone reverse lookup config.
    cloud_logging_config: Stackdriver logging config.
    api_version: the API version of this request.
    cleared_fields: the fields that should be included in the request JSON as
      their default value (fields that are their default value will be omitted
      otherwise).

  Returns:
    The update labels and PATCH call response.
  """
    registry = util.GetRegistry(api_version)

    zone_ref = registry.Parse(args.zone,
                              util.GetParamsForRegistry(api_version, args),
                              collection='dns.managedZones')

    dnssec_config = command_util.ParseDnssecConfigArgs(args,
                                                       zones_client.messages,
                                                       api_version)
    labels_update = labels_util.ProcessUpdateArgsLazy(
        args, zones_client.messages.ManagedZone.LabelsValue,
        lambda: zones_client.Get(zone_ref).labels)

    update_results = []

    if labels_update.GetOrNone():
        update_results.append(
            zones_client.UpdateLabels(zone_ref, labels_update.GetOrNone()))

    kwargs = {}
    if private_visibility_config:
        kwargs['private_visibility_config'] = private_visibility_config
    if forwarding_config:
        kwargs['forwarding_config'] = forwarding_config
    if peering_config:
        kwargs['peering_config'] = peering_config
    if reverse_lookup_config:
        kwargs['reverse_lookup_config'] = reverse_lookup_config
    if cloud_logging_config:
        kwargs['cloud_logging_config'] = cloud_logging_config

    if dnssec_config or args.description or kwargs:
        update_results.append(
            zones_client.Patch(zone_ref,
                               args.async_,
                               dnssec_config=dnssec_config,
                               description=args.description,
                               labels=None,
                               cleared_fields=cleared_fields,
                               **kwargs))

    return update_results
示例#22
0
    def Run(self, args):
        dataproc = dp.Dataproc(self.ReleaseTrack())

        cluster_ref = args.CONCEPTS.cluster.Parse()

        cluster_config = dataproc.messages.ClusterConfig()
        changed_fields = []

        has_changes = False

        if args.num_workers is not None:
            worker_config = dataproc.messages.InstanceGroupConfig(
                numInstances=args.num_workers)
            cluster_config.workerConfig = worker_config
            changed_fields.append('config.worker_config.num_instances')
            has_changes = True

        num_secondary_workers = _FirstNonNone(args.num_preemptible_workers,
                                              args.num_secondary_workers)
        if num_secondary_workers is not None:
            worker_config = dataproc.messages.InstanceGroupConfig(
                numInstances=num_secondary_workers)
            cluster_config.secondaryWorkerConfig = worker_config
            changed_fields.append(
                'config.secondary_worker_config.num_instances')
            has_changes = True

        if args.autoscaling_policy:
            cluster_config.autoscalingConfig = dataproc.messages.AutoscalingConfig(
                policyUri=args.CONCEPTS.autoscaling_policy.Parse(
                ).RelativeName())
            changed_fields.append('config.autoscaling_config.policy_uri')
            has_changes = True
        elif args.autoscaling_policy == '' or args.disable_autoscaling:  # pylint: disable=g-explicit-bool-comparison
            # Disabling autoscaling. Don't need to explicitly set
            # cluster_config.autoscaling_config to None.
            changed_fields.append('config.autoscaling_config.policy_uri')
            has_changes = True

        lifecycle_config = dataproc.messages.LifecycleConfig()
        changed_config = False
        if args.max_age is not None:
            lifecycle_config.autoDeleteTtl = six.text_type(args.max_age) + 's'
            changed_fields.append('config.lifecycle_config.auto_delete_ttl')
            changed_config = True
        if args.expiration_time is not None:
            lifecycle_config.autoDeleteTime = times.FormatDateTime(
                args.expiration_time)
            changed_fields.append('config.lifecycle_config.auto_delete_time')
            changed_config = True
        if args.max_idle is not None:
            lifecycle_config.idleDeleteTtl = six.text_type(args.max_idle) + 's'
            changed_fields.append('config.lifecycle_config.idle_delete_ttl')
            changed_config = True
        if args.no_max_age:
            lifecycle_config.autoDeleteTtl = None
            changed_fields.append('config.lifecycle_config.auto_delete_ttl')
            changed_config = True
        if args.no_max_idle:
            lifecycle_config.idleDeleteTtl = None
            changed_fields.append('config.lifecycle_config.idle_delete_ttl')
            changed_config = True
        if changed_config:
            cluster_config.lifecycleConfig = lifecycle_config
            has_changes = True

        def _GetCurrentCluster():
            # This is used for labels and auxiliary_node_pool_configs
            get_cluster_request = (
                dataproc.messages.DataprocProjectsRegionsClustersGetRequest(
                    projectId=cluster_ref.projectId,
                    region=cluster_ref.region,
                    clusterName=cluster_ref.clusterName))
            current_cluster = dataproc.client.projects_regions_clusters.Get(
                get_cluster_request)
            return current_cluster

        # Put in a thunk so we only make this call if needed
        def _GetCurrentLabels():
            # We need to fetch cluster first so we know what the labels look like. The
            # labels_util will fill out the proto for us with all the updates and
            # removals, but first we need to provide the current state of the labels
            current_cluster = _GetCurrentCluster()
            return current_cluster.labels

        labels_update = labels_util.ProcessUpdateArgsLazy(
            args,
            dataproc.messages.Cluster.LabelsValue,
            orig_labels_thunk=_GetCurrentLabels)
        if labels_update.needs_update:
            has_changes = True
            changed_fields.append('labels')
        labels = labels_update.GetOrNone()

        if args.driver_pool_size is not None:
            # Getting the node_pool_ids from the current node_pools and other attrs
            # that are not shared with the user
            # Driver pools can only be updated currently with NO other updates
            # We are relying on our frontend validation to prevent this until
            # the change is made to allow driver pools to be updated with other fields
            auxiliary_node_pools = _GetCurrentCluster(
            ).config.auxiliaryNodePoolConfigs

            # get the index of the current cluster's driver pool in the auxiliary
            # node pools list, index_driver_pools is also a list that should have a
            # length of 1
            index_driver_pools = [
                i for i, n in enumerate(auxiliary_node_pools)
                if dataproc.messages.NodePoolConfig.
                RolesValueListEntryValuesEnum.DRIVER in n.roles
            ]

            if len(index_driver_pools) > 1:
                raise exceptions.ArgumentError(
                    'At most one driver pool can be specified per cluster.')
            elif len(index_driver_pools) == 1:
                index = index_driver_pools[0]
                auxiliary_node_pools[
                    index].nodePoolConfig.numInstances = args.driver_pool_size
            else:
                # This case is only relevant for scaling from 0 -> N nodes
                # this will not be supported initially, but will be relying on our
                # front end validation to prevent or allow
                worker_config = dataproc.messages.InstanceGroupConfig(
                    numInstances=args.driver_pool_size)
                node_config = dataproc.messages.NodePoolConfig(
                    nodePoolConfig=worker_config,
                    roles=[
                        dataproc.messages.NodePoolConfig.
                        RolesValueListEntryValuesEnum.DRIVER
                    ])
                auxiliary_node_pools.append(node_config)

            cluster_config.auxiliaryNodePoolConfigs = auxiliary_node_pools
            changed_fields.append('config.auxiliary_node_pool_configs')
            has_changes = True

        if not has_changes:
            raise exceptions.ArgumentError(
                'Must specify at least one cluster parameter to update.')

        cluster = dataproc.messages.Cluster(
            config=cluster_config,
            clusterName=cluster_ref.clusterName,
            labels=labels,
            projectId=cluster_ref.projectId)

        request = dataproc.messages.DataprocProjectsRegionsClustersPatchRequest(
            clusterName=cluster_ref.clusterName,
            region=cluster_ref.region,
            projectId=cluster_ref.projectId,
            cluster=cluster,
            updateMask=','.join(changed_fields),
            requestId=util.GetUniqueId())

        if args.graceful_decommission_timeout is not None:
            request.gracefulDecommissionTimeout = (
                six.text_type(args.graceful_decommission_timeout) + 's')

        operation = dataproc.client.projects_regions_clusters.Patch(request)

        if args.async_:
            log.status.write('Updating [{0}] with operation [{1}].'.format(
                cluster_ref, operation.name))
            return

        util.WaitForOperation(dataproc,
                              operation,
                              message='Waiting for cluster update operation',
                              timeout_s=args.timeout)

        request = dataproc.messages.DataprocProjectsRegionsClustersGetRequest(
            projectId=cluster_ref.projectId,
            region=cluster_ref.region,
            clusterName=cluster_ref.clusterName)
        cluster = dataproc.client.projects_regions_clusters.Get(request)
        log.UpdatedResource(cluster_ref)
        return cluster