Esempio n. 1
0
    def create_cluster(self, project_id, region, cluster):
        self._check_region_matches_endpoint(region)

        # convert dict to object
        if not isinstance(cluster, Cluster):
            cluster = Cluster(**cluster)

        if cluster.project_id:
            if cluster.project_id != project_id:
                raise InvalidArgument(
                    'If provided, CreateClusterRequest.cluster.project_id must'
                    ' match CreateClusterRequest.project_id')
        else:
            cluster.project_id = project_id

        if not cluster.cluster_name:
            raise InvalidArgument('Cluster name is required')

        # initialize cluster status
        cluster.status.state = _cluster_state_value('CREATING')

        cluster_key = (project_id, region, cluster.cluster_name)

        if cluster_key in self.mock_clusters:
            raise AlreadyExists('Already exists: Cluster ' +
                                _cluster_path(*cluster_key))

        self.mock_clusters[cluster_key] = cluster
def _parse_attributes(values: AttributeValues) -> str:
    if not len(values.values) == 1:
        raise InvalidArgument(
            "Received an unparseable message with multiple values for an attribute."
        )
    value: bytes = values.values[0]
    try:
        return value.decode("utf-8")
    except UnicodeError:
        raise InvalidArgument(
            "Received an unparseable message with a non-utf8 attribute.")
Esempio n. 3
0
    def create_cluster(self, project_id, region, cluster):
        self._check_region_matches_endpoint(region)

        # convert dict to object
        if not isinstance(cluster, Cluster):
            cluster = Cluster(**cluster)

        if cluster.project_id:
            if cluster.project_id != project_id:
                raise InvalidArgument(
                    'If provided, CreateClusterRequest.cluster.project_id must'
                    ' match CreateClusterRequest.project_id')
        else:
            cluster.project_id = project_id

        if not cluster.cluster_name:
            raise InvalidArgument('Cluster name is required')

        # add in default disk config
        for x in ('master', 'worker', 'secondary_worker'):
            field = x + '_config'
            conf = getattr(cluster.config, field, None)
            if conf and str(conf):  # empty DiskConfigs are still true-ish
                if not conf.disk_config:
                    conf.disk_config = DiskConfig()
                if not conf.disk_config.boot_disk_size_gb:
                    conf.disk_config.boot_disk_size_gb = _DEFAULT_DISK_SIZE_GB

        # update gce_cluster_config
        gce_config = cluster.config.gce_cluster_config

        # add in default scopes and sort
        scopes = set(gce_config.service_account_scopes)

        if not scopes:
            scopes.update(_DEFAULT_SCOPES)
        scopes.update(_MANDATORY_SCOPES)

        gce_config.service_account_scopes[:] = sorted(scopes)

        # initialize cluster status
        cluster.status.state = _cluster_state_value('CREATING')

        cluster_key = (project_id, region, cluster.cluster_name)

        if cluster_key in self.mock_clusters:
            raise AlreadyExists('Already exists: Cluster ' +
                                _cluster_path(*cluster_key))

        self.mock_clusters[cluster_key] = cluster
Esempio n. 4
0
def decode_attribute_event_time(attr: str) -> datetime.datetime:
    try:
        ts = Timestamp()
        ts.FromJsonString(attr)
        return ts.ToDatetime()
    except ValueError:
        raise InvalidArgument("Invalid value for event time attribute.")
Esempio n. 5
0
 def add_id_to_message(source: SequencedMessage):
     message: PubsubMessage = transformer.transform(source)
     if message.message_id:
         raise InvalidArgument(
             "Message after transforming has the message_id field set.")
     message.message_id = MessageMetadata(partition, source.cursor).encode()
     return message
Esempio n. 6
0
    def _check_region_matches_endpoint(self, region):
        expected = self._expected_region()

        if region != expected:
            raise InvalidArgument(
                "Region '%s' invalid or not supported by this endpoint;"
                " permitted regions: '[%s]'" % (region, expected))
Esempio n. 7
0
 def __init__(
     self,
     instance_id,
     client,
     configuration_name=None,
     node_count=None,
     display_name=None,
     emulator_host=None,
     labels=None,
     processing_units=None,
 ):
     self.instance_id = instance_id
     self._client = client
     self.configuration_name = configuration_name
     if node_count is not None and processing_units is not None:
         if processing_units != node_count * PROCESSING_UNITS_PER_NODE:
             raise InvalidArgument(
                 "Only one of node count and processing units can be set."
             )
     if node_count is None and processing_units is None:
         self._node_count = DEFAULT_NODE_COUNT
         self._processing_units = DEFAULT_NODE_COUNT * PROCESSING_UNITS_PER_NODE
     elif node_count is not None:
         self._node_count = node_count
         self._processing_units = node_count * PROCESSING_UNITS_PER_NODE
     else:
         self._processing_units = processing_units
         self._node_count = processing_units // PROCESSING_UNITS_PER_NODE
     self.display_name = display_name or instance_id
     self.emulator_host = emulator_host
     if labels is None:
         labels = {}
     self.labels = labels
