Ejemplo n.º 1
0
class TestCLIArgument(unittest.TestCase):
    def setUp(self):
        self.service_name = "baz"
        self.service_model = ServiceModel(
            {
                "metadata": {"endpointPrefix": "bad"},
                "operations": {"SampleOperation": {"name": "SampleOperation", "input": {"shape": "Input"}}},
                "shapes": {
                    "StringShape": {"type": "string"},
                    "Input": {"type": "structure", "members": {"Foo": {"shape": "StringShape"}}},
                },
            },
            self.service_name,
        )
        self.operation_model = self.service_model.operation_model("SampleOperation")
        self.argument_model = self.operation_model.input_shape.members["Foo"]
        self.event_emitter = mock.Mock()

    def create_argument(self):
        return arguments.CLIArgument(
            self.argument_model.name, self.argument_model, self.operation_model, self.event_emitter
        )

    def test_unpack_uses_service_name_in_event(self):
        self.event_emitter.emit.return_value = ["value"]
        argument = self.create_argument()
        params = {}
        argument.add_to_params(params, "value")
        expected_event_name = "process-cli-arg.%s.%s" % (self.service_name, "sample-operation")
        actual_event_name = self.event_emitter.emit.call_args[0][0]
        self.assertEqual(actual_event_name, expected_event_name)
Ejemplo n.º 2
0
class TestBinaryTypes(unittest.TestCase):

    def setUp(self):
        self.model = {
            'metadata': {'protocol': 'query', 'apiVersion': '2014-01-01'},
            'documentation': '',
            'operations': {
                'TestOperation': {
                    'name': 'TestOperation',
                    'http': {
                        'method': 'POST',
                        'requestUri': '/',
                    },
                    'input': {'shape': 'InputShape'},
                }
            },
            'shapes': {
                'InputShape': {
                    'type': 'structure',
                    'members': {
                        'Blob': {'shape': 'BlobType'},
                    }
                },
                'BlobType': {
                    'type': 'blob',
                }
            }
        }
        self.service_model = ServiceModel(self.model)

    def serialize_to_request(self, input_params):
        request_serializer = serialize.create_serializer(
            self.service_model.metadata['protocol'])
        return request_serializer.serialize_to_request(
            input_params, self.service_model.operation_model('TestOperation'))

    def assert_serialized_blob_equals(self, request, blob_bytes):
        # This method handles all the details of the base64 decoding.
        encoded = base64.b64encode(blob_bytes)
        # Now the serializers actually have the base64 encoded contents
        # as str types so we need to decode back.  We know that this is
        # ascii so it's safe to use the ascii encoding.
        expected = encoded.decode('ascii')
        self.assertEqual(request['body']['Blob'], expected)

    def test_blob_accepts_bytes_type(self):
        body = b'bytes body'
        request = self.serialize_to_request(input_params={'Blob': body})

    def test_blob_accepts_str_type(self):
        body = u'ascii text'
        request = self.serialize_to_request(input_params={'Blob': body})
        self.assert_serialized_blob_equals(
            request, blob_bytes=body.encode('ascii'))

    def test_blob_handles_unicode_chars(self):
        body = u'\u2713'
        request = self.serialize_to_request(input_params={'Blob': body})
        self.assert_serialized_blob_equals(
            request, blob_bytes=body.encode('utf-8'))
Ejemplo n.º 3
0
    def test_no_output(self):
        service_model = ServiceModel({
            'operations': {
                'SampleOperation': {
                    'name': 'SampleOperation',
                    'input': {
                        'shape': 'SampleOperationInputOutput'
                    },
                }
            },
            'shapes': {
                'SampleOperationInput': {
                    'type': 'structure',
                    'members': {},
                },
                'String': {
                    'type': 'string'
                },
            },
        })
        operation_model = service_model.operation_model('SampleOperation')

        parsed = {}
        self.injector.inject_attribute_value_output(parsed=parsed,
                                                    model=operation_model)
        assert parsed == {}
Ejemplo n.º 4
0
    def test_validate_ignores_response_metadata(self):
        service_response = {'ResponseMetadata': {'foo': 'bar'}}
        service_model = ServiceModel({
            'documentation': '',
            'operations': {
                'foo': {
                    'name': 'foo',
                    'input': {'shape': 'StringShape'},
                    'output': {'shape': 'StringShape'}
                }
            },
            'shapes': {
                'StringShape': {'type': 'string'}
            }
        })
        op_name = service_model.operation_names[0]
        output_shape = service_model.operation_model(op_name).output_shape

        self.client.meta.service_model = service_model
        self.stubber.add_response('TestOperation', service_response)
        self.validate_parameters_mock.assert_called_with(
            {}, output_shape)

        # Make sure service response hasn't been mutated
        self.assertEqual(
            service_response, {'ResponseMetadata': {'foo': 'bar'}})
Ejemplo n.º 5
0
class TestCLIArgument(unittest.TestCase):
    def setUp(self):
        self.service_name = 'baz'
        self.service_model = ServiceModel(
            {
                'metadata': {
                    'endpointPrefix': 'bad',
                },
                'operations': {
                    'SampleOperation': {
                        'name': 'SampleOperation',
                        'input': {
                            'shape': 'Input'
                        }
                    }
                },
                'shapes': {
                    'StringShape': {
                        'type': 'string'
                    },
                    'Input': {
                        'type': 'structure',
                        'members': {
                            'Foo': {
                                'shape': 'StringShape'
                            }
                        }
                    }
                }
            }, self.service_name)
        self.operation_model = self.service_model.operation_model(
            'SampleOperation')
        self.argument_model = self.operation_model.input_shape.members['Foo']
        self.event_emitter = mock.Mock()

    def create_argument(self):
        return arguments.CLIArgument(self.argument_model.name,
                                     self.argument_model, self.operation_model,
                                     self.event_emitter)

    def test_unpack_uses_service_name_in_event(self):
        self.event_emitter.emit.return_value = ['value']
        argument = self.create_argument()
        params = {}
        argument.add_to_params(params, 'value')
        expected_event_name = 'process-cli-arg.%s.%s' % (self.service_name,
                                                         'sample-operation')
        actual_event_name = self.event_emitter.emit.call_args[0][0]
        self.assertEqual(actual_event_name, expected_event_name)

    def test_list_type_has_correct_nargs_value(self):
        # We don't actually care about the values, we just need a ListArgument
        # type.
        arg = arguments.ListArgument(argument_model=self.argument_model,
                                     event_emitter=self.event_emitter,
                                     is_required=True,
                                     name='test-nargs',
                                     operation_model=None,
                                     serialized_name='TestNargs')
        self.assertEqual(arg.nargs, '*')
Ejemplo n.º 6
0
    def test_validate_ignores_response_metadata(self):
        service_response = {'ResponseMetadata': {'foo': 'bar'}}
        service_model = ServiceModel({
            'documentation': '',
            'operations': {
                'foo': {
                    'name': 'foo',
                    'input': {
                        'shape': 'StringShape'
                    },
                    'output': {
                        'shape': 'StringShape'
                    }
                }
            },
            'shapes': {
                'StringShape': {
                    'type': 'string'
                }
            }
        })
        op_name = service_model.operation_names[0]
        output_shape = service_model.operation_model(op_name).output_shape

        self.client.meta.service_model = service_model
        self.stubber.add_response('TestOperation', service_response)
        self.validate_parameters_mock.assert_called_with({}, output_shape)

        # Make sure service response hasn't been mutated
        self.assertEqual(service_response,
                         {'ResponseMetadata': {
                             'foo': 'bar'
                         }})
Ejemplo n.º 7
0
    def test_RdbSerializer(self):
        rdb_model = {
            "metadata": self.rdb_model_metadata,
            "operations": {
                "RdbOperation": {
                    "http": {
                        "method": "POST",
                        "requestUri": "/"
                    },
                    "input": {
                        "shape": "RdbOperationRequest"
                    },
                    "name": "rdbOperation",
                    "output": {
                        "shape": "RdbOperationResult"
                    }
                }
            },
            "shapes": {
                "RdbOperationRequest": {
                    "members": {
                        "Parameter": {
                            "locationName": "Parameter",
                            "shape": "String"
                        }
                    },
                    "name": "RdbOperationRequest",
                    "type": "structure"
                },
                "RdbOperationResult": {
                    "members": {
                        "Response": {
                            "locationName": "Response",
                            "shape": "String"
                        }
                    },
                    "name": "RdbOperationResult",
                    "type": "structure"
                },
                "String": {
                    "name": "String",
                    "type": "string"
                },
            }
        }

        rdb_service_model = ServiceModel(rdb_model)
        params = {
            "Parameter": "test"
        }
        rdb_serializer = serialize.RdbSerializer()
        res = rdb_serializer.serialize_to_request(
            params, rdb_service_model.operation_model("RdbOperation"))
        assert res["body"] == {"Action": "RdbOperation", "Parameter": "test", "Version": "2013-05-15N2013-12-16"}
        assert res["headers"] == {"Content-Type": "application/x-www-form-urlencoded; charset=utf-8"}
        assert res["method"] == "POST"
        assert res["query_string"] == ""
        assert res["url_path"] == "/"
    def test_ComputingSerializer(self):
        computing_model = {
            "metadata": self.computing_model_metadata,
            "operations": {
                "ComputingOperation": {
                    "http": {
                        "method": "POST",
                        "requestUri": "/api/"
                    },
                    "input": {
                        "shape": "ComputingOperationRequest"
                    },
                    "name": "ComputingOperation",
                    "output": {
                        "shape": "ComputingOperationResult"
                    }
                }
            },
            "shapes": {
                "ComputingOperationRequest": {
                    "members": {
                        "Parameter": {
                            "locationName": "Parameter",
                            "shape": "String"
                        }
                    },
                    "name": "ComputingOperationRequest",
                    "type": "structure"
                },
                "ComputingOperationResult": {
                    "members": {
                        "Response": {
                            "locationName": "Response",
                            "shape": "String"
                        }
                    },
                    "name": "ComputingOperationResult",
                    "type": "structure"
                },
                "String": {
                    "name": "String",
                    "type": "string"
                },
            }
        }

        computing_service_model = ServiceModel(computing_model)
        params = {
            "Parameter": "test"
        }
        computing_serializer = serialize.ComputingSerializer()
        res = computing_serializer.serialize_to_request(
            params, computing_service_model.operation_model("ComputingOperation"))
        assert res["body"] == {"Action": "ComputingOperation", "Parameter": "test", "Version": "3.0"}
        assert res["headers"] == {"Content-Type": "application/x-www-form-urlencoded; charset=utf-8"}
        assert res["method"] == "POST"
        assert res["query_string"] == ""
        assert res["url_path"] == "/api/"
Ejemplo n.º 9
0
class TestRestXMLUnicodeSerialization(unittest.TestCase):
    def setUp(self):
        self.model = {
            'metadata': {
                'protocol': 'rest-xml',
                'apiVersion': '2014-01-01'
            },
            'documentation': '',
            'operations': {
                'TestOperation': {
                    'name': 'TestOperation',
                    'http': {
                        'method': 'POST',
                        'requestUri': '/',
                    },
                    'input': {
                        'shape': 'InputShape'
                    },
                }
            },
            'shapes': {
                'InputShape': {
                    'type': 'structure',
                    'members': {
                        'Foo': {
                            'shape': 'FooShape',
                            'locationName': 'Foo'
                        },
                    },
                    'payload': 'Foo'
                },
                'FooShape': {
                    'type': 'list',
                    'member': {
                        'shape': 'StringShape'
                    }
                },
                'StringShape': {
                    'type': 'string',
                }
            }
        }
        self.service_model = ServiceModel(self.model)

    def serialize_to_request(self, input_params):
        request_serializer = serialize.create_serializer(
            self.service_model.metadata['protocol'])
        return request_serializer.serialize_to_request(
            input_params, self.service_model.operation_model('TestOperation'))

    def test_restxml_serializes_unicode(self):
        params = {'Foo': [u'\u65e5\u672c\u8a9e\u3067\u304a\uff4b']}
        try:
            self.serialize_to_request(params)
        except UnicodeEncodeError:
            self.fail("RestXML serializer failed to serialize unicode text.")
Ejemplo n.º 10
0
class TestJSONTimestampSerialization(unittest.TestCase):

    def setUp(self):
        self.model = {
            'metadata': {'protocol': 'json', 'apiVersion': '2014-01-01',
                         'jsonVersion': '1.1', 'targetPrefix': 'foo'},
            'documentation': '',
            'operations': {
                'TestOperation': {
                    'name': 'TestOperation',
                    'http': {
                        'method': 'POST',
                        'requestUri': '/',
                    },
                    'input': {'shape': 'InputShape'},
                }
            },
            'shapes': {
                'InputShape': {
                    'type': 'structure',
                    'members': {
                        'Timestamp': {'shape': 'TimestampType'},
                    }
                },
                'TimestampType': {
                    'type': 'timestamp',
                }
            }
        }
        self.service_model = ServiceModel(self.model)

    def serialize_to_request(self, input_params):
        request_serializer = serialize.create_serializer(
            self.service_model.metadata['protocol'])
        return request_serializer.serialize_to_request(
            input_params, self.service_model.operation_model('TestOperation'))

    def test_accepts_iso_8601_format(self):
        body = json.loads(self.serialize_to_request(
            {'Timestamp': '1970-01-01T00:00:00'})['body'].decode('utf-8'))
        self.assertEqual(body['Timestamp'], 0)

    def test_accepts_epoch(self):
        body = json.loads(self.serialize_to_request(
            {'Timestamp': '0'})['body'].decode('utf-8'))
        self.assertEqual(body['Timestamp'], 0)
        # Can also be an integer 0.
        body = json.loads(self.serialize_to_request(
            {'Timestamp': 0})['body'].decode('utf-8'))
        self.assertEqual(body['Timestamp'], 0)

    def test_accepts_partial_iso_format(self):
        body = json.loads(self.serialize_to_request(
            {'Timestamp': '1970-01-01'})['body'].decode('utf-8'))
        self.assertEqual(body['Timestamp'], 0)
class TestJSONTimestampSerialization(unittest.TestCase):

    def setUp(self):
        self.model = {
            'metadata': {'protocol': 'json', 'apiVersion': '2014-01-01',
                         'jsonVersion': '1.1', 'targetPrefix': 'foo'},
            'documentation': '',
            'operations': {
                'TestOperation': {
                    'name': 'TestOperation',
                    'http': {
                        'method': 'POST',
                        'requestUri': '/',
                    },
                    'input': {'shape': 'InputShape'},
                }
            },
            'shapes': {
                'InputShape': {
                    'type': 'structure',
                    'members': {
                        'Timestamp': {'shape': 'TimestampType'},
                    }
                },
                'TimestampType': {
                    'type': 'timestamp',
                }
            }
        }
        self.service_model = ServiceModel(self.model)

    def serialize_to_request(self, input_params):
        request_serializer = serialize.create_serializer(
            self.service_model.metadata['protocol'])
        return request_serializer.serialize_to_request(
            input_params, self.service_model.operation_model('TestOperation'))

    def test_accepts_iso_8601_format(self):
        body = json.loads(self.serialize_to_request(
            {'Timestamp': '1970-01-01T00:00:00'})['body'].decode('utf-8'))
        self.assertEqual(body['Timestamp'], 0)

    def test_accepts_epoch(self):
        body = json.loads(self.serialize_to_request(
            {'Timestamp': '0'})['body'].decode('utf-8'))
        self.assertEqual(body['Timestamp'], 0)
        # Can also be an integer 0.
        body = json.loads(self.serialize_to_request(
            {'Timestamp': 0})['body'].decode('utf-8'))
        self.assertEqual(body['Timestamp'], 0)

    def test_accepts_partial_iso_format(self):
        body = json.loads(self.serialize_to_request(
            {'Timestamp': '1970-01-01'})['body'].decode('utf-8'))
        self.assertEqual(body['Timestamp'], 0)
