Example #1
0
    def _mock_http_get(self, *_, **kwargs):
        if "foo.bar" == kwargs[
                'endpoint'] and not self._mock_imds_expect_fallback:
            raise Exception("Unexpected endpoint called")
        if self._mock_imds_primary_ioerror and "169.254.169.254" == kwargs[
                'endpoint']:
            raise HttpError(
                "[HTTP Failed] GET http://{0}/metadata/{1} -- IOError timed out -- 6 attempts made"
                .format(kwargs['endpoint'], kwargs['resource_path']))
        if self._mock_imds_secondary_ioerror and "foo.bar" == kwargs[
                'endpoint']:
            raise HttpError(
                "[HTTP Failed] GET http://{0}/metadata/{1} -- IOError timed out -- 6 attempts made"
                .format(kwargs['endpoint'], kwargs['resource_path']))
        if self._mock_imds_gone_error:
            raise ResourceGoneError("Resource is gone")
        if self._mock_imds_throttled:
            raise HttpError(
                "[HTTP Retry] GET http://{0}/metadata/{1} -- Status Code 429 -- 25 attempts made"
                .format(kwargs['endpoint'], kwargs['resource_path']))

        resp = MagicMock()
        resp.reason = 'reason'
        if self._mock_imds_bad_request:
            resp.status = httpclient.NOT_FOUND
            resp.read.return_value = 'Mock not found'
        else:
            resp.status = httpclient.OK
            resp.read.return_value = 'Mock success response'
        return resp
Example #2
0
    def _put_page_blob_status(self, sas_url, status_blob):
        url = URI_FORMAT_PUT_VM_STATUS.format(self.endpoint, HOST_PLUGIN_PORT)

        # Convert the status into a blank-padded string whose length is modulo 512
        status = bytearray(status_blob.data, encoding='utf-8')
        status_size = int((len(status) + 511) / 512) * 512
        status = bytearray(status_blob.data.ljust(status_size), encoding='utf-8')

        # First, initialize an empty blob
        response = restutil.http_put(url,
                                     data=self._build_status_data(
                                         sas_url,
                                         status_blob.get_page_blob_create_headers(status_size)),
                                     headers=self._build_status_headers())

        if restutil.request_failed(response):
            error_response = restutil.read_response_error(response)
            is_healthy = not restutil.request_failed_at_hostplugin(response)
            self.report_status_health(is_healthy=is_healthy, response=error_response)
            raise HttpError("HostGAPlugin: Failed PageBlob clean-up: {0}"
                            .format(error_response))
        else:
            self.report_status_health(is_healthy=True)
            logger.verbose("HostGAPlugin: PageBlob clean-up succeeded")
        
        # Then, upload the blob in pages
        if sas_url.count("?") <= 0:
            sas_url = "{0}?comp=page".format(sas_url)
        else:
            sas_url = "{0}&comp=page".format(sas_url)

        start = 0
        end = 0
        while start < len(status):
            # Create the next page
            end = start + min(len(status) - start, MAXIMUM_PAGEBLOB_PAGE_SIZE)
            page_size = int((end - start + 511) / 512) * 512
            buf = bytearray(page_size)
            buf[0: end - start] = status[start: end]

            # Send the page
            response = restutil.http_put(url,
                                         data=self._build_status_data(
                                             sas_url,
                                             status_blob.get_page_blob_page_headers(start, end),
                                             buf),
                                         headers=self._build_status_headers())

            if restutil.request_failed(response):
                error_response = restutil.read_response_error(response)
                is_healthy = not restutil.request_failed_at_hostplugin(response)
                self.report_status_health(is_healthy=is_healthy, response=error_response)
                raise HttpError(
                    "HostGAPlugin Error: Put PageBlob bytes "
                    "[{0},{1}]: {2}".format(start, end, error_response))

            # Advance to the next page (if any)
            start = end
Example #3
0
    def put_vm_log(self, content):
        """
        Try to upload VM logs, a compressed zip file, via the host plugin /vmAgentLog channel.
        :param content: the binary content of the zip file to upload
        """
        if not self.ensure_initialized():
            raise ProtocolError("HostGAPlugin: HostGAPlugin is not available")

        if content is None:
            raise ProtocolError(
                "HostGAPlugin: Invalid argument passed to upload VM logs. Content was not provided."
            )

        url = URI_FORMAT_PUT_LOG.format(self.endpoint, HOST_PLUGIN_PORT)
        response = restutil.http_put(url,
                                     data=content,
                                     headers=self._build_log_headers(),
                                     redact_data=True)

        if restutil.request_failed(response):  # pylint: disable=R1720
            error_response = restutil.read_response_error(response)
            raise HttpError("HostGAPlugin: Upload VM logs failed: {0}".format(
                error_response))

        return response
Example #4
0
def http_request(method, url, data, headers=None, max_retry=3,
                 chk_proxy=False):
    """
    Sending http request to server
    On error, sleep 10 and retry max_retry times.
    """
    logger.verbose("HTTP Req: {0} {1}", method, url)
    logger.verbose("    Data={0}", data)
    logger.verbose("    Header={0}", headers)
    host, port, secure, rel_uri = _parse_url(url)

    # Check proxy
    proxy_host, proxy_port = (None, None)
    if chk_proxy:
        proxy_host, proxy_port = get_http_proxy()

    # If httplib module is not built with ssl support. Fallback to http
    if secure and not hasattr(httpclient, "HTTPSConnection"):
        logger.warn("httplib is not built with ssl support")
        secure = False

    # If httplib module doesn't support https tunnelling. Fallback to http
    if secure and proxy_host is not None and proxy_port is not None \
            and not hasattr(httpclient.HTTPSConnection, "set_tunnel"):
        logger.warn("httplib does not support https tunnelling "
                    "(new in python 2.7)")
        secure = False

    for retry in range(0, max_retry):
        try:
            resp = _http_request(method, host, rel_uri, port=port, data=data,
                                 secure=secure, headers=headers,
                                 proxy_host=proxy_host, proxy_port=proxy_port)
            logger.verbose("HTTP Resp: Status={0}", resp.status)
            logger.verbose("    Header={0}", resp.getheaders())
            return resp
        except httpclient.HTTPException as e:
            logger.warn('HTTPException {0}, args:{1}', e, repr(e.args))
        except IOError as e:
            logger.warn('Socket IOError {0}, args:{1}', e, repr(e.args))

        if retry < max_retry - 1:
            logger.info("Retry={0}, {1} {2}", retry, method, url)
            time.sleep(RETRY_WAITING_INTERVAL)

    if url is not None and len(url) > 100:
        url_log = url[0: 100]  # In case the url is too long
    else:
        url_log = url
    raise HttpError("HTTP Err: {0} {1}".format(method, url_log))
Example #5
0
    def _put_block_blob_status(self, sas_url, status_blob):
        url = URI_FORMAT_PUT_VM_STATUS.format(self.endpoint, HOST_PLUGIN_PORT)

        response = restutil.http_put(url,
                        data=self._build_status_data(
                                    sas_url,
                                    status_blob.get_block_blob_headers(len(status_blob.data)),
                                    bytearray(status_blob.data, encoding='utf-8')),
                        headers=self._build_status_headers())

        if restutil.request_failed(response):
            raise HttpError("HostGAPlugin: Put BlockBlob failed: {0}".format(
                restutil.read_response_error(response)))
        else:
            logger.verbose("HostGAPlugin: Put BlockBlob status succeeded")
Example #6
0
    def get_compute(self):
        """
        Fetch compute information.

        :return: instance of a ComputeInfo
        :rtype: ComputeInfo
        """

        # ensure we get a 200
        result = self.get_metadata('instance/compute', is_health=False)
        if not result.success:
            raise HttpError(result.response)

        data = json.loads(ustr(result.response, encoding="utf-8"))

        compute_info = ComputeInfo()
        set_properties('compute', compute_info, data)

        return compute_info
Example #7
0
    def get_compute(self):
        """
        Fetch compute information.

        :return: instance of a ComputeInfo
        :rtype: ComputeInfo
        """

        resp = restutil.http_get(self.compute_url, headers=self._headers)

        if restutil.request_failed(resp):
            raise HttpError("{0} - GET: {1}".format(resp.status, self.compute_url))

        data = resp.read()
        data = json.loads(ustr(data, encoding="utf-8"))

        compute_info = ComputeInfo()
        set_properties('compute', compute_info, data)

        return compute_info
    def test_read_response_bytes(self):
        response_bytes = '7b:0a:20:20:20:20:22:65:72:72:6f:72:43:6f:64:65:22:' \
                         '3a:20:22:54:68:65:20:62:6c:6f:62:20:74:79:70:65:20:' \
                         '69:73:20:69:6e:76:61:6c:69:64:20:66:6f:72:20:74:68:' \
                         '69:73:20:6f:70:65:72:61:74:69:6f:6e:2e:22:2c:0a:20:' \
                         '20:20:20:22:6d:65:73:73:61:67:65:22:3a:20:22:c3:af:' \
                         'c2:bb:c2:bf:3c:3f:78:6d:6c:20:76:65:72:73:69:6f:6e:' \
                         '3d:22:31:2e:30:22:20:65:6e:63:6f:64:69:6e:67:3d:22:' \
                         '75:74:66:2d:38:22:3f:3e:3c:45:72:72:6f:72:3e:3c:43:' \
                         '6f:64:65:3e:49:6e:76:61:6c:69:64:42:6c:6f:62:54:79:' \
                         '70:65:3c:2f:43:6f:64:65:3e:3c:4d:65:73:73:61:67:65:' \
                         '3e:54:68:65:20:62:6c:6f:62:20:74:79:70:65:20:69:73:' \
                         '20:69:6e:76:61:6c:69:64:20:66:6f:72:20:74:68:69:73:' \
                         '20:6f:70:65:72:61:74:69:6f:6e:2e:0a:52:65:71:75:65:' \
                         '73:74:49:64:3a:63:37:34:32:39:30:63:62:2d:30:30:30:' \
                         '31:2d:30:30:62:35:2d:30:36:64:61:2d:64:64:36:36:36:' \
                         '61:30:30:30:22:2c:0a:20:20:20:20:22:64:65:74:61:69:' \
                         '6c:73:22:3a:20:22:22:0a:7d'.split(':')
        expected_response = '[status: reason] {\n    "errorCode": "The blob ' \
                            'type is invalid for this operation.",\n    ' \
                            '"message": "<?xml version="1.0" ' \
                            'encoding="utf-8"?>' \
                            '<Error><Code>InvalidBlobType</Code><Message>The ' \
                            'blob type is invalid for this operation.\n' \
                            'RequestId:c74290cb-0001-00b5-06da-dd666a000",' \
                            '\n    "details": ""\n}'

        response_string = ''.join(chr(int(b, 16)) for b in response_bytes)
        response = MagicMock()
        response.status = 'status'
        response.reason = 'reason'
        with patch.object(response, 'read') as patch_response:
            patch_response.return_value = response_string
            result = hostplugin.HostPluginProtocol.read_response_error(
                response)
            self.assertEqual(result, expected_response)
            try:
                raise HttpError("{0}".format(result))
            except HttpError as e:
                self.assertTrue(result in ustr(e))
Example #9
0
    def _put_block_blob_status(self, sas_url, status_blob):
        url = URI_FORMAT_PUT_VM_STATUS.format(self.endpoint, HOST_PLUGIN_PORT)

        response = restutil.http_put(url,
                                     data=self._build_status_data(
                                         sas_url,
                                         status_blob.get_block_blob_headers(
                                             len(status_blob.data)),
                                         bytearray(status_blob.data,
                                                   encoding='utf-8')),
                                     headers=self._build_status_headers())

        if restutil.request_failed(response):  # pylint: disable=R1720
            error_response = restutil.read_response_error(response)
            is_healthy = not restutil.request_failed_at_hostplugin(response)
            self.report_status_health(is_healthy=is_healthy,
                                      response=error_response)
            raise HttpError("HostGAPlugin: Put BlockBlob failed: {0}".format(
                error_response))
        else:
            self.report_status_health(is_healthy=True)
            logger.verbose("HostGAPlugin: Put BlockBlob status succeeded")
Example #10
0
 def http_post_handler(url, _, **__):
     if self.is_telemetry_request(url):
         return HttpError("A test exception")
     return None
