Пример #1
0
    def write(self, digest_hash, digest_size, first_block, other_blocks):
        if len(digest_hash) != HASH_LENGTH or not digest_size.isdigit():
            raise InvalidArgumentError("Invalid digest [{}/{}]"
                                       .format(digest_hash, digest_size))

        digest = re_pb2.Digest(hash=digest_hash, size_bytes=int(digest_size))

        write_session = self.__storage.begin_write(digest)

        # Start the write session and write the first request's data.
        write_session.write(first_block)

        computed_hash = HASH(first_block)
        bytes_written = len(first_block)

        # Handle subsequent write requests.
        for next_block in other_blocks:
            write_session.write(next_block)

            computed_hash.update(next_block)
            bytes_written += len(next_block)

        # Check that the data matches the provided digest.
        if bytes_written != digest.size_bytes:
            raise NotImplementedError(
                "Cannot close stream before finishing write")

        elif computed_hash.hexdigest() != digest.hash:
            raise InvalidArgumentError("Data does not match hash")

        self.__storage.commit_write(digest, write_session)

        return bytestream_pb2.WriteResponse(committed_size=bytes_written)
Пример #2
0
    def read(self, digest_hash, digest_size, read_offset, read_limit):
        if len(digest_hash) != HASH_LENGTH or not digest_size.isdigit():
            raise InvalidArgumentError("Invalid digest [{}/{}]"
                                       .format(digest_hash, digest_size))

        digest = re_pb2.Digest(hash=digest_hash, size_bytes=int(digest_size))

        # Check the given read offset and limit.
        if read_offset < 0 or read_offset > digest.size_bytes:
            raise OutOfRangeError("Read offset out of range")

        elif read_limit == 0:
            bytes_remaining = digest.size_bytes - read_offset

        elif read_limit > 0:
            bytes_remaining = read_limit

        else:
            raise InvalidArgumentError("Negative read_limit is invalid")

        # Read the blob from storage and send its contents to the client.
        result = self.__storage.get_blob(digest)
        if result is None:
            raise NotFoundError("Blob not found")

        elif result.seekable():
            result.seek(read_offset)

        else:
            result.read(read_offset)

        while bytes_remaining > 0:
            yield bytestream_pb2.ReadResponse(
                data=result.read(min(self.BLOCK_SIZE, bytes_remaining)))
            bytes_remaining -= self.BLOCK_SIZE
Пример #3
0
 def _check_bot_ids(self, bot_id, name=None):
     """ Checks whether the ID and the name of the bot match,
     otherwise closes the bot sessions with that name or ID
     """
     if name is not None:
         _bot_id = self._bot_ids.get(name)
         if _bot_id is None:
             eviction_record = self._evicted_bot_sessions.get(name)
             if eviction_record:
                 raise InvalidArgumentError(
                     "Server has recently evicted the bot_name=[{}] at "
                     "timestamp=[{}], reason=[{}]".format(
                         name, eviction_record[0], eviction_record[1]))
             raise InvalidArgumentError(
                 'Name not registered on server: bot_name=[{}]'.format(
                     name))
         elif _bot_id != bot_id:
             self._close_bot_session(
                 name, reason="bot_id mismatch between worker and bgd")
             raise InvalidArgumentError(
                 'Bot id invalid. ID sent: bot_id=[{}] with name: bot_name[{}].'
                 'ID registered: bgd_bot_id[{}] for that name'.format(
                     bot_id, name, _bot_id))
     else:
         for _name, _bot_id in self._bot_ids.items():
             if bot_id == _bot_id:
                 self._close_bot_session(
                     _name, reason="bot already registered and given name")
                 raise InvalidArgumentError(
                     'Bot id already registered. ID sent: bot_id=[{}].'
                     'Id registered: bgd_bot_id=[{}] with bgd_bot_name=[{}]'
                     .format(bot_id, _bot_id, _name))