Ejemplo n.º 12
0
def handle_resource_action(
    client_name, class_name, action, method_path, fn_name, service_model: ServiceModel, shapes_path, resource_path
):
    operation_model = service_model.operation_model(action.request.operation)
    input_shape = operation_model.input_shape

    has_output_shape = action.resource and action.resource.model.shape in service_model.shape_names
    output_shape = service_model.shape_for(action.resource.model.shape) if has_output_shape else None

    output_name = action.resource.model.name if action.resource else None
    if output_name:
        new_path = get_resource_path_for(output_name, resource_path)
        append_return_type = ' -> ' + f'[{output_name}]({new_path})'
    else:
        append_return_type = ''

    sub_res_var_name = None
    is_sub_res = is_sub_resource(resource_path)
    parameters = input_shape.members if input_shape else {}
    if is_sub_res:
        sub_res_name = resource_path[resource_path.rindex('/') + 1 :]
        sub_res_var_name = get_variable_name_for(sub_res_name)
        request_params = list(map(lambda x: x.target, action.request.params))
        if input_shape:
            include_params = {name: value for name, value in input_shape.members.items() if name not in request_params}
        else:
            include_params = {}
        param_str = get_param_str_params(input_shape, shapes_path, include_params)
    else:
        if input_shape:
            include_params = {
                name: value for name, value in input_shape.members.items() if name in input_shape.required_members
            }
        else:
            include_params = {}
        param_str = get_param_str(input_shape, shapes_path)
    signature = get_signature_string(
        client_name,
        class_name,
        input_shape,
        output_shape,
        fn_name,
        param_str,
        shapes_path,
        append_return_type,
        sub_res_var_name,
        parameters,
        include_params,
    )
    documentation = get_operation_documentation(operation_model, service_model)

    headline = f'# {fn_name} action'
    list_item = f'-  **[{fn_name}]({method_path})**({param_str}){append_return_type}'
    return list_item, signature, documentation, headline
Ejemplo n.º 13
0
 def __init__(self, service: ServiceModel) -> None:
     super().__init__(service)
     # When parsing a request, we need to lookup the operation based on the HTTP method and URI.
     # Therefore we create a mapping when the parser is initialized.
     self.operation_lookup = defaultdict(lambda: defaultdict(OperationModel))
     for operation in service.operation_names:
         operation_model = service.operation_model(operation)
         http = operation_model.http
         if len(http) > 0:
             method = http.get("method")
             request_uri = http.get("requestUri")
             self.operation_lookup[method][request_uri] = operation_model
Ejemplo n.º 14
0
class TestRestXMLUnicodeSerialization(unittest.TestCase):

    def setUp(self):
        self.model = {
            'metadata': {'protocol': 'rest-xml', 'apiVersion': '2014-01-01'},
            'documentation': '',
            'operations': {
                'TestOperation': {
                    'name': 'TestOperation',
                    'http': {
                        'method': 'POST',
                        'requestUri': '/',
                    },
                    'input': {'shape': 'InputShape'},
                }
            },
            'shapes': {
                'InputShape': {
                    'type': 'structure',
                    'members': {
                        'Foo': {
                            'shape': 'FooShape',
                            'locationName': 'Foo'
                        },
                    },
                    'payload': 'Foo'
                },
                'FooShape': {
                    'type': 'list',
                    'member': {'shape': 'StringShape'}
                },
                'StringShape': {
                    'type': 'string',
                }
            }
        }
        self.service_model = ServiceModel(self.model)

    def serialize_to_request(self, input_params):
        request_serializer = serialize.create_serializer(
            self.service_model.metadata['protocol'])
        return request_serializer.serialize_to_request(
            input_params, self.service_model.operation_model('TestOperation'))

    def test_restxml_serializes_unicode(self):
        params = {
            'Foo': [u'\u65e5\u672c\u8a9e\u3067\u304a\uff4b']
        }
        try:
            self.serialize_to_request(params)
        except UnicodeEncodeError:
            self.fail("RestXML serializer failed to serialize unicode text.")
Ejemplo n.º 15
0
class TestCLIArgument(unittest.TestCase):
    def setUp(self):
        self.service_name = 'baz'
        self.service_model = ServiceModel(
            {
                'metadata': {
                    'endpointPrefix': 'bad',
                },
                'operations': {
                    'SampleOperation': {
                        'name': 'SampleOperation',
                        'input': {
                            'shape': 'Input'
                        }
                    }
                },
                'shapes': {
                    'StringShape': {
                        'type': 'string'
                    },
                    'Input': {
                        'type': 'structure',
                        'members': {
                            'Foo': {
                                'shape': 'StringShape'
                            }
                        }
                    }
                }
            }, self.service_name)
        self.operation_model = self.service_model.operation_model(
            'SampleOperation')
        self.argument_model = self.operation_model.input_shape.members['Foo']
        self.event_emitter = mock.Mock()

    def create_argument(self):
        return arguments.CLIArgument(self.argument_model.name,
                                     self.argument_model, self.operation_model,
                                     self.event_emitter)

    def test_unpack_uses_service_name_in_event(self):
        self.event_emitter.emit.return_value = ['value']
        argument = self.create_argument()
        params = {}
        argument.add_to_params(params, 'value')
        expected_event_name = 'process-cli-arg.%s.%s' % (self.service_name,
                                                         'sample-operation')
        actual_event_name = self.event_emitter.emit.call_args[0][0]
        self.assertEqual(actual_event_name, expected_event_name)
Ejemplo n.º 16
0
class TestBinaryTypesJSON(unittest.TestCase):
    def setUp(self):
        self.model = {
            'metadata': {
                'protocol': 'json',
                'apiVersion': '2014-01-01',
                'jsonVersion': '1.1',
                'targetPrefix': 'foo'
            },
            'documentation': '',
            'operations': {
                'TestOperation': {
                    'name': 'TestOperation',
                    'http': {
                        'method': 'POST',
                        'requestUri': '/',
                    },
                    'input': {
                        'shape': 'InputShape'
                    },
                }
            },
            'shapes': {
                'InputShape': {
                    'type': 'structure',
                    'members': {
                        'Blob': {
                            'shape': 'BlobType'
                        },
                    }
                },
                'BlobType': {
                    'type': 'blob',
                }
            }
        }
        self.service_model = ServiceModel(self.model)

    def serialize_to_request(self, input_params):
        request_serializer = serialize.create_serializer(
            self.service_model.metadata['protocol'])
        return request_serializer.serialize_to_request(
            input_params, self.service_model.operation_model('TestOperation'))

    def test_blob_accepts_bytes_type(self):
        body = b'bytes body'
        self.serialize_to_request(input_params={'Blob': body})
Ejemplo n.º 17
0
def _create_service_map(service: ServiceModel) -> Map:
    """
    Creates a Werkzeug Map object with all rules necessary for the specific service.
    :param service: botocore service model to create the rules for
    :return: a Map instance which is used to perform the in-service operation routing
    """
    ops = [
        service.operation_model(op_name) for op_name in service.operation_names
    ]

    rules = []

    # group all operations by their path and method
    path_index: Dict[(str, str), List[_HttpOperation]] = defaultdict(list)
    for op in ops:
        http_op = _HttpOperation.from_operation(op)
        path_index[(http_op.path, http_op.method)].append(http_op)

    # create a matching rule for each (path, method) combination
    for (path, method), ops in path_index.items():
        # translate the requestUri to a Werkzeug rule string
        rule_string = _path_param_regex.sub(
            _transform_path_params_to_rule_vars, path)

        if len(ops) == 1:
            # if there is only a single operation for a (path, method) combination,
            # the default Werkzeug rule can be used directly (this is the case for most rules)
            op = ops[0]
            rules.append(
                _StrictMethodRule(string=rule_string,
                                  method=method,
                                  endpoint=op.operation))  # type: ignore
        else:
            # if there is an ambiguity with only the (path, method) combination,
            # a custom rule - which can use additional request metadata - needs to be used
            rules.append(
                _RequestMatchingRule(string=rule_string,
                                     method=method,
                                     operations=ops))

    return Map(
        rules=rules,
        strict_slashes=False,
        merge_slashes=False,
        converters={"path": GreedyPathConverter},
    )
Ejemplo n.º 18
0
    def test_decode_json_policy(self):
        parsed = {
            'Document': '{"foo": "foobarbaz"}',
            'Other': 'bar',
        }
        service_def = {
            'operations': {
                'Foo': {
                    'output': {
                        'shape': 'PolicyOutput'
                    },
                }
            },
            'shapes': {
                'PolicyOutput': {
                    'type': 'structure',
                    'members': {
                        'Document': {
                            'shape': 'policyDocumentType'
                        },
                        'Other': {
                            'shape': 'stringType'
                        }
                    }
                },
                'policyDocumentType': {
                    'type': 'string'
                },
                'stringType': {
                    'type': 'string'
                },
            }
        }
        model = ServiceModel(service_def)
        op_model = model.operation_model('Foo')
        handlers.json_decode_policies(parsed, op_model)
        self.assertEqual(parsed['Document'], {'foo': 'foobarbaz'})

        no_document = {'Other': 'bar'}
        handlers.json_decode_policies(no_document, op_model)
        self.assertEqual(no_document, {'Other': 'bar'})
Ejemplo n.º 19
0
    def test_decode_json_policy(self):
        parsed = {"Document": '{"foo": "foobarbaz"}', "Other": "bar"}
        service_def = {
            "operations": {"Foo": {"output": {"shape": "PolicyOutput"}}},
            "shapes": {
                "PolicyOutput": {
                    "type": "structure",
                    "members": {"Document": {"shape": "policyDocumentType"}, "Other": {"shape": "stringType"}},
                },
                "policyDocumentType": {"type": "string"},
                "stringType": {"type": "string"},
            },
        }
        model = ServiceModel(service_def)
        op_model = model.operation_model("Foo")
        handlers.json_decode_policies(parsed, op_model)
        self.assertEqual(parsed["Document"], {"foo": "foobarbaz"})

        no_document = {"Other": "bar"}
        handlers.json_decode_policies(no_document, op_model)
        self.assertEqual(no_document, {"Other": "bar"})
Ejemplo n.º 20
0
    def test_decode_json_policy(self):
        parsed = {
            'Document': '{"foo": "foobarbaz"}',
            'Other': 'bar',
        }
        service_def = {
            'operations': {
                'Foo': {
                    'output': {'shape': 'PolicyOutput'},
                }
            },
            'shapes': {
                'PolicyOutput': {
                    'type': 'structure',
                    'members': {
                        'Document': {
                            'shape': 'policyDocumentType'
                        },
                        'Other': {
                            'shape': 'stringType'
                        }
                    }
                },
                'policyDocumentType': {
                    'type': 'string'
                },
                'stringType': {
                    'type': 'string'
                },
            }
        }
        model = ServiceModel(service_def)
        op_model = model.operation_model('Foo')
        handlers.json_decode_policies(parsed, op_model)
        self.assertEqual(parsed['Document'], {'foo': 'foobarbaz'})

        no_document = {'Other': 'bar'}
        handlers.json_decode_policies(no_document, op_model)
        self.assertEqual(no_document, {'Other': 'bar'})
Ejemplo n.º 21
0
class TestBinaryTypesJSON(unittest.TestCase):
    def setUp(self):
        self.model = {
            'metadata': {'protocol': 'json', 'apiVersion': '2014-01-01',
                         'jsonVersion': '1.1', 'targetPrefix': 'foo'},
            'documentation': '',
            'operations': {
                'TestOperation': {
                    'name': 'TestOperation',
                    'http': {
                        'method': 'POST',
                        'requestUri': '/',
                    },
                    'input': {'shape': 'InputShape'},
                }
            },
            'shapes': {
                'InputShape': {
                    'type': 'structure',
                    'members': {
                        'Blob': {'shape': 'BlobType'},
                    }
                },
                'BlobType': {
                    'type': 'blob',
                }
            }
        }
        self.service_model = ServiceModel(self.model)

    def serialize_to_request(self, input_params):
        request_serializer = serialize.create_serializer(
            self.service_model.metadata['protocol'])
        return request_serializer.serialize_to_request(
            input_params, self.service_model.operation_model('TestOperation'))

    def test_blob_accepts_bytes_type(self):
        body = b'bytes body'
        self.serialize_to_request(input_params={'Blob': body})
Ejemplo n.º 22
0
 def __init__(self, service: ServiceModel) -> None:
     super().__init__(service)
     # When parsing a request, we need to lookup the operation based on the HTTP method and URI.
     # We create a mapping when the parser is initialized.
     # Since the path can contain URI path parameters, the key of the dict is a regex.
     self.operation_lookup: DefaultDict[str, OrderedDict[
         Pattern[str], OperationModel]] = defaultdict(lambda: OrderedDict())
     # Extract all operation models from the service spec
     operation_models = [
         service.operation_model(operation)
         for operation in service.operation_names
     ]
     # Sort the operation models descending based on their normalized request URIs.
     # This is necessary, to ensure that greedy regex matches do not cause wrong method lookups.
     # f.e. /fuu/{bar}/baz should have precedence over /fuu/{bar}.
     sorted_operation_models = sorted(
         operation_models,
         key=self._get_normalized_request_uri_length,
         reverse=True)
     for operation_model in sorted_operation_models:
         http = operation_model.http
         method = http.get("method")
         request_uri_regex = self._get_request_uri_regex(operation_model)
         self.operation_lookup[method][request_uri_regex] = operation_model
Ejemplo n.º 23
0
    def test_no_output(self):
        service_model = ServiceModel({
            'operations': {
                'SampleOperation': {
                    'name': 'SampleOperation',
                    'input': {'shape': 'SampleOperationInputOutput'},
                }
            },
            'shapes': {
                'SampleOperationInput': {
                    'type': 'structure',
                    'members': {}
                },
                'String': {
                    'type': 'string'
                }
            }
        })
        operation_model = service_model.operation_model('SampleOperation')

        parsed = {}
        self.injector.inject_attribute_value_output(
            parsed=parsed, model=operation_model)
        self.assertEqual(parsed, {})
Ejemplo n.º 24
0
class TestEndpointDiscoveryHandler(BaseEndpointDiscoveryTest):
    def setUp(self):
        super(TestEndpointDiscoveryHandler, self).setUp()
        self.manager = mock.Mock(spec=EndpointDiscoveryManager)
        self.handler = EndpointDiscoveryHandler(self.manager)
        self.service_model = ServiceModel(self.service_description)

    def test_register_handler(self):
        events = mock.Mock(spec=HierarchicalEmitter)
        self.handler.register(events, 'foo-bar')
        events.register.assert_any_call('before-parameter-build.foo-bar',
                                        self.handler.gather_identifiers)
        events.register.assert_any_call('needs-retry.foo-bar',
                                        self.handler.handle_retries)
        events.register_first.assert_called_with(
            'request-created.foo-bar', self.handler.discover_endpoint)

    def test_discover_endpoint(self):
        request = AWSRequest()
        request.context = {'discovery': {'identifiers': {}}}
        self.manager.describe_endpoint.return_value = 'https://new.foo'
        self.handler.discover_endpoint(request, 'TestOperation')
        self.assertEqual(request.url, 'https://new.foo')
        self.manager.describe_endpoint.assert_called_with(
            Operation='TestOperation', Identifiers={})

    def test_discover_endpoint_fails(self):
        request = AWSRequest()
        request.context = {'discovery': {'identifiers': {}}}
        request.url = 'old.com'
        self.manager.describe_endpoint.return_value = None
        self.handler.discover_endpoint(request, 'TestOperation')
        self.assertEqual(request.url, 'old.com')
        self.manager.describe_endpoint.assert_called_with(
            Operation='TestOperation', Identifiers={})

    def test_discover_endpoint_no_protocol(self):
        request = AWSRequest()
        request.context = {'discovery': {'identifiers': {}}}
        self.manager.describe_endpoint.return_value = 'new.foo'
        self.handler.discover_endpoint(request, 'TestOperation')
        self.assertEqual(request.url, 'https://new.foo')
        self.manager.describe_endpoint.assert_called_with(
            Operation='TestOperation', Identifiers={})

    def test_inject_no_context(self):
        request = AWSRequest(url='https://original.foo')
        self.handler.discover_endpoint(request, 'TestOperation')
        self.assertEqual(request.url, 'https://original.foo')
        self.manager.describe_endpoint.assert_not_called()

    def test_gather_identifiers(self):
        context = {}
        params = {'Foo': 'value1', 'Nested': {'Bar': 'value2'}}
        ids = {'Foo': 'value1', 'Bar': 'value2'}
        model = self.service_model.operation_model('TestDiscoveryRequired')
        self.manager.gather_identifiers.return_value = ids
        self.handler.gather_identifiers(params, model, context)
        self.assertEqual(context['discovery']['identifiers'], ids)

    def test_gather_identifiers_not_discoverable(self):
        context = {}
        model = self.service_model.operation_model('DescribeEndpoints')
        self.handler.gather_identifiers({}, model, context)
        self.assertEqual(context, {})

    def test_discovery_disabled_but_required(self):
        model = self.service_model.operation_model('TestDiscoveryRequired')
        with self.assertRaises(EndpointDiscoveryRequired):
            block_endpoint_discovery_required_operations(model)

    def test_discovery_disabled_but_optional(self):
        context = {}
        model = self.service_model.operation_model('TestDiscoveryOptional')
        block_endpoint_discovery_required_operations(model, context=context)
        self.assertEqual(context, {})

    def test_does_not_retry_no_response(self):
        retry = self.handler.handle_retries(None, None, None)
        self.assertIsNone(retry)

    def test_does_not_retry_other_errors(self):
        parsed_response = {'ResponseMetadata': {'HTTPStatusCode': 200}}
        response = (None, parsed_response)
        retry = self.handler.handle_retries(None, response, None)
        self.assertIsNone(retry)

    def test_does_not_retry_if_no_context(self):
        request_dict = {'context': {}}
        parsed_response = {'ResponseMetadata': {'HTTPStatusCode': 421}}
        response = (None, parsed_response)
        retry = self.handler.handle_retries(request_dict, response, None)
        self.assertIsNone(retry)

    def _assert_retries(self, parsed_response):
        request_dict = {'context': {'discovery': {'identifiers': {}}}}
        response = (None, parsed_response)
        model = self.service_model.operation_model('TestDiscoveryOptional')
        retry = self.handler.handle_retries(request_dict, response, model)
        self.assertEqual(retry, 0)
        self.manager.delete_endpoints.assert_called_with(
            Operation='TestDiscoveryOptional', Identifiers={})

    def test_retries_421_status_code(self):
        parsed_response = {'ResponseMetadata': {'HTTPStatusCode': 421}}
        self._assert_retries(parsed_response)

    def test_retries_invalid_endpoint_exception(self):
        parsed_response = {'Error': {'Code': 'InvalidEndpointException'}}
        self._assert_retries(parsed_response)