def http_request(method,
                 url,
                 data,
                 headers=None,
                 use_proxy=False,
                 max_retry=DEFAULT_RETRIES,
                 retry_codes=RETRY_CODES,
                 retry_delay=DELAY_IN_SECONDS):

    global SECURE_WARNING_EMITTED

    host, port, secure, rel_uri = _parse_url(url)

    # Use the HTTP(S) proxy
    proxy_host, proxy_port = (None, None)
    if use_proxy:
        proxy_host, proxy_port = _get_http_proxy(secure=secure)

        if proxy_host or proxy_port:
            logger.verbose("HTTP proxy: [{0}:{1}]", proxy_host, proxy_port)

    # If httplib module is not built with ssl support,
    # fallback to HTTP if allowed
    if secure and not hasattr(httpclient, "HTTPSConnection"):
        if not conf.get_allow_http():
            raise HttpError("HTTPS is unavailable and required")

        secure = False
        if not SECURE_WARNING_EMITTED:
            logger.warn("Python does not include SSL support")
            SECURE_WARNING_EMITTED = True

    # If httplib module doesn't support HTTPS tunnelling,
    # fallback to HTTP if allowed
    if secure and \
        proxy_host is not None and \
        proxy_port is not None \
        and not hasattr(httpclient.HTTPSConnection, "set_tunnel"):

        if not conf.get_allow_http():
            raise HttpError("HTTPS tunnelling is unavailable and required")

        secure = False
        if not SECURE_WARNING_EMITTED:
            logger.warn("Python does not support HTTPS tunnelling")
            SECURE_WARNING_EMITTED = True

    msg = ''
    attempt = 0
    delay = 0
    was_throttled = False

    while attempt < max_retry:
        if attempt > 0:
            # Compute the request delay
            # -- Use a fixed delay if the server ever rate-throttles the request
            #    (with a safe, minimum number of retry attempts)
            # -- Otherwise, compute a delay that is the product of the next
            #    item in the Fibonacci series and the initial delay value
            delay = THROTTLE_DELAY_IN_SECONDS \
                        if was_throttled \
                        else _compute_delay(retry_attempt=attempt,
                                            delay=retry_delay)

            logger.verbose(
                "[HTTP Retry] "
                "Attempt {0} of {1} will delay {2} seconds: {3}", attempt + 1,
                max_retry, delay, msg)

            time.sleep(delay)

        attempt += 1

        try:
            resp = _http_request(method,
                                 host,
                                 rel_uri,
                                 port=port,
                                 data=data,
                                 secure=secure,
                                 headers=headers,
                                 proxy_host=proxy_host,
                                 proxy_port=proxy_port)
            logger.verbose("[HTTP Response] Status Code {0}", resp.status)

            if request_failed(resp):
                if _is_retry_status(resp.status, retry_codes=retry_codes):
                    msg = '[HTTP Retry] {0} {1} -- Status Code {2}'.format(
                        method, url, resp.status)
                    # Note if throttled and ensure a safe, minimum number of
                    # retry attempts
                    if _is_throttle_status(resp.status):
                        was_throttled = True
                        max_retry = max(max_retry, THROTTLE_RETRIES)
                    continue

            if resp.status in RESOURCE_GONE_CODES:
                raise ResourceGoneError()

            # Map invalid container configuration errors to resource gone in
            # order to force a goal state refresh, which in turn updates the
            # container-id header passed to HostGAPlugin.
            # See #1294.
            if _is_invalid_container_configuration(resp):
                raise ResourceGoneError()

            return resp

        except httpclient.HTTPException as e:
            clean_url = redact_sas_tokens_in_urls(url)
            msg = '[HTTP Failed] {0} {1} -- HttpException {2}'.format(
                method, clean_url, e)
            if _is_retry_exception(e):
                continue
            break

        except IOError as e:
            IOErrorCounter.increment(host=host, port=port)
            clean_url = redact_sas_tokens_in_urls(url)
            msg = '[HTTP Failed] {0} {1} -- IOError {2}'.format(
                method, clean_url, e)
            continue

    raise HttpError("{0} -- {1} attempts made".format(msg, attempt))
Example #12
0
def http_request(method,
                 url,
                 data,
                 headers=None,
                 max_retry=3,
                 chk_proxy=False):
    """
    Sending http request to server
    On error, sleep 10 and retry max_retry times.
    """
    host, port, secure, rel_uri = _parse_url(url)

    # Check proxy
    proxy_host, proxy_port = (None, None)
    if chk_proxy:
        proxy_host, proxy_port = get_http_proxy()

    # If httplib module is not built with ssl support. Fallback to http
    if secure and not hasattr(httpclient, "HTTPSConnection"):
        logger.warn("httplib is not built with ssl support")
        secure = False

    # If httplib module doesn't support https tunnelling. Fallback to http
    if secure and proxy_host is not None and proxy_port is not None \
            and not hasattr(httpclient.HTTPSConnection, "set_tunnel"):
        logger.warn("httplib does not support https tunnelling "
                    "(new in python 2.7)")
        secure = False

    logger.verbose("HTTP method: [{0}]", method)
    logger.verbose("HTTP host: [{0}]", host)
    logger.verbose("HTTP uri: [{0}]", rel_uri)
    logger.verbose("HTTP port: [{0}]", port)
    logger.verbose("HTTP data: [{0}]", data)
    logger.verbose("HTTP secure: [{0}]", secure)
    logger.verbose("HTTP headers: [{0}]", headers)
    logger.verbose("HTTP proxy: [{0}:{1}]", proxy_host, proxy_port)

    for retry in range(0, max_retry):
        try:
            resp = _http_request(method,
                                 host,
                                 rel_uri,
                                 port=port,
                                 data=data,
                                 secure=secure,
                                 headers=headers,
                                 proxy_host=proxy_host,
                                 proxy_port=proxy_port)
            logger.verbose("HTTP response status: [{0}]", resp.status)
            return resp
        except httpclient.HTTPException as e:
            logger.warn('HTTPException: [{0}]', e)
        except IOError as e:
            logger.warn('IOError: [{0}]', e)

        if retry < max_retry - 1:
            logger.info("Retry {0}", retry)
            time.sleep(RETRY_WAITING_INTERVAL)
        else:
            logger.error("All retries failed")

    if url is not None and len(url) > 100:
        url_log = url[0:100]  # In case the url is too long
    else:
        url_log = url
    raise HttpError("HTTPError: {0} {1}".format(method, url_log))
Example #13
0
def http_request(method,
                 url,
                 data,
                 headers=None,
                 use_proxy=False,
                 max_retry=None,
                 retry_codes=None,
                 retry_delay=DELAY_IN_SECONDS,
                 redact_data=False):

    if max_retry is None:
        max_retry = DEFAULT_RETRIES
    if retry_codes is None:
        retry_codes = RETRY_CODES
    global SECURE_WARNING_EMITTED  # pylint: disable=W0603

    host, port, secure, rel_uri = _parse_url(url)

    # Use the HTTP(S) proxy
    proxy_host, proxy_port = (None, None)
    if use_proxy and not bypass_proxy(host):
        proxy_host, proxy_port = _get_http_proxy(secure=secure)

        if proxy_host or proxy_port:
            logger.verbose("HTTP proxy: [{0}:{1}]", proxy_host, proxy_port)

    # If httplib module is not built with ssl support,
    # fallback to HTTP if allowed
    if secure and not hasattr(httpclient, "HTTPSConnection"):
        if not conf.get_allow_http():
            raise HttpError("HTTPS is unavailable and required")

        secure = False
        if not SECURE_WARNING_EMITTED:
            logger.warn("Python does not include SSL support")
            SECURE_WARNING_EMITTED = True

    # If httplib module doesn't support HTTPS tunnelling,
    # fallback to HTTP if allowed
    if secure and \
        proxy_host is not None and \
        proxy_port is not None \
        and not hasattr(httpclient.HTTPSConnection, "set_tunnel"):

        if not conf.get_allow_http():
            raise HttpError("HTTPS tunnelling is unavailable and required")

        secure = False
        if not SECURE_WARNING_EMITTED:
            logger.warn("Python does not support HTTPS tunnelling")
            SECURE_WARNING_EMITTED = True

    msg = ''
    attempt = 0
    delay = 0
    was_throttled = False

    while attempt < max_retry:
        if attempt > 0:
            # Compute the request delay
            # -- Use a fixed delay if the server ever rate-throttles the request
            #    (with a safe, minimum number of retry attempts)
            # -- Otherwise, compute a delay that is the product of the next
            #    item in the Fibonacci series and the initial delay value
            delay = THROTTLE_DELAY_IN_SECONDS \
                        if was_throttled \
                        else _compute_delay(retry_attempt=attempt,
                                            delay=retry_delay)

            logger.verbose(
                "[HTTP Retry] "
                "Attempt {0} of {1} will delay {2} seconds: {3}", attempt + 1,
                max_retry, delay, msg)

            time.sleep(delay)

        attempt += 1

        try:
            resp = _http_request(method,
                                 host,
                                 rel_uri,
                                 port=port,
                                 data=data,
                                 secure=secure,
                                 headers=headers,
                                 proxy_host=proxy_host,
                                 proxy_port=proxy_port,
                                 redact_data=redact_data)
            logger.verbose("[HTTP Response] Status Code {0}", resp.status)

            if request_failed(resp):
                if _is_retry_status(resp.status, retry_codes=retry_codes):
                    msg = '[HTTP Retry] {0} {1} -- Status Code {2}'.format(
                        method, url, resp.status)
                    # Note if throttled and ensure a safe, minimum number of
                    # retry attempts
                    if _is_throttle_status(resp.status):
                        was_throttled = True
                        max_retry = max(max_retry, THROTTLE_RETRIES)
                    continue

            # If we got a 410 (resource gone) for any reason, raise an exception. The caller will handle it by
            # forcing a goal state refresh and retrying the call.
            if resp.status in RESOURCE_GONE_CODES:
                response_error = read_response_error(resp)
                raise ResourceGoneError(response_error)

            # If we got a 400 (bad request) because the container id is invalid, it could indicate a stale goal
            # state. The caller will handle this exception by forcing a goal state refresh and retrying the call.
            if resp.status == httpclient.BAD_REQUEST:
                response_error = read_response_error(resp)
                if INVALID_CONTAINER_CONFIGURATION in response_error:
                    raise InvalidContainerError(response_error)

            return resp

        except httpclient.HTTPException as e:
            clean_url = _trim_url_parameters(url)
            msg = '[HTTP Failed] {0} {1} -- HttpException {2}'.format(
                method, clean_url, e)
            if _is_retry_exception(e):
                continue
            break

        except IOError as e:
            IOErrorCounter.increment(host=host, port=port)
            clean_url = _trim_url_parameters(url)
            msg = '[HTTP Failed] {0} {1} -- IOError {2}'.format(
                method, clean_url, e)
            continue

    raise HttpError("{0} -- {1} attempts made".format(msg, attempt))
Example #14
0
    def mock_http_get(self, url, *_, **kwargs):
        content = None
        response_headers = []

        resp = MagicMock()
        resp.status = httpclient.OK

        if "comp=versions" in url:  # wire server versions
            content = self.version_info
            self.call_counts["comp=versions"] += 1
        elif "/versions" in url:  # HostPlugin versions
            content = '["2015-09-01"]'
            self.call_counts["/versions"] += 1
        elif url.endswith("/health"):  # HostPlugin health
            content = ''
            self.call_counts["/health"] += 1
        elif "goalstate" in url:
            content = self.goal_state
            self.call_counts["goalstate"] += 1
        elif "hostingenvuri" in url:
            content = self.hosting_env
            self.call_counts["hostingenvuri"] += 1
        elif "sharedconfiguri" in url:
            content = self.shared_config
            self.call_counts["sharedconfiguri"] += 1
        elif "certificatesuri" in url:
            content = self.certs
            self.call_counts["certificatesuri"] += 1
        elif "extensionsconfiguri" in url:
            content = self.ext_conf
            self.call_counts["extensionsconfiguri"] += 1
        elif "remoteaccessinfouri" in url:
            content = self.remote_access
            self.call_counts["remoteaccessinfouri"] += 1
        elif ".vmSettings" in url or ".settings" in url:
            content = self.in_vm_artifacts_profile
            self.call_counts["in_vm_artifacts_profile"] += 1
        elif "/vmSettings" in url:
            content = self.vm_settings
            response_headers = [('ETag', self.etag)]
            self.call_counts["vm_settings"] += 1

        else:
            # A stale GoalState results in a 400 from the HostPlugin
            # for which the HTTP handler in restutil raises ResourceGoneError
            if self.emulate_stale_goal_state:
                if "extensionArtifact" in url:
                    self.emulate_stale_goal_state = False
                    self.call_counts["extensionArtifact"] += 1
                    raise ResourceGoneError()
                else:
                    raise HttpError()

            # For HostPlugin requests, replace the URL with that passed
            # via the x-ms-artifact-location header
            if "extensionArtifact" in url:
                self.call_counts["extensionArtifact"] += 1
                if "headers" not in kwargs:
                    raise ValueError(
                        "HostPlugin request is missing the HTTP headers: {0}",
                        kwargs)  # pylint: disable=raising-format-tuple
                if "x-ms-artifact-location" not in kwargs["headers"]:
                    raise ValueError(
                        "HostPlugin request is missing the x-ms-artifact-location header: {0}",
                        kwargs)  # pylint: disable=raising-format-tuple
                url = kwargs["headers"]["x-ms-artifact-location"]

            if "manifest.xml" in url:
                content = self.manifest
                self.call_counts["manifest.xml"] += 1
            elif "manifest_of_ga.xml" in url:
                content = self.ga_manifest
                self.call_counts["manifest_of_ga.xml"] += 1
            elif "ExampleHandlerLinux" in url:
                content = self.ext
                self.call_counts["ExampleHandlerLinux"] += 1
                resp.read = Mock(return_value=content)
                return resp
            elif ".vmSettings" in url or ".settings" in url:
                content = self.in_vm_artifacts_profile
                self.call_counts["in_vm_artifacts_profile"] += 1
            else:
                raise Exception("Bad url {0}".format(url))

        resp.read = Mock(return_value=content.encode("utf-8"))
        resp.getheaders = Mock(return_value=response_headers)
        return resp