Пример #4
0
    def Write(self, requests, context):
        self.__logger.debug("Write request from [%s]", context.peer())

        request = next(requests)
        names = request.resource_name.split('/')

        try:
            instance_name = ''
            # Format: "{instance_name}/uploads/{uuid}/blobs/{hash}/{size}/{anything}":
            if len(names
                   ) < 5 or 'uploads' not in names or 'blobs' not in names:
                raise InvalidArgumentError(
                    "Invalid resource name: [{}]".format(
                        request.resource_name))

            elif names[0] != 'uploads':
                index = names.index('uploads')
                instance_name = '/'.join(names[:index])
                names = names[index:]

            if len(names) < 5:
                raise InvalidArgumentError(
                    "Invalid resource name: [{}]".format(
                        request.resource_name))

            _, hash_, size_bytes = names[1], names[3], names[4]

            instance = self._get_instance(instance_name)

            return instance.write(hash_, size_bytes, request.data,
                                  [request.data for request in requests])

        except NotImplementedError as e:
            self.__logger.error(e)
            context.set_details(str(e))
            context.set_code(grpc.StatusCode.UNIMPLEMENTED)

        except InvalidArgumentError as e:
            self.__logger.error(e)
            context.set_details(str(e))
            context.set_code(grpc.StatusCode.INVALID_ARGUMENT)

        except NotFoundError as e:
            self.__logger.error(e)
            context.set_details(str(e))
            context.set_code(grpc.StatusCode.NOT_FOUND)

        return bytestream_pb2.WriteResponse()
Пример #5
0
    def _get_instance(self, instance_name):
        try:
            return self._instances[instance_name]

        except KeyError:
            raise InvalidArgumentError(
                "Invalid instance name: [{}]".format(instance_name))
Пример #6
0
    def __init__(self, auth_token=None, auth_secret=None):
        """Initialises a new :class:`AuthMetadataClientInterceptor`.

        Important:
            One of `auth_token` or `auth_secret` must be provided.

        Args:
            auth_token (str, optional): Authorization token as a string.
            auth_secret (str, optional): Authorization secret as a string.

        Raises:
            InvalidArgumentError: If neither `auth_token` or `auth_secret` are
                provided.
        """
        if auth_token:
            self.__secret = auth_token.strip()

        elif auth_secret:
            self.__secret = base64.b64encode(auth_secret.strip())

        else:
            raise InvalidArgumentError("A secret or token must be provided")

        self.__header_field_name = 'authorization'
        self.__header_field_value = 'Bearer {}'.format(self.__secret)
Пример #7
0
    def batch_read_blobs(self, digests):
        storage = self.__storage

        response = re_pb2.BatchReadBlobsResponse()

        requested_bytes = sum((digest.size_bytes for digest in digests))
        max_batch_size = self.max_batch_total_size_bytes()

        if requested_bytes > max_batch_size:
            raise InvalidArgumentError('Combined total size of blobs exceeds '
                                       'server limit. '
                                       '({} > {} [byte])'.format(requested_bytes,
                                                                 max_batch_size))

        blobs_read = storage.bulk_read_blobs(digests)

        for digest in digests:
            response_proto = response.responses.add()
            response_proto.digest.CopyFrom(digest)

            if digest.hash in blobs_read and blobs_read[digest.hash] is not None:
                response_proto.data = blobs_read[digest.hash].read()
                status_code = code_pb2.OK
            else:
                status_code = code_pb2.NOT_FOUND

            response_proto.status.CopyFrom(status_pb2.Status(code=status_code))

        return response
Пример #8
0
    def _get_instance(self, name):
        try:
            return self._instances[name]

        except KeyError:
            raise InvalidArgumentError(
                "Instance doesn't exist on server: [{}]".format(name))
Пример #9
0
    def Read(self, request, context):
        self.__logger.debug("Read request from [%s]", context.peer())

        names = request.resource_name.split('/')

        try:
            instance_name = ''
            # Format: "{instance_name}/blobs/{hash}/{size}":
            if len(names) < 3 or names[-3] != 'blobs':
                raise InvalidArgumentError(
                    "Invalid resource name: [{}]".format(
                        request.resource_name))

            elif names[0] != 'blobs':
                index = names.index('blobs')
                instance_name = '/'.join(names[:index])
                names = names[index:]

            if len(names) < 3:
                raise InvalidArgumentError(
                    "Invalid resource name: [{}]".format(
                        request.resource_name))

            hash_, size_bytes = names[1], names[2]

            instance = self._get_instance(instance_name)

            yield from instance.read(hash_, size_bytes, request.read_offset,
                                     request.read_limit)

        except InvalidArgumentError as e:
            self.__logger.error(e)
            context.set_details(str(e))
            context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
            yield bytestream_pb2.ReadResponse()

        except NotFoundError as e:
            self.__logger.error(e)
            context.set_details(str(e))
            context.set_code(grpc.StatusCode.NOT_FOUND)
            yield bytestream_pb2.ReadResponse()

        except OutOfRangeError as e:
            self.__logger.error(e)
            context.set_details(str(e))
            context.set_code(grpc.StatusCode.OUT_OF_RANGE)
            yield bytestream_pb2.ReadResponse()