Ejemplo n.º 25
0
 def serialize_to_request(self, input_params):
     service_model = ServiceModel(self.model)
     request_serializer = serialize.create_serializer(
         service_model.metadata['protocol'])
     return request_serializer.serialize_to_request(
         input_params, service_model.operation_model('TestOperation'))
Ejemplo n.º 26
0
def generate_service_api(output, service: ServiceModel, doc=True):
    service_name = service.service_name.replace("-", "_")
    class_name = service_name + "_api"
    class_name = snake_to_camel_case(class_name)

    output.write(f"class {class_name}:\n")
    output.write("\n")
    output.write(f'    service = "{service.service_name}"\n')
    output.write(f'    version = "{service.api_version}"\n')
    for op_name in service.operation_names:
        operation: OperationModel = service.operation_model(op_name)

        fn_name = camel_to_snake_case(op_name)

        if operation.output_shape:
            output_shape = operation.output_shape.name
        else:
            output_shape = "None"

        output.write("\n")
        parameters = OrderedDict()
        param_shapes = OrderedDict()

        input_shape = operation.input_shape
        if input_shape is not None:
            members = list(input_shape.members)
            for m in input_shape.required_members:
                members.remove(m)
                m_shape = input_shape.members[m]
                parameters[xform_name(m)] = m_shape.name
                param_shapes[xform_name(m)] = m_shape
            for m in members:
                m_shape = input_shape.members[m]
                param_shapes[xform_name(m)] = m_shape
                parameters[xform_name(m)] = f"{m_shape.name} = None"

        if any(map(is_bad_param_name, parameters.keys())):
            # if we cannot render the parameter name, don't expand the parameters in the handler
            param_list = f"request: {input_shape.name}" if input_shape else ""
            output.write(f'    @handler("{operation.name}", expand=False)\n')
        else:
            param_list = ", ".join([f"{k}: {v}" for k, v in parameters.items()])
            output.write(f'    @handler("{operation.name}")\n')

        output.write(
            f"    def {fn_name}(self, context: RequestContext, {param_list}) -> {output_shape}:\n"
        )

        # convert html documentation to rst and print it into to the signature
        if doc:
            html = operation.documentation
            import pypandoc

            doc = pypandoc.convert_text(html, "rst", format="html")
            output.write('        """')
            output.write(f"{doc.strip()}\n")
            output.write("\n")

            # parameters
            for param_name, shape in param_shapes.items():
                # FIXME: this doesn't work properly
                pdoc = pypandoc.convert_text(shape.documentation, "rst", format="html")
                pdoc = pdoc.strip().split(".")[0] + "."
                output.write(f":param {param_name}: {pdoc}\n")

            # return value
            if operation.output_shape:
                output.write(f":returns: {operation.output_shape.name}\n")

            # errors
            for error in operation.error_shapes:
                output.write(f":raises {error.name}:\n")

            output.write('        """\n')

        output.write("        raise NotImplementedError\n")
Ejemplo n.º 27
0
class TestTimestamps(unittest.TestCase):
    def setUp(self):
        self.model = {
            'metadata': {'protocol': 'query', 'apiVersion': '2014-01-01'},
            'documentation': '',
            'operations': {
                'TestOperation': {
                    'name': 'TestOperation',
                    'http': {
                        'method': 'POST',
                        'requestUri': '/',
                    },
                    'input': {'shape': 'InputShape'},
                }
            },
            'shapes': {
                'InputShape': {
                    'type': 'structure',
                    'members': {
                        'Timestamp': {'shape': 'TimestampType'},
                    }
                },
                'TimestampType': {
                    'type': 'timestamp',
                }
            }
        }
        self.service_model = ServiceModel(self.model)

    def serialize_to_request(self, input_params):
        request_serializer = serialize.create_serializer(
            self.service_model.metadata['protocol'])
        return request_serializer.serialize_to_request(
            input_params, self.service_model.operation_model('TestOperation'))

    def test_accepts_datetime_object(self):
        request = self.serialize_to_request(
            {'Timestamp': datetime.datetime(2014, 1, 1, 12, 12, 12,
                                            tzinfo=dateutil.tz.tzutc())})
        self.assertEqual(request['body']['Timestamp'], '2014-01-01T12:12:12Z')

    def test_accepts_naive_datetime_object(self):
        request = self.serialize_to_request(
            {'Timestamp': datetime.datetime(2014, 1, 1, 12, 12, 12)})
        self.assertEqual(request['body']['Timestamp'], '2014-01-01T12:12:12Z')

    def test_accepts_iso_8601_format(self):
        request = self.serialize_to_request({'Timestamp': '2014-01-01T12:12:12Z'})
        self.assertEqual(request['body']['Timestamp'], '2014-01-01T12:12:12Z')

    def test_accepts_timestamp_without_tz_info(self):
        # If a timezone/utc is not specified, assume they meant
        # UTC.  This is also the previous behavior from older versions
        # of botocore so we want to make sure we preserve this behavior.
        request = self.serialize_to_request({'Timestamp': '2014-01-01T12:12:12'})
        self.assertEqual(request['body']['Timestamp'], '2014-01-01T12:12:12Z')

    def test_microsecond_timestamp_without_tz_info(self):
        request = self.serialize_to_request(
            {'Timestamp': '2014-01-01T12:12:12.123456'})
        self.assertEqual(request['body']['Timestamp'],
                         '2014-01-01T12:12:12.123456Z')
Ejemplo n.º 28
0
class TestInstanceCreation(unittest.TestCase):
    def setUp(self):
        self.model = {
            'metadata': {'protocol': 'query', 'apiVersion': '2014-01-01'},
            'documentation': '',
            'operations': {
                'TestOperation': {
                    'name': 'TestOperation',
                    'http': {
                        'method': 'POST',
                        'requestUri': '/',
                    },
                    'input': {'shape': 'InputShape'},
                }
            },
            'shapes': {
                'InputShape': {
                    'type': 'structure',
                    'members': {
                        'Timestamp': {'shape': 'StringTestType'},
                    }
                },
                'StringTestType': {
                    'type': 'string',
                    'min': 15
                }
            }
        }
        self.service_model = ServiceModel(self.model)

    def assert_serialize_valid_parameter(self, request_serializer):
        valid_string = 'valid_string_with_min_15_chars'
        request = request_serializer.serialize_to_request(
            {'Timestamp': valid_string},
            self.service_model.operation_model('TestOperation'))

        self.assertEqual(request['body']['Timestamp'], valid_string)

    def assert_serialize_invalid_parameter(self, request_serializer):
        invalid_string = 'short string'
        request = request_serializer.serialize_to_request(
            {'Timestamp': invalid_string},
            self.service_model.operation_model('TestOperation'))

        self.assertEqual(request['body']['Timestamp'], invalid_string)

    def test_instantiate_without_validation(self):
        request_serializer = serialize.create_serializer(
            self.service_model.metadata['protocol'], False)

        try:
            self.assert_serialize_valid_parameter(request_serializer)
        except ParamValidationError as e:
            self.fail("Shouldn't fail serializing valid parameter without validation")

        try:
            self.assert_serialize_invalid_parameter(request_serializer)
        except ParamValidationError as e:
            self.fail("Shouldn't fail serializing invalid parameter without validation")

    def test_instantiate_with_validation(self):
        request_serializer = serialize.create_serializer(
            self.service_model.metadata['protocol'], True)
        try:
            self.assert_serialize_valid_parameter(request_serializer)
        except ParamValidationError as e:
            self.fail("Shouldn't fail serializing valid parameter with validation")

        with self.assertRaises(ParamValidationError):
            self.assert_serialize_invalid_parameter(request_serializer)
Ejemplo n.º 29
0
 def serialize_to_request(self, input_params):
     service_model = ServiceModel(self.model)
     request_serializer = serialize.create_serializer(
         service_model.metadata['protocol'])
     return request_serializer.serialize_to_request(
         input_params, service_model.operation_model('TestOperation'))
 def test_ComputingSerializer_fix_describe_load_balancers_params_with_loadbalancer_name(self):
     computing_model = {
         "metadata": self.computing_model_metadata,
         "operations": {
             "DescribeLoadBalancers": {
                 "http": {
                     "method": "POST",
                     "requestUri": "/api/"
                 },
                 "input": {
                     "shape": "DescribeLoadBalancersRequest"
                 },
                 "name": "DescribeLoadBalancers",
                 "output": {
                     "shape": "ComputingOperationResult"
                 }
             }
         },
         "shapes": {
             "DescribeLoadBalancersRequest": {
                 "members": {
                     "LoadBalancerNames": {
                         "locationName": "LoadBalancerNames",
                         "shape": "ListOfRequestLoadBalancerNames"
                     },
                     "Patamator": {
                         "locationName": "Patamator",
                         "shape": "String"
                     }
                 },
                 "name": "DescribeLoadBalancersRequest",
                 "type": "structure"
             },
             "ListOfRequestLoadBalancerNames": {
                 "member": {
                     "locationName": "member",
                     "shape": "RequestLoadBalancerNames"
                 },
                 "name": "ListOfRequestLoadBalancerNames",
                 "type": "list"
             },
             "RequestLoadBalancerNames": {
                 "members": {
                     "InstancePort": {
                         "locationName": "InstancePort",
                         "shape": "Integer"
                     },
                     "LoadBalancerName": {
                         "locationName": "LoadBalancerName",
                         "shape": "String"
                     },
                     "LoadBalancerPort": {
                         "locationName": "LoadBalancerPort",
                         "shape": "Integer"
                     }
                 },
                 "name": "RequestLoadBalancerNames",
                 "type": "structure"
             },
             "ComputingOperationResult": {
                 "members": {
                     "Response": {
                         "locationName": "Response",
                         "shape": "String"
                     }
                 },
                 "name": "ComputingOperationResult",
                 "type": "structure"
             },
             "String": {
                 "name": "String",
                 "type": "string"
             },
             "Integer": {
                 "name": "Integer",
                 "type": "integer"
             },
         }
     }
     computing_service_model = ServiceModel(computing_model)
     params = {
         "LoadBalancerNames": [
                 {
                     "LoadBalancerName": "test_load_balancer_name",
                     "LoadBalancerPort": "test_load_balancer_port",
                     "InstancePort": "test_instance_port"
                 }
             ]
     }
     computing_serializer = serialize.ComputingSerializer()
     res = computing_serializer.serialize_to_request(
         params, computing_service_model.operation_model("DescribeLoadBalancers"))
     assert res["body"] == {
         "Action": "DescribeLoadBalancers",
         "Version": "3.0",
         "LoadBalancerNames.member.1": "test_load_balancer_name",
         "LoadBalancerNames.LoadBalancerPort.1": "test_load_balancer_port",
         "LoadBalancerNames.InstancePort.1": "test_instance_port"
     }
     assert res["headers"] == {"Content-Type": "application/x-www-form-urlencoded; charset=utf-8"}
     assert res["method"] == "POST"
     assert res["query_string"] == ""
     assert res["url_path"] == "/api/"
Ejemplo n.º 31
0
class TestTimestamps(unittest.TestCase):
    def setUp(self):
        self.model = {
            'metadata': {
                'protocol': 'query',
                'apiVersion': '2014-01-01'
            },
            'documentation': '',
            'operations': {
                'TestOperation': {
                    'name': 'TestOperation',
                    'http': {
                        'method': 'POST',
                        'requestUri': '/',
                    },
                    'input': {
                        'shape': 'InputShape'
                    },
                }
            },
            'shapes': {
                'InputShape': {
                    'type': 'structure',
                    'members': {
                        'Timestamp': {
                            'shape': 'TimestampType'
                        },
                    }
                },
                'TimestampType': {
                    'type': 'timestamp',
                }
            }
        }
        self.service_model = ServiceModel(self.model)

    def serialize_to_request(self, input_params):
        request_serializer = serialize.create_serializer(
            self.service_model.metadata['protocol'])
        return request_serializer.serialize_to_request(
            input_params, self.service_model.operation_model('TestOperation'))

    def test_accepts_datetime_object(self):
        request = self.serialize_to_request({
            'Timestamp':
            datetime.datetime(2014,
                              1,
                              1,
                              12,
                              12,
                              12,
                              tzinfo=dateutil.tz.tzutc())
        })
        self.assertEqual(request['body']['Timestamp'], '2014-01-01T12:12:12Z')

    def test_accepts_naive_datetime_object(self):
        request = self.serialize_to_request(
            {'Timestamp': datetime.datetime(2014, 1, 1, 12, 12, 12)})
        self.assertEqual(request['body']['Timestamp'], '2014-01-01T12:12:12Z')

    def test_accepts_iso_8601_format(self):
        request = self.serialize_to_request(
            {'Timestamp': '2014-01-01T12:12:12Z'})
        self.assertEqual(request['body']['Timestamp'], '2014-01-01T12:12:12Z')

    def test_accepts_timestamp_without_tz_info(self):
        # If a timezone/utc is not specified, assume they meant
        # UTC.  This is also the previous behavior from older versions
        # of botocore so we want to make sure we preserve this behavior.
        request = self.serialize_to_request(
            {'Timestamp': '2014-01-01T12:12:12'})
        self.assertEqual(request['body']['Timestamp'], '2014-01-01T12:12:12Z')

    def test_microsecond_timestamp_without_tz_info(self):
        request = self.serialize_to_request(
            {'Timestamp': '2014-01-01T12:12:12.123456'})
        self.assertEqual(request['body']['Timestamp'],
                         '2014-01-01T12:12:12.123456Z')