Esempio n. 8
0
 def parse(to_parse: str) -> "TopicPath":
     splits = to_parse.split("/")
     if (len(splits) != 6 or splits[0] != "projects"
             or splits[2] != "locations" or splits[4] != "topics"):
         raise InvalidArgument(
             "Topic path must be formatted like projects/{project_number}/locations/{location}/topics/{name} but was instead "
             + to_parse)
     return TopicPath(splits[1], CloudZone.parse(splits[3]), splits[5])
Esempio n. 9
0
 def parse(to_parse: str) -> "LocationPath":
     splits = to_parse.split("/")
     if len(splits) != 6 or splits[0] != "projects" or splits[2] != "locations":
         raise InvalidArgument(
             "Location path must be formatted like projects/{project_number}/locations/{location} but was instead "
             + to_parse
         )
     return LocationPath(splits[1], _parse_location(splits[3]))
Esempio n. 10
0
def to_cps_publish_message(source: PubSubMessage) -> PubsubMessage:
    out = PubsubMessage()
    try:
        out.ordering_key = source.key.decode("utf-8")
    except UnicodeError:
        raise InvalidArgument(
            "Received an unparseable message with a non-utf8 key.")
    if PUBSUB_LITE_EVENT_TIME in source.attributes:
        raise InvalidArgument(
            "Special timestamp attribute exists in wire message. Unable to parse message."
        )
    out.data = source.data
    for key, values in source.attributes.items():
        out.attributes[key] = _parse_attributes(values)
    if "event_time" in source:
        out.attributes[PUBSUB_LITE_EVENT_TIME] = encode_attribute_event_time(
            source.event_time)
    return out
Esempio n. 11
0
        def count_write_wrapper(self, *args, **kwargs):
            self.__total_writes += 1

            if self.__total_writes > self.MAX_WRITES_PER_BATCH:
                raise InvalidArgument(
                    f"Maximum {self.MAX_WRITES_PER_BATCH} writes allowed per request"
                )

            func(self, *args, **kwargs)
Esempio n. 12
0
 def parse(to_parse: str) -> "ReservationPath":
     splits = to_parse.split("/")
     if (len(splits) != 6 or splits[0] != "projects"
             or splits[2] != "locations" or splits[4] != "reservations"):
         raise InvalidArgument(
             "Reservation path must be formatted like projects/{project_number}/locations/{location}/reservations/{name} but was instead "
             + to_parse)
     return ReservationPath(splits[1], CloudRegion.parse(splits[3]),
                            splits[5])
def _decode_attribute_event_time_proto(attr: str) -> Timestamp:
    try:
        ts = Timestamp()
        loaded = fast_serialize.load(attr)
        ts.seconds = loaded[0]
        ts.nanos = loaded[1]
        return ts
    except Exception:  # noqa: E722
        raise InvalidArgument("Invalid value for event time attribute.")
 def add_id_to_message(source: SequencedMessage):
     source_pb = source._pb
     message: PubsubMessage = transformer.transform(source)
     message_pb = message._pb
     if message_pb.message_id:
         raise InvalidArgument(
             "Message after transforming has the message_id field set.")
     message_pb.message_id = MessageMetadata._encode_parts(
         partition.value, source_pb.cursor.offset)
     return message
Esempio n. 15
0
def _parse_location(to_parse: str) -> Union[CloudRegion, CloudZone]:
    try:
        return CloudZone.parse(to_parse)
    except InvalidArgument:
        pass
    try:
        return CloudRegion.parse(to_parse)
    except InvalidArgument:
        pass
    raise InvalidArgument("Invalid location name: " + to_parse)