Пример #10
0
    def register_operation_peer(self, operation_name, peer, message_queue):
        try:
            self._scheduler.register_job_operation_peer(
                operation_name, peer, message_queue)

        except NotFoundError:
            raise InvalidArgumentError(
                "Operation name does not exist: [{}]".format(operation_name))
Пример #11
0
    def register_job_peer(self, job_name, peer, message_queue):
        try:
            return self._scheduler.register_job_peer(job_name, peer,
                                                     message_queue)

        except NotFoundError:
            raise InvalidArgumentError(
                "Job name does not exist: [{}]".format(job_name))
Пример #12
0
    def get_operation(self, job_name):
        try:
            operation = self._scheduler.get_job_operation(job_name)

        except NotFoundError:
            raise InvalidArgumentError("Operation name does not exist: [{}]".format(job_name))

        return operation
Пример #13
0
    def _check_jwt_support(self, algorithm=AuthMetadataAlgorithm.UNSPECIFIED):
        """Ensures JWT and possible dependencies are available."""
        if not HAVE_JWT:
            raise InvalidArgumentError(
                "JWT authorization method requires PyJWT")

        try:
            if algorithm != AuthMetadataAlgorithm.UNSPECIFIED:
                jwt.register_algorithm(algorithm.value.upper(), None)

        except TypeError:
            raise InvalidArgumentError(
                "Algorithm not supported for JWT decoding: [{}]".format(
                    self._algorithm))

        except ValueError:
            pass
Пример #14
0
    def _setup_bot_session_reaper_loop(self):
        if self._bot_session_keepalive_timeout:
            if self._bot_session_keepalive_timeout <= 0:
                raise InvalidArgumentError(
                    "[bot_session_keepalive_timeout] set to [%s], "
                    "must be > 0, in seconds",
                    self._bot_session_keepalive_timeout)

            # Add the expired session reaper in the event loop
            main_loop = asyncio.get_event_loop()
            main_loop.create_task(self._reap_expired_sessions_loop())
Пример #15
0
    def _close_bot_session(self, name, *, reason=None):
        """ Before removing the session, close any leases and
        requeue with high priority.
        """
        bot_id = self._bot_ids.get(name)

        if bot_id is None:
            raise InvalidArgumentError(
                "Bot id does not exist: [{}]".format(name))

        self.__logger.debug("Attempting to close [%s] with name: [%s]", bot_id,
                            name)
        for lease_id in self._assigned_leases[name]:
            try:
                self._scheduler.retry_job_lease(lease_id)
            except NotFoundError:
                pass
        self._assigned_leases.pop(name)

        # If we had assigned an expire_time for this botsession, make sure to
        # clean up, regardless of the reason we end up closing this BotSession
        self._untrack_deadline_for_botsession(name)

        # Make sure we're only keeping the last N evicted sessions
        # NOTE: there could be some rare race conditions when the length of the OrderedDict is
        # only 1 below the limit; Multiple threads could check the size simultaneously before
        # they get to add their items in the OrderedDict, resulting in a size bigger than initially intented
        # (with a very unlikely upper bound of:
        #   O(n) = `remember_last_n_evicted_bot_sessions`
        #             + min(number_of_threads, number_of_concurrent_threads_cpu_can_handle)).
        #   The size being only 1 below the limit could also happen when the OrderedDict contains
        # exactly `n` items and a thread trying to insert sees the limit has been reached and makes
        # just enough space to add its own item.
        #   The cost of locking vs using a bit more memory for a few more items in-memory is high, thus
        # we opt for the unlikely event of the OrderedDict growing a bit more and
        # make the next thread which tries to to insert an item, clean up `while len > n`.
        while len(self._evicted_bot_sessions
                  ) > self._remember_last_n_evicted_bot_sessions:
            self._evicted_bot_sessions.popitem()
        # Record this eviction
        self._evicted_bot_sessions[name] = (datetime.utcnow(), reason)

        self.__logger.debug("Closing bot session: [%s]", name)
        self._bot_ids.pop(name)
        self.__logger.info("Closed bot [%s] with name: [%s]", bot_id, name)