Ejemplo n.º 32
0
class ShapeParser:
    """
    Parser for botocore shape files.

    Arguments:
        session -- Boto3 session.
        service_name -- ServiceName.
    """

    # Type map for shape types.
    SHAPE_TYPE_MAP: Mapping[str, FakeAnnotation] = {
        "integer":
        Type.int,
        "long":
        Type.int,
        "boolean":
        Type.bool,
        "double":
        Type.float,
        "float":
        Type.float,
        "timestamp":
        TypeSubscript(Type.Union, [Type.datetime, Type.str]),
        "blob":
        TypeSubscript(Type.Union,
                      [Type.bytes, Type.IOBytes,
                       TypeClass(StreamingBody)]),
        "blob_streaming":
        TypeSubscript(Type.Union,
                      [Type.bytes, Type.IOBytes,
                       TypeClass(StreamingBody)]),
    }

    OUTPUT_SHAPE_TYPE_MAP: Mapping[str, FakeAnnotation] = {
        "timestamp": Type.datetime,
        "blob": Type.bytes,
        "blob_streaming": TypeClass(StreamingBody),
    }

    # Alias map fixes added by botocore for documentation build.
    # https://github.com/boto/botocore/blob/develop/botocore/handlers.py#L773
    # https://github.com/boto/botocore/blob/develop/botocore/handlers.py#L1055
    ARGUMENT_ALIASES: dict[str, dict[str, dict[str, str]]] = {
        ServiceNameCatalog.cloudsearchdomain.boto3_name: {
            "Search": {
                "return": "returnFields"
            }
        },
        ServiceNameCatalog.logs.boto3_name: {
            "CreateExportTask": {
                "from": "fromTime"
            }
        },
        ServiceNameCatalog.ec2.boto3_name: {
            "*": {
                "Filter": "Filters"
            }
        },
        ServiceNameCatalog.s3.boto3_name: {
            "PutBucketAcl": {
                "ContentMD5": "None"
            },
            "PutBucketCors": {
                "ContentMD5": "None"
            },
            "PutBucketLifecycle": {
                "ContentMD5": "None"
            },
            "PutBucketLogging": {
                "ContentMD5": "None"
            },
            "PutBucketNotification": {
                "ContentMD5": "None"
            },
            "PutBucketPolicy": {
                "ContentMD5": "None"
            },
            "PutBucketReplication": {
                "ContentMD5": "None"
            },
            "PutBucketRequestPayment": {
                "ContentMD5": "None"
            },
            "PutBucketTagging": {
                "ContentMD5": "None"
            },
            "PutBucketVersioning": {
                "ContentMD5": "None"
            },
            "PutBucketWebsite": {
                "ContentMD5": "None"
            },
            "PutObjectAcl": {
                "ContentMD5": "None"
            },
        },
    }

    def __init__(self, session: Session, service_name: ServiceName):
        loader = session._loader
        botocore_session: BotocoreSession = session._session
        service_data = botocore_session.get_service_data(
            service_name.boto3_name)
        self.service_name = service_name
        self.service_model = ServiceModel(service_data,
                                          service_name.boto3_name)
        self._typed_dict_map: dict[str, TypeTypedDict] = {}
        self._waiters_shape: Mapping[str, Any] | None = None
        try:
            self._waiters_shape = loader.load_service_model(
                service_name.boto3_name, "waiters-2")
        except UnknownServiceError:
            pass
        self._paginators_shape: Mapping[str, Any] | None = None
        try:
            self._paginators_shape = loader.load_service_model(
                service_name.boto3_name, "paginators-1")
        except UnknownServiceError:
            pass
        self._resources_shape: Mapping[str, Any] | None = None
        try:
            self._resources_shape = loader.load_service_model(
                service_name.boto3_name, "resources-1")
        except UnknownServiceError:
            pass

        self.logger = get_logger()
        self.response_metadata_typed_dict = TypeTypedDict(
            "ResponseMetadataTypeDef",
            [
                TypedDictAttribute("RequestId", Type.str, True),
                TypedDictAttribute("HostId", Type.str, True),
                TypedDictAttribute("HTTPStatusCode", Type.int, True),
                TypedDictAttribute("HTTPHeaders", Type.DictStrStr, True),
                TypedDictAttribute("RetryAttempts", Type.int, True),
            ],
        )
        self.proxy_operation_model = OperationModel({}, self.service_model)

    def _get_operation(self, name: str) -> OperationModel:
        return self.service_model.operation_model(name)

    def _get_operation_names(self) -> list[str]:
        return list(self.service_model.operation_names)

    def _get_paginator(self, name: str) -> dict[str, Any]:
        if not self._paginators_shape:
            raise ShapeParserError(f"Unknown paginator: {name}")
        try:
            return self._paginators_shape["pagination"][name]
        except KeyError as e:
            raise ShapeParserError(f"Unknown paginator: {name}") from e

    def _get_service_resource(self) -> dict[str, Any]:
        if not self._resources_shape:
            raise ShapeParserError("Resource shape not found")
        return self._resources_shape["service"]

    def _get_resource_shape(self, name: str) -> dict[str, Any]:
        if not self._resources_shape:
            raise ShapeParserError("Resource shape not found")
        try:
            return self._resources_shape["resources"][name]
        except KeyError as e:
            raise ShapeParserError(f"Unknown resource: {name}") from e

    def get_paginator_names(self) -> list[str]:
        """
        Get available paginator names.

        Returns:
            A list of paginator names.
        """
        result: list[str] = []
        if self._paginators_shape:
            for name in self._paginators_shape.get("pagination", []):
                result.append(name)
        result.sort()
        return result

    def _get_argument_alias(self, operation_name: str,
                            argument_name: str) -> str:
        service_map = self.ARGUMENT_ALIASES.get(self.service_name.boto3_name)
        if not service_map:
            return argument_name

        operation_map: dict[str, str] = {}
        if "*" in service_map:
            operation_map = service_map["*"]
        if operation_name in service_map:
            operation_map = service_map[operation_name]

        if not operation_map:
            return argument_name

        if argument_name not in operation_map:
            return argument_name

        return operation_map[argument_name]

    def _parse_arguments(
        self,
        class_name: str,
        method_name: str,
        operation_name: str,
        shape: StructureShape,
        exclude_names: Iterable[str] = tuple(),
        optional_only: bool = False,
    ) -> list[Argument]:
        result: list[Argument] = []
        required = shape.required_members
        for argument_name, argument_shape in shape.members.items():
            if argument_name in exclude_names:
                continue
            argument_alias = self._get_argument_alias(operation_name,
                                                      argument_name)
            if argument_alias == "None":
                continue

            argument_type_stub = get_method_type_stub(self.service_name,
                                                      class_name, method_name,
                                                      argument_name)
            if argument_type_stub is Type.RemoveArgument:
                continue
            if argument_type_stub is not None:
                argument_type = argument_type_stub
            else:
                argument_type = self.parse_shape(argument_shape)
            argument = Argument(argument_alias, argument_type)
            if argument_name not in required:
                argument.default = Type.Ellipsis
            if optional_only and argument.required:
                continue

            # FIXME: https://github.com/boto/boto3/issues/2813
            # if not argument.required and argument.type_annotation:
            #     argument.type_annotation = Type.get_optional(argument.type_annotation)

            result.append(argument)

        result.sort(key=lambda x: not x.required)
        return result

    def _parse_return_type(self, class_name: str, method_name: str,
                           shape: Shape | None) -> FakeAnnotation:
        argument_type_stub = get_method_type_stub(self.service_name,
                                                  class_name, method_name,
                                                  "return")
        if argument_type_stub is not None:
            return argument_type_stub

        if shape:
            return self.parse_shape(shape, output=True)

        return Type.none

    @staticmethod
    def _get_kw_flags(method_name: str,
                      arguments: Sequence[Argument]) -> list[Argument]:
        if len(arguments) and not method_name[0].isupper():
            return [Argument.kwflag()]

        return []

    def get_client_method_map(self) -> dict[str, Method]:
        """
        Get client methods from shape.

        Returns:
            A map of method name to Method.
        """
        result: dict[str, Method] = {
            "can_paginate":
            Method(
                "can_paginate",
                [Argument("self", None),
                 Argument("operation_name", Type.str)],
                Type.bool,
            ),
            "generate_presigned_url":
            Method(
                "generate_presigned_url",
                [
                    Argument("self", None),
                    Argument("ClientMethod", Type.str),
                    Argument("Params", Type.MappingStrAny, Type.Ellipsis),
                    Argument("ExpiresIn", Type.int, TypeConstant(3600)),
                    Argument("HttpMethod", Type.str, Type.Ellipsis),
                ],
                Type.str,
            ),
        }
        for operation_name in self._get_operation_names():
            operation_model = self._get_operation(operation_name)
            arguments: list[Argument] = [Argument("self", None)]
            method_name = xform_name(operation_name)

            if operation_model.input_shape is not None:
                shape_arguments = self._parse_arguments(
                    "Client",
                    method_name,
                    operation_name,
                    operation_model.input_shape,
                )
                arguments.extend(
                    self._get_kw_flags(method_name, shape_arguments))
                arguments.extend(shape_arguments)

            return_type = self._parse_return_type("Client", method_name,
                                                  operation_model.output_shape)

            method = Method(name=method_name,
                            arguments=arguments,
                            return_type=return_type)
            if operation_model.input_shape:
                method.request_type_annotation = method.get_request_type_annotation(
                    self._get_typed_dict_name(operation_model.input_shape,
                                              postfix="Request"))
            result[method.name] = method
        return result

    @staticmethod
    def _get_typed_dict_name(shape: Shape, postfix: str = "") -> str:
        return f"{shape.name}{postfix}TypeDef"

    def _parse_shape_string(self, shape: StringShape) -> FakeAnnotation:
        if not shape.enum:
            return Type.str

        literal_name = f"{shape.name}Type"
        literal_type_stub = get_literal_type_stub(self.service_name,
                                                  literal_name)
        if literal_type_stub:
            return literal_type_stub

        return TypeLiteral(f"{shape.name}Type",
                           [option for option in shape.enum])

    def _parse_shape_map(
        self,
        shape: MapShape,
        output_child: bool = False,
        is_streaming: bool = False,
    ) -> FakeAnnotation:
        type_subscript = TypeSubscript(
            Type.Dict) if output_child else TypeSubscript(Type.Mapping)
        if shape.key:
            type_subscript.add_child(
                self.parse_shape(shape.key,
                                 output_child=output_child,
                                 is_streaming=is_streaming))
        else:
            type_subscript.add_child(Type.str)
        if shape.value:
            type_subscript.add_child(
                self.parse_shape(shape.value,
                                 output_child=output_child,
                                 is_streaming=is_streaming))
        else:
            type_subscript.add_child(Type.Any)
        return type_subscript

    def _parse_shape_structure(
        self,
        shape: StructureShape,
        output: bool = False,
        output_child: bool = False,
        is_streaming: bool = False,
    ) -> FakeAnnotation:
        if not shape.members.items():
            return Type.DictStrAny if output_child else Type.MappingStrAny

        required = shape.required_members
        typed_dict_name = self._get_typed_dict_name(shape)
        shape_type_stub = get_shape_type_stub(self.service_name,
                                              typed_dict_name)
        if shape_type_stub:
            return shape_type_stub
        typed_dict = TypeTypedDict(typed_dict_name)

        if typed_dict.name in self._typed_dict_map:
            old_typed_dict = self._typed_dict_map[typed_dict.name]
            child_names = {i.name for i in old_typed_dict.children}
            if output and "ResponseMetadata" in child_names:
                return self._typed_dict_map[typed_dict.name]
            if not output and "ResponseMetadata" not in child_names:
                return self._typed_dict_map[typed_dict.name]

            if output:
                typed_dict.name = self._get_typed_dict_name(
                    shape, postfix="ResponseMetadata")
                self.logger.debug(
                    f"Marking {typed_dict.name} as ResponseMetadataTypeDef")
            else:
                old_typed_dict.name = self._get_typed_dict_name(
                    shape, postfix="ResponseMetadata")
                self._typed_dict_map[old_typed_dict.name] = old_typed_dict
                self.logger.debug(
                    f"Marking {old_typed_dict.name} as ResponseMetadataTypeDef"
                )

        self._typed_dict_map[typed_dict.name] = typed_dict
        for attr_name, attr_shape in shape.members.items():
            typed_dict.add_attribute(
                attr_name,
                self.parse_shape(
                    attr_shape,
                    output_child=output or output_child,
                    is_streaming=is_streaming,
                ),
                attr_name in required,
            )
        if output:
            self._make_output_typed_dict(typed_dict)
        return typed_dict

    def _make_output_typed_dict(self, typed_dict: TypeTypedDict) -> None:
        for attribute in typed_dict.children:
            attribute.required = True
        child_names = {i.name for i in typed_dict.children}
        if "ResponseMetadata" not in child_names:
            typed_dict.add_attribute(
                "ResponseMetadata",
                self.response_metadata_typed_dict,
                True,
            )

    def _parse_shape_list(self,
                          shape: ListShape,
                          output_child: bool = False) -> FakeAnnotation:
        type_subscript = TypeSubscript(
            Type.List) if output_child else TypeSubscript(Type.Sequence)
        if shape.member:
            type_subscript.add_child(
                self.parse_shape(shape.member, output_child=output_child))
        else:
            type_subscript.add_child(Type.Any)
        return type_subscript

    def parse_shape(
        self,
        shape: Shape,
        output: bool = False,
        output_child: bool = False,
        is_streaming: bool = False,
    ) -> FakeAnnotation:
        """
        Parse any botocore shape to TypeAnnotation.

        Arguments:
            shape -- Botocore shape.
            output -- Whether shape should use strict output types.
            output_child -- Whether shape parent is marked as output.
            is_streaming -- Whether shape should be streaming.

        Returns:
            TypeAnnotation or similar class.
        """
        if not is_streaming:
            is_streaming = "streaming" in shape.serialization and shape.serialization[
                "streaming"]
            if output or output_child:
                is_streaming = self.proxy_operation_model._get_streaming_body(
                    shape) is not None  # type: ignore

        type_name = shape.type_name
        if is_streaming and type_name == "blob":
            type_name = "blob_streaming"

        if output or output_child:
            if type_name in self.OUTPUT_SHAPE_TYPE_MAP:
                return self.OUTPUT_SHAPE_TYPE_MAP[type_name]

        if type_name in self.SHAPE_TYPE_MAP:
            return self.SHAPE_TYPE_MAP[type_name]

        if isinstance(shape, StringShape):
            return self._parse_shape_string(shape)

        if isinstance(shape, MapShape):
            return self._parse_shape_map(
                shape,
                output_child=output or output_child,
                is_streaming=is_streaming,
            )

        if isinstance(shape, StructureShape):
            return self._parse_shape_structure(
                shape,
                output=output,
                output_child=output or output_child,
                is_streaming=is_streaming,
            )

        if isinstance(shape, ListShape):
            return self._parse_shape_list(shape,
                                          output_child=output or output_child)

        if self._resources_shape and shape.type_name in self._resources_shape[
                "resources"]:
            return AliasInternalImport(shape.type_name)

        self.logger.warning(f"Unknown shape: {shape} {type_name}")
        return Type.Any

    def get_paginate_method(self, paginator_name: str) -> Method:
        """
        Get Paginator `paginate` method.

        Arguments:
            paginator_name -- Paginator name.

        Returns:
            Method.
        """
        operation_name = paginator_name
        paginator_shape = self._get_paginator(paginator_name)
        operation_shape = self._get_operation(operation_name)
        skip_argument_names: list[str] = []
        input_token = paginator_shape["input_token"]
        if isinstance(input_token, list):
            skip_argument_names.extend(input_token)
        else:
            skip_argument_names.append(input_token)
        if "limit_key" in paginator_shape:
            skip_argument_names.append(paginator_shape["limit_key"])

        arguments: list[Argument] = [Argument("self", None)]

        if operation_shape.input_shape is not None:
            shape_arguments = self._parse_arguments(
                "Paginator",
                "paginate",
                operation_name,
                operation_shape.input_shape,
                exclude_names=skip_argument_names,
            )
            shape_arguments.append(
                Argument("PaginationConfig", paginator_config_type,
                         Type.Ellipsis))
            arguments.extend(self._get_kw_flags("paginate", shape_arguments))
            arguments.extend(shape_arguments)

        return_type: FakeAnnotation = Type.none
        if operation_shape.output_shape is not None:
            page_iterator_import = InternalImport("_PageIterator",
                                                  stringify=False)
            return_item = self._parse_return_type("Paginator", "paginate",
                                                  operation_shape.output_shape)
            return_type = TypeSubscript(page_iterator_import, [return_item])

        return Method("paginate", arguments, return_type)

    def get_wait_method(self, waiter_name: str) -> Method:
        """
        Get Waiter `wait` method.

        Arguments:
            waiter_name -- Waiter name.

        Returns:
            Method.
        """
        if not self._waiters_shape:
            raise ShapeParserError("Waiter not found")
        operation_name = self._waiters_shape["waiters"][waiter_name][
            "operation"]
        operation_shape = self._get_operation(operation_name)

        arguments: list[Argument] = [Argument("self", None)]

        if operation_shape.input_shape is not None:
            shape_arguments = self._parse_arguments(
                "Waiter", "wait", operation_name, operation_shape.input_shape)
            shape_arguments.append(
                Argument("WaiterConfig", waiter_config_type, Type.Ellipsis))
            arguments.extend(self._get_kw_flags("wait", shape_arguments))
            arguments.extend(shape_arguments)

        return Method(name="wait", arguments=arguments, return_type=Type.none)

    def get_service_resource_method_map(self) -> dict[str, Method]:
        """
        Get methods for ServiceResource.

        Returns:
            A map of method name to Method.
        """
        result: dict[str, Method] = {
            "get_available_subresources":
            Method(
                "get_available_subresources",
                [Argument("self", None)],
                TypeSubscript(Type.Sequence, [Type.str]),
            ),
        }
        service_resource_shape = self._get_service_resource()
        for action_name, action_shape in service_resource_shape.get(
                "actions", {}).items():
            method = self._get_resource_method("ServiceResource", action_name,
                                               action_shape)
            result[method.name] = method

        return result

    def get_resource_method_map(self, resource_name: str) -> dict[str, Method]:
        """
        Get methods for Resource.

        Arguments:
            resource_name -- Resource name.

        Returns:
            A map of method name to Method.
        """
        resource_shape = self._get_resource_shape(resource_name)
        result: dict[str, Method] = {
            "get_available_subresources":
            Method(
                "get_available_subresources",
                [Argument("self", None)],
                TypeSubscript(Type.Sequence, [Type.str]),
            ),
            "load":
            Method("load", [Argument("self", None)], Type.none),
            "reload":
            Method("reload", [Argument("self", None)], Type.none),
        }

        for action_name, action_shape in resource_shape.get("actions",
                                                            {}).items():
            method = self._get_resource_method(resource_name, action_name,
                                               action_shape)
            result[method.name] = method

        for waiter_name in resource_shape.get("waiters", {}):
            method = Method(
                f"wait_until_{xform_name(waiter_name)}",
                [Argument("self", None)],
                Type.none,
            )
            result[method.name] = method

        return result

    @staticmethod
    def _get_arg_from_target(target: str) -> str:
        if "[" not in target:
            return target
        return target.split("[")[0]

    def _get_resource_method(self, resource_name: str, action_name: str,
                             action_shape: dict[str, Any]) -> Method:
        return_type: FakeAnnotation = Type.none
        method_name = xform_name(action_name)
        arguments: list[Argument] = [Argument("self", None)]
        if "resource" in action_shape:
            return_type = self._parse_return_type(
                resource_name, method_name,
                Shape("resource", action_shape["resource"]))
            path = action_shape["resource"].get("path", "")
            if path.endswith("[]"):
                return_type = TypeSubscript(Type.List, [return_type])

        operation_shape = None
        if "request" in action_shape:
            operation_name = action_shape["request"]["operation"]
            operation_shape = self._get_operation(operation_name)
            skip_argument_names = {
                self._get_arg_from_target(i["target"])
                for i in action_shape["request"].get("params", {})
                if i["source"] == "identifier"
            }
            if operation_shape.input_shape is not None:
                shape_arguments = self._parse_arguments(
                    resource_name,
                    method_name,
                    operation_name,
                    operation_shape.input_shape,
                    exclude_names=skip_argument_names,
                )
                arguments.extend(
                    self._get_kw_flags(method_name, shape_arguments))
                arguments.extend(shape_arguments)
            if operation_shape.output_shape is not None and return_type is Type.none:
                operation_return_type = self.parse_shape(
                    operation_shape.output_shape, output=True)
                return_type = operation_return_type

        method = Method(name=method_name,
                        arguments=arguments,
                        return_type=return_type)
        if operation_shape and operation_shape.input_shape is not None:
            method.request_type_annotation = method.get_request_type_annotation(
                self._get_typed_dict_name(operation_shape.input_shape,
                                          postfix=resource_name))
        return method

    def get_collection_filter_method(self, name: str, collection: Collection,
                                     self_type: FakeAnnotation) -> Method:
        """
        Get `filter` classmethod for Resource collection.

        Arguments:
            name -- Collection record name.
            collection -- Boto3 Collection.
            class_type -- Collection class type annotation.

        Returns:
            Filter Method record.
        """
        result = Method(
            name="filter",
            arguments=[Argument("self", None)],
            return_type=self_type,
        )
        if not collection.request:
            return result

        operation_name = collection.request.operation
        operation_model = self._get_operation(operation_name)

        if operation_model.input_shape is not None:
            shape_arguments = self._parse_arguments(
                name,
                result.name,
                operation_name,
                operation_model.input_shape,
                optional_only=True,
            )
            result.arguments.extend(
                self._get_kw_flags(result.name, shape_arguments))
            result.arguments.extend(shape_arguments)

        return result

    def get_collection_batch_methods(self, name: str,
                                     collection: Collection) -> list[Method]:
        """
        Get batch operations for Resource collection.

        Arguments:
            name -- Collection record name.
            collection -- Boto3 Collection.
            class_type -- Collection self type annotation.

        Returns:
            List of Method records.
        """
        result = []
        for batch_action in collection.batch_actions:
            method = Method(
                name=batch_action.name,
                arguments=[Argument("self", None)],
                return_type=Type.none,
            )
            result.append(method)
            if batch_action.request:
                operation_name = batch_action.request.operation
                operation_model = self._get_operation(operation_name)
                if operation_model.input_shape is not None:
                    shape_arguments = self._parse_arguments(
                        name,
                        batch_action.name,
                        operation_name,
                        operation_model.input_shape,
                        optional_only=True,
                    )
                    method.arguments.extend(
                        self._get_kw_flags(batch_action.name, shape_arguments))
                    method.arguments.extend(shape_arguments)
                if operation_model.output_shape is not None:
                    item_return_type = self.parse_shape(
                        operation_model.output_shape, output=True)
                    return_type = TypeSubscript(Type.List, [item_return_type])
                    method.return_type = return_type

        return result