Esempio n. 16
0
    def submit_job(self, project_id, region, job):
        self._check_region_matches_endpoint(region)

        # convert dict to object
        if not isinstance(job, Job):
            job = Job(**job)

        if not (project_id and job.reference.job_id):
            raise NotImplementedError('generation of job IDs not implemented')
        job_id = job.reference.job_id

        if not job.placement.cluster_name:
            raise InvalidArgument('Cluster name is required')

        # cluster must exist
        cluster_key = (project_id, region, job.placement.cluster_name)
        if cluster_key not in self.mock_clusters:
            raise NotFound('Not Found: Cluster ' + _cluster_path(*cluster_key))

        if not job.hadoop_job:
            raise NotImplementedError('only hadoop jobs are supported')

        if job.reference.project_id:
            if job.reference.project_id != project_id:
                raise InvalidArgument(
                    'If provided, SubmitJobRequest.job.job_reference'
                    '.project_id must match SubmitJobRequest.project_id')
        else:
            job.reference.project_id = project_id

        job.status.state = _job_state_value('SETUP_DONE')

        job_key = (project_id, region, job_id)

        if job_key in self.mock_jobs:
            raise AlreadyExists(
                'Already exists: Job ' + _job_path(*job_key))

        self.mock_jobs[job_key] = job

        return deepcopy(job)
Esempio n. 17
0
 def parse(to_parse: str) -> "SubscriptionPath":
     splits = to_parse.split("/")
     if (
         len(splits) != 6
         or splits[0] != "projects"
         or splits[2] != "locations"
         or splits[4] != "subscriptions"
     ):
         raise InvalidArgument(
             "Subscription path must be formatted like projects/{project_number}/locations/{location}/subscriptions/{name} but was instead "
             + to_parse
         )
     return SubscriptionPath(splits[1], _parse_location(splits[3]), splits[5])
Esempio n. 18
0
 def seek_subscription(
     self,
     subscription_path: SubscriptionPath,
     target: Union[BacklogLocation, PublishTime, EventTime],
 ) -> Operation:
     request = SeekSubscriptionRequest(name=str(subscription_path))
     if isinstance(target, PublishTime):
         request.time_target = TimeTarget(publish_time=target.value)
     elif isinstance(target, EventTime):
         request.time_target = TimeTarget(event_time=target.value)
     elif isinstance(target, BacklogLocation):
         if target == BacklogLocation.END:
             request.named_target = SeekSubscriptionRequest.NamedTarget.HEAD
         else:
             request.named_target = SeekSubscriptionRequest.NamedTarget.TAIL
     else:
         raise InvalidArgument("A valid seek target must be specified.")
     return self._underlying.seek_subscription(request=request)
Esempio n. 19
0
def test_span_creation_error(setup):
    import google.auth.credentials
    from google.cloud.bigquery import client
    from google.api_core.exceptions import GoogleAPICallError, InvalidArgument

    mock_credentials = mock.Mock(spec=google.auth.credentials.Credentials)
    test_client = client.Client(project="test_project",
                                credentials=mock_credentials,
                                location="test_location")

    expected_attributes = {
        "foo": "baz",
        "db.system": "BigQuery",
        "db.name": "test_project",
        "location": "test_location",
    }
    with pytest.raises(GoogleAPICallError):
        with opentelemetry_tracing.create_span(TEST_SPAN_NAME,
                                               attributes=TEST_SPAN_ATTRIBUTES,
                                               client=test_client) as span:
            assert span.name == TEST_SPAN_NAME
            assert span.attributes == expected_attributes
            raise InvalidArgument("test_error")
Esempio n. 20
0
 def test_invalid_argument(self, lang):
     lang.return_value.analyze_sentiment.side_effect = InvalidArgument("fo")
     assert analyze_sentiment("foo") == "fo"
Esempio n. 21
0
 def classify_text(self, document=None):
     '''classify_text'''
     raise InvalidArgument(
         'Invalid text content: too few tokens (words) to process.')
async def test_wait_until_empty_completes_on_failure(
    committer: Committer,
    default_connection,
    initial_request,
    asyncio_sleep,
    sleep_queues,
):
    sleep_called = sleep_queues[FLUSH_SECONDS].called
    sleep_results = sleep_queues[FLUSH_SECONDS].results
    cursor1 = Cursor(offset=1)
    write_called_queue = asyncio.Queue()
    write_result_queue = asyncio.Queue()
    default_connection.write.side_effect = make_queue_waiter(
        write_called_queue, write_result_queue
    )
    read_called_queue = asyncio.Queue()
    read_result_queue = asyncio.Queue()
    default_connection.read.side_effect = make_queue_waiter(
        read_called_queue, read_result_queue
    )
    read_result_queue.put_nowait(StreamingCommitCursorResponse(initial={}))
    write_result_queue.put_nowait(None)
    async with committer:
        # Set up connection
        await write_called_queue.get()
        await read_called_queue.get()
        default_connection.write.assert_has_calls([call(initial_request)])

        # New committer is empty.
        await committer.wait_until_empty()

        # Write message 1
        commit_fut1 = asyncio.ensure_future(committer.commit(cursor1))
        empty_fut = asyncio.ensure_future(committer.wait_until_empty())
        assert not commit_fut1.done()
        assert not empty_fut.done()

        # Wait for writes to be waiting
        await sleep_called.get()
        asyncio_sleep.assert_called_with(FLUSH_SECONDS)

        # Handle the connection write
        await sleep_results.put(None)
        await write_called_queue.get()
        await write_result_queue.put(None)
        default_connection.write.assert_has_calls(
            [call(initial_request), call(as_request(cursor1))]
        )
        assert not commit_fut1.done()
        assert not empty_fut.done()

        # Wait for writes to be waiting
        await sleep_called.get()
        asyncio_sleep.assert_has_calls([call(FLUSH_SECONDS), call(FLUSH_SECONDS)])

        # Fail the connection with a permanent error
        await read_called_queue.get()
        await read_result_queue.put(InvalidArgument("permanent"))

        with pytest.raises(InvalidArgument):
            await empty_fut