Example #15
0
    def mock_http_get(self, url, *args, **kwargs):
        content = None

        resp = MagicMock()
        resp.status = httpclient.OK

        # wire server versions
        if "comp=versions" in url:
            content = self.version_info
            self.call_counts["comp=versions"] += 1

        # HostPlugin versions
        elif "/versions" in url:
            content = '["2015-09-01"]'
            self.call_counts["/versions"] += 1
        elif "goalstate" in url:
            content = self.goal_state
            self.call_counts["goalstate"] += 1
        elif "hostingenvuri" in url:
            content = self.hosting_env
            self.call_counts["hostingenvuri"] += 1
        elif "sharedconfiguri" in url:
            content = self.shared_config
            self.call_counts["sharedconfiguri"] += 1
        elif "certificatesuri" in url:
            content = self.certs
            self.call_counts["certificatesuri"] += 1
        elif "extensionsconfiguri" in url:
            content = self.ext_conf
            self.call_counts["extensionsconfiguri"] += 1

        else:
            # A stale GoalState results in a 400 from the HostPlugin
            # for which the HTTP handler in restutil raises ResourceGoneError
            if self.emulate_stale_goal_state:
                if "extensionArtifact" in url:
                    self.emulate_stale_goal_state = False
                    self.call_counts["extensionArtifact"] += 1
                    raise ResourceGoneError()
                else:
                    raise HttpError()

            # For HostPlugin requests, replace the URL with that passed
            # via the x-ms-artifact-location header
            if "extensionArtifact" in url:
                self.call_counts["extensionArtifact"] += 1
                if "headers" not in kwargs or \
                    "x-ms-artifact-location" not in kwargs["headers"]:
                    raise Exception("Bad HEADERS passed to HostPlugin: {0}",
                                    kwargs)
                url = kwargs["headers"]["x-ms-artifact-location"]

            if "manifest.xml" in url:
                content = self.manifest
                self.call_counts["manifest.xml"] += 1
            elif "manifest_of_ga.xml" in url:
                content = self.ga_manifest
                self.call_counts["manifest_of_ga.xml"] += 1
            elif "ExampleHandlerLinux" in url:
                content = self.ext
                self.call_counts["ExampleHandlerLinux"] += 1
                resp.read = Mock(return_value=content)
                return resp
            else:
                raise Exception("Bad url {0}".format(url))

        resp.read = Mock(return_value=content.encode("utf-8"))
        return resp