Ejemplo n.º 33
0
def handle_method(fragment):
    if fragment["Type"] != "AWS::ApiGateway::Method":
        response_string = "Macro only supports \"AWS::ApiGateway::Method\", user supplied {}"
        raise InvalidTypeException(response_string.format(fragment["Type"]))

    service_name = fragment["Properties"]["Integration"].pop("Service").lower()
    action = fragment["Properties"]["Integration"].pop("Action")
    response_maps = fragment["Properties"]["Integration"].pop("ResponseMaps")
    try:
        fragment.pop("Fn::Transform")
    except:
        pass

    loader = Loader()
    service_description = loader.load_service_model(service_name=service_name, type_name='service-2')
    service_model = ServiceModel(service_description)
    protocol = service_model.protocol
    op_model = service_model.operation_model(action["Name"])

    request_parameters = action.get("Parameters", {})
    params = dict(ChainMap(*request_parameters))
    print("params: {}".format(params))
    serializer = create_serializer(protocol)
    response_parser = create_parser(protocol)

    print(service_model.protocol)
    request = serializer.serialize_to_request(params, op_model)
    request_object = AWSRequest(
        method=request['method'],
        url=get_endpoint(service_model.service_name),
        data=request['body'],
        headers=request['headers'])

    X = request_object.prepare()

    print("Raw request: {}".format(request))
    print("Prepared request: {}".format(X))

    integration = fragment["Properties"]["Integration"]
    new_integration = integration_template()

    # Copy the existing values to the new template
    for entry in integration.keys():
        new_integration[entry] = integration[entry]

    # Add headers to cfn template
    if X.headers is not None and callable(getattr(X.headers, "keys", None)):
        for header in X.headers.keys():
            if header.lower() != 'Content-Length'.lower():
                new_integration["RequestParameters"].update({"integration.request.header.{}".format(header): "'{}'".format(X.headers[header])})

    # Add Query Strings to cfn template
    if 'query_string' in request and callable(getattr(request['query_string'], "keys", None)):
        for query in request['query_string'].keys():
            new_integration["RequestParameters"].update({"integration.request.querystring.{}".format(query): "'{}'".format(request['query_string'][query])})

    # Set the body
    if isinstance(X.body, str):
        new_integration["RequestTemplates"]["application/json"] = X.body
    else:
        new_integration["RequestTemplates"]["application/json"] = str(X.body, "utf-8") if X.body else ''
    new_integration["Uri"] = ":".join([
        "arn",
        "aws",
        "apigateway",
        REGION,
        service_model.endpoint_prefix,
        "path/" + request["url_path"]
    ])
    new_integration["IntegrationHttpMethod"] = X.method

    fragment["Properties"]["Integration"] = new_integration
    print(fragment)
    return fragment
Ejemplo n.º 34
0
class TestEndpointDiscoveryManager(BaseEndpointDiscoveryTest):
    def setUp(self):
        super(TestEndpointDiscoveryManager, self).setUp()
        self.construct_manager()

    def construct_manager(self, cache=None, time=None, side_effect=None):
        self.service_model = ServiceModel(self.service_description)
        self.meta = Mock(spec=ClientMeta)
        self.meta.service_model = self.service_model
        self.client = Mock()
        if side_effect is None:
            side_effect = [{
                'Endpoints': [{
                    'Address': 'new.com',
                    'CachePeriodInMinutes': 2,
                }]
            }]
        self.client.describe_endpoints.side_effect = side_effect
        self.client.meta = self.meta
        self.manager = EndpointDiscoveryManager(
            self.client, cache=cache, current_time=time
        )

    def test_injects_api_version_if_endpoint_operation(self):
        model = self.service_model.operation_model('DescribeEndpoints')
        params = {'headers': {}}
        inject_api_version_header_if_needed(model, params)
        self.assertEqual(params['headers'].get('x-amz-api-version'),
                         '2018-08-31')

    def test_no_inject_api_version_if_not_endpoint_operation(self):
        model = self.service_model.operation_model('TestDiscoveryRequired')
        params = {'headers': {}}
        inject_api_version_header_if_needed(model, params)
        self.assertNotIn('x-amz-api-version', params['headers'])

    def test_gather_identifiers(self):
        params = {
            'Foo': 'value1',
            'Nested': {'Bar': 'value2'}
        }
        operation = self.service_model.operation_model('TestDiscoveryRequired')
        ids = self.manager.gather_identifiers(operation, params)
        self.assertEqual(ids, {'Foo': 'value1', 'Bar': 'value2'})

    def test_gather_identifiers_none(self):
        operation = self.service_model.operation_model('TestDiscovery')
        ids = self.manager.gather_identifiers(operation, {})
        self.assertEqual(ids, {})

    def test_describe_endpoint(self):
        kwargs = {
            'Operation': 'FooBar',
            'Identifiers': {'Foo': 'value1', 'Bar': 'value2'},
        }
        self.manager.describe_endpoint(**kwargs)
        self.client.describe_endpoints.assert_called_with(**kwargs)

    def test_describe_endpoint_no_input(self):
        describe = self.service_description['operations']['DescribeEndpoints']
        del describe['input']
        self.construct_manager()
        self.manager.describe_endpoint(Operation='FooBar', Identifiers={})
        self.client.describe_endpoints.assert_called_with()

    def test_describe_endpoint_empty_input(self):
        describe = self.service_description['operations']['DescribeEndpoints']
        describe['input'] = {'shape': 'EmptyStruct'}
        self.construct_manager()
        self.manager.describe_endpoint(Operation='FooBar', Identifiers={})
        self.client.describe_endpoints.assert_called_with()

    def test_describe_endpoint_ids_and_operation(self):
        cache = {}
        self.construct_manager(cache=cache)
        ids = {'Foo': 'value1', 'Bar': 'value2'}
        kwargs = {
            'Operation': 'TestDiscoveryRequired',
            'Identifiers': ids,
        }
        self.manager.describe_endpoint(**kwargs)
        self.client.describe_endpoints.assert_called_with(**kwargs)
        key = ((('Bar', 'value2'), ('Foo', 'value1')), 'TestDiscoveryRequired')
        self.assertIn(key, cache)
        self.assertEqual(cache[key][0]['Address'], 'new.com')
        self.manager.describe_endpoint(**kwargs)
        call_count = self.client.describe_endpoints.call_count
        self.assertEqual(call_count, 1)

    def test_describe_endpoint_no_ids_or_operation(self):
        cache = {}
        describe = self.service_description['operations']['DescribeEndpoints']
        describe['input'] = {'shape': 'EmptyStruct'}
        self.construct_manager(cache=cache)
        self.manager.describe_endpoint(
            Operation='TestDiscoveryRequired', Identifiers={}
        )
        self.client.describe_endpoints.assert_called_with()
        key = ()
        self.assertIn(key, cache)
        self.assertEqual(cache[key][0]['Address'], 'new.com')
        self.manager.describe_endpoint(
            Operation='TestDiscoveryRequired', Identifiers={}
        )
        call_count = self.client.describe_endpoints.call_count
        self.assertEqual(call_count, 1)

    def test_describe_endpoint_expired_entry(self):
        current_time = time.time()
        key = ()
        cache = {
            key: [{'Address': 'old.com', 'Expiration': current_time - 10}]
        }
        self.construct_manager(cache=cache)
        kwargs = {
            'Identifiers': {},
            'Operation': 'TestDiscoveryRequired',
        }
        self.manager.describe_endpoint(**kwargs)
        self.client.describe_endpoints.assert_called_with()
        self.assertIn(key, cache)
        self.assertEqual(cache[key][0]['Address'], 'new.com')
        self.manager.describe_endpoint(**kwargs)
        call_count = self.client.describe_endpoints.call_count
        self.assertEqual(call_count, 1)

    def test_describe_endpoint_cache_expiration(self):
        def _time():
            return float(0)
        cache = {}
        self.construct_manager(cache=cache, time=_time)
        self.manager.describe_endpoint(
            Operation='TestDiscoveryRequired', Identifiers={}
        )
        key = ()
        self.assertIn(key, cache)
        self.assertEqual(cache[key][0]['Expiration'], float(120))

    def test_delete_endpoints_present(self):
        key = ()
        cache = {
            key: [{'Address': 'old.com', 'Expiration': 0}]
        }
        self.construct_manager(cache=cache)
        kwargs = {
            'Identifiers': {},
            'Operation': 'TestDiscoveryRequired',
        }
        self.manager.delete_endpoints(**kwargs)
        self.assertEqual(cache, {})

    def test_delete_endpoints_absent(self):
        cache = {}
        self.construct_manager(cache=cache)
        kwargs = {
            'Identifiers': {},
            'Operation': 'TestDiscoveryRequired',
        }
        self.manager.delete_endpoints(**kwargs)
        self.assertEqual(cache, {})

    def test_describe_endpoint_optional_fails_no_cache(self):
        side_effect = [ConnectionError(error=None)]
        self.construct_manager(side_effect=side_effect)
        kwargs = {'Operation': 'TestDiscoveryOptional'}
        endpoint = self.manager.describe_endpoint(**kwargs)
        self.assertIsNone(endpoint)
        # This second call should be blocked as we just failed
        endpoint = self.manager.describe_endpoint(**kwargs)
        self.assertIsNone(endpoint)
        self.client.describe_endpoints.call_args_list == [call()]

    def test_describe_endpoint_optional_fails_stale_cache(self):
        key = ()
        cache = {
            key: [{'Address': 'old.com', 'Expiration': 0}]
        }
        side_effect = [ConnectionError(error=None)] * 2
        self.construct_manager(cache=cache, side_effect=side_effect)
        kwargs = {'Operation': 'TestDiscoveryOptional'}
        endpoint = self.manager.describe_endpoint(**kwargs)
        self.assertEqual(endpoint, 'old.com')
        # This second call shouldn't go through as we just failed
        endpoint = self.manager.describe_endpoint(**kwargs)
        self.assertEqual(endpoint, 'old.com')
        self.client.describe_endpoints.call_args_list == [call()]

    def test_describe_endpoint_required_fails_no_cache(self):
        side_effect = [ConnectionError(error=None)] * 2
        self.construct_manager(side_effect=side_effect)
        kwargs = {'Operation': 'TestDiscoveryRequired'}
        with self.assertRaises(EndpointDiscoveryRefreshFailed):
            self.manager.describe_endpoint(**kwargs)
        # This second call should go through, as we have no cache
        with self.assertRaises(EndpointDiscoveryRefreshFailed):
            self.manager.describe_endpoint(**kwargs)
        describe_count = self.client.describe_endpoints.call_count
        self.assertEqual(describe_count, 2)

    def test_describe_endpoint_required_fails_stale_cache(self):
        key = ()
        cache = {
            key: [{'Address': 'old.com', 'Expiration': 0}]
        }
        side_effect = [ConnectionError(error=None)] * 2
        self.construct_manager(cache=cache, side_effect=side_effect)
        kwargs = {'Operation': 'TestDiscoveryRequired'}
        endpoint = self.manager.describe_endpoint(**kwargs)
        self.assertEqual(endpoint, 'old.com')
        # We have a stale endpoint, so this shouldn't fail or force a refresh
        endpoint = self.manager.describe_endpoint(**kwargs)
        self.assertEqual(endpoint, 'old.com')
        self.client.describe_endpoints.call_args_list == [call()]

    def test_describe_endpoint_required_force_refresh_success(self):
        side_effect = [
            ConnectionError(error=None),
            {'Endpoints': [{
                'Address': 'new.com',
                'CachePeriodInMinutes': 2,
            }]},
        ]
        self.construct_manager(side_effect=side_effect)
        kwargs = {'Operation': 'TestDiscoveryRequired'}
        # First call will fail
        with self.assertRaises(EndpointDiscoveryRefreshFailed):
            self.manager.describe_endpoint(**kwargs)
        self.client.describe_endpoints.call_args_list == [call()]
        # Force a refresh if the cache is empty but discovery is required
        endpoint = self.manager.describe_endpoint(**kwargs)
        self.assertEqual(endpoint, 'new.com')

    def test_describe_endpoint_retries_after_failing(self):
        fake_time = Mock()
        fake_time.side_effect = [0, 100, 200]
        side_effect = [
            ConnectionError(error=None),
            {'Endpoints': [{
                'Address': 'new.com',
                'CachePeriodInMinutes': 2,
            }]},
        ]
        self.construct_manager(side_effect=side_effect, time=fake_time)
        kwargs = {'Operation': 'TestDiscoveryOptional'}
        endpoint = self.manager.describe_endpoint(**kwargs)
        self.assertIsNone(endpoint)
        self.client.describe_endpoints.call_args_list == [call()]
        # Second time should try again as enough time has elapsed
        endpoint = self.manager.describe_endpoint(**kwargs)
        self.assertEqual(endpoint, 'new.com')
    def test_ScriptSerializer(self):
        script_model = {
            "metadata": self.script_model_metadata,
            "operations": {
                "ExecuteScript": {
                    "http": {
                        "method": "POST",
                        "requestUri": "/2015-09-01/"
                    },
                    "input": {
                        "shape": "ExecuteScriptRequest"
                    },
                    "name": "ExecuteScript",
                    "output": {
                        "resultWrapper": "ExecuteScriptResult",
                        "shape": "ExecuteScriptResult"
                    }
                }
            },
            "shapes": {
                "ExecuteScriptRequest": {
                    "members": {
                        "Body": {
                            "locationName": "Body",
                            "shape": "String"
                        },
                        "Header": {
                            "locationName": "Header",
                            "shape": "String"
                        },
                        "Method": {
                            "locationName": "Method",
                            "shape": "String"
                        },
                        "Query": {
                            "locationName": "Query",
                            "shape": "String"
                        },
                        "ScriptIdentifier": {
                            "locationName": "ScriptIdentifier",
                            "shape": "String"
                        }
                    },
                    "name": "ExecuteScriptRequest",
                    "required": ["Method", "ScriptIdentifier"],
                    "type": "structure"
                },
                "ExecuteScriptResult": {
                    "members": {
                        "Response": {
                            "locationName": "Response",
                            "shape": "String"
                        }
                    },
                    "name": "ExecuteScriptResult",
                    "type": "structure"
                },
                "String": {
                    "name": "String",
                    "type": "string"
                },
            }
        }

        script_service_model = ServiceModel(script_model)
        params = {
            "Body": "test_body",
            "Header": "test_header",
            "Method": "test_method",
            "Query": "test_query",
            "ScriptIdentifier": "test_script_identifier"
        }
        script_serializer = serialize.ScriptSerializer()
        res = script_serializer.serialize_to_request(
            params, script_service_model.operation_model("ExecuteScript"))
        assert res["body"] == {
            "Body": "test_body",
            "Header": "test_header",
            "Method": "test_method",
            "Query": "test_query",
            "ScriptIdentifier": "test_script_identifier"
        }
        assert res["headers"] == {
            "Content-Type": "application/x-www-form-urlencoded; charset=utf-8",
            "X-Amz-Target": "2015-09-01.ExecuteScript"
        }
        assert res["method"] == "POST"
        assert res["query_string"] == ""
        assert res["url_path"] == "/2015-09-01/"