Esempio n. 23
0
    def create_cluster(self, project_id, region, cluster):
        self._check_region_matches_endpoint(region)

        # convert dict to object
        if not isinstance(cluster, Cluster):
            cluster = Cluster(**cluster)

        if cluster.project_id:
            if cluster.project_id != project_id:
                raise InvalidArgument(
                    'If provided, CreateClusterRequest.cluster.project_id must'
                    ' match CreateClusterRequest.project_id')
        else:
            cluster.project_id = project_id

        if not cluster.cluster_name:
            raise InvalidArgument('Cluster name is required')

        # add in default disk config
        for x in ('master', 'worker', 'secondary_worker'):
            field = x + '_config'
            conf = getattr(cluster.config, field, None)
            if conf and str(conf):  # empty DiskConfigs are still true-ish
                if not conf.disk_config:
                    conf.disk_config = DiskConfig()
                if not conf.disk_config.boot_disk_size_gb:
                    conf.disk_config.boot_disk_size_gb = _DEFAULT_DISK_SIZE_GB

        # update gce_cluster_config
        gce_config = cluster.config.gce_cluster_config

        # check region and zone_uri
        if region == 'global':
            if gce_config.zone_uri:
                cluster_region = _zone_to_region(gce_config.zone_uri)
            else:
                raise InvalidArgument(
                    "Must specify a zone in GCE configuration"
                    " when using 'regions/global'")
        else:
            cluster_region = region

        # add in default scopes and sort
        scopes = set(gce_config.service_account_scopes)

        if not scopes:
            scopes.update(_DEFAULT_SCOPES)
        scopes.update(_MANDATORY_SCOPES)

        gce_config.service_account_scopes[:] = sorted(scopes)

        # handle network_uri and subnetwork_uri
        if gce_config.network_uri and gce_config.subnetwork_uri:
            raise InvalidArgument('GceClusterConfiguration cannot contain both'
                                  ' Network URI and Subnetwork URI')

        if not (gce_config.network_uri or gce_config.subnetwork_uri):
            gce_config.network_uri = 'default'

        if gce_config.network_uri:
            gce_config.network_uri = _fully_qualify_network_uri(
                gce_config.network_uri, project_id)

        if gce_config.subnetwork_uri:
            gce_config.subnetwork_uri = _fully_qualify_subnetwork_uri(
                gce_config.subnetwork_uri, project_id, region)

        # add in default cluster properties
        props = cluster.config.software_config.properties

        for k, v in _DEFAULT_CLUSTER_PROPERTIES.items():
            if k not in props:
                props[k] = v

        # initialize cluster status
        cluster.status.state = _cluster_state_value('CREATING')

        cluster_key = (project_id, region, cluster.cluster_name)

        if cluster_key in self.mock_clusters:
            raise AlreadyExists('Already exists: Cluster ' +
                                _cluster_path(*cluster_key))

        self.mock_clusters[cluster_key] = cluster
Esempio n. 24
0
 def parse(to_parse: str):
     splits = to_parse.split("-")
     if len(splits) != 3 or len(splits[2]) != 1:
         raise InvalidArgument("Invalid zone name: " + to_parse)
     region = CloudRegion(name=splits[0] + "-" + splits[1])
     return CloudZone(region, zone_id=splits[2])
 async def reinit_action(conn, last_error):
     assert conn == default_connection
     assert last_error is None
     raise InvalidArgument("abc")
Esempio n. 26
0
 def parse(to_parse: str):
     splits = to_parse.split("-")
     if len(splits) != 2:
         raise InvalidArgument("Invalid region name: " + to_parse)
     return CloudRegion(name=splits[0] + "-" + splits[1])
 async def reinit_action(conn):
     assert conn == default_connection
     raise InvalidArgument("abc")