Example #16
0
class TestHostPlugin(AgentTestCase):
    def _init_host(self):
        test_goal_state = wire.GoalState(
            WireProtocolData(DATA_FILE).goal_state)
        host_plugin = wire.HostPluginProtocol(wireserver_url,
                                              test_goal_state.container_id,
                                              test_goal_state.role_config_name)
        self.assertTrue(host_plugin.health_service is not None)
        return host_plugin

    def _init_status_blob(self):
        wire_protocol_client = wire.WireProtocol(wireserver_url).client
        status_blob = wire_protocol_client.status_blob
        status_blob.data = faux_status
        status_blob.vm_status = restapi.VMStatus(message="Ready",
                                                 status="Ready")
        return status_blob

    def _compare_data(self, actual, expected):
        for k in iter(expected.keys()):
            if k == 'content' or k == 'requestUri':
                if actual[k] != expected[k]:
                    print("Mismatch: Actual '{0}'='{1}', "
                          "Expected '{0}'='{2}'".format(
                              k, actual[k], expected[k]))
                    return False
            elif k == 'headers':
                for h in expected['headers']:
                    if not (h in actual['headers']):
                        print("Missing Header: '{0}'".format(h))
                        return False
            else:
                print("Unexpected Key: '{0}'".format(k))
                return False
        return True

    def _hostplugin_data(self, blob_headers, content=None):
        headers = []
        for name in iter(blob_headers.keys()):
            headers.append({
                'headerName': name,
                'headerValue': blob_headers[name]
            })

        data = {'requestUri': sas_url, 'headers': headers}
        if not content is None:
            s = base64.b64encode(bytes(content))
            if PY_VERSION_MAJOR > 2:
                s = s.decode('utf-8')
            data['content'] = s
        return data

    def _hostplugin_headers(self, goal_state):
        return {
            'x-ms-version': '2015-09-01',
            'Content-type': 'application/json',
            'x-ms-containerid': goal_state.container_id,
            'x-ms-host-config-name': goal_state.role_config_name
        }

    def _validate_hostplugin_args(self, args, goal_state, exp_method, exp_url,
                                  exp_data):
        args, kwargs = args
        self.assertEqual(exp_method, args[0])
        self.assertEqual(exp_url, args[1])
        self.assertTrue(self._compare_data(json.loads(args[2]), exp_data))

        headers = kwargs['headers']
        self.assertEqual(headers['x-ms-containerid'], goal_state.container_id)
        self.assertEqual(headers['x-ms-host-config-name'],
                         goal_state.role_config_name)

    @patch(
        "azurelinuxagent.common.protocol.healthservice.HealthService.report_host_plugin_versions"
    )
    @patch("azurelinuxagent.ga.update.restutil.http_get")
    @patch("azurelinuxagent.common.event.report_event")
    def assert_ensure_initialized(self, patch_event, patch_http_get,
                                  patch_report_health, response_body,
                                  response_status_code, should_initialize,
                                  should_report_healthy):

        host = hostplugin.HostPluginProtocol(endpoint='ws',
                                             container_id='cid',
                                             role_config_name='rcf')

        host.is_initialized = False
        patch_http_get.return_value = MockResponse(
            body=response_body,
            reason='reason',
            status_code=response_status_code)
        return_value = host.ensure_initialized()

        self.assertEqual(return_value, host.is_available)
        self.assertEqual(should_initialize, host.is_initialized)

        self.assertEqual(1, patch_event.call_count)
        self.assertEqual('InitializeHostPlugin', patch_event.call_args[0][0])

        self.assertEqual(should_initialize,
                         patch_event.call_args[1]['is_success'])
        self.assertEqual(1, patch_report_health.call_count)

        self.assertEqual(should_report_healthy,
                         patch_report_health.call_args[1]['is_healthy'])

        actual_response = patch_report_health.call_args[1]['response']
        if should_initialize:
            self.assertEqual('', actual_response)
        else:
            self.assertTrue('HTTP Failed' in actual_response)
            self.assertTrue(response_body in actual_response)
            self.assertTrue(ustr(response_status_code) in actual_response)

    def test_ensure_initialized(self):
        """
        Test calls to ensure_initialized
        """
        self.assert_ensure_initialized(response_body=api_versions,
                                       response_status_code=200,
                                       should_initialize=True,
                                       should_report_healthy=True)

        self.assert_ensure_initialized(response_body='invalid ip',
                                       response_status_code=400,
                                       should_initialize=False,
                                       should_report_healthy=True)

        self.assert_ensure_initialized(response_body='generic bad request',
                                       response_status_code=400,
                                       should_initialize=False,
                                       should_report_healthy=True)

        self.assert_ensure_initialized(response_body='resource gone',
                                       response_status_code=410,
                                       should_initialize=False,
                                       should_report_healthy=True)

        self.assert_ensure_initialized(response_body='generic error',
                                       response_status_code=500,
                                       should_initialize=False,
                                       should_report_healthy=False)

        self.assert_ensure_initialized(response_body='upstream error',
                                       response_status_code=502,
                                       should_initialize=False,
                                       should_report_healthy=True)

    @patch(
        "azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol.ensure_initialized",
        return_value=True)
    @patch("azurelinuxagent.common.protocol.wire.StatusBlob.upload",
           return_value=False)
    @patch(
        "azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol._put_page_blob_status"
    )
    @patch("azurelinuxagent.common.protocol.wire.WireClient.update_goal_state")
    def test_default_channel(self, patch_update, patch_put, patch_upload, _):
        """
        Status now defaults to HostPlugin. Validate that any errors on the public
        channel are ignored.  Validate that the default channel is never changed
        as part of status upload.
        """
        test_goal_state = wire.GoalState(
            WireProtocolData(DATA_FILE).goal_state)
        status = restapi.VMStatus(status="Ready",
                                  message="Guest Agent is running")

        wire_protocol_client = wire.WireProtocol(wireserver_url).client
        wire_protocol_client.get_goal_state = Mock(
            return_value=test_goal_state)
        wire_protocol_client.ext_conf = wire.ExtensionsConfig(None)
        wire_protocol_client.ext_conf.status_upload_blob = sas_url
        wire_protocol_client.ext_conf.status_upload_blob_type = page_blob_type
        wire_protocol_client.status_blob.set_vm_status(status)

        # act
        wire_protocol_client.upload_status_blob()

        # assert direct route is not called
        self.assertEqual(0, patch_upload.call_count, "Direct channel was used")

        # assert host plugin route is called
        self.assertEqual(1, patch_put.call_count, "Host plugin was not used")

        # assert update goal state is only called once, non-forced
        self.assertEqual(1, patch_update.call_count, "Unexpected call count")
        self.assertEqual(0, len(patch_update.call_args[1]),
                         "Unexpected parameters")

        # ensure the correct url is used
        self.assertEqual(sas_url, patch_put.call_args[0][0])

        # ensure host plugin is not set as default
        self.assertFalse(wire.HostPluginProtocol.is_default_channel())

    @patch(
        "azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol.ensure_initialized",
        return_value=True)
    @patch("azurelinuxagent.common.protocol.wire.StatusBlob.upload",
           return_value=True)
    @patch(
        "azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol._put_page_blob_status",
        side_effect=HttpError("503"))
    @patch("azurelinuxagent.common.protocol.wire.WireClient.update_goal_state")
    def test_fallback_channel_503(self, patch_update, patch_put, patch_upload,
                                  _):
        """
        When host plugin returns a 503, we should fall back to the direct channel
        """
        test_goal_state = wire.GoalState(
            WireProtocolData(DATA_FILE).goal_state)
        status = restapi.VMStatus(status="Ready",
                                  message="Guest Agent is running")

        wire_protocol_client = wire.WireProtocol(wireserver_url).client
        wire_protocol_client.get_goal_state = Mock(
            return_value=test_goal_state)
        wire_protocol_client.ext_conf = wire.ExtensionsConfig(None)
        wire_protocol_client.ext_conf.status_upload_blob = sas_url
        wire_protocol_client.ext_conf.status_upload_blob_type = page_blob_type
        wire_protocol_client.status_blob.set_vm_status(status)

        # act
        wire_protocol_client.upload_status_blob()

        # assert direct route is called
        self.assertEqual(1, patch_upload.call_count,
                         "Direct channel was not used")

        # assert host plugin route is called
        self.assertEqual(1, patch_put.call_count, "Host plugin was not used")

        # assert update goal state is only called once, non-forced
        self.assertEqual(1, patch_update.call_count,
                         "Update goal state unexpected call count")
        self.assertEqual(0, len(patch_update.call_args[1]),
                         "Update goal state unexpected call count")

        # ensure the correct url is used
        self.assertEqual(sas_url, patch_put.call_args[0][0])

        # ensure host plugin is not set as default
        self.assertFalse(wire.HostPluginProtocol.is_default_channel())

    @patch(
        "azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol.ensure_initialized",
        return_value=True)
    @patch("azurelinuxagent.common.protocol.wire.StatusBlob.upload",
           return_value=True)
    @patch(
        "azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol._put_page_blob_status",
        side_effect=ResourceGoneError("410"))
    @patch("azurelinuxagent.common.protocol.wire.WireClient.update_goal_state")
    def test_fallback_channel_410(self, patch_update, patch_put, patch_upload,
                                  _):
        """
        When host plugin returns a 410, we should force the goal state update and return
        """
        test_goal_state = wire.GoalState(
            WireProtocolData(DATA_FILE).goal_state)
        status = restapi.VMStatus(status="Ready",
                                  message="Guest Agent is running")

        wire_protocol_client = wire.WireProtocol(wireserver_url).client
        wire_protocol_client.get_goal_state = Mock(
            return_value=test_goal_state)
        wire_protocol_client.ext_conf = wire.ExtensionsConfig(None)
        wire_protocol_client.ext_conf.status_upload_blob = sas_url
        wire_protocol_client.ext_conf.status_upload_blob_type = page_blob_type
        wire_protocol_client.status_blob.set_vm_status(status)

        # act
        wire_protocol_client.upload_status_blob()

        # assert direct route is not called
        self.assertEqual(0, patch_upload.call_count, "Direct channel was used")

        # assert host plugin route is called
        self.assertEqual(1, patch_put.call_count, "Host plugin was not used")

        # assert update goal state is called twice, forced=True on the second
        self.assertEqual(2, patch_update.call_count,
                         "Update goal state unexpected call count")
        self.assertEqual(1, len(patch_update.call_args[1]),
                         "Update goal state unexpected call count")
        self.assertTrue(patch_update.call_args[1]['forced'],
                        "Update goal state unexpected call count")

        # ensure the correct url is used
        self.assertEqual(sas_url, patch_put.call_args[0][0])

        # ensure host plugin is not set as default
        self.assertFalse(wire.HostPluginProtocol.is_default_channel())

    @patch(
        "azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol.ensure_initialized",
        return_value=True)
    @patch("azurelinuxagent.common.protocol.wire.StatusBlob.upload",
           return_value=False)
    @patch(
        "azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol._put_page_blob_status",
        side_effect=HttpError("500"))
    @patch("azurelinuxagent.common.protocol.wire.WireClient.update_goal_state")
    def test_fallback_channel_failure(self, patch_update, patch_put,
                                      patch_upload, _):
        """
        When host plugin returns a 500, and direct fails, we should raise a ProtocolError
        """
        test_goal_state = wire.GoalState(
            WireProtocolData(DATA_FILE).goal_state)
        status = restapi.VMStatus(status="Ready",
                                  message="Guest Agent is running")

        wire_protocol_client = wire.WireProtocol(wireserver_url).client
        wire_protocol_client.get_goal_state = Mock(
            return_value=test_goal_state)
        wire_protocol_client.ext_conf = wire.ExtensionsConfig(None)
        wire_protocol_client.ext_conf.status_upload_blob = sas_url
        wire_protocol_client.ext_conf.status_upload_blob_type = page_blob_type
        wire_protocol_client.status_blob.set_vm_status(status)

        # act
        self.assertRaises(wire.ProtocolError,
                          wire_protocol_client.upload_status_blob)

        # assert direct route is not called
        self.assertEqual(1, patch_upload.call_count,
                         "Direct channel was not used")

        # assert host plugin route is called
        self.assertEqual(1, patch_put.call_count, "Host plugin was not used")

        # assert update goal state is called twice, forced=True on the second
        self.assertEqual(1, patch_update.call_count,
                         "Update goal state unexpected call count")
        self.assertEqual(0, len(patch_update.call_args[1]),
                         "Update goal state unexpected call count")

        # ensure the correct url is used
        self.assertEqual(sas_url, patch_put.call_args[0][0])

        # ensure host plugin is not set as default
        self.assertFalse(wire.HostPluginProtocol.is_default_channel())

    @patch("azurelinuxagent.common.protocol.wire.WireClient.update_goal_state")
    @patch("azurelinuxagent.common.event.add_event")
    def test_put_status_error_reporting(self, patch_add_event, _):
        """
        Validate the telemetry when uploading status fails
        """
        test_goal_state = wire.GoalState(
            WireProtocolData(DATA_FILE).goal_state)
        status = restapi.VMStatus(status="Ready",
                                  message="Guest Agent is running")
        wire.HostPluginProtocol.set_default_channel(False)
        with patch.object(wire.StatusBlob, "upload", return_value=False):
            wire_protocol_client = wire.WireProtocol(wireserver_url).client
            wire_protocol_client.get_goal_state = Mock(
                return_value=test_goal_state)
            wire_protocol_client.ext_conf = wire.ExtensionsConfig(None)
            wire_protocol_client.ext_conf.status_upload_blob = sas_url
            wire_protocol_client.status_blob.set_vm_status(status)
            put_error = wire.HttpError("put status http error")
            with patch.object(restutil, "http_put",
                              side_effect=put_error) as patch_http_put:
                with patch.object(wire.HostPluginProtocol,
                                  "ensure_initialized",
                                  return_value=True):
                    self.assertRaises(wire.ProtocolError,
                                      wire_protocol_client.upload_status_blob)

                    # The agent tries to upload via HostPlugin and that fails due to
                    # http_put having a side effect of "put_error"
                    #
                    # The agent tries to upload using a direct connection, and that succeeds.
                    self.assertEqual(
                        1, wire_protocol_client.status_blob.upload.call_count)
                    # The agent never touches the default protocol is this code path, so no change.
                    self.assertFalse(
                        wire.HostPluginProtocol.is_default_channel())
                    # The agent never logs telemetry event for direct fallback
                    self.assertEqual(1, patch_add_event.call_count)
                    self.assertEqual('ReportStatus',
                                     patch_add_event.call_args[1]['op'])
                    self.assertTrue('Falling back to direct' in
                                    patch_add_event.call_args[1]['message'])
                    self.assertEqual(
                        True, patch_add_event.call_args[1]['is_success'])

    def test_validate_http_request(self):
        """Validate correct set of data is sent to HostGAPlugin when reporting VM status"""

        wire_protocol_client = wire.WireProtocol(wireserver_url).client
        test_goal_state = wire.GoalState(
            WireProtocolData(DATA_FILE).goal_state)

        status_blob = wire_protocol_client.status_blob
        status_blob.data = faux_status
        status_blob.vm_status = restapi.VMStatus(message="Ready",
                                                 status="Ready")

        exp_method = 'PUT'
        exp_url = hostplugin_status_url
        exp_data = self._hostplugin_data(
            status_blob.get_block_blob_headers(len(faux_status)),
            bytearray(faux_status, encoding='utf-8'))

        with patch.object(restutil, "http_request") as patch_http:
            patch_http.return_value = Mock(status=httpclient.OK)

            wire_protocol_client.get_goal_state = Mock(
                return_value=test_goal_state)
            plugin = wire_protocol_client.get_host_plugin()

            with patch.object(plugin, 'get_api_versions') as patch_api:
                patch_api.return_value = API_VERSION
                plugin.put_vm_status(status_blob, sas_url, block_blob_type)

                self.assertTrue(patch_http.call_count == 2)

                # first call is to host plugin
                self._validate_hostplugin_args(patch_http.call_args_list[0],
                                               test_goal_state, exp_method,
                                               exp_url, exp_data)

                # second call is to health service
                self.assertEqual('POST', patch_http.call_args_list[1][0][0])
                self.assertEqual(health_service_url,
                                 patch_http.call_args_list[1][0][1])

    @patch("azurelinuxagent.common.protocol.wire.WireClient.update_goal_state")
    def test_no_fallback(self, _):
        """
        Validate fallback to upload status using HostGAPlugin is not happening
        when status reporting via default method is successful
        """
        vmstatus = restapi.VMStatus(message="Ready", status="Ready")
        with patch.object(wire.HostPluginProtocol,
                          "put_vm_status") as patch_put:
            with patch.object(wire.StatusBlob, "upload") as patch_upload:
                patch_upload.return_value = True
                wire_protocol_client = wire.WireProtocol(wireserver_url).client
                wire_protocol_client.ext_conf = wire.ExtensionsConfig(None)
                wire_protocol_client.ext_conf.status_upload_blob = sas_url
                wire_protocol_client.status_blob.vm_status = vmstatus
                wire_protocol_client.upload_status_blob()
                self.assertTrue(patch_put.call_count == 0,
                                "Fallback was engaged")

    def test_validate_block_blob(self):
        """Validate correct set of data is sent to HostGAPlugin when reporting VM status"""
        wire_protocol_client = wire.WireProtocol(wireserver_url).client
        test_goal_state = wire.GoalState(
            WireProtocolData(DATA_FILE).goal_state)

        host_client = wire.HostPluginProtocol(wireserver_url,
                                              test_goal_state.container_id,
                                              test_goal_state.role_config_name)
        self.assertFalse(host_client.is_initialized)
        self.assertTrue(host_client.api_versions is None)
        self.assertTrue(host_client.health_service is not None)

        status_blob = wire_protocol_client.status_blob
        status_blob.data = faux_status
        status_blob.type = block_blob_type
        status_blob.vm_status = restapi.VMStatus(message="Ready",
                                                 status="Ready")

        exp_method = 'PUT'
        exp_url = hostplugin_status_url
        exp_data = self._hostplugin_data(
            status_blob.get_block_blob_headers(len(faux_status)),
            bytearray(faux_status, encoding='utf-8'))

        with patch.object(restutil, "http_request") as patch_http:
            patch_http.return_value = Mock(status=httpclient.OK)

            with patch.object(wire.HostPluginProtocol,
                              "get_api_versions") as patch_get:
                patch_get.return_value = api_versions
                host_client.put_vm_status(status_blob, sas_url)

                self.assertTrue(patch_http.call_count == 2)

                # first call is to host plugin
                self._validate_hostplugin_args(patch_http.call_args_list[0],
                                               test_goal_state, exp_method,
                                               exp_url, exp_data)

                # second call is to health service
                self.assertEqual('POST', patch_http.call_args_list[1][0][0])
                self.assertEqual(health_service_url,
                                 patch_http.call_args_list[1][0][1])

    def test_validate_page_blobs(self):
        """Validate correct set of data is sent for page blobs"""
        wire_protocol_client = wire.WireProtocol(wireserver_url).client
        test_goal_state = wire.GoalState(
            WireProtocolData(DATA_FILE).goal_state)

        host_client = wire.HostPluginProtocol(wireserver_url,
                                              test_goal_state.container_id,
                                              test_goal_state.role_config_name)

        self.assertFalse(host_client.is_initialized)
        self.assertTrue(host_client.api_versions is None)

        status_blob = wire_protocol_client.status_blob
        status_blob.data = faux_status
        status_blob.type = page_blob_type
        status_blob.vm_status = restapi.VMStatus(message="Ready",
                                                 status="Ready")

        exp_method = 'PUT'
        exp_url = hostplugin_status_url

        page_status = bytearray(status_blob.data, encoding='utf-8')
        page_size = int((len(page_status) + 511) / 512) * 512
        page_status = bytearray(status_blob.data.ljust(page_size),
                                encoding='utf-8')
        page = bytearray(page_size)
        page[0:page_size] = page_status[0:len(page_status)]
        mock_response = MockResponse('', httpclient.OK)

        with patch.object(restutil, "http_request",
                          return_value=mock_response) as patch_http:
            with patch.object(wire.HostPluginProtocol,
                              "get_api_versions") as patch_get:
                patch_get.return_value = api_versions
                host_client.put_vm_status(status_blob, sas_url)

                self.assertTrue(patch_http.call_count == 3)

                # first call is to host plugin
                exp_data = self._hostplugin_data(
                    status_blob.get_page_blob_create_headers(page_size))
                self._validate_hostplugin_args(patch_http.call_args_list[0],
                                               test_goal_state, exp_method,
                                               exp_url, exp_data)

                # second call is to health service
                self.assertEqual('POST', patch_http.call_args_list[1][0][0])
                self.assertEqual(health_service_url,
                                 patch_http.call_args_list[1][0][1])

                # last call is to host plugin
                exp_data = self._hostplugin_data(
                    status_blob.get_page_blob_page_headers(0, page_size), page)
                exp_data['requestUri'] += "?comp=page"
                self._validate_hostplugin_args(patch_http.call_args_list[2],
                                               test_goal_state, exp_method,
                                               exp_url, exp_data)

    def test_validate_get_extension_artifacts(self):
        test_goal_state = wire.GoalState(
            WireProtocolData(DATA_FILE).goal_state)
        expected_url = hostplugin.URI_FORMAT_GET_EXTENSION_ARTIFACT.format(
            wireserver_url, hostplugin.HOST_PLUGIN_PORT)
        expected_headers = {
            'x-ms-version': '2015-09-01',
            "x-ms-containerid": test_goal_state.container_id,
            "x-ms-host-config-name": test_goal_state.role_config_name,
            "x-ms-artifact-location": sas_url
        }

        host_client = wire.HostPluginProtocol(wireserver_url,
                                              test_goal_state.container_id,
                                              test_goal_state.role_config_name)
        self.assertFalse(host_client.is_initialized)
        self.assertTrue(host_client.api_versions is None)
        self.assertTrue(host_client.health_service is not None)

        with patch.object(wire.HostPluginProtocol,
                          "get_api_versions",
                          return_value=api_versions) as patch_get:
            actual_url, actual_headers = host_client.get_artifact_request(
                sas_url)
            self.assertTrue(host_client.is_initialized)
            self.assertFalse(host_client.api_versions is None)
            self.assertEqual(expected_url, actual_url)
            for k in expected_headers:
                self.assertTrue(k in actual_headers)
                self.assertEqual(expected_headers[k], actual_headers[k])

    @patch("azurelinuxagent.common.utils.restutil.http_get")
    def test_health(self, patch_http_get):
        host_plugin = self._init_host()

        patch_http_get.return_value = MockResponse('', 200)
        result = host_plugin.get_health()
        self.assertEqual(1, patch_http_get.call_count)
        self.assertTrue(result)

        patch_http_get.return_value = MockResponse('', 500)
        result = host_plugin.get_health()
        self.assertFalse(result)

        patch_http_get.side_effect = IOError('client IO error')
        try:
            host_plugin.get_health()
            self.fail('IO error expected to be raised')
        except IOError:
            # expected
            pass

    @patch("azurelinuxagent.common.utils.restutil.http_get",
           return_value=MockResponse(status_code=200, body=b''))
    @patch(
        "azurelinuxagent.common.protocol.healthservice.HealthService.report_host_plugin_versions"
    )
    def test_ensure_health_service_called(self, patch_http_get,
                                          patch_report_versions):
        host_plugin = self._init_host()

        host_plugin.get_api_versions()
        self.assertEqual(1, patch_http_get.call_count)
        self.assertEqual(1, patch_report_versions.call_count)

    @patch("azurelinuxagent.common.utils.restutil.http_get")
    @patch("azurelinuxagent.common.utils.restutil.http_post")
    @patch("azurelinuxagent.common.utils.restutil.http_put")
    def test_put_status_healthy_signal(self, patch_http_put, patch_http_post,
                                       patch_http_get):
        host_plugin = self._init_host()
        status_blob = self._init_status_blob()
        # get_api_versions
        patch_http_get.return_value = MockResponse(api_versions, 200)
        # put status blob
        patch_http_put.return_value = MockResponse(None, 201)

        host_plugin.put_vm_status(status_blob=status_blob, sas_url=sas_url)
        self.assertEqual(1, patch_http_get.call_count)
        self.assertEqual(hostplugin_versions_url,
                         patch_http_get.call_args[0][0])

        self.assertEqual(2, patch_http_put.call_count)
        self.assertEqual(hostplugin_status_url,
                         patch_http_put.call_args_list[0][0][0])
        self.assertEqual(hostplugin_status_url,
                         patch_http_put.call_args_list[1][0][0])

        self.assertEqual(2, patch_http_post.call_count)

        # signal for /versions
        self.assertEqual(health_service_url,
                         patch_http_post.call_args_list[0][0][0])
        jstr = patch_http_post.call_args_list[0][0][1]
        obj = json.loads(jstr)
        self.assertEqual(1, len(obj['Observations']))
        self.assertTrue(obj['Observations'][0]['IsHealthy'])
        self.assertEqual('GuestAgentPluginVersions',
                         obj['Observations'][0]['ObservationName'])

        # signal for /status
        self.assertEqual(health_service_url,
                         patch_http_post.call_args_list[1][0][0])
        jstr = patch_http_post.call_args_list[1][0][1]
        obj = json.loads(jstr)
        self.assertEqual(1, len(obj['Observations']))
        self.assertTrue(obj['Observations'][0]['IsHealthy'])
        self.assertEqual('GuestAgentPluginStatus',
                         obj['Observations'][0]['ObservationName'])

    @patch("azurelinuxagent.common.utils.restutil.http_get")
    @patch("azurelinuxagent.common.utils.restutil.http_post")
    @patch("azurelinuxagent.common.utils.restutil.http_put")
    def test_put_status_unhealthy_signal_transient(self, patch_http_put,
                                                   patch_http_post,
                                                   patch_http_get):
        host_plugin = self._init_host()
        status_blob = self._init_status_blob()
        # get_api_versions
        patch_http_get.return_value = MockResponse(api_versions, 200)
        # put status blob
        patch_http_put.return_value = MockResponse(None, 500)

        if sys.version_info < (2, 7):
            self.assertRaises(HttpError, host_plugin.put_vm_status,
                              status_blob, sas_url)
        else:
            with self.assertRaises(HttpError):
                host_plugin.put_vm_status(status_blob=status_blob,
                                          sas_url=sas_url)

        self.assertEqual(1, patch_http_get.call_count)
        self.assertEqual(hostplugin_versions_url,
                         patch_http_get.call_args[0][0])

        self.assertEqual(1, patch_http_put.call_count)
        self.assertEqual(hostplugin_status_url, patch_http_put.call_args[0][0])

        self.assertEqual(2, patch_http_post.call_count)

        # signal for /versions
        self.assertEqual(health_service_url,
                         patch_http_post.call_args_list[0][0][0])
        jstr = patch_http_post.call_args_list[0][0][1]
        obj = json.loads(jstr)
        self.assertEqual(1, len(obj['Observations']))
        self.assertTrue(obj['Observations'][0]['IsHealthy'])
        self.assertEqual('GuestAgentPluginVersions',
                         obj['Observations'][0]['ObservationName'])

        # signal for /status
        self.assertEqual(health_service_url,
                         patch_http_post.call_args_list[1][0][0])
        jstr = patch_http_post.call_args_list[1][0][1]
        obj = json.loads(jstr)
        self.assertEqual(1, len(obj['Observations']))
        self.assertTrue(obj['Observations'][0]['IsHealthy'])
        self.assertEqual('GuestAgentPluginStatus',
                         obj['Observations'][0]['ObservationName'])

    @patch("azurelinuxagent.common.utils.restutil.http_get")
    @patch("azurelinuxagent.common.utils.restutil.http_post")
    @patch("azurelinuxagent.common.utils.restutil.http_put")
    def test_put_status_unhealthy_signal_permanent(self, patch_http_put,
                                                   patch_http_post,
                                                   patch_http_get):
        host_plugin = self._init_host()
        status_blob = self._init_status_blob()
        # get_api_versions
        patch_http_get.return_value = MockResponse(api_versions, 200)
        # put status blob
        patch_http_put.return_value = MockResponse(None, 500)

        host_plugin.status_error_state.is_triggered = Mock(return_value=True)

        if sys.version_info < (2, 7):
            self.assertRaises(HttpError, host_plugin.put_vm_status,
                              status_blob, sas_url)
        else:
            with self.assertRaises(HttpError):
                host_plugin.put_vm_status(status_blob=status_blob,
                                          sas_url=sas_url)

        self.assertEqual(1, patch_http_get.call_count)
        self.assertEqual(hostplugin_versions_url,
                         patch_http_get.call_args[0][0])

        self.assertEqual(1, patch_http_put.call_count)
        self.assertEqual(hostplugin_status_url, patch_http_put.call_args[0][0])

        self.assertEqual(2, patch_http_post.call_count)

        # signal for /versions
        self.assertEqual(health_service_url,
                         patch_http_post.call_args_list[0][0][0])
        jstr = patch_http_post.call_args_list[0][0][1]
        obj = json.loads(jstr)
        self.assertEqual(1, len(obj['Observations']))
        self.assertTrue(obj['Observations'][0]['IsHealthy'])
        self.assertEqual('GuestAgentPluginVersions',
                         obj['Observations'][0]['ObservationName'])

        # signal for /status
        self.assertEqual(health_service_url,
                         patch_http_post.call_args_list[1][0][0])
        jstr = patch_http_post.call_args_list[1][0][1]
        obj = json.loads(jstr)
        self.assertEqual(1, len(obj['Observations']))
        self.assertFalse(obj['Observations'][0]['IsHealthy'])
        self.assertEqual('GuestAgentPluginStatus',
                         obj['Observations'][0]['ObservationName'])

    @patch(
        "azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol.should_report",
        return_value=True)
    @patch(
        "azurelinuxagent.common.protocol.healthservice.HealthService.report_host_plugin_extension_artifact"
    )
    def test_report_fetch_health(self, patch_report_artifact,
                                 patch_should_report):
        host_plugin = self._init_host()
        host_plugin.report_fetch_health(uri='', is_healthy=True)
        self.assertEqual(0, patch_should_report.call_count)

        host_plugin.report_fetch_health(
            uri='http://169.254.169.254/extensionArtifact', is_healthy=True)
        self.assertEqual(0, patch_should_report.call_count)

        host_plugin.report_fetch_health(
            uri='http://168.63.129.16:32526/status', is_healthy=True)
        self.assertEqual(0, patch_should_report.call_count)

        self.assertEqual(None, host_plugin.fetch_last_timestamp)
        host_plugin.report_fetch_health(
            uri='http://168.63.129.16:32526/extensionArtifact',
            is_healthy=True)
        self.assertNotEqual(None, host_plugin.fetch_last_timestamp)
        self.assertEqual(1, patch_should_report.call_count)
        self.assertEqual(1, patch_report_artifact.call_count)

    @patch(
        "azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol.should_report",
        return_value=True)
    @patch(
        "azurelinuxagent.common.protocol.healthservice.HealthService.report_host_plugin_status"
    )
    def test_report_status_health(self, patch_report_status,
                                  patch_should_report):
        host_plugin = self._init_host()
        self.assertEqual(None, host_plugin.status_last_timestamp)
        host_plugin.report_status_health(is_healthy=True)
        self.assertNotEqual(None, host_plugin.status_last_timestamp)
        self.assertEqual(1, patch_should_report.call_count)
        self.assertEqual(1, patch_report_status.call_count)

    def test_should_report(self):
        host_plugin = self._init_host()
        error_state = ErrorState(min_timedelta=datetime.timedelta(minutes=5))
        period = datetime.timedelta(minutes=1)
        last_timestamp = None

        # first measurement at 0s, should report
        is_healthy = True
        actual = host_plugin.should_report(is_healthy, error_state,
                                           last_timestamp, period)
        self.assertEqual(True, actual)

        # second measurement at 30s, should not report
        last_timestamp = datetime.datetime.utcnow() - datetime.timedelta(
            seconds=30)
        actual = host_plugin.should_report(is_healthy, error_state,
                                           last_timestamp, period)
        self.assertEqual(False, actual)

        # third measurement at 60s, should report
        last_timestamp = datetime.datetime.utcnow() - datetime.timedelta(
            seconds=60)
        actual = host_plugin.should_report(is_healthy, error_state,
                                           last_timestamp, period)
        self.assertEqual(True, actual)

        # fourth measurement unhealthy, should report and increment counter
        is_healthy = False
        self.assertEqual(0, error_state.count)
        actual = host_plugin.should_report(is_healthy, error_state,
                                           last_timestamp, period)
        self.assertEqual(1, error_state.count)
        self.assertEqual(True, actual)

        # fifth measurement, should not report and reset counter
        is_healthy = True
        last_timestamp = datetime.datetime.utcnow() - datetime.timedelta(
            seconds=30)
        self.assertEqual(1, error_state.count)
        actual = host_plugin.should_report(is_healthy, error_state,
                                           last_timestamp, period)
        self.assertEqual(0, error_state.count)
        self.assertEqual(False, actual)
    def test_reporting(self, patch_post, patch_add_event):
        health_service = HealthService('endpoint')
        health_service.report_host_plugin_status(is_healthy=True, response='response')
        self.assertEqual(1, patch_post.call_count)
        self.assertEqual(0, patch_add_event.call_count)
        self.assert_observation(call_args=patch_post.call_args,
                                name=HealthService.HOST_PLUGIN_STATUS_OBSERVATION_NAME,
                                is_healthy=True,
                                value='response',
                                description='')
        self.assertEqual(0, len(health_service.observations))

        health_service.report_host_plugin_status(is_healthy=False, response='error')
        self.assertEqual(2, patch_post.call_count)
        self.assertEqual(1, patch_add_event.call_count)
        self.assert_telemetry(call_args=patch_add_event.call_args, response='error')
        self.assert_observation(call_args=patch_post.call_args,
                                name=HealthService.HOST_PLUGIN_STATUS_OBSERVATION_NAME,
                                is_healthy=False,
                                value='error',
                                description='')
        self.assertEqual(0, len(health_service.observations))

        health_service.report_host_plugin_extension_artifact(is_healthy=True, source='source', response='response')
        self.assertEqual(3, patch_post.call_count)
        self.assertEqual(1, patch_add_event.call_count)
        self.assert_observation(call_args=patch_post.call_args,
                                name=HealthService.HOST_PLUGIN_ARTIFACT_OBSERVATION_NAME,
                                is_healthy=True,
                                value='response',
                                description='source')
        self.assertEqual(0, len(health_service.observations))

        health_service.report_host_plugin_extension_artifact(is_healthy=False, source='source', response='response')
        self.assertEqual(4, patch_post.call_count)
        self.assertEqual(2, patch_add_event.call_count)
        self.assert_telemetry(call_args=patch_add_event.call_args, response='response')
        self.assert_observation(call_args=patch_post.call_args,
                                name=HealthService.HOST_PLUGIN_ARTIFACT_OBSERVATION_NAME,
                                is_healthy=False,
                                value='response',
                                description='source')
        self.assertEqual(0, len(health_service.observations))

        health_service.report_host_plugin_heartbeat(is_healthy=True)
        self.assertEqual(5, patch_post.call_count)
        self.assertEqual(2, patch_add_event.call_count)
        self.assert_observation(call_args=patch_post.call_args,
                                name=HealthService.HOST_PLUGIN_HEARTBEAT_OBSERVATION_NAME,
                                is_healthy=True,
                                value='',
                                description='')
        self.assertEqual(0, len(health_service.observations))

        health_service.report_host_plugin_heartbeat(is_healthy=False)
        self.assertEqual(3, patch_add_event.call_count)
        self.assert_telemetry(call_args=patch_add_event.call_args)
        self.assertEqual(6, patch_post.call_count)
        self.assert_observation(call_args=patch_post.call_args,
                                name=HealthService.HOST_PLUGIN_HEARTBEAT_OBSERVATION_NAME,
                                is_healthy=False,
                                value='',
                                description='')
        self.assertEqual(0, len(health_service.observations))

        health_service.report_host_plugin_versions(is_healthy=True, response='response')
        self.assertEqual(7, patch_post.call_count)
        self.assertEqual(3, patch_add_event.call_count)
        self.assert_observation(call_args=patch_post.call_args,
                                name=HealthService.HOST_PLUGIN_VERSIONS_OBSERVATION_NAME,
                                is_healthy=True,
                                value='response',
                                description='')
        self.assertEqual(0, len(health_service.observations))

        health_service.report_host_plugin_versions(is_healthy=False, response='response')
        self.assertEqual(8, patch_post.call_count)
        self.assertEqual(4, patch_add_event.call_count)
        self.assert_telemetry(call_args=patch_add_event.call_args, response='response')
        self.assert_observation(call_args=patch_post.call_args,
                                name=HealthService.HOST_PLUGIN_VERSIONS_OBSERVATION_NAME,
                                is_healthy=False,
                                value='response',
                                description='')
        self.assertEqual(0, len(health_service.observations))

        patch_post.side_effect = HttpError()
        health_service.report_host_plugin_versions(is_healthy=True, response='')

        self.assertEqual(9, patch_post.call_count)
        self.assertEqual(4, patch_add_event.call_count)
        self.assertEqual(0, len(health_service.observations))