Ejemplo n.º 36
0
class TestTimestampHeadersWithRestXML(unittest.TestCase):
    def setUp(self):
        self.model = {
            'metadata': {
                'protocol': 'rest-xml',
                'apiVersion': '2014-01-01'
            },
            'documentation': '',
            'operations': {
                'TestOperation': {
                    'name': 'TestOperation',
                    'http': {
                        'method': 'POST',
                        'requestUri': '/',
                    },
                    'input': {
                        'shape': 'InputShape'
                    },
                }
            },
            'shapes': {
                'InputShape': {
                    'type': 'structure',
                    'members': {
                        'TimestampHeader': {
                            'shape': 'TimestampType',
                            'location': 'header',
                            'locationName': 'x-timestamp'
                        },
                    }
                },
                'TimestampType': {
                    'type': 'timestamp',
                }
            }
        }
        self.service_model = ServiceModel(self.model)

    def serialize_to_request(self, input_params):
        request_serializer = serialize.create_serializer(
            self.service_model.metadata['protocol'])
        return request_serializer.serialize_to_request(
            input_params, self.service_model.operation_model('TestOperation'))

    def test_accepts_datetime_object(self):
        request = self.serialize_to_request({
            'TimestampHeader':
            datetime.datetime(2014,
                              1,
                              1,
                              12,
                              12,
                              12,
                              tzinfo=dateutil.tz.tzutc())
        })
        self.assertEqual(request['headers']['x-timestamp'],
                         'Wed, 01 Jan 2014 12:12:12 GMT')

    def test_accepts_iso_8601_format(self):
        request = self.serialize_to_request(
            {'TimestampHeader': '2014-01-01T12:12:12+00:00'})
        self.assertEqual(request['headers']['x-timestamp'],
                         'Wed, 01 Jan 2014 12:12:12 GMT')

    def test_accepts_iso_8601_format_non_utc(self):
        request = self.serialize_to_request(
            {'TimestampHeader': '2014-01-01T07:12:12-05:00'})
        self.assertEqual(request['headers']['x-timestamp'],
                         'Wed, 01 Jan 2014 12:12:12 GMT')

    def test_accepts_rfc_822_format(self):
        request = self.serialize_to_request(
            {'TimestampHeader': 'Wed, 01 Jan 2014 12:12:12 GMT'})
        self.assertEqual(request['headers']['x-timestamp'],
                         'Wed, 01 Jan 2014 12:12:12 GMT')

    def test_accepts_unix_timestamp_integer(self):
        request = self.serialize_to_request({'TimestampHeader': 1388578332})
        self.assertEqual(request['headers']['x-timestamp'],
                         'Wed, 01 Jan 2014 12:12:12 GMT')
Ejemplo n.º 37
0
class TestBinaryTypes(unittest.TestCase):
    def setUp(self):
        self.model = {
            'metadata': {
                'protocol': 'query',
                'apiVersion': '2014-01-01'
            },
            'documentation': '',
            'operations': {
                'TestOperation': {
                    'name': 'TestOperation',
                    'http': {
                        'method': 'POST',
                        'requestUri': '/',
                    },
                    'input': {
                        'shape': 'InputShape'
                    },
                }
            },
            'shapes': {
                'InputShape': {
                    'type': 'structure',
                    'members': {
                        'Blob': {
                            'shape': 'BlobType'
                        },
                    }
                },
                'BlobType': {
                    'type': 'blob',
                }
            }
        }
        self.service_model = ServiceModel(self.model)

    def serialize_to_request(self, input_params):
        request_serializer = serialize.create_serializer(
            self.service_model.metadata['protocol'])
        return request_serializer.serialize_to_request(
            input_params, self.service_model.operation_model('TestOperation'))

    def assert_serialized_blob_equals(self, request, blob_bytes):
        # This method handles all the details of the base64 decoding.
        encoded = base64.b64encode(blob_bytes)
        # Now the serializers actually have the base64 encoded contents
        # as str types so we need to decode back.  We know that this is
        # ascii so it's safe to use the ascii encoding.
        expected = encoded.decode('ascii')
        self.assertEqual(request['body']['Blob'], expected)

    def test_blob_accepts_bytes_type(self):
        body = b'bytes body'
        request = self.serialize_to_request(input_params={'Blob': body})

    def test_blob_accepts_str_type(self):
        body = u'ascii text'
        request = self.serialize_to_request(input_params={'Blob': body})
        self.assert_serialized_blob_equals(request,
                                           blob_bytes=body.encode('ascii'))

    def test_blob_handles_unicode_chars(self):
        body = u'\u2713'
        request = self.serialize_to_request(input_params={'Blob': body})
        self.assert_serialized_blob_equals(request,
                                           blob_bytes=body.encode('utf-8'))
Ejemplo n.º 38
0
class TestInstanceCreation(unittest.TestCase):
    def setUp(self):
        self.model = {
            'metadata': {
                'protocol': 'query',
                'apiVersion': '2014-01-01'
            },
            'documentation': '',
            'operations': {
                'TestOperation': {
                    'name': 'TestOperation',
                    'http': {
                        'method': 'POST',
                        'requestUri': '/',
                    },
                    'input': {
                        'shape': 'InputShape'
                    },
                }
            },
            'shapes': {
                'InputShape': {
                    'type': 'structure',
                    'members': {
                        'Timestamp': {
                            'shape': 'StringTestType'
                        },
                    }
                },
                'StringTestType': {
                    'type': 'string',
                    'min': 15
                }
            }
        }
        self.service_model = ServiceModel(self.model)

    def assert_serialize_valid_parameter(self, request_serializer):
        valid_string = 'valid_string_with_min_15_chars'
        request = request_serializer.serialize_to_request(
            {'Timestamp': valid_string},
            self.service_model.operation_model('TestOperation'))

        self.assertEqual(request['body']['Timestamp'], valid_string)

    def assert_serialize_invalid_parameter(self, request_serializer):
        invalid_string = 'short string'
        request = request_serializer.serialize_to_request(
            {'Timestamp': invalid_string},
            self.service_model.operation_model('TestOperation'))

        self.assertEqual(request['body']['Timestamp'], invalid_string)

    def test_instantiate_without_validation(self):
        request_serializer = serialize.create_serializer(
            self.service_model.metadata['protocol'], False)

        try:
            self.assert_serialize_valid_parameter(request_serializer)
        except ParamValidationError as e:
            self.fail(
                "Shouldn't fail serializing valid parameter without validation"
            )

        try:
            self.assert_serialize_invalid_parameter(request_serializer)
        except ParamValidationError as e:
            self.fail(
                "Shouldn't fail serializing invalid parameter without validation"
            )

    def test_instantiate_with_validation(self):
        request_serializer = serialize.create_serializer(
            self.service_model.metadata['protocol'], True)
        try:
            self.assert_serialize_valid_parameter(request_serializer)
        except ParamValidationError as e:
            self.fail(
                "Shouldn't fail serializing valid parameter with validation")

        with self.assertRaises(ParamValidationError):
            self.assert_serialize_invalid_parameter(request_serializer)
    def test_EssSerializer_GetDeliveryLog(self):
        ess_model = {
            "metadata": self.ess_model_metadata,
            "operations": {
                "GetDeliveryLog": {
                    "http": {
                        "method": "POST",
                        "requestUri": "/"
                    },
                    "input": {
                        "shape": "GetDeliveryLogRequest"
                    },
                    "name": "essOperation",
                    "output": {
                        "shape": "EssOperationResult"
                    }
                }
            },
            "shapes": {
                "GetDeliveryLogRequest": {
                    "members": {
                        "EndDate": {
                            "locationName": "EndDate",
                            "shape": "TStamp"
                        },
                        "MaxItems": {
                            "locationName": "MaxItems",
                            "shape": "Integer"
                        },
                        "NextToken": {
                            "locationName": "NextToken",
                            "shape": "String"
                        },
                        "StartDate": {
                            "locationName": "StartDate",
                            "shape": "TStamp"
                        },
                        "Status": {
                            "locationName": "Status",
                            "shape": "Integer"
                        }
                    },
                    "name": "GetDeliveryLogRequest",
                    "required": [
                        "EndDate",
                        "StartDate"
                    ],
                    "type": "structure"
                },
                "EssOperationResult": {
                    "members": {
                        "Response": {
                            "locationName": "Response",
                            "shape": "String"
                        }
                    },
                    "name": "EssOperationResult",
                    "type": "structure"
                },
                "Integer": {
                    "name": "Integer",
                    "type": "integer"
                },
                "TStamp": {
                    "name": "TStamp",
                    "type": "timestamp"
                },
                "String": {
                    "name": "String",
                    "type": "string"
                }
            }
        }

        ess_service_model = ServiceModel(ess_model)
        params = {}
        ess_serializer = serialize.EssSerializer()
        res = ess_serializer.serialize_to_request(
            params, ess_service_model.operation_model("GetDeliveryLog"))
        assert res["body"] == {"Action": "GetDeliveryLog", "Version": "2010-12-01N2014-05-28"}
        assert res["headers"] == {"Content-Type": "application/x-www-form-urlencoded; charset=utf-8"}
        assert res["method"] == "POST"
        assert res["query_string"] == ""
        assert res["url_path"] == "/"
Ejemplo n.º 40
0
class TestTimestampHeadersWithRestXML(unittest.TestCase):

    def setUp(self):
        self.model = {
            'metadata': {'protocol': 'rest-xml', 'apiVersion': '2014-01-01'},
            'documentation': '',
            'operations': {
                'TestOperation': {
                    'name': 'TestOperation',
                    'http': {
                        'method': 'POST',
                        'requestUri': '/',
                    },
                    'input': {'shape': 'InputShape'},
                }
            },
            'shapes': {
                'InputShape': {
                    'type': 'structure',
                    'members': {
                        'TimestampHeader': {
                            'shape': 'TimestampType',
                            'location': 'header',
                            'locationName': 'x-timestamp'
                        },
                    }
                },
                'TimestampType': {
                    'type': 'timestamp',
                }
            }
        }
        self.service_model = ServiceModel(self.model)

    def serialize_to_request(self, input_params):
        request_serializer = serialize.create_serializer(
            self.service_model.metadata['protocol'])
        return request_serializer.serialize_to_request(
            input_params, self.service_model.operation_model('TestOperation'))

    def test_accepts_datetime_object(self):
        request = self.serialize_to_request(
            {'TimestampHeader': datetime.datetime(2014, 1, 1, 12, 12, 12,
                                                  tzinfo=dateutil.tz.tzutc())})
        self.assertEqual(request['headers']['x-timestamp'],
                         'Wed, 01 Jan 2014 12:12:12 GMT')

    def test_accepts_iso_8601_format(self):
        request = self.serialize_to_request(
            {'TimestampHeader': '2014-01-01T12:12:12+00:00'})
        self.assertEqual(request['headers']['x-timestamp'],
                         'Wed, 01 Jan 2014 12:12:12 GMT')

    def test_accepts_iso_8601_format_non_utc(self):
        request = self.serialize_to_request(
            {'TimestampHeader': '2014-01-01T07:12:12-05:00'})
        self.assertEqual(request['headers']['x-timestamp'],
                         'Wed, 01 Jan 2014 12:12:12 GMT')

    def test_accepts_rfc_822_format(self):
        request = self.serialize_to_request(
            {'TimestampHeader': 'Wed, 01 Jan 2014 12:12:12 GMT'})
        self.assertEqual(request['headers']['x-timestamp'],
                         'Wed, 01 Jan 2014 12:12:12 GMT')

    def test_accepts_unix_timestamp_integer(self):
        request = self.serialize_to_request(
            {'TimestampHeader': 1388578332})
        self.assertEqual(request['headers']['x-timestamp'],
                         'Wed, 01 Jan 2014 12:12:12 GMT')
Ejemplo n.º 41
0
    def test_DnsSerializer(self):
        dns_model = {
            "metadata": self.dns_model_metadata,
            "operations": {
                "DnsOperation": {
                    "http": {
                        "method": "POST",
                        "requestUri": "/2012-12-12N2013-12-16/operation"
                    },
                    "input": {
                        "locationName": "DnsOperationRequest",
                        "shape": "DnsOperationRequest",
                        "xmlNamespace": {
                            "uri":
                            "https://route53.amazonaws.com/doc/2012-12-12/"
                        }
                    },
                    "name": "DnsOperation",
                    "output": {
                        "shape": "DnsOperationResult"
                    }
                },
            },
            "shapes": {
                "DnsOperationRequest": {
                    "members": {
                        "Parameter": {
                            "locationName": "Parameter",
                            "shape": "String"
                        },
                    },
                    "name": "DnsOperationRequest",
                    "type": "structure"
                },
                "DnsOperationResult": {
                    "members": {
                        "Parameter": {
                            "locationName": "Parameter",
                            "shape": "String"
                        }
                    },
                    "name": "DnsOperationResult",
                    "type": "structure"
                },
                "String": {
                    "name": "String",
                    "type": "string"
                },
            }
        }

        dns_service_model = ServiceModel(dns_model)
        params = {"Parameter": "test"}
        dns_serializer = serialize.DnsSerializer()
        res = dns_serializer.serialize_to_request(
            params, dns_service_model.operation_model("DnsOperation"))
        assert res[
            "body"] == b'<DnsOperationRequest xmlns="https://route53.amazonaws.com/doc/2012-12-12/"><Parameter>test</Parameter></DnsOperationRequest>'  # noqa: E501
        assert res["headers"] == {}
        assert res["method"] == "POST"
        assert res["query_string"] == {}
        assert res["url_path"] == "/2012-12-12N2013-12-16/operation"
