Beispiel #1
0
    def test_malformed_request_body(self, request_data, bad_request):
        ex = os_ex.MalformedRequestBody()
        request_data.side_effect = _raise(ex)
        m_func = Mock()
        m_func.__name__ = "m_func"

        v.validate(m_func)(m_func)()

        self._assert_calls(bad_request,
                           (1, 'MALFORMED_REQUEST_BODY',
                            'Malformed message body: %(reason)s'))
Beispiel #2
0
    def _assert_cluster_scaling_validation(self,
                                           bad_req=None,
                                           req_data=None,
                                           data=None,
                                           bad_req_i=None):
        m_func = mock.Mock()
        m_func.__name__ = "m_func"
        req_data.return_value = data
        v.validate(c_s.CLUSTER_SCALING_SCHEMA,
                   self._create_object_fun)(m_func)(data=data, cluster_id='42')

        self.assertEqual(req_data.call_count, 1)
        self._assert_calls(bad_req, bad_req_i)
Beispiel #3
0
    def _assert_create_object_validation(
            self, bad_req=None, request_data=None,
            data=None, bad_req_i=None):

        request_data.return_value = data
        # mock function that should be validated
        patchers = start_patch()
        m_func = mock.Mock()
        m_func.__name__ = "m_func"
        v.validate(self.scheme, self._create_object_fun)(m_func)(data=data)

        self.assertEqual(request_data.call_count, 1)
        self._assert_calls(bad_req, bad_req_i)
        stop_patch(patchers)
    def _assert_cluster_scaling_validation(self,
                                           bad_req=None,
                                           req_data=None,
                                           data=None,
                                           bad_req_i=None):
        m_func = mock.Mock()
        m_func.__name__ = "m_func"
        req_data.return_value = data
        v.validate(c_s.CLUSTER_SCALING_SCHEMA,
                   self._create_object_fun)(m_func)(data=data,
                                                    cluster_id='42')

        self.assertEqual(req_data.call_count, 1)
        self._assert_calls(bad_req, bad_req_i)
Beispiel #5
0
    def _assert_create_object_validation(
            self, bad_req=None, request_data=None,
            data=None, bad_req_i=None):

        request_data.return_value = data
        # mock function that should be validated
        patchers = start_patch()
        m_func = mock.Mock()
        m_func.__name__ = "m_func"
        v.validate(self.scheme, self._create_object_fun)(m_func)(data=data)

        self.assertEqual(request_data.call_count, 1)
        self._assert_calls(bad_req, bad_req_i)
        stop_patch(patchers)
Beispiel #6
0
    def _assert_create_object_validation(
            self, scheme, data, bad_req_i=None,
            not_found_i=None, int_err_i=None):

        bad_req, int_err, not_found, request_data, patchers = \
            self.start_patch(data)
        # mock function that should be validated
        m_func = mock.Mock()
        m_func.__name__ = "m_func"
        v.validate(scheme, self._create_object_fun)(m_func)(data=data)

        self.assertEqual(request_data.call_count, 1)
        self._assert_calls(bad_req, bad_req_i)
        self._assert_calls(not_found, not_found_i)
        self._assert_calls(int_err, int_err_i)
        self.stop_patch(patchers)
Beispiel #7
0
    def _assert_create_object_validation(
            self, data, bad_req_i=None, not_found_i=None, int_err_i=None):

        request_data_p = patch("savanna.utils.api.request_data")
        bad_req_p = patch("savanna.utils.api.bad_request")
        not_found_p = patch("savanna.utils.api.not_found")
        int_err_p = patch("savanna.utils.api.internal_error")
        get_clusters_p = patch("savanna.service.api.get_clusters")
        get_templates_p = patch("savanna.service.api.get_node_templates")
        get_template_p = patch("savanna.service.api.get_node_template")
        get_types_p = patch("savanna.service.api.get_node_types")
        get_node_type_required_params_p = \
            patch("savanna.service.api.get_node_type_required_params")
        get_node_type_all_params_p = \
            patch("savanna.service.api.get_node_type_all_params")
        patchers = (request_data_p, bad_req_p, not_found_p, int_err_p,
                    get_clusters_p, get_templates_p, get_template_p,
                    get_types_p, get_node_type_required_params_p,
                    get_node_type_all_params_p)

        request_data = request_data_p.start()
        bad_req = bad_req_p.start()
        not_found = not_found_p.start()
        int_err = int_err_p.start()
        get_clusters = get_clusters_p.start()
        get_templates = get_templates_p.start()
        get_template = get_template_p.start()
        get_types = get_types_p.start()
        get_node_type_required_params = get_node_type_required_params_p.start()
        get_node_type_all_params = get_node_type_all_params_p.start()

        # stub clusters list
        get_clusters.return_value = getattr(self, "_clusters_data", [
            Resource("cluster", {
                "name": "some-cluster-1"
            })
        ])

        # stub node templates
        get_templates.return_value = getattr(self, "_templates_data", [
            Resource("node_template", {
                "name": "jt_nn.small",
                "node_type": {
                    "name": "JT+NN",
                    "processes": ["job_tracker", "name_node"]
                }
            }),
            Resource("node_template", {
                "name": "nn.small",
                "node_type": {
                    "name": "NN",
                    "processes": ["name_node"]
                }
            })
        ])

        def _get_template(name):
            for template in get_templates():
                if template.name == name:
                    return template
            return None

        get_template.side_effect = _get_template

        get_types.return_value = getattr(self, "_types_data", [
            Resource("node_type", {
                "name": "JT+NN",
                "processes": ["job_tracker", "name_node"]
            })
        ])

        def _get_r_params(name):
            if name == "JT+NN":
                return {"job_tracker": ["jt_param"]}
            return dict()

        get_node_type_required_params.side_effect = _get_r_params

        def _get_all_params(name):
            if name == "JT+NN":
                return {"job_tracker": ["jt_param"]}
            return dict()

        get_node_type_all_params.side_effect = _get_all_params

        # mock function that should be validated
        m_func = Mock()
        m_func.__name__ = "m_func"

        # request data to validate
        request_data.return_value = data

        v.validate(self._create_object_fun)(m_func)()

        self.assertEqual(request_data.call_count, 1)

        self._assert_calls(bad_req, bad_req_i)
        self._assert_calls(not_found, not_found_i)
        self._assert_calls(int_err, int_err_i)

        for patcher in patchers:
            patcher.stop()