def batch_get_item(self, **kwargs): """Mock batch_get_item method and return mimicking dynamodb response Keyword Argments: exception (bool): True raise exception. Returns: (dict): Response dictionary containing fake results. """ if self.exception: err = {'Error': {'Code': 400, 'Message': 'raising test exception'}} raise ClientError(err, 'batch_get_item') if not kwargs.get('RequestItems'): err = { 'Error': { 'Code': 403, 'Message': 'raising test exceptionParameter validation failed' } } raise ParamValidationError(report=err) # Validate query keys for _, item_value in kwargs['RequestItems'].iteritems(): if not item_value.get('Keys'): err = {'Error': {'Code': 400, 'Message': '[Keys] parameter is required'}} raise ParamValidationError(report=err) self._validate_keys(item_value['Keys']) response = { 'UnprocessedKeys': {}, 'Responses': { 'test_table_name': [ { 'ioc_value': {'S': '1.1.1.2'}, 'sub_type': {'S': 'mal_ip'} }, { 'ioc_value': {'S': 'evil.com'}, 'sub_type': {'S': 'c2_domain'} } ] }, 'ResponseMetadata': { 'RetryAttempts': 0, 'HTTPStatusCode': 200, 'RequestId': 'ABCD1234', 'HTTPHeaders': {} } } if self.has_unprocessed_keys: response['UnprocessedKeys'] = { 'test_table_name': { 'Keys': [ {'ioc_value': {'S': 'foo'}}, {'ioc_value': {'S': 'bar'}} ] } } return response
def _validate_connector_args(connector_args): if connector_args is None: return for k, v in connector_args.items(): # verify_ssl is handled by verify parameter to create_client if k == 'use_dns_cache': if not isinstance(v, bool): raise ParamValidationError( report='{} value must be a boolean'.format(k)) elif k in ['keepalive_timeout']: if not isinstance(v, (float, int)): raise ParamValidationError( report='{} value must be a float/int'.format(k)) elif k == 'force_close': if not isinstance(v, bool): raise ParamValidationError( report='{} value must be a boolean'.format(k)) # limit is handled by max_pool_connections elif k == 'ssl_context': import ssl if not isinstance(v, ssl.SSLContext): raise ParamValidationError( report='{} must be an SSLContext instance'.format(k)) else: raise ParamValidationError( report='invalid connector_arg:{}'.format(k))
def _validate_keys(dynamodb_data): """Helper method to check if query key empty or duplicated""" result = [] if not dynamodb_data: err_msg = {'Error': {'Code': 403, 'Message': 'Empty query keys'}} raise ParamValidationError(report=err_msg) deserializer = TypeDeserializer() for raw_data in dynamodb_data: for _, val in raw_data.iteritems(): python_data = deserializer.deserialize(val).lower() if not python_data or python_data in result: err_msg = {'Error': {'Code': 403, 'Message': 'Parameter Validation Error'}} raise ParamValidationError(report=err_msg) result.append(python_data)
def test_run_query_botocore_error(self) -> None: report = 'Unknown parameter in QueryExecutionContext: "banana", must be one of: Database, Catalog' mock_athena = Mock(start_query_execution=Mock(side_effect=ParamValidationError(report=report))) with self.assertRaises(exception.RunQueryException) as ex: AwsAthenaAsyncClient(mock_athena).run_query("some query") self.assertIn(report, ex.exception.args[0])
def _ensure_header_is_valid_host(self, header): match = self._VALID_HOSTNAME.match(header) if not match: raise ParamValidationError(report=( 'Hostnames must contain only - and alphanumeric characters, ' 'and between 1 and 63 characters long.' ))
def batch_get_image(self, repository_name, registry_id=None, image_ids=None, accepted_media_types=None): if repository_name in self.repositories: repository = self.repositories[repository_name] else: raise RepositoryNotFoundException(repository_name, registry_id or DEFAULT_REGISTRY_ID) if not image_ids: raise ParamValidationError(msg='Missing required parameter in input: "imageIds"') response = { 'images': [], 'failures': [], } for image_id in image_ids: found = False for image in repository.images: if (('imageDigest' in image_id and image.get_image_digest() == image_id['imageDigest']) or ('imageTag' in image_id and image.image_tag == image_id['imageTag'])): found = True response['images'].append(image.response_batch_get_image) if not found: response['failures'].append({ 'imageId': { 'imageTag': image_id.get('imageTag', 'null') }, 'failureCode': 'ImageNotFound', 'failureReason': 'Requested image not found' }) return response
async def test_insert_post(self): api_res = {'ResponseMetadata': {'HTTPStatusCode': 200, 'RequestId': 'd92c4314-439a-4bc7-90f5-125a615dfaa2'}} db = await self.client.db db.put_item = asynctest.CoroutineMock(return_value=api_res) date = datetime.utcnow() date_str = datetime.strftime(date, self.client.date_fmt) kwargs = {"target_id": "target-id", "post_id": "post-id", "source_id": "source-id", "text": "Text", "sticky": True, "created": date, "updated": date, "target_doc": {"foo": "doc"}} res = await self.client.insert_post(**kwargs) assert res is True assert type(res) == bool db.put_item.assert_called_once_with( Item={ 'source_id': {'S': 'source-id'}, 'text': {'S': 'Text'}, 'target_id': {'S': 'target-id'}, 'post_id': {'S': 'post-id'}, 'created': {'S': date_str}, 'updated': {'S': date_str}, 'target_doc': {'S': '{"foo": "doc"}'}, 'sticky': {'N': '1'}}, TableName='livebridge_test') # insert_post failing(self): db.put_item = asynctest.CoroutineMock(side_effect=ParamValidationError(report="Exception raised")) res = await self.client.insert_post(**kwargs) assert res is False assert type(res) == bool
def invoke(cls, **kwargs): """Mocked invoke function that returns a reponse mimicking boto3's reponse Keyword Args: FuncitonName (str): The AWS Lambda function name being invoked InvocationType (str): Type of invocation (typically 'Event') Payload (str): Payload in string or file format to send to lambda Qualifier (str): Alias for fully qualified AWS ARN Returns: dict: Response dictionary containing a fake RequestId """ if cls._raise_exception: # Turn of the raise exception boolean so we don't do this next time cls._raise_exception = not cls._raise_exception err = {'Error': {'Code': 400, 'Message': 'raising test exception'}} raise ClientError(err, 'invoke') req_keywords = {'FunctionName', 'InvocationType', 'Payload'} key_diff = req_keywords.difference(set(kwargs)) if key_diff: message = 'required keyword missing: {}'.format(', '.join(key_diff)) err = {'Error': {'Code': 400, 'Message': message}} raise ClientError(err, 'invoke') if not isinstance(kwargs['Payload'], (str, bytearray)): if not hasattr(kwargs['Payload'], 'read'): err = ('Invalid type for parameter Payload, value: {}, type: {}, ' 'valid types: <type \'str\'>, <type \'bytearray\'>, ' 'file-like object').format(kwargs['Payload'], type(kwargs['Payload'])) raise ParamValidationError(response=err) return {'ResponseMetadata': {'RequestId': '9af88643-7b3c-43cd-baae-addb73bb4d27'}}
def _add_tags(self, operation_name, Tags, cluster): _validate_param_type(Tags, (list, tuple)) new_tags = {} for Tag in Tags: _validate_param_type(Tag, dict) if set(Tag) > set(['Key', 'Value']): raise ParamValidationError(report='Unknown parameter in Tags') Key = Tag.get('Key') if not Key or not 1 <= len(Key) <= 128: raise _InvalidRequestException( operation_name, "Invalid tag key: '%s'. Tag keys must be between 1 and 128" " characters in length." % ('null' if Key is None else Key)) Value = Tag.get('Value') or '' if not 0 <= len(Value) <= 256: raise _InvalidRequestException( operation_name, "Invalid tag value: '%s'. Tag values must be between 1 and" " 128 characters in length." % Value) new_tags[Key] = Value tags_dict = dict((t['Key'], t['Value']) for t in cluster['Tags']) tags_dict.update(new_tags) cluster['Tags'] = [ dict(Key=k, Value=v) for k, v in sorted(tags_dict.items()) ]
def validate_ascii_metadata(params, **kwargs): """Verify S3 Metadata only contains ascii characters. From: http://docs.aws.amazon.com/AmazonS3/latest/dev/UsingMetadata.html "Amazon S3 stores user-defined metadata in lowercase. Each name, value pair must conform to US-ASCII when using REST and UTF-8 when using SOAP or browser-based uploads via POST." """ metadata = params.get('Metadata') if not metadata or not isinstance(metadata, dict): # We have to at least type check the metadata as a dict type # because this handler is called before param validation. # We'll go ahead and return because the param validator will # give a descriptive error message for us. # We might need a post-param validation event. return for key, value in metadata.items(): try: key.encode('ascii') value.encode('ascii') except UnicodeEncodeError as e: error_msg = ('Non ascii characters found in S3 metadata ' 'for key "%s", value: "%s". \nS3 metadata can only ' 'contain ASCII characters. ' % (key, value)) raise ParamValidationError(report=error_msg)
def esearch(event, context): if not validate_body(event['body']): raise ParamValidationError(report='Invalid request body') body = json.loads(event['body']) query = body.get('simple_query') include = body.get('include') exclude = body.get('exclude') size = body['size'] if 'size' in body else 10 from_index = body['from'] if 'from' in body else 0 s = Search(using=es, index='api-data') if include: s = build_match_query(s, include) elif query: s = s.query('multi_match', query=query) else: s = s.query('match_all') response = s[from_index:size].execute() results = [] for hit in response: results.append(hit.to_dict()) response = { 'statusCode': 200, 'body': json.dumps({"items": results}) } return response
def _validate_fixed_response_action(self, action, i, index): status_code = action.data.get("fixed_response_config._status_code") if status_code is None: raise ParamValidationError( report='Missing required parameter in Actions[%s].FixedResponseConfig: "StatusCode"' % i ) expression = r"^(2|4|5)\d\d$" if not re.match(expression, status_code): raise InvalidStatusCodeActionTypeError( "1 validation error detected: Value '{}' at 'actions.{}.member.fixedResponseConfig.statusCode' failed to satisfy constraint: \ Member must satisfy regular expression pattern: {}".format( status_code, index, expression ) ) content_type = action.data["fixed_response_config._content_type"] if content_type and content_type not in [ "text/plain", "text/css", "text/html", "application/javascript", "application/json", ]: raise InvalidLoadBalancerActionException( "The ContentType must be one of:'text/html', 'application/json', 'application/javascript', 'text/css', 'text/plain'" )
def test_s3_upload_paramvalidationerror(self, mtransaction, mclient): s3_connection_mock = mock.Mock() s3_connection_mock.upload_fileobj.side_effect = ParamValidationError(report='Some validation error') mclient.return_value = s3_connection_mock response = self.app_client.post('/file_export/update', content_type='application/json', data=json.dumps(self.location)) self.assertEqual(response.status_code, 500)
def validate_bucket_name(params, **kwargs): if 'Bucket' not in params: return bucket = params['Bucket'] if VALID_BUCKET.search(bucket) is None: error_msg = ('Invalid bucket name "%s": Bucket name must match ' 'the regex "%s"' % (bucket, VALID_BUCKET.pattern)) raise ParamValidationError(report=error_msg)
def get_query_execution(**kwargs): response = { 'Error': { 'Code': 'invalid_parameter', 'Message': 'Unable to access the data' } } raise ParamValidationError(report='invalid_parameter')
def serialize_to_request(self, parameters, operation_model): input_shape = operation_model.input_shape if input_shape is not None: report = self._param_validator.validate(parameters, operation_model.input_shape) if report.has_errors(): raise ParamValidationError(report=report.generate_report()) return self._serializer.serialize_to_request(parameters, operation_model)
def validate_bucket_name(params, **kwargs): if 'Bucket' not in params: return bucket = params['Bucket'] if not VALID_BUCKET.search(bucket) and not VALID_S3_ARN.search(bucket): error_msg = ('Invalid bucket name "%s": Bucket name must match ' 'the regex "%s" or be an ARN matching the regex "%s"' % (bucket, VALID_BUCKET.pattern, VALID_S3_ARN.pattern)) raise ParamValidationError(report=error_msg)
def test_process_ioc_values_parameter_error(self, log_mock): """ThreatIntel - Process IOC Values, ParamValidationError""" potential_iocs = ['1.1.1.1', '2.2.2.2'] with patch.object(self._threat_intel, '_query') as query_mock: query_mock.side_effect = ParamValidationError(report='BadParams') result = list(self._threat_intel._process_ioc_values(potential_iocs)) assert_equal(result, []) log_mock.assert_called_with('An error occurred while querying dynamodb table')
def test_crawl_boto_param_exception(self, mock_session): """Test botocore parameter exception is caught properly.""" logging.disable(logging.NOTSET) mock_session.client = MagicMock() unit_crawler = AWSOrgUnitCrawler(self.account) unit_crawler._init_session() unit_crawler._client.list_roots.side_effect = ParamValidationError(report="Bad Param") with self.assertLogs(logger=crawler_log, level=logging.WARNING): unit_crawler.crawl_account_hierarchy()
def test_select_propegates_param_validation_exception(mock): mock.side_effect = ParamValidationError(report='error report') expected_error_message = '[S3 CLIENT ERROR]: Parameter validation failed:\nerror report' try: storage_client.select(STUDY_GUIDE_ID, EXPRESSION) except Exception as error: assert str(error) == expected_error_message
def test_account_not_found(self) -> None: account_id = "123456789012" error_msg = "boom" mock_boto_orgs = Mock(describe_account=Mock( side_effect=ParamValidationError(report=error_msg))) with self.assertLogs("AwsOrganizationsClient", level="ERROR") as error_log: AwsOrganizationsClient(mock_boto_orgs).find_account_by_id( account_id) self.assertIn(account_id, error_log.output[0]) self.assertIn(error_msg, error_log.output[0])
def test_thingname_nostr(self, mock): """ Test non string thing name specified. """ mock.configure_mock( **(self.config_shadowget(ParamValidationError(report='UnitTest')))) self.assertRaises(ParamValidationError, lf.lambda_handler, event=self.lambdaevent, context=None) mock.client.return_value.update_thing_shadow.assert_not_called()
def _quote_source_header_from_dict(source_dict): try: bucket = source_dict['Bucket'] key = percent_encode(source_dict['Key'], safe=SAFE_CHARS + '/') version_id = source_dict.get('VersionId') except KeyError as e: raise ParamValidationError(report='Missing required parameter: %s' % str(e)) final = '%s/%s' % (bucket, key) if version_id is not None: final += '?versionId=%s' % version_id return final
def _validate_connector_args(connector_args): if connector_args is None: return for k, v in connector_args.items(): if k in ['use_dns_cache', 'verify_ssl']: if not isinstance(v, bool): raise ParamValidationError( report='{} value must be a boolean'.format(k)) elif k in ['keepalive_timeout']: if not isinstance(v, float) and not isinstance(v, int): raise ParamValidationError( report='{} value must be a float/int'.format(k)) elif k == 'force_close': if not isinstance(v, bool): raise ParamValidationError( report='{} value must be a boolean'.format(k)) elif k == 'limit': if not isinstance(v, int): raise ParamValidationError( report='{} value must be an int'.format(k)) elif k == 'ssl_context': import ssl if not isinstance(v, ssl.SSLContext): raise ParamValidationError( report='{} must be an SSLContext instance'.format(k)) else: raise ParamValidationError( report='invalid connector_arg:{}'.format(k))
def test_parm_val_exception(self, mock_boto3_client): """Test _get_sts_access fail.""" logging.disable(logging.NOTSET) sts_client = Mock() sts_client.assume_role.side_effect = ParamValidationError(report="test") mock_boto3_client.return_value = sts_client iam_arn = "BAD" with self.assertLogs(level=logging.CRITICAL): credentials = _get_sts_access(iam_arn) self.assertIn("aws_access_key_id", credentials) self.assertIn("aws_secret_access_key", credentials) self.assertIn("aws_session_token", credentials) self.assertIsNone(credentials.get("aws_access_key_id")) self.assertIsNone(credentials.get("aws_secret_access_key")) self.assertIsNone(credentials.get("aws_session_token"))
def _validate_param(params, name, type=None): """Check that the param *name* is found in *params*, and if *type* is set, validate that it has the proper type. *type* may also be a tuple (multiple types) or a list (multiple values to match) """ if name not in params: raise ParamValidationError( report='Missing required parameter in input: "%s"' % name) if type: if isinstance(type, list): _validate_param_enum(params[name], type) else: _validate_param_type(params[name], type)
def build_parameters(self, **kwargs): """ Returns a dictionary containing the kwargs for the given operation formatted as required to pass to the service in a request. """ protocol = self._model.metadata['protocol'] input_shape = self._model.input_shape if input_shape is not None: self._convert_kwargs_to_correct_casing(kwargs) validator = ParamValidator() errors = validator.validate(kwargs, self._model.input_shape) if errors.has_errors(): raise ParamValidationError(report=errors.generate_report()) serializer = serialize.create_serializer(protocol) request_dict = serializer.serialize_to_request(kwargs, self._model) return request_dict
def test_s3_upload_paramvalidationerror(self, mtransaction, mclient): good_token = jwt.encode({'authorities': ['developer', 'tester']}, 'secret', algorithm='HS256') s3_connection_mock = mock.Mock() s3_connection_mock.upload_fileobj.side_effect = ParamValidationError( report='Some validation error') mclient.return_value = s3_connection_mock response = self.app_client.post('/file_export/change', content_type='application/json', headers={ 'Authorization': 'Bearer {0}'.format( good_token.decode('utf-8')) }, data=json.dumps(self.location_change)) self.assertEqual(response.status_code, 500)
def batch_get_image( self, repository_name, registry_id=None, image_ids=None, accepted_media_types=None, ): if repository_name in self.repositories: repository = self.repositories[repository_name] else: raise RepositoryNotFoundException( repository_name, registry_id or DEFAULT_REGISTRY_ID ) if not image_ids: raise ParamValidationError( msg='Missing required parameter in input: "imageIds"' ) response = {"images": [], "failures": []} for image_id in image_ids: found = False for image in repository.images: if ( "imageDigest" in image_id and image.get_image_digest() == image_id["imageDigest"] ) or ( "imageTag" in image_id and image.image_tag == image_id["imageTag"] ): found = True response["images"].append(image.response_batch_get_image) if not found: response["failures"].append( { "imageId": {"imageTag": image_id.get("imageTag", "null")}, "failureCode": "ImageNotFound", "failureReason": "Requested image not found", } ) return response
def _validate_response(self, operation_name, service_response): service_model = self.client.meta.service_model operation_model = service_model.operation_model(operation_name) output_shape = operation_model.output_shape # Remove ResponseMetadata so that the validator doesn't attempt to # perform validation on it. response = service_response if 'ResponseMetadata' in response: response = copy.copy(service_response) del response['ResponseMetadata'] if output_shape is not None: validate_parameters(response, output_shape) elif response: # If the output shape is None, that means the response should be # empty apart from ResponseMetadata raise ParamValidationError(report=( "Service response should only contain ResponseMetadata."))