Ejemplo n.º 42
0
class TestEndpointDiscoveryManager(BaseEndpointDiscoveryTest):
    def setUp(self):
        super(TestEndpointDiscoveryManager, self).setUp()
        self.construct_manager()

    def construct_manager(self, cache=None, time=None, side_effect=None):
        self.service_model = ServiceModel(self.service_description)
        self.meta = mock.Mock(spec=ClientMeta)
        self.meta.service_model = self.service_model
        self.client = mock.Mock()
        if side_effect is None:
            side_effect = [{
                'Endpoints': [{
                    'Address': 'new.com',
                    'CachePeriodInMinutes': 2,
                }]
            }]
        self.client.describe_endpoints.side_effect = side_effect
        self.client.meta = self.meta
        self.manager = EndpointDiscoveryManager(self.client,
                                                cache=cache,
                                                current_time=time)

    def test_injects_api_version_if_endpoint_operation(self):
        model = self.service_model.operation_model('DescribeEndpoints')
        params = {'headers': {}}
        inject_api_version_header_if_needed(model, params)
        self.assertEqual(params['headers'].get('x-amz-api-version'),
                         '2018-08-31')

    def test_no_inject_api_version_if_not_endpoint_operation(self):
        model = self.service_model.operation_model('TestDiscoveryRequired')
        params = {'headers': {}}
        inject_api_version_header_if_needed(model, params)
        self.assertNotIn('x-amz-api-version', params['headers'])

    def test_gather_identifiers(self):
        params = {'Foo': 'value1', 'Nested': {'Bar': 'value2'}}
        operation = self.service_model.operation_model('TestDiscoveryRequired')
        ids = self.manager.gather_identifiers(operation, params)
        self.assertEqual(ids, {'Foo': 'value1', 'Bar': 'value2'})

    def test_gather_identifiers_none(self):
        operation = self.service_model.operation_model('TestDiscovery')
        ids = self.manager.gather_identifiers(operation, {})
        self.assertEqual(ids, {})

    def test_describe_endpoint(self):
        kwargs = {
            'Operation': 'FooBar',
            'Identifiers': {
                'Foo': 'value1',
                'Bar': 'value2'
            },
        }
        self.manager.describe_endpoint(**kwargs)
        self.client.describe_endpoints.assert_called_with(**kwargs)

    def test_describe_endpoint_no_input(self):
        describe = self.service_description['operations']['DescribeEndpoints']
        del describe['input']
        self.construct_manager()
        self.manager.describe_endpoint(Operation='FooBar', Identifiers={})
        self.client.describe_endpoints.assert_called_with()

    def test_describe_endpoint_empty_input(self):
        describe = self.service_description['operations']['DescribeEndpoints']
        describe['input'] = {'shape': 'EmptyStruct'}
        self.construct_manager()
        self.manager.describe_endpoint(Operation='FooBar', Identifiers={})
        self.client.describe_endpoints.assert_called_with()

    def test_describe_endpoint_ids_and_operation(self):
        cache = {}
        self.construct_manager(cache=cache)
        ids = {'Foo': 'value1', 'Bar': 'value2'}
        kwargs = {
            'Operation': 'TestDiscoveryRequired',
            'Identifiers': ids,
        }
        self.manager.describe_endpoint(**kwargs)
        self.client.describe_endpoints.assert_called_with(**kwargs)
        key = ((('Bar', 'value2'), ('Foo', 'value1')), 'TestDiscoveryRequired')
        self.assertIn(key, cache)
        self.assertEqual(cache[key][0]['Address'], 'new.com')
        self.manager.describe_endpoint(**kwargs)
        call_count = self.client.describe_endpoints.call_count
        self.assertEqual(call_count, 1)

    def test_describe_endpoint_no_ids_or_operation(self):
        cache = {}
        describe = self.service_description['operations']['DescribeEndpoints']
        describe['input'] = {'shape': 'EmptyStruct'}
        self.construct_manager(cache=cache)
        self.manager.describe_endpoint(Operation='TestDiscoveryRequired',
                                       Identifiers={})
        self.client.describe_endpoints.assert_called_with()
        key = ()
        self.assertIn(key, cache)
        self.assertEqual(cache[key][0]['Address'], 'new.com')
        self.manager.describe_endpoint(Operation='TestDiscoveryRequired',
                                       Identifiers={})
        call_count = self.client.describe_endpoints.call_count
        self.assertEqual(call_count, 1)

    def test_describe_endpoint_expired_entry(self):
        current_time = time.time()
        key = ()
        cache = {
            key: [{
                'Address': 'old.com',
                'Expiration': current_time - 10
            }]
        }
        self.construct_manager(cache=cache)
        kwargs = {
            'Identifiers': {},
            'Operation': 'TestDiscoveryRequired',
        }
        self.manager.describe_endpoint(**kwargs)
        self.client.describe_endpoints.assert_called_with()
        self.assertIn(key, cache)
        self.assertEqual(cache[key][0]['Address'], 'new.com')
        self.manager.describe_endpoint(**kwargs)
        call_count = self.client.describe_endpoints.call_count
        self.assertEqual(call_count, 1)

    def test_describe_endpoint_cache_expiration(self):
        def _time():
            return float(0)

        cache = {}
        self.construct_manager(cache=cache, time=_time)
        self.manager.describe_endpoint(Operation='TestDiscoveryRequired',
                                       Identifiers={})
        key = ()
        self.assertIn(key, cache)
        self.assertEqual(cache[key][0]['Expiration'], float(120))

    def test_delete_endpoints_present(self):
        key = ()
        cache = {key: [{'Address': 'old.com', 'Expiration': 0}]}
        self.construct_manager(cache=cache)
        kwargs = {
            'Identifiers': {},
            'Operation': 'TestDiscoveryRequired',
        }
        self.manager.delete_endpoints(**kwargs)
        self.assertEqual(cache, {})

    def test_delete_endpoints_absent(self):
        cache = {}
        self.construct_manager(cache=cache)
        kwargs = {
            'Identifiers': {},
            'Operation': 'TestDiscoveryRequired',
        }
        self.manager.delete_endpoints(**kwargs)
        self.assertEqual(cache, {})

    def test_describe_endpoint_optional_fails_no_cache(self):
        side_effect = [ConnectionError(error=None)]
        self.construct_manager(side_effect=side_effect)
        kwargs = {'Operation': 'TestDiscoveryOptional'}
        endpoint = self.manager.describe_endpoint(**kwargs)
        self.assertIsNone(endpoint)
        # This second call should be blocked as we just failed
        endpoint = self.manager.describe_endpoint(**kwargs)
        self.assertIsNone(endpoint)
        self.client.describe_endpoints.call_args_list == [mock.call()]

    def test_describe_endpoint_optional_fails_stale_cache(self):
        key = ()
        cache = {key: [{'Address': 'old.com', 'Expiration': 0}]}
        side_effect = [ConnectionError(error=None)] * 2
        self.construct_manager(cache=cache, side_effect=side_effect)
        kwargs = {'Operation': 'TestDiscoveryOptional'}
        endpoint = self.manager.describe_endpoint(**kwargs)
        self.assertEqual(endpoint, 'old.com')
        # This second call shouldn't go through as we just failed
        endpoint = self.manager.describe_endpoint(**kwargs)
        self.assertEqual(endpoint, 'old.com')
        self.client.describe_endpoints.call_args_list == [mock.call()]

    def test_describe_endpoint_required_fails_no_cache(self):
        side_effect = [ConnectionError(error=None)] * 2
        self.construct_manager(side_effect=side_effect)
        kwargs = {'Operation': 'TestDiscoveryRequired'}
        with self.assertRaises(EndpointDiscoveryRefreshFailed):
            self.manager.describe_endpoint(**kwargs)
        # This second call should go through, as we have no cache
        with self.assertRaises(EndpointDiscoveryRefreshFailed):
            self.manager.describe_endpoint(**kwargs)
        describe_count = self.client.describe_endpoints.call_count
        self.assertEqual(describe_count, 2)

    def test_describe_endpoint_required_fails_stale_cache(self):
        key = ()
        cache = {key: [{'Address': 'old.com', 'Expiration': 0}]}
        side_effect = [ConnectionError(error=None)] * 2
        self.construct_manager(cache=cache, side_effect=side_effect)
        kwargs = {'Operation': 'TestDiscoveryRequired'}
        endpoint = self.manager.describe_endpoint(**kwargs)
        self.assertEqual(endpoint, 'old.com')
        # We have a stale endpoint, so this shouldn't fail or force a refresh
        endpoint = self.manager.describe_endpoint(**kwargs)
        self.assertEqual(endpoint, 'old.com')
        self.client.describe_endpoints.call_args_list == [mock.call()]

    def test_describe_endpoint_required_force_refresh_success(self):
        side_effect = [
            ConnectionError(error=None),
            {
                'Endpoints': [{
                    'Address': 'new.com',
                    'CachePeriodInMinutes': 2,
                }]
            },
        ]
        self.construct_manager(side_effect=side_effect)
        kwargs = {'Operation': 'TestDiscoveryRequired'}
        # First call will fail
        with self.assertRaises(EndpointDiscoveryRefreshFailed):
            self.manager.describe_endpoint(**kwargs)
        self.client.describe_endpoints.call_args_list == [mock.call()]
        # Force a refresh if the cache is empty but discovery is required
        endpoint = self.manager.describe_endpoint(**kwargs)
        self.assertEqual(endpoint, 'new.com')

    def test_describe_endpoint_retries_after_failing(self):
        fake_time = mock.Mock()
        fake_time.side_effect = [0, 100, 200]
        side_effect = [
            ConnectionError(error=None),
            {
                'Endpoints': [{
                    'Address': 'new.com',
                    'CachePeriodInMinutes': 2,
                }]
            },
        ]
        self.construct_manager(side_effect=side_effect, time=fake_time)
        kwargs = {'Operation': 'TestDiscoveryOptional'}
        endpoint = self.manager.describe_endpoint(**kwargs)
        self.assertIsNone(endpoint)
        self.client.describe_endpoints.call_args_list == [mock.call()]
        # Second time should try again as enough time has elapsed
        endpoint = self.manager.describe_endpoint(**kwargs)
        self.assertEqual(endpoint, 'new.com')
Ejemplo n.º 43
0
class TestEndpointDiscoveryHandler(BaseEndpointDiscoveryTest):
    def setUp(self):
        super(TestEndpointDiscoveryHandler, self).setUp()
        self.manager = Mock(spec=EndpointDiscoveryManager)
        self.handler = EndpointDiscoveryHandler(self.manager)
        self.service_model = ServiceModel(self.service_description)

    def test_register_handler(self):
        events = Mock(spec=HierarchicalEmitter)
        self.handler.register(events, 'foo-bar')
        events.register.assert_any_call(
            'before-parameter-build.foo-bar', self.handler.gather_identifiers
        )
        events.register.assert_any_call(
            'needs-retry.foo-bar', self.handler.handle_retries
        )
        events.register_first.assert_called_with(
            'request-created.foo-bar', self.handler.discover_endpoint
        )

    def test_discover_endpoint(self):
        request = AWSRequest()
        request.context = {
            'discovery': {'identifiers': {}}
        }
        self.manager.describe_endpoint.return_value = 'https://new.foo'
        self.handler.discover_endpoint(request, 'TestOperation')
        self.assertEqual(request.url, 'https://new.foo')
        self.manager.describe_endpoint.assert_called_with(
            Operation='TestOperation', Identifiers={}
        )

    def test_discover_endpoint_fails(self):
        request = AWSRequest()
        request.context = {
            'discovery': {'identifiers': {}}
        }
        request.url = 'old.com'
        self.manager.describe_endpoint.return_value = None
        self.handler.discover_endpoint(request, 'TestOperation')
        self.assertEqual(request.url, 'old.com')
        self.manager.describe_endpoint.assert_called_with(
            Operation='TestOperation', Identifiers={}
        )

    def test_discover_endpoint_no_protocol(self):
        request = AWSRequest()
        request.context = {
            'discovery': {'identifiers': {}}
        }
        self.manager.describe_endpoint.return_value = 'new.foo'
        self.handler.discover_endpoint(request, 'TestOperation')
        self.assertEqual(request.url, 'https://new.foo')
        self.manager.describe_endpoint.assert_called_with(
            Operation='TestOperation', Identifiers={}
        )

    def test_inject_no_context(self):
        request = AWSRequest(url='https://original.foo')
        self.handler.discover_endpoint(request, 'TestOperation')
        self.assertEqual(request.url, 'https://original.foo')
        self.manager.describe_endpoint.assert_not_called()

    def test_gather_identifiers(self):
        context = {}
        params = {
            'Foo': 'value1',
            'Nested': {'Bar': 'value2'}
        }
        ids = {
            'Foo': 'value1',
            'Bar': 'value2'
        }
        model = self.service_model.operation_model('TestDiscoveryRequired')
        self.manager.gather_identifiers.return_value = ids
        self.handler.gather_identifiers(params, model, context)
        self.assertEqual(context['discovery']['identifiers'], ids)

    def test_gather_identifiers_not_discoverable(self):
        context = {}
        model = self.service_model.operation_model('DescribeEndpoints')
        self.handler.gather_identifiers({}, model, context)
        self.assertEqual(context, {})

    def test_discovery_disabled_but_required(self):
        model = self.service_model.operation_model('TestDiscoveryRequired')
        with self.assertRaises(EndpointDiscoveryRequired):
            block_endpoint_discovery_required_operations(model)

    def test_discovery_disabled_but_optional(self):
        context = {}
        model = self.service_model.operation_model('TestDiscoveryOptional')
        block_endpoint_discovery_required_operations(model, context=context)
        self.assertEqual(context, {})

    def test_does_not_retry_no_response(self):
        retry = self.handler.handle_retries(None, None, None)
        self.assertIsNone(retry)

    def test_does_not_retry_other_errors(self):
        parsed_response = {
            'ResponseMetadata': {'HTTPStatusCode': 200}
        }
        response = (None, parsed_response)
        retry = self.handler.handle_retries(None, response, None)
        self.assertIsNone(retry)

    def test_does_not_retry_if_no_context(self):
        request_dict = {'context': {}}
        parsed_response = {
            'ResponseMetadata': {'HTTPStatusCode': 421}
        }
        response = (None, parsed_response)
        retry = self.handler.handle_retries(request_dict, response, None)
        self.assertIsNone(retry)

    def _assert_retries(self, parsed_response):
        request_dict = {
            'context': {
                'discovery': {'identifiers': {}}
            }
        }
        response = (None, parsed_response)
        model = self.service_model.operation_model('TestDiscoveryOptional')
        retry = self.handler.handle_retries(request_dict, response, model)
        self.assertEqual(retry, 0)
        self.manager.delete_endpoints.assert_called_with(
            Operation='TestDiscoveryOptional', Identifiers={}
        )

    def test_retries_421_status_code(self):
        parsed_response = {
            'ResponseMetadata': {'HTTPStatusCode': 421}
        }
        self._assert_retries(parsed_response)

    def test_retries_invalid_endpoint_exception(self):
        parsed_response = {'Error': {'Code': 'InvalidEndpointException'}}
        self._assert_retries(parsed_response)