Пример #16
0
    def __init__(self,
                 event_loop,
                 endpoint_type=MonitoringOutputType.SOCKET,
                 endpoint_location=None,
                 metric_prefix="",
                 serialisation_format=MonitoringOutputFormat.BINARY):
        self.__event_loop = event_loop
        self.__streaming_task = None

        self.__message_queue = asyncio.Queue(loop=self.__event_loop)
        self.__sequence_number = 1

        self.__output_location = None
        self.__async_output = False
        self.__json_output = False
        self.__statsd_output = False
        self.__print_output = False
        self.__udp_output = False

        if endpoint_type == MonitoringOutputType.FILE:
            self.__output_location = endpoint_location

        elif endpoint_type == MonitoringOutputType.SOCKET:
            self.__output_location = endpoint_location
            self.__async_output = True

        elif endpoint_type == MonitoringOutputType.STDOUT:
            self.__print_output = True

        elif endpoint_type == MonitoringOutputType.UDP:
            self.__output_location = endpoint_location
            self.__udp_output = True

        else:
            raise InvalidArgumentError(
                "Invalid endpoint output type: [{}]".format(endpoint_type))

        self.__metric_prefix = metric_prefix

        if serialisation_format == MonitoringOutputFormat.JSON:
            self.__json_output = True
        elif serialisation_format == MonitoringOutputFormat.STATSD:
            self.__statsd_output = True
Пример #17
0
    def create_bot_session(self, parent, bot_session):
        """ Creates a new bot session. Server should assign a unique
        name to the session. If a bot with the same bot id tries to
        register with the service, the old one should be closed along
        with all its jobs.
        """
        if not bot_session.bot_id:
            raise InvalidArgumentError("Bot's id must be set by client.")

        try:
            self._check_bot_ids(bot_session.bot_id)
        except InvalidArgumentError:
            pass

        # Bot session name, selected by the server
        name = "{}/{}".format(parent, str(uuid.uuid4()))
        bot_session.name = name

        self._bot_ids[name] = bot_session.bot_id

        # We want to keep a copy of lease ids we have assigned
        self._assigned_leases[name] = set()

        self._request_leases(bot_session, name=name)
        self._assign_deadline_for_botsession(bot_session, name)

        if self.__debug:
            self.__logger.info(
                "Opened session bot_name=[%s] for bot_id=[%s], leases=[%s]",
                bot_session.name, bot_session.bot_id,
                ",".join([lease.id[:8] for lease in bot_session.leases]))
        else:
            self.__logger.info("Opened session, bot_name=[%s] for bot_id=[%s]",
                               bot_session.name, bot_session.bot_id)

        return bot_session
Пример #18
0
def setup_channel(remote_url,
                  auth_token=None,
                  client_key=None,
                  client_cert=None,
                  server_cert=None,
                  action_id=None,
                  tool_invocation_id=None,
                  correlated_invocations_id=None):
    """Creates a new gRPC client communication chanel.

    If `remote_url` does not specifies a port number, defaults 50051.

    Args:
        remote_url (str): URL for the remote, including port and protocol.
        auth_token (str): Authorization token file path.
        server_cert(str): TLS certificate chain file path.
        client_key (str): TLS root certificate file path.
        client_cert (str): TLS private key file path.
        action_id (str): Action identifier to which the request belongs to.
        tool_invocation_id (str): Identifier for a related group of Actions.
        correlated_invocations_id (str): Identifier that ties invocations together.

    Returns:
        Channel: Client Channel to be used in order to access the server
            at `remote_url`.

    Raises:
        InvalidArgumentError: On any input parsing error.
    """
    url = urlparse(remote_url)
    remote = '{}:{}'.format(url.hostname, url.port or 50051)
    details = None, None, None

    if url.scheme == 'http':
        channel = grpc.insecure_channel(remote,
                                        options=[
                                            ('grpc.max_send_message_length',
                                             MAX_REQUEST_SIZE),
                                            ('grpc.max_receive_message_length',
                                             MAX_REQUEST_SIZE),
                                        ])

    elif url.scheme == 'https':
        credentials, details = load_tls_channel_credentials(
            client_key, client_cert, server_cert)
        if not credentials:
            raise InvalidArgumentError(
                "Given TLS details (or defaults) could be loaded")

        channel = grpc.secure_channel(remote, credentials)

    else:
        raise InvalidArgumentError("Given remote does not specify a protocol")

    request_metadata_interceptor = RequestMetadataInterceptor(
        action_id=action_id,
        tool_invocation_id=tool_invocation_id,
        correlated_invocations_id=correlated_invocations_id)

    channel = grpc.intercept_channel(channel, request_metadata_interceptor)

    if auth_token is not None:
        token = load_channel_authorization_token(auth_token)
        if not token:
            raise InvalidArgumentError(
                "Given authorization token could be loaded")

        auth_interceptor = AuthMetadataClientInterceptor(auth_token=token)

        channel = grpc.intercept_channel(channel, auth_interceptor)

    return channel, details