Example #18
0
class TestHostPlugin(HttpRequestPredicates, AgentTestCase):  # pylint: disable=too-many-public-methods
    def _init_host(self):
        with mock_wire_protocol(DATA_FILE) as protocol:
            test_goal_state = protocol.client.get_goal_state()
            host_plugin = wire.HostPluginProtocol(
                wireserver_url, test_goal_state.container_id,
                test_goal_state.role_config_name)
            self.assertTrue(host_plugin.health_service is not None)
            return host_plugin

    def _init_status_blob(self):
        wire_protocol_client = wire.WireProtocol(wireserver_url).client
        status_blob = wire_protocol_client.status_blob
        status_blob.data = faux_status
        status_blob.vm_status = restapi.VMStatus(message="Ready",
                                                 status="Ready")
        return status_blob

    def _relax_timestamp(self, headers):
        new_headers = []

        for header in headers:
            header_value = header['headerValue']
            if header['headerName'] == 'x-ms-date':
                timestamp = header['headerValue']
                header_value = timestamp[:timestamp.rfind(":")]

            new_header = {header['headerName']: header_value}
            new_headers.append(new_header)

        return new_headers

    def _compare_data(self, actual, expected):
        # Remove seconds from the timestamps for testing purposes, that level or granularity introduces test flakiness
        actual['headers'] = self._relax_timestamp(actual['headers'])
        expected['headers'] = self._relax_timestamp(expected['headers'])

        for k in iter(expected.keys()):
            if k == 'content' or k == 'requestUri':  # pylint: disable=consider-using-in
                if actual[k] != expected[k]:
                    print("Mismatch: Actual '{0}'='{1}', "
                          "Expected '{0}'='{2}'".format(
                              k, actual[k], expected[k]))
                    return False
            elif k == 'headers':
                for h in expected['headers']:  # pylint: disable=invalid-name
                    if not (h in actual['headers']):
                        print("Missing Header: '{0}'".format(h))
                        return False
            else:
                print("Unexpected Key: '{0}'".format(k))
                return False
        return True

    def _hostplugin_data(self, blob_headers, content=None):
        headers = []
        for name in iter(blob_headers.keys()):
            headers.append({
                'headerName': name,
                'headerValue': blob_headers[name]
            })

        data = {'requestUri': sas_url, 'headers': headers}
        if not content is None:
            s = base64.b64encode(bytes(content))  # pylint: disable=invalid-name
            if PY_VERSION_MAJOR > 2:
                s = s.decode('utf-8')  # pylint: disable=invalid-name
            data['content'] = s
        return data

    def _hostplugin_headers(self, goal_state):
        return {
            'x-ms-version': '2015-09-01',
            'Content-type': 'application/json',
            'x-ms-containerid': goal_state.container_id,
            'x-ms-host-config-name': goal_state.role_config_name
        }

    def _validate_hostplugin_args(self, args, goal_state, exp_method, exp_url,
                                  exp_data):  # pylint: disable=too-many-arguments
        args, kwargs = args
        self.assertEqual(exp_method, args[0])
        self.assertEqual(exp_url, args[1])
        self.assertTrue(self._compare_data(json.loads(args[2]), exp_data))

        headers = kwargs['headers']
        self.assertEqual(headers['x-ms-containerid'], goal_state.container_id)
        self.assertEqual(headers['x-ms-host-config-name'],
                         goal_state.role_config_name)

    @staticmethod
    @contextlib.contextmanager
    def create_mock_protocol():
        with mock_wire_protocol(DATA_FILE_NO_EXT) as protocol:
            # These tests use mock wire data that don't have any extensions (extension config will be empty).
            # Populate the upload blob and set an initial empty status before returning the protocol.
            ext_conf = protocol.client._goal_state.ext_conf  # pylint: disable=protected-access
            ext_conf.status_upload_blob = sas_url
            ext_conf.status_upload_blob_type = page_blob_type

            status = restapi.VMStatus(status="Ready",
                                      message="Guest Agent is running")
            protocol.client.status_blob.set_vm_status(status)

            # Also, they mock WireClient.update_goal_state() to verify how it is called
            protocol.client.update_goal_state = Mock()

            yield protocol

    @patch(
        "azurelinuxagent.common.protocol.healthservice.HealthService.report_host_plugin_versions"
    )
    @patch("azurelinuxagent.ga.update.restutil.http_get")
    @patch("azurelinuxagent.common.protocol.hostplugin.add_event")
    def assert_ensure_initialized(
            self,
            patch_event,
            patch_http_get,
            patch_report_health,  # pylint: disable=too-many-arguments
            response_body,
            response_status_code,
            should_initialize,
            should_report_healthy):

        host = hostplugin.HostPluginProtocol(endpoint='ws',
                                             container_id='cid',
                                             role_config_name='rcf')

        host.is_initialized = False
        patch_http_get.return_value = MockResponse(
            body=response_body,
            reason='reason',
            status_code=response_status_code)
        return_value = host.ensure_initialized()

        self.assertEqual(return_value, host.is_available)
        self.assertEqual(should_initialize, host.is_initialized)

        self.assertEqual(1, patch_event.call_count)
        self.assertEqual('InitializeHostPlugin', patch_event.call_args[0][0])

        self.assertEqual(should_initialize,
                         patch_event.call_args[1]['is_success'])
        self.assertEqual(1, patch_report_health.call_count)

        self.assertEqual(should_report_healthy,
                         patch_report_health.call_args[1]['is_healthy'])

        actual_response = patch_report_health.call_args[1]['response']
        if should_initialize:
            self.assertEqual('', actual_response)
        else:
            self.assertTrue('HTTP Failed' in actual_response)
            self.assertTrue(response_body in actual_response)
            self.assertTrue(ustr(response_status_code) in actual_response)

    def test_ensure_initialized(self):
        """
        Test calls to ensure_initialized
        """
        self.assert_ensure_initialized(
            response_body=api_versions,  # pylint: disable=no-value-for-parameter
            response_status_code=200,
            should_initialize=True,
            should_report_healthy=True)

        self.assert_ensure_initialized(
            response_body='invalid ip',  # pylint: disable=no-value-for-parameter
            response_status_code=400,
            should_initialize=False,
            should_report_healthy=True)

        self.assert_ensure_initialized(
            response_body='generic bad request',  # pylint: disable=no-value-for-parameter
            response_status_code=400,
            should_initialize=False,
            should_report_healthy=True)

        self.assert_ensure_initialized(
            response_body='resource gone',  # pylint: disable=no-value-for-parameter
            response_status_code=410,
            should_initialize=False,
            should_report_healthy=True)

        self.assert_ensure_initialized(
            response_body='generic error',  # pylint: disable=no-value-for-parameter
            response_status_code=500,
            should_initialize=False,
            should_report_healthy=False)

        self.assert_ensure_initialized(
            response_body='upstream error',  # pylint: disable=no-value-for-parameter
            response_status_code=502,
            should_initialize=False,
            should_report_healthy=True)

    @patch(
        "azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol.ensure_initialized",
        return_value=True)
    @patch("azurelinuxagent.common.protocol.wire.StatusBlob.upload",
           return_value=False)
    @patch(
        "azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol._put_page_blob_status"
    )
    def test_default_channel(self, patch_put, patch_upload, _):
        """
        Status now defaults to HostPlugin. Validate that any errors on the public
        channel are ignored.  Validate that the default channel is never changed
        as part of status upload.
        """
        with self.create_mock_protocol() as wire_protocol:
            wire_protocol.update_goal_state()

            # act
            wire_protocol.client.upload_status_blob()

            # assert direct route is not called
            self.assertEqual(0, patch_upload.call_count,
                             "Direct channel was used")

            # assert host plugin route is called
            self.assertEqual(1, patch_put.call_count,
                             "Host plugin was not used")

            # assert update goal state is only called once, non-forced
            self.assertEqual(1,
                             wire_protocol.client.update_goal_state.call_count,
                             "Unexpected call count")
            self.assertEqual(
                0, len(wire_protocol.client.update_goal_state.call_args[1]),
                "Unexpected parameters")

            # ensure the correct url is used
            self.assertEqual(sas_url, patch_put.call_args[0][0])

            # ensure host plugin is not set as default
            self.assertFalse(wire.HostPluginProtocol.is_default_channel())

    @patch(
        "azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol.ensure_initialized",
        return_value=True)
    @patch("azurelinuxagent.common.protocol.wire.StatusBlob.upload",
           return_value=True)
    @patch(
        "azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol._put_page_blob_status",
        side_effect=HttpError("503"))
    def test_fallback_channel_503(self, patch_put, patch_upload, _):
        """
        When host plugin returns a 503, we should fall back to the direct channel
        """
        with self.create_mock_protocol() as wire_protocol:
            wire_protocol.update_goal_state()

            # act
            wire_protocol.client.upload_status_blob()

            # assert direct route is called
            self.assertEqual(1, patch_upload.call_count,
                             "Direct channel was not used")

            # assert host plugin route is called
            self.assertEqual(1, patch_put.call_count,
                             "Host plugin was not used")

            # assert update goal state is only called once, non-forced
            self.assertEqual(1,
                             wire_protocol.client.update_goal_state.call_count,
                             "Update goal state unexpected call count")
            self.assertEqual(
                0, len(wire_protocol.client.update_goal_state.call_args[1]),
                "Update goal state unexpected call count")

            # ensure the correct url is used
            self.assertEqual(sas_url, patch_put.call_args[0][0])

            # ensure host plugin is not set as default
            self.assertFalse(wire.HostPluginProtocol.is_default_channel())

    @patch(
        "azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol.ensure_initialized",
        return_value=True)
    @patch("azurelinuxagent.common.protocol.wire.StatusBlob.upload",
           return_value=True)
    @patch(
        "azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol._put_page_blob_status",
        side_effect=ResourceGoneError("410"))
    @patch(
        "azurelinuxagent.common.protocol.wire.WireClient.update_host_plugin_from_goal_state"
    )
    def test_fallback_channel_410(self, patch_refresh_host_plugin, patch_put,
                                  patch_upload, _):
        """
        When host plugin returns a 410, we should force the goal state update and return
        """
        with self.create_mock_protocol() as wire_protocol:
            wire_protocol.update_goal_state()

            # act
            wire_protocol.client.upload_status_blob()

            # assert direct route is not called
            self.assertEqual(0, patch_upload.call_count,
                             "Direct channel was used")

            # assert host plugin route is called
            self.assertEqual(1, patch_put.call_count,
                             "Host plugin was not used")

            # assert update goal state is called with no arguments (forced=False), then update_host_plugin_from_goal_state is called
            self.assertEqual(1,
                             wire_protocol.client.update_goal_state.call_count,
                             "Update goal state unexpected call count")
            self.assertEqual(
                0, len(wire_protocol.client.update_goal_state.call_args[1]),
                "Update goal state unexpected argument count")
            self.assertEqual(1, patch_refresh_host_plugin.call_count,
                             "Refresh host plugin unexpected call count")

            # ensure the correct url is used
            self.assertEqual(sas_url, patch_put.call_args[0][0])

            # ensure host plugin is not set as default
            self.assertFalse(wire.HostPluginProtocol.is_default_channel())

    @patch(
        "azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol.ensure_initialized",
        return_value=True)
    @patch("azurelinuxagent.common.protocol.wire.StatusBlob.upload",
           return_value=False)
    @patch(
        "azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol._put_page_blob_status",
        side_effect=HttpError("500"))
    def test_fallback_channel_failure(self, patch_put, patch_upload, _):
        """
        When host plugin returns a 500, and direct fails, we should raise a ProtocolError
        """
        with self.create_mock_protocol() as wire_protocol:
            wire_protocol.update_goal_state()

            # act
            self.assertRaises(wire.ProtocolError,
                              wire_protocol.client.upload_status_blob)

            # assert direct route is not called
            self.assertEqual(1, patch_upload.call_count,
                             "Direct channel was not used")

            # assert host plugin route is called
            self.assertEqual(1, patch_put.call_count,
                             "Host plugin was not used")

            # assert update goal state is called twice, forced=True on the second
            self.assertEqual(1,
                             wire_protocol.client.update_goal_state.call_count,
                             "Update goal state unexpected call count")
            self.assertEqual(
                0, len(wire_protocol.client.update_goal_state.call_args[1]),
                "Update goal state unexpected call count")

            # ensure the correct url is used
            self.assertEqual(sas_url, patch_put.call_args[0][0])

            # ensure host plugin is not set as default
            self.assertFalse(wire.HostPluginProtocol.is_default_channel())

    @patch("azurelinuxagent.common.event.add_event")
    def test_put_status_error_reporting(self, patch_add_event):
        """
        Validate the telemetry when uploading status fails
        """
        wire.HostPluginProtocol.set_default_channel(False)
        with patch.object(wire.StatusBlob, "upload", return_value=False):
            with self.create_mock_protocol() as wire_protocol:
                wire_protocol_client = wire_protocol.client

                put_error = wire.HttpError("put status http error")
                with patch.object(restutil, "http_put", side_effect=put_error):
                    with patch.object(wire.HostPluginProtocol,
                                      "ensure_initialized",
                                      return_value=True):
                        self.assertRaises(
                            wire.ProtocolError,
                            wire_protocol_client.upload_status_blob)

                        # The agent tries to upload via HostPlugin and that fails due to
                        # http_put having a side effect of "put_error"
                        #
                        # The agent tries to upload using a direct connection, and that succeeds.
                        self.assertEqual(
                            1,
                            wire_protocol_client.status_blob.upload.call_count)  # pylint: disable=no-member
                        # The agent never touches the default protocol is this code path, so no change.
                        self.assertFalse(
                            wire.HostPluginProtocol.is_default_channel())
                        # The agent never logs telemetry event for direct fallback
                        self.assertEqual(1, patch_add_event.call_count)
                        self.assertEqual('ReportStatus',
                                         patch_add_event.call_args[1]['op'])
                        self.assertTrue(
                            'Falling back to direct' in
                            patch_add_event.call_args[1]['message'])
                        self.assertEqual(
                            True, patch_add_event.call_args[1]['is_success'])

    def test_validate_http_request_when_uploading_status(self):
        """Validate correct set of data is sent to HostGAPlugin when reporting VM status"""

        with mock_wire_protocol(DATA_FILE) as protocol:
            test_goal_state = protocol.client._goal_state  # pylint: disable=protected-access
            plugin = protocol.client.get_host_plugin()

            status_blob = protocol.client.status_blob
            status_blob.data = faux_status
            status_blob.vm_status = restapi.VMStatus(message="Ready",
                                                     status="Ready")

            exp_method = 'PUT'
            exp_url = hostplugin_status_url
            exp_data = self._hostplugin_data(
                status_blob.get_block_blob_headers(len(faux_status)),
                bytearray(faux_status, encoding='utf-8'))

            with patch.object(restutil, "http_request") as patch_http:
                patch_http.return_value = Mock(status=httpclient.OK)

                with patch.object(plugin, 'get_api_versions') as patch_api:
                    patch_api.return_value = API_VERSION
                    plugin.put_vm_status(status_blob, sas_url, block_blob_type)

                    self.assertTrue(patch_http.call_count == 2)

                    # first call is to host plugin
                    self._validate_hostplugin_args(
                        patch_http.call_args_list[0], test_goal_state,
                        exp_method, exp_url, exp_data)

                    # second call is to health service
                    self.assertEqual('POST',
                                     patch_http.call_args_list[1][0][0])
                    self.assertEqual(health_service_url,
                                     patch_http.call_args_list[1][0][1])

    def test_validate_block_blob(self):
        with mock_wire_protocol(DATA_FILE) as protocol:
            test_goal_state = protocol.client._goal_state  # pylint: disable=protected-access

            host_client = wire.HostPluginProtocol(
                wireserver_url, test_goal_state.container_id,
                test_goal_state.role_config_name)
            self.assertFalse(host_client.is_initialized)
            self.assertTrue(host_client.api_versions is None)
            self.assertTrue(host_client.health_service is not None)

            status_blob = protocol.client.status_blob
            status_blob.data = faux_status
            status_blob.type = block_blob_type
            status_blob.vm_status = restapi.VMStatus(message="Ready",
                                                     status="Ready")

            exp_method = 'PUT'
            exp_url = hostplugin_status_url
            exp_data = self._hostplugin_data(
                status_blob.get_block_blob_headers(len(faux_status)),
                bytearray(faux_status, encoding='utf-8'))

            with patch.object(restutil, "http_request") as patch_http:
                patch_http.return_value = Mock(status=httpclient.OK)

                with patch.object(wire.HostPluginProtocol,
                                  "get_api_versions") as patch_get:
                    patch_get.return_value = api_versions
                    host_client.put_vm_status(status_blob, sas_url)

                    self.assertTrue(patch_http.call_count == 2)

                    # first call is to host plugin
                    self._validate_hostplugin_args(
                        patch_http.call_args_list[0], test_goal_state,
                        exp_method, exp_url, exp_data)

                    # second call is to health service
                    self.assertEqual('POST',
                                     patch_http.call_args_list[1][0][0])
                    self.assertEqual(health_service_url,
                                     patch_http.call_args_list[1][0][1])

    def test_validate_page_blobs(self):
        """Validate correct set of data is sent for page blobs"""
        with mock_wire_protocol(DATA_FILE) as protocol:
            test_goal_state = protocol.client._goal_state  # pylint: disable=protected-access

            host_client = wire.HostPluginProtocol(
                wireserver_url, test_goal_state.container_id,
                test_goal_state.role_config_name)

            self.assertFalse(host_client.is_initialized)
            self.assertTrue(host_client.api_versions is None)

            status_blob = protocol.client.status_blob
            status_blob.data = faux_status
            status_blob.type = page_blob_type
            status_blob.vm_status = restapi.VMStatus(message="Ready",
                                                     status="Ready")

            exp_method = 'PUT'
            exp_url = hostplugin_status_url

            page_status = bytearray(status_blob.data, encoding='utf-8')
            page_size = int((len(page_status) + 511) / 512) * 512
            page_status = bytearray(status_blob.data.ljust(page_size),
                                    encoding='utf-8')
            page = bytearray(page_size)
            page[0:page_size] = page_status[0:len(page_status)]
            mock_response = MockResponse('', httpclient.OK)

            with patch.object(restutil,
                              "http_request",
                              return_value=mock_response) as patch_http:
                with patch.object(wire.HostPluginProtocol,
                                  "get_api_versions") as patch_get:
                    patch_get.return_value = api_versions
                    host_client.put_vm_status(status_blob, sas_url)

                    self.assertTrue(patch_http.call_count == 3)

                    # first call is to host plugin
                    exp_data = self._hostplugin_data(
                        status_blob.get_page_blob_create_headers(page_size))
                    self._validate_hostplugin_args(
                        patch_http.call_args_list[0], test_goal_state,
                        exp_method, exp_url, exp_data)

                    # second call is to health service
                    self.assertEqual('POST',
                                     patch_http.call_args_list[1][0][0])
                    self.assertEqual(health_service_url,
                                     patch_http.call_args_list[1][0][1])

                    # last call is to host plugin
                    exp_data = self._hostplugin_data(
                        status_blob.get_page_blob_page_headers(0, page_size),
                        page)
                    exp_data['requestUri'] += "?comp=page"
                    self._validate_hostplugin_args(
                        patch_http.call_args_list[2], test_goal_state,
                        exp_method, exp_url, exp_data)

    def test_validate_http_request_for_put_vm_log(self):
        def http_put_handler(url, *args, **kwargs):  # pylint: disable=inconsistent-return-statements
            if self.is_host_plugin_put_logs_request(url):
                http_put_handler.args, http_put_handler.kwargs = args, kwargs
                return MockResponse(body=b'', status_code=200)

        http_put_handler.args, http_put_handler.kwargs = [], {}

        with mock_wire_protocol(DATA_FILE,
                                http_put_handler=http_put_handler) as protocol:
            test_goal_state = protocol.client.get_goal_state()

            expected_url = hostplugin.URI_FORMAT_PUT_LOG.format(
                wireserver_url, hostplugin.HOST_PLUGIN_PORT)
            expected_headers = {
                'x-ms-version':
                '2015-09-01',
                "x-ms-containerid":
                test_goal_state.container_id,
                "x-ms-vmagentlog-deploymentid":
                test_goal_state.role_config_name.split(".")[0],
                "x-ms-client-name":
                AGENT_NAME,
                "x-ms-client-version":
                AGENT_VERSION
            }

            host_client = wire.HostPluginProtocol(
                wireserver_url, test_goal_state.container_id,
                test_goal_state.role_config_name)

            self.assertFalse(host_client.is_initialized,
                             "Host plugin should not be initialized!")

            content = b"test"
            host_client.put_vm_log(content)
            self.assertTrue(host_client.is_initialized,
                            "Host plugin is not initialized!")

            urls = protocol.get_tracked_urls()

            self.assertEqual(expected_url, urls[0], "Unexpected request URL!")
            self.assertEqual(content, http_put_handler.args[0],
                             "Unexpected content for HTTP PUT request!")

            headers = http_put_handler.kwargs['headers']
            for k in expected_headers:
                self.assertTrue(k in headers,
                                "Header {0} not found in headers!".format(k))
                self.assertEqual(expected_headers[k], headers[k],
                                 "Request headers don't match!")

            # Special check for correlation id header value, check for pattern, not exact value
            self.assertTrue("x-ms-client-correlationid" in headers.keys(),
                            "Correlation id not found in headers!")
            self.assertTrue(
                UUID_PATTERN.match(headers["x-ms-client-correlationid"]),
                "Correlation id is not in GUID form!")

    def test_put_vm_log_should_raise_an_exception_when_request_fails(self):
        def http_put_handler(url, *args, **kwargs):  # pylint: disable=inconsistent-return-statements
            if self.is_host_plugin_put_logs_request(url):
                http_put_handler.args, http_put_handler.kwargs = args, kwargs
                return MockResponse(body=ustr('Gone'), status_code=410)

        http_put_handler.args, http_put_handler.kwargs = [], {}

        with mock_wire_protocol(DATA_FILE,
                                http_put_handler=http_put_handler) as protocol:
            test_goal_state = protocol.client.get_goal_state()

            host_client = wire.HostPluginProtocol(
                wireserver_url, test_goal_state.container_id,
                test_goal_state.role_config_name)

            self.assertFalse(host_client.is_initialized,
                             "Host plugin should not be initialized!")

            with self.assertRaises(HttpError) as context_manager:
                content = b"test"
                host_client.put_vm_log(content)

            self.assertIsInstance(context_manager.exception, HttpError)
            self.assertIn("410", ustr(context_manager.exception))
            self.assertIn("Gone", ustr(context_manager.exception))

    def test_validate_get_extension_artifacts(self):
        with mock_wire_protocol(DATA_FILE) as protocol:
            test_goal_state = protocol.client._goal_state  # pylint: disable=protected-access

            expected_url = hostplugin.URI_FORMAT_GET_EXTENSION_ARTIFACT.format(
                wireserver_url, hostplugin.HOST_PLUGIN_PORT)
            expected_headers = {
                'x-ms-version': '2015-09-01',
                "x-ms-containerid": test_goal_state.container_id,
                "x-ms-host-config-name": test_goal_state.role_config_name,
                "x-ms-artifact-location": sas_url
            }

            host_client = wire.HostPluginProtocol(
                wireserver_url, test_goal_state.container_id,
                test_goal_state.role_config_name)
            self.assertFalse(host_client.is_initialized)
            self.assertTrue(host_client.api_versions is None)
            self.assertTrue(host_client.health_service is not None)

            with patch.object(wire.HostPluginProtocol,
                              "get_api_versions",
                              return_value=api_versions) as patch_get:  # pylint: disable=unused-variable
                actual_url, actual_headers = host_client.get_artifact_request(
                    sas_url)
                self.assertTrue(host_client.is_initialized)
                self.assertFalse(host_client.api_versions is None)
                self.assertEqual(expected_url, actual_url)
                for k in expected_headers:
                    self.assertTrue(k in actual_headers)
                    self.assertEqual(expected_headers[k], actual_headers[k])

    def test_health(self):
        host_plugin = self._init_host()

        with patch("azurelinuxagent.common.utils.restutil.http_get"
                   ) as patch_http_get:
            patch_http_get.return_value = MockResponse('', 200)
            result = host_plugin.get_health()
            self.assertEqual(1, patch_http_get.call_count)
            self.assertTrue(result)

            patch_http_get.return_value = MockResponse('', 500)
            result = host_plugin.get_health()
            self.assertFalse(result)

            patch_http_get.side_effect = IOError('client IO error')
            try:
                host_plugin.get_health()
                self.fail('IO error expected to be raised')
            except IOError:
                # expected
                pass

    def test_ensure_health_service_called(self):
        host_plugin = self._init_host()

        with patch("azurelinuxagent.common.utils.restutil.http_get",
                   return_value=TestWireMockResponse(
                       status_code=200, body=b'')) as patch_http_get:
            with patch(
                    "azurelinuxagent.common.protocol.healthservice.HealthService.report_host_plugin_versions"
            ) as patch_report_versions:
                host_plugin.get_api_versions()
                self.assertEqual(1, patch_http_get.call_count)
                self.assertEqual(1, patch_report_versions.call_count)

    def test_put_status_healthy_signal(self):
        host_plugin = self._init_host()

        with patch("azurelinuxagent.common.utils.restutil.http_get"
                   ) as patch_http_get:
            with patch("azurelinuxagent.common.utils.restutil.http_post"
                       ) as patch_http_post:
                with patch("azurelinuxagent.common.utils.restutil.http_put"
                           ) as patch_http_put:
                    status_blob = self._init_status_blob()
                    # get_api_versions
                    patch_http_get.return_value = MockResponse(
                        api_versions, 200)
                    # put status blob
                    patch_http_put.return_value = MockResponse(None, 201)

                    host_plugin.put_vm_status(status_blob=status_blob,
                                              sas_url=sas_url)
                    self.assertEqual(1, patch_http_get.call_count)
                    self.assertEqual(hostplugin_versions_url,
                                     patch_http_get.call_args[0][0])

                    self.assertEqual(2, patch_http_put.call_count)
                    self.assertEqual(hostplugin_status_url,
                                     patch_http_put.call_args_list[0][0][0])
                    self.assertEqual(hostplugin_status_url,
                                     patch_http_put.call_args_list[1][0][0])

                    self.assertEqual(2, patch_http_post.call_count)

                    # signal for /versions
                    self.assertEqual(health_service_url,
                                     patch_http_post.call_args_list[0][0][0])
                    jstr = patch_http_post.call_args_list[0][0][1]
                    obj = json.loads(jstr)
                    self.assertEqual(1, len(obj['Observations']))
                    self.assertTrue(obj['Observations'][0]['IsHealthy'])
                    self.assertEqual('GuestAgentPluginVersions',
                                     obj['Observations'][0]['ObservationName'])

                    # signal for /status
                    self.assertEqual(health_service_url,
                                     patch_http_post.call_args_list[1][0][0])
                    jstr = patch_http_post.call_args_list[1][0][1]
                    obj = json.loads(jstr)
                    self.assertEqual(1, len(obj['Observations']))
                    self.assertTrue(obj['Observations'][0]['IsHealthy'])
                    self.assertEqual('GuestAgentPluginStatus',
                                     obj['Observations'][0]['ObservationName'])

    def test_put_status_unhealthy_signal_transient(self):
        host_plugin = self._init_host()

        with patch("azurelinuxagent.common.utils.restutil.http_get"
                   ) as patch_http_get:
            with patch("azurelinuxagent.common.utils.restutil.http_post"
                       ) as patch_http_post:
                with patch("azurelinuxagent.common.utils.restutil.http_put"
                           ) as patch_http_put:
                    status_blob = self._init_status_blob()
                    # get_api_versions
                    patch_http_get.return_value = MockResponse(
                        api_versions, 200)
                    # put status blob
                    patch_http_put.return_value = MockResponse(None, 500)

                    with self.assertRaises(HttpError):
                        host_plugin.put_vm_status(status_blob=status_blob,
                                                  sas_url=sas_url)

                    self.assertEqual(1, patch_http_get.call_count)
                    self.assertEqual(hostplugin_versions_url,
                                     patch_http_get.call_args[0][0])

                    self.assertEqual(1, patch_http_put.call_count)
                    self.assertEqual(hostplugin_status_url,
                                     patch_http_put.call_args[0][0])

                    self.assertEqual(2, patch_http_post.call_count)

                    # signal for /versions
                    self.assertEqual(health_service_url,
                                     patch_http_post.call_args_list[0][0][0])
                    jstr = patch_http_post.call_args_list[0][0][1]
                    obj = json.loads(jstr)
                    self.assertEqual(1, len(obj['Observations']))
                    self.assertTrue(obj['Observations'][0]['IsHealthy'])
                    self.assertEqual('GuestAgentPluginVersions',
                                     obj['Observations'][0]['ObservationName'])

                    # signal for /status
                    self.assertEqual(health_service_url,
                                     patch_http_post.call_args_list[1][0][0])
                    jstr = patch_http_post.call_args_list[1][0][1]
                    obj = json.loads(jstr)
                    self.assertEqual(1, len(obj['Observations']))
                    self.assertTrue(obj['Observations'][0]['IsHealthy'])
                    self.assertEqual('GuestAgentPluginStatus',
                                     obj['Observations'][0]['ObservationName'])

    def test_put_status_unhealthy_signal_permanent(self):
        host_plugin = self._init_host()

        with patch("azurelinuxagent.common.utils.restutil.http_get"
                   ) as patch_http_get:
            with patch("azurelinuxagent.common.utils.restutil.http_post"
                       ) as patch_http_post:
                with patch("azurelinuxagent.common.utils.restutil.http_put"
                           ) as patch_http_put:
                    status_blob = self._init_status_blob()
                    # get_api_versions
                    patch_http_get.return_value = MockResponse(
                        api_versions, 200)
                    # put status blob
                    patch_http_put.return_value = MockResponse(None, 500)

                    host_plugin.status_error_state.is_triggered = Mock(
                        return_value=True)

                    with self.assertRaises(HttpError):
                        host_plugin.put_vm_status(status_blob=status_blob,
                                                  sas_url=sas_url)

                    self.assertEqual(1, patch_http_get.call_count)
                    self.assertEqual(hostplugin_versions_url,
                                     patch_http_get.call_args[0][0])

                    self.assertEqual(1, patch_http_put.call_count)
                    self.assertEqual(hostplugin_status_url,
                                     patch_http_put.call_args[0][0])

                    self.assertEqual(2, patch_http_post.call_count)

                    # signal for /versions
                    self.assertEqual(health_service_url,
                                     patch_http_post.call_args_list[0][0][0])
                    jstr = patch_http_post.call_args_list[0][0][1]
                    obj = json.loads(jstr)
                    self.assertEqual(1, len(obj['Observations']))
                    self.assertTrue(obj['Observations'][0]['IsHealthy'])
                    self.assertEqual('GuestAgentPluginVersions',
                                     obj['Observations'][0]['ObservationName'])

                    # signal for /status
                    self.assertEqual(health_service_url,
                                     patch_http_post.call_args_list[1][0][0])
                    jstr = patch_http_post.call_args_list[1][0][1]
                    obj = json.loads(jstr)
                    self.assertEqual(1, len(obj['Observations']))
                    self.assertFalse(obj['Observations'][0]['IsHealthy'])
                    self.assertEqual('GuestAgentPluginStatus',
                                     obj['Observations'][0]['ObservationName'])

    @patch(
        "azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol.should_report",
        return_value=True)
    @patch(
        "azurelinuxagent.common.protocol.healthservice.HealthService.report_host_plugin_extension_artifact"
    )
    def test_report_fetch_health(self, patch_report_artifact,
                                 patch_should_report):
        host_plugin = self._init_host()
        host_plugin.report_fetch_health(uri='', is_healthy=True)
        self.assertEqual(0, patch_should_report.call_count)

        host_plugin.report_fetch_health(
            uri='http://169.254.169.254/extensionArtifact', is_healthy=True)
        self.assertEqual(0, patch_should_report.call_count)

        host_plugin.report_fetch_health(
            uri='http://168.63.129.16:32526/status', is_healthy=True)
        self.assertEqual(0, patch_should_report.call_count)

        self.assertEqual(None, host_plugin.fetch_last_timestamp)
        host_plugin.report_fetch_health(
            uri='http://168.63.129.16:32526/extensionArtifact',
            is_healthy=True)
        self.assertNotEqual(None, host_plugin.fetch_last_timestamp)
        self.assertEqual(1, patch_should_report.call_count)
        self.assertEqual(1, patch_report_artifact.call_count)

    @patch(
        "azurelinuxagent.common.protocol.hostplugin.HostPluginProtocol.should_report",
        return_value=True)
    @patch(
        "azurelinuxagent.common.protocol.healthservice.HealthService.report_host_plugin_status"
    )
    def test_report_status_health(self, patch_report_status,
                                  patch_should_report):
        host_plugin = self._init_host()
        self.assertEqual(None, host_plugin.status_last_timestamp)
        host_plugin.report_status_health(is_healthy=True)
        self.assertNotEqual(None, host_plugin.status_last_timestamp)
        self.assertEqual(1, patch_should_report.call_count)
        self.assertEqual(1, patch_report_status.call_count)

    def test_should_report(self):
        host_plugin = self._init_host()
        error_state = ErrorState(min_timedelta=datetime.timedelta(minutes=5))
        period = datetime.timedelta(minutes=1)
        last_timestamp = None

        # first measurement at 0s, should report
        is_healthy = True
        actual = host_plugin.should_report(is_healthy, error_state,
                                           last_timestamp, period)
        self.assertEqual(True, actual)

        # second measurement at 30s, should not report
        last_timestamp = datetime.datetime.utcnow() - datetime.timedelta(
            seconds=30)
        actual = host_plugin.should_report(is_healthy, error_state,
                                           last_timestamp, period)
        self.assertEqual(False, actual)

        # third measurement at 60s, should report
        last_timestamp = datetime.datetime.utcnow() - datetime.timedelta(
            seconds=60)
        actual = host_plugin.should_report(is_healthy, error_state,
                                           last_timestamp, period)
        self.assertEqual(True, actual)

        # fourth measurement unhealthy, should report and increment counter
        is_healthy = False
        self.assertEqual(0, error_state.count)
        actual = host_plugin.should_report(is_healthy, error_state,
                                           last_timestamp, period)
        self.assertEqual(1, error_state.count)
        self.assertEqual(True, actual)

        # fifth measurement, should not report and reset counter
        is_healthy = True
        last_timestamp = datetime.datetime.utcnow() - datetime.timedelta(
            seconds=30)
        self.assertEqual(1, error_state.count)
        actual = host_plugin.should_report(is_healthy, error_state,
                                           last_timestamp, period)
        self.assertEqual(0, error_state.count)
        self.assertEqual(False, actual)
 def http_post_handler(url, _, **__):
     if self.is_telemetry_request(url):
         return HttpError(test_str)
     return None