Ejemplo n.º 44
0
class ShapeParser:
    """
    Parser for botocore shape files.

    Arguments:
        session -- Boto3 session.
        service_name -- ServiceName.
    """

    # Type map for shape types.
    SHAPE_TYPE_MAP = {
        "integer": Type.int,
        "long": Type.int,
        "boolean": Type.bool,
        "double": Type.float,
        "float": Type.float,
        "timestamp": ExternalImport(ImportString("datetime"), "datetime"),
        "blob": TypeSubscript(Type.Union, [Type.bytes, Type.IO]),
    }

    # Alias map fixes added by botocore for documentation build.
    # https://github.com/boto/botocore/blob/develop/botocore/handlers.py#L773
    # https://github.com/boto/botocore/blob/develop/botocore/handlers.py#L1055
    ARGUMENT_ALIASES: Dict[str, Dict[str, Dict[str, str]]] = {
        ServiceNameCatalog.cloudsearchdomain.boto3_name: {
            "Search": {"return": "returnFields"}
        },
        ServiceNameCatalog.logs.boto3_name: {"CreateExportTask": {"from": "fromTime"}},
        ServiceNameCatalog.ec2.boto3_name: {"*": {"Filter": "Filters"}},
        ServiceNameCatalog.s3.boto3_name: {
            "PutBucketAcl": {"ContentMD5": "None"},
            "PutBucketCors": {"ContentMD5": "None"},
            "PutBucketLifecycle": {"ContentMD5": "None"},
            "PutBucketLogging": {"ContentMD5": "None"},
            "PutBucketNotification": {"ContentMD5": "None"},
            "PutBucketPolicy": {"ContentMD5": "None"},
            "PutBucketReplication": {"ContentMD5": "None"},
            "PutBucketRequestPayment": {"ContentMD5": "None"},
            "PutBucketTagging": {"ContentMD5": "None"},
            "PutBucketVersioning": {"ContentMD5": "None"},
            "PutBucketWebsite": {"ContentMD5": "None"},
            "PutObjectAcl": {"ContentMD5": "None"},
        },
    }

    def __init__(self, session: Session, service_name: ServiceName):
        loader = session._loader  # pylint: disable=protected-access
        botocore_session: BotocoreSession = session._session  # pylint: disable=protected-access
        service_data = botocore_session.get_service_data(service_name.boto3_name)
        self.service_name = service_name
        self.service_model = ServiceModel(service_data, service_name.boto3_name)
        self._typed_dict_map: Dict[str, TypeTypedDict] = {}
        self._waiters_shape: Shape = {}
        try:
            self._waiters_shape = loader.load_service_model(
                service_name.boto3_name, "waiters-2"
            )
        except UnknownServiceError:
            pass
        self._paginators_shape: Shape = {}
        try:
            self._paginators_shape = loader.load_service_model(
                service_name.boto3_name, "paginators-1"
            )
        except UnknownServiceError:
            pass
        self._resources_shape: Shape = {}
        try:
            self._resources_shape = loader.load_service_model(
                service_name.boto3_name, "resources-1"
            )
        except UnknownServiceError:
            pass

        self.logger = get_logger()

    def _get_operation(self, name: str) -> OperationModel:
        return self.service_model.operation_model(name)

    def _get_operation_names(self) -> List[str]:
        return list(
            self.service_model.operation_names
        )  # pylint: disable=not-an-iterable

    def _get_paginator(self, name: str) -> Shape:
        try:
            return self._paginators_shape["pagination"][name]
        except KeyError:
            raise ShapeParserError(f"Unknown paginator: {name}")

    def _get_service_resource(self) -> Shape:
        return self._resources_shape["service"]

    def _get_resource_shape(self, name: str) -> Shape:
        try:
            return self._resources_shape["resources"][name]
        except KeyError:
            raise ShapeParserError(f"Unknown resource: {name}")

    def get_paginator_names(self) -> List[str]:
        """
        Get available paginator names.

        Returns:
            A list of paginator names.
        """
        result: List[str] = []
        for name in self._paginators_shape.get("pagination", []):
            result.append(name)
        result.sort()
        return result

    def _get_argument_alias(self, operation_name: str, argument_name: str) -> str:
        service_map = self.ARGUMENT_ALIASES.get(self.service_name.boto3_name)
        if not service_map:
            return argument_name

        operation_map: Dict[str, str] = {}
        if "*" in service_map:
            operation_map = service_map["*"]
        if operation_name in service_map:
            operation_map = service_map[operation_name]

        if not operation_map:
            return argument_name

        if argument_name not in operation_map:
            return argument_name

        return operation_map[argument_name]

    def _parse_arguments(
        self,
        class_name: str,
        method_name: str,
        operation_name: str,
        shape: StructureShape,
    ) -> List[Argument]:
        result: List[Argument] = []
        required = shape.required_members
        for argument_name, argument_shape in shape.members.items():
            argument_alias = self._get_argument_alias(operation_name, argument_name)
            if argument_alias == "None":
                continue

            argument_type_stub = get_method_type_stub(
                self.service_name, class_name, method_name, argument_name
            )
            if argument_type_stub is not None:
                argument_type = argument_type_stub
            else:
                argument_type = self._parse_shape(argument_shape)
            argument = Argument(argument_alias, argument_type)
            if argument_name not in required:
                argument.default = Type.none
            result.append(argument)

        result.sort(key=lambda x: not x.required)
        return result

    def _parse_return_type(
        self, class_name: str, method_name: str, shape: Optional[Shape]
    ) -> FakeAnnotation:
        argument_type_stub = get_method_type_stub(
            self.service_name, class_name, method_name, "return"
        )
        if argument_type_stub is not None:
            return argument_type_stub

        if shape:
            return self._parse_shape(shape)

        return Type.none

    def get_client_method_map(self) -> Dict[str, Method]:
        """
        Get client methods from shape.

        Returns:
            A map of method name to Method.
        """
        result: Dict[str, Method] = {
            "can_paginate": Method(
                "can_paginate",
                [Argument("self", None), Argument("operation_name", Type.str)],
                Type.bool,
            ),
            "generate_presigned_url": Method(
                "generate_presigned_url",
                [
                    Argument("self", None),
                    Argument("ClientMethod", Type.str),
                    Argument("Params", Type.DictStrAny, Type.none),
                    Argument("ExpiresIn", Type.int, TypeConstant(3600)),
                    Argument("HttpMethod", Type.str, Type.none),
                ],
                Type.str,
            ),
        }
        for operation_name in self._get_operation_names():
            operation_model = self._get_operation(operation_name)
            arguments: List[Argument] = [Argument("self", None)]
            method_name = xform_name(operation_name)

            if operation_model.input_shape is not None:
                arguments.extend(
                    self._parse_arguments(
                        "Client",
                        method_name,
                        operation_name,
                        operation_model.input_shape,
                    )
                )

            return_type = self._parse_return_type(
                "Client", method_name, operation_model.output_shape
            )

            method = Method(
                name=method_name, arguments=arguments, return_type=return_type
            )
            result[method.name] = method
        return result

    @staticmethod
    def _parse_shape_string(shape: StringShape) -> FakeAnnotation:
        if not shape.enum:
            return Type.str

        type_literal = TypeLiteral()
        for option in shape.enum:
            type_literal.add_literal_child(option)

        return type_literal

    def _parse_shape_map(self, shape: MapShape) -> FakeAnnotation:
        type_subscript = TypeSubscript(Type.Dict)
        if shape.key:
            type_subscript.add_child(self._parse_shape(shape.key))
        else:
            type_subscript.add_child(Type.str)
        if shape.value:
            type_subscript.add_child(self._parse_shape(shape.value))
        else:
            type_subscript.add_child(Type.Any)
        return type_subscript

    def _parse_shape_structure(self, shape: StructureShape) -> FakeAnnotation:
        if not shape.members.items():
            return Type.DictStrAny

        required = shape.required_members
        typed_dict_name = f"{shape.name}TypeDef"
        shape_type_stub = get_shape_type_stub(self.service_name, typed_dict_name)
        if shape_type_stub:
            return shape_type_stub

        if typed_dict_name in self._typed_dict_map:
            return self._typed_dict_map[typed_dict_name]
        typed_dict = TypeTypedDict(typed_dict_name)
        self._typed_dict_map[typed_dict_name] = typed_dict
        for attr_name, attr_shape in shape.members.items():
            typed_dict.add_attribute(
                attr_name, self._parse_shape(attr_shape), attr_name in required,
            )
        return typed_dict

    def _parse_shape_list(self, shape: ListShape) -> FakeAnnotation:
        type_subscript = TypeSubscript(Type.List)
        if shape.member:
            type_subscript.add_child(self._parse_shape(shape.member))
        else:
            type_subscript.add_child(Type.Any)
        return type_subscript

    def _parse_shape(self, shape: Shape) -> FakeAnnotation:
        if shape.type_name in self.SHAPE_TYPE_MAP:
            return self.SHAPE_TYPE_MAP[shape.type_name]

        if isinstance(shape, StringShape):
            return self._parse_shape_string(shape)

        if isinstance(shape, MapShape):
            return self._parse_shape_map(shape)

        if isinstance(shape, StructureShape):
            return self._parse_shape_structure(shape)

        if isinstance(shape, ListShape):
            return self._parse_shape_list(shape)

        if shape.type_name in self._resources_shape["resources"]:
            return AliasInternalImport(shape.type_name)

        self.logger.warning(f"Unknown shape: {shape}")
        return Type.Any

    def get_paginate_method(self, paginator_name: str) -> Method:
        """
        Get Paginator `paginate` method.

        Arguments:
            paginator_name -- Paginator name.

        Returns:
            Method.
        """
        operation_name = paginator_name
        paginator_shape = self._get_paginator(paginator_name)
        operation_shape = self._get_operation(operation_name)
        skip_argument_names: List[str] = []
        input_token = paginator_shape["input_token"]
        if isinstance(input_token, list):
            skip_argument_names.extend(input_token)
        else:
            skip_argument_names.append(input_token)
        if "limit_key" in paginator_shape:
            skip_argument_names.append(paginator_shape["limit_key"])

        arguments: List[Argument] = [Argument("self", None)]

        if operation_shape.input_shape is not None:
            for argument in self._parse_arguments(
                "Paginator", "paginate", operation_name, operation_shape.input_shape
            ):
                if argument.name in skip_argument_names:
                    continue
                arguments.append(argument)

        arguments.append(Argument("PaginationConfig", paginator_config_type, Type.none))

        return_type: FakeAnnotation = Type.none
        if operation_shape.output_shape is not None:
            return_type = TypeSubscript(
                Type.Iterator,
                [
                    self._parse_return_type(
                        "Paginator", "paginate", operation_shape.output_shape
                    ),
                ],
            )

        return Method("paginate", arguments, return_type)

    def get_wait_method(self, waiter_name: str) -> Method:
        """
        Get Waiter `wait` method.

        Arguments:
            waiter_name -- Waiter name.

        Returns:
            Method.
        """
        operation_name = self._waiters_shape["waiters"][waiter_name]["operation"]
        operation_shape = self._get_operation(operation_name)

        arguments: List[Argument] = [Argument("self", None)]

        if operation_shape.input_shape is not None:
            arguments.extend(
                self._parse_arguments(
                    "Waiter", "wait", operation_name, operation_shape.input_shape
                )
            )

        arguments.append(Argument("WaiterConfig", waiter_config_type, Type.none))

        return Method(name="wait", arguments=arguments, return_type=Type.none)

    def get_service_resource_method_map(self) -> Dict[str, Method]:
        """
        Get methods for ServiceResource.

        Returns:
            A map of method name to Method.
        """
        result: Dict[str, Method] = {
            "get_available_subresources": Method(
                "get_available_subresources",
                [Argument("self", None)],
                TypeSubscript(Type.List, [Type.str]),
            ),
        }
        service_resource_shape = self._get_service_resource()
        for action_name, action_shape in service_resource_shape.get(
            "actions", {}
        ).items():
            method = self._get_resource_method(
                "ServiceResource", action_name, action_shape
            )
            result[method.name] = method

        return result

    def get_resource_method_map(self, resource_name: str) -> Dict[str, Method]:
        """
        Get methods for Resource.

        Arguments:
            resource_name -- Resource name.

        Returns:
            A map of method name to Method.
        """
        resource_shape = self._get_resource_shape(resource_name)
        result: Dict[str, Method] = {
            "get_available_subresources": Method(
                "get_available_subresources",
                [Argument("self", None)],
                TypeSubscript(Type.List, [Type.str]),
            ),
            "load": Method("load", [Argument("self", None)], Type.none),
            "reload": Method("reload", [Argument("self", None)], Type.none),
        }

        for action_name, action_shape in resource_shape.get("actions", {}).items():
            method = self._get_resource_method(resource_name, action_name, action_shape)
            result[method.name] = method

        for waiter_name in resource_shape.get("waiters", {}):
            method = Method(
                f"wait_until_{xform_name(waiter_name)}",
                [Argument("self", None)],
                Type.none,
            )
            result[method.name] = method

        return result

    def _get_resource_method(
        self, resource_name: str, action_name: str, action_shape: Dict[str, Any]
    ) -> Method:
        return_type: FakeAnnotation = Type.none
        method_name = xform_name(action_name)
        arguments: List[Argument] = [Argument("self", None)]
        if "resource" in action_shape:
            return_type = self._parse_return_type(
                resource_name, method_name, Shape("resource", action_shape["resource"])
            )
            path = action_shape["resource"].get("path", "")
            if path.endswith("[]"):
                return_type = TypeSubscript(Type.List, [return_type])

        if "request" in action_shape:
            operation_name = action_shape["request"]["operation"]
            operation_shape = self._get_operation(operation_name)
            skip_argument_names: List[str] = [
                i["target"]
                for i in action_shape["request"].get("params", {})
                if i["source"] == "identifier"
            ]
            if operation_shape.input_shape is not None:
                for argument in self._parse_arguments(
                    resource_name,
                    method_name,
                    operation_name,
                    operation_shape.input_shape,
                ):
                    if argument.name not in skip_argument_names:
                        arguments.append(argument)
            if operation_shape.output_shape is not None and return_type is Type.none:
                operation_return_type = self._parse_shape(operation_shape.output_shape)
                return_type = operation_return_type

        return Method(name=method_name, arguments=arguments, return_type=return_type)

    def get_collection_filter_method(
        self, name: str, collection: Collection, self_type: FakeAnnotation
    ) -> Method:
        """
        Get `filter` classmethod for Resource collection.

        Arguments:
            name -- Collection record name.
            collection -- Boto3 Collection.
            class_type -- Collection class type annotation.

        Returns:
            Filter Method record.
        """
        arguments: List[Argument] = [Argument("self", None)]
        result = Method("filter", arguments, self_type)
        if not collection.request:
            return result

        operation_name = collection.request.operation
        operation_model = self._get_operation(operation_name)

        if operation_model.input_shape is not None:
            for argument in self._parse_arguments(
                name, result.name, operation_name, operation_model.input_shape,
            ):
                if argument.required:
                    continue
                arguments.append(argument)

        return result

    def get_collection_batch_methods(
        self, name: str, collection: Collection
    ) -> List[Method]:
        """
        Get batch operations for Resource collection.

        Arguments:
            name -- Collection record name.
            collection -- Boto3 Collection.
            class_type -- Collection self type annotation.

        Returns:
            List of Method records.
        """
        result = []
        for batch_action in collection.batch_actions:
            method = Method(batch_action.name, [Argument("self", None)], Type.none)
            result.append(method)
            if batch_action.request:
                operation_name = batch_action.request.operation
                operation_model = self._get_operation(operation_name)
                if operation_model.input_shape is not None:
                    for argument in self._parse_arguments(
                        name,
                        batch_action.name,
                        operation_name,
                        operation_model.input_shape,
                    ):
                        if argument.required:
                            continue
                        method.arguments.append(argument)
                if operation_model.output_shape is not None:
                    return_type = self._parse_shape(operation_model.output_shape)
                    method.return_type = return_type

        return result
Ejemplo n.º 45
0
    def test_RdbSerializer_fix_get_metrics_statistics_params(self):
        rdb_model = {
            "metadata": self.rdb_model_metadata,
            "operations": {
                "NiftyGetMetricStatistics": {
                    "http": {
                        "method": "POST",
                        "requestUri": "/"
                    },
                    "input": {
                        "shape": "NiftyGetMetricStatisticsRequest"
                    },
                    "name": "NiftyGetMetricStatistics",
                    "output": {
                        "resultWrapper": "RdbOperationResult",
                        "shape": "RdbOperationResult"
                    }
                },
            },
            "shapes": {
                "NiftyGetMetricStatisticsRequest": {
                    "members": {
                        "Dimensions": {
                            "locationName": "Dimensions",
                            "shape": "ListOfRequestDimensions"
                        },
                        "EndTime": {
                            "locationName": "EndTime",
                            "shape": "TStamp"
                        },
                        "MetricName": {
                            "locationName": "MetricName",
                            "shape": "String"
                        },
                        "StartTime": {
                            "locationName": "StartTime",
                            "shape": "TStamp"
                        }
                    },
                    "name": "NiftyGetMetricStatisticsRequest",
                    "type": "structure"
                },
                "ListOfRequestDimensions": {
                    "member": {
                        "locationName": "member",
                        "shape": "RequestDimensions"
                    },
                    "name": "ListOfRequestDimensions",
                    "type": "list"
                },
                "RequestDimensions": {
                    "members": {
                        "Name": {
                            "locationName": "Name",
                            "shape": "String"
                        },
                        "Value": {
                            "locationName": "Value",
                            "shape": "String"
                        }
                    },
                    "name": "RequestDimensions",
                    "required": [
                        "Name",
                        "Value"
                    ],
                    "type": "structure"
                },
                "RdbOperationResult": {
                    "members": {
                        "Response": {
                            "locationName": "Response",
                            "shape": "String"
                        }
                    },
                    "name": "RdbOperationResult",
                    "type": "structure"
                },
                "String": {
                    "name": "String",
                    "type": "string"
                },
                "TStamp": {
                    "name": "TStamp",
                    "type": "timestamp"
                }
            }
        }

        rdb_service_model = ServiceModel(rdb_model)
        params = {}
        rdb_serializer = serialize.RdbSerializer()
        res = rdb_serializer.serialize_to_request(
            params, rdb_service_model.operation_model("NiftyGetMetricStatistics"))
        assert res["body"] == {
            "Action": "NiftyGetMetricStatistics",
            "Version": "2013-05-15N2013-12-16"
        }
        assert res["headers"] == {"Content-Type": "application/x-www-form-urlencoded; charset=utf-8"}
        assert res["method"] == "POST"
        assert res["query_string"] == ""
        assert res["url_path"] == "/"