Пример #19
0
def _print_operation_status(operation, print_details=False):
    metadata = remote_execution_pb2.ExecuteOperationMetadata()
    # The metadata is expected to be an ExecuteOperationMetadata message:
    if not operation.metadata.Is(metadata.DESCRIPTOR):
        raise InvalidArgumentError(
            'Metadata is not an ExecuteOperationMetadata '
            'message')
    operation.metadata.Unpack(metadata)

    stage = OperationStage(metadata.stage)

    if not operation.done:
        if stage == OperationStage.CACHE_CHECK:
            click.echo(
                'CacheCheck: {}: Querying action-cache (stage={})'.format(
                    operation.name, metadata.stage))
        elif stage == OperationStage.QUEUED:
            click.echo('Queued: {}: Waiting for execution (stage={})'.format(
                operation.name, metadata.stage))
        elif stage == OperationStage.EXECUTING:
            click.echo('Executing: {}: Currently running (stage={})'.format(
                operation.name, metadata.stage))
        else:
            click.echo('Error: {}: In an invalid state (stage={})'.format(
                operation.name, metadata.stage),
                       err=True)
        return

    assert stage == OperationStage.COMPLETED

    response = remote_execution_pb2.ExecuteResponse()
    # The response is expected to be an ExecutionResponse message:
    assert operation.response.Is(response.DESCRIPTOR)
    operation.response.Unpack(response)

    if response.status.code != code_pb2.OK:
        click.echo('Failure: {}: {} (code={})'.format(operation.name,
                                                      response.status.message,
                                                      response.status.code))
    else:
        if response.result.exit_code != 0:
            click.echo(
                'Success: {}: Completed with failure (stage={}, exit_code={})'.
                format(operation.name, metadata.stage,
                       response.result.exit_code))
        else:
            click.echo(
                'Success: {}: Completed succesfully (stage={}, exit_code={})'.
                format(operation.name, metadata.stage,
                       response.result.exit_code))

    if print_details:
        metadata = response.result.execution_metadata
        click.echo(indent('worker={}'.format(metadata.worker), '  '))

        queued = metadata.queued_timestamp.ToDatetime()
        click.echo(indent('queued_at={}'.format(queued), '  '))

        worker_start = metadata.worker_start_timestamp.ToDatetime()
        worker_completed = metadata.worker_completed_timestamp.ToDatetime()
        click.echo(
            indent('work_duration={}'.format(worker_completed - worker_start),
                   '  '))

        fetch_start = metadata.input_fetch_start_timestamp.ToDatetime()
        fetch_completed = metadata.input_fetch_completed_timestamp.ToDatetime()
        click.echo(
            indent('fetch_duration={}'.format(fetch_completed - fetch_start),
                   '    '))

        execution_start = metadata.execution_start_timestamp.ToDatetime()
        execution_completed = metadata.execution_completed_timestamp.ToDatetime(
        )
        click.echo(
            indent(
                'exection_duration={}'.format(execution_completed -
                                              execution_start), '    '))

        upload_start = metadata.output_upload_start_timestamp.ToDatetime()
        upload_completed = metadata.output_upload_completed_timestamp.ToDatetime(
        )
        click.echo(
            indent(
                'upload_duration={}'.format(upload_completed - upload_start),
                '    '))

        click.echo(
            indent('total_duration={}'.format(worker_completed - queued),
                   '  '))
Пример #20
0
    def cancel_operation(self, job_name):
        try:
            self._scheduler.cancel_job_operation(job_name)

        except NotFoundError:
            raise InvalidArgumentError("Operation name does not exist: [{}]".format(job_name))