Example #20
0
def http_request(method,
                 url,
                 data,
                 headers=None,
                 use_proxy=False,
                 max_retry=DEFAULT_RETRIES,
                 retry_codes=RETRY_CODES,
                 retry_delay=SHORT_DELAY_IN_SECONDS):

    global SECURE_WARNING_EMITTED

    host, port, secure, rel_uri = _parse_url(url)

    # Use the HTTP(S) proxy
    proxy_host, proxy_port = (None, None)
    if use_proxy:
        proxy_host, proxy_port = _get_http_proxy(secure=secure)

        if proxy_host or proxy_port:
            logger.verbose("HTTP proxy: [{0}:{1}]", proxy_host, proxy_port)

    # If httplib module is not built with ssl support,
    # fallback to HTTP if allowed
    if secure and not hasattr(httpclient, "HTTPSConnection"):
        if not conf.get_allow_http():
            raise HttpError("HTTPS is unavailable and required")

        secure = False
        if not SECURE_WARNING_EMITTED:
            logger.warn("Python does not include SSL support")
            SECURE_WARNING_EMITTED = True

    # If httplib module doesn't support HTTPS tunnelling,
    # fallback to HTTP if allowed
    if secure and \
        proxy_host is not None and \
        proxy_port is not None \
        and not hasattr(httpclient.HTTPSConnection, "set_tunnel"):

        if not conf.get_allow_http():
            raise HttpError("HTTPS tunnelling is unavailable and required")

        secure = False
        if not SECURE_WARNING_EMITTED:
            logger.warn("Python does not support HTTPS tunnelling")
            SECURE_WARNING_EMITTED = True

    msg = ''
    attempt = 0
    delay = retry_delay

    while attempt < max_retry:
        if attempt > 0:
            logger.info("[HTTP Retry] Attempt {0} of {1}: {2}", attempt + 1,
                        max_retry, msg)
            time.sleep(delay)

        attempt += 1
        delay = retry_delay

        try:
            resp = _http_request(method,
                                 host,
                                 rel_uri,
                                 port=port,
                                 data=data,
                                 secure=secure,
                                 headers=headers,
                                 proxy_host=proxy_host,
                                 proxy_port=proxy_port)
            logger.verbose("[HTTP Response] Status Code {0}", resp.status)

            if request_failed(resp):
                if _is_retry_status(resp.status, retry_codes=retry_codes):
                    msg = '[HTTP Retry] HTTP {0} Status Code {1}'.format(
                        method, resp.status)
                    if _is_throttle_status(resp.status):
                        delay = LONG_DELAY_IN_SECONDS
                        logger.info("[HTTP Delay] Delay {0} seconds for " \
                                    "Status Code {1}".format(
                                        delay, resp.status))
                    continue

            if resp.status in RESOURCE_GONE_CODES:
                raise ResourceGoneError()

            return resp

        except httpclient.HTTPException as e:
            msg = '[HTTP Failed] HTTP {0} HttpException {1}'.format(method, e)
            if _is_retry_exception(e):
                continue
            break

        except IOError as e:
            msg = '[HTTP Failed] HTTP {0} IOError {1}'.format(method, e)
            continue

    raise HttpError(msg)
Example #21
0
def http_request(method,
                 url,
                 data,
                 headers=None,
                 max_retry=3,
                 chk_proxy=False):
    """
    Sending http request to server
    On error, sleep 10 and retry max_retry times.
    """
    host, port, secure, rel_uri = _parse_url(url)
    global secure_warning

    # Check proxy
    proxy_host, proxy_port = (None, None)
    if chk_proxy:
        proxy_host, proxy_port = get_http_proxy()

    # If httplib module is not built with ssl support. Fallback to http
    if secure and not hasattr(httpclient, "HTTPSConnection"):
        secure = False
        if secure_warning:
            logger.warn("httplib is not built with ssl support")
            secure_warning = False

    # If httplib module doesn't support https tunnelling. Fallback to http
    if secure and proxy_host is not None and proxy_port is not None \
            and not hasattr(httpclient.HTTPSConnection, "set_tunnel"):
        secure = False
        if secure_warning:
            logger.warn("httplib does not support https tunnelling "
                        "(new in python 2.7)")
            secure_warning = False

    if proxy_host or proxy_port:
        logger.verbose("HTTP proxy: [{0}:{1}]", proxy_host, proxy_port)

    retry_msg = ''
    log_msg = "HTTP {0}".format(method)
    for retry in range(0, max_retry):
        retry_interval = RETRY_WAITING_INTERVAL
        try:
            resp = _http_request(method,
                                 host,
                                 rel_uri,
                                 port=port,
                                 data=data,
                                 secure=secure,
                                 headers=headers,
                                 proxy_host=proxy_host,
                                 proxy_port=proxy_port)
            logger.verbose("HTTP response status: [{0}]", resp.status)
            return resp
        except httpclient.HTTPException as e:
            retry_msg = 'HTTP exception: {0} {1}'.format(log_msg, e)
            retry_interval = 5
        except IOError as e:
            retry_msg = 'IO error: {0} {1}'.format(log_msg, e)
            # error 101: network unreachable; when the adapter resets we may
            # see this transient error for a short time, retry once.
            if e.errno == 101:
                retry_interval = RETRY_WAITING_INTERVAL
                max_retry = 1
            else:
                retry_interval = 0
                max_retry = 0

        if retry < max_retry:
            logger.info("Retry [{0}/{1} - {3}]", retry + 1, max_retry,
                        retry_interval, retry_msg)
            time.sleep(retry_interval)

    raise HttpError("{0} failed".format(log_msg))