Esempio n. 1
0
    def _download(self):
        for uri in self.pkg.uris:
            if not HostPluginProtocol.is_default_channel() and self._fetch(uri.uri):
                break
            elif self.host is not None and self.host.ensure_initialized():
                if not HostPluginProtocol.is_default_channel():
                    logger.warn("Download unsuccessful, falling back to host plugin")
                else:
                    logger.verbose("Using host plugin as default channel")

                uri, headers = self.host.get_artifact_request(uri.uri, self.host.manifest_uri)
                if self._fetch(uri, headers=headers):
                    if not HostPluginProtocol.is_default_channel():
                        logger.verbose("Setting host plugin as default channel")
                        HostPluginProtocol.set_default_channel(True)
                    break
                else:
                    logger.warn("Host plugin download unsuccessful")
            else:
                logger.error("No download channels available")

        if not os.path.isfile(self.get_agent_pkg_path()):
            msg = u"Unable to download Agent {0} from any URI".format(self.name)
            add_event(
                AGENT_NAME,
                op=WALAEventOperation.Download,
                version=CURRENT_VERSION,
                is_success=False,
                message=msg)
            raise UpdateError(msg)
        return
Esempio n. 2
0
    def _download(self):
        for uri in self.pkg.uris:
            if not HostPluginProtocol.is_default_channel() and self._fetch(
                    uri.uri):
                break
            elif self.host is not None and self.host.ensure_initialized():
                if not HostPluginProtocol.is_default_channel():
                    logger.warn(
                        "Download unsuccessful, falling back to host plugin")
                else:
                    logger.verbose("Using host plugin as default channel")

                uri, headers = self.host.get_artifact_request(
                    uri.uri, self.host.manifest_uri)
                if self._fetch(uri, headers=headers):
                    if not HostPluginProtocol.is_default_channel():
                        logger.verbose(
                            "Setting host plugin as default channel")
                        HostPluginProtocol.set_default_channel(True)
                    break
                else:
                    logger.warn("Host plugin download unsuccessful")
            else:
                logger.error("No download channels available")

        if not os.path.isfile(self.get_agent_pkg_path()):
            msg = u"Unable to download Agent {0} from any URI".format(
                self.name)
            add_event(AGENT_NAME,
                      op=WALAEventOperation.Download,
                      version=CURRENT_VERSION,
                      is_success=False,
                      message=msg)
            raise UpdateError(msg)
        return
Esempio n. 3
0
 def __init__(self, endpoint):
     logger.info("Wire server endpoint:{0}", endpoint)
     self.endpoint = endpoint
     self.goal_state = None
     self.updated = None
     self.hosting_env = None
     self.shared_conf = None
     self.certs = None
     self.ext_conf = None
     self.last_request = 0
     self.req_count = 0
     self.status_blob = StatusBlob(self)
     self.host_plugin = HostPluginProtocol(self.endpoint)
Esempio n. 4
0
 def __init__(self, endpoint):
     logger.info("Wire server endpoint:{0}", endpoint)
     self.endpoint = endpoint
     self.goal_state = None
     self.updated = None
     self.hosting_env = None
     self.shared_conf = None
     self.certs = None
     self.ext_conf = None
     self.last_request = 0
     self.req_count = 0
     self.status_blob = StatusBlob(self)
     self.host_plugin = HostPluginProtocol(self.endpoint)
Esempio n. 5
0
 def _fetch(self, uri, headers=None):
     package = None
     try:
         resp = restutil.http_get(uri, chk_proxy=True, headers=headers)
         if resp.status == restutil.httpclient.OK:
             package = resp.read()
             fileutil.write_file(self.get_agent_pkg_path(),
                                 bytearray(package),
                                 asbin=True)
             logger.verbose(u"Agent {0} downloaded from {1}", self.name,
                            uri)
         else:
             logger.verbose("Fetch was unsuccessful [{0}]",
                            HostPluginProtocol.read_response_error(resp))
     except restutil.HttpError as http_error:
         logger.verbose(u"Agent {0} download from {1} failed [{2}]",
                        self.name, uri, http_error)
     return package is not None
Esempio n. 6
0
 def _fetch(self, uri, headers=None):
     package = None
     try:
         resp = restutil.http_get(uri, chk_proxy=True, headers=headers)
         if resp.status == restutil.httpclient.OK:
             package = resp.read()
             fileutil.write_file(self.get_agent_pkg_path(),
                                 bytearray(package),
                                 asbin=True)
             logger.verbose(u"Agent {0} downloaded from {1}", self.name, uri)
         else:
             logger.verbose("Fetch was unsuccessful [{0}]",
                            HostPluginProtocol.read_response_error(resp))
     except restutil.HttpError as http_error:
         logger.verbose(u"Agent {0} download from {1} failed [{2}]",
                        self.name,
                        uri,
                        http_error)
     return package is not None
Esempio n. 7
0
    def _download(self):
        uris_shuffled = self.pkg.uris
        random.shuffle(uris_shuffled)
        for uri in uris_shuffled:
            if not HostPluginProtocol.is_default_channel() and self._fetch(
                    uri.uri):
                break

            elif self.host is not None and self.host.ensure_initialized():
                if not HostPluginProtocol.is_default_channel():
                    logger.warn("Download failed, switching to host plugin")
                else:
                    logger.verbose("Using host plugin as default channel")

                uri, headers = self.host.get_artifact_request(
                    uri.uri, self.host.manifest_uri)
                try:
                    if self._fetch(uri, headers=headers, use_proxy=False):
                        if not HostPluginProtocol.is_default_channel():
                            logger.verbose(
                                "Setting host plugin as default channel")
                            HostPluginProtocol.set_default_channel(True)
                        break
                    else:
                        logger.warn("Host plugin download failed")

                # If the HostPlugin rejects the request,
                # let the error continue, but set to use the HostPlugin
                except ResourceGoneError:
                    HostPluginProtocol.set_default_channel(True)
                    raise

            else:
                logger.error("No download channels available")

        if not os.path.isfile(self.get_agent_pkg_path()):
            msg = u"Unable to download Agent {0} from any URI".format(
                self.name)
            add_event(AGENT_NAME,
                      op=WALAEventOperation.Download,
                      version=CURRENT_VERSION,
                      is_success=False,
                      message=msg)
            raise UpdateError(msg)
Esempio n. 8
0
    def _download(self):
        uris_shuffled = self.pkg.uris
        random.shuffle(uris_shuffled)
        for uri in uris_shuffled:
            if not HostPluginProtocol.is_default_channel() and self._fetch(uri.uri):
                break

            elif self.host is not None and self.host.ensure_initialized():
                if not HostPluginProtocol.is_default_channel():
                    logger.warn("Download failed, switching to host plugin")
                else:
                    logger.verbose("Using host plugin as default channel")

                uri, headers = self.host.get_artifact_request(uri.uri, self.host.manifest_uri)
                try:
                    if self._fetch(uri, headers=headers, use_proxy=False):
                        if not HostPluginProtocol.is_default_channel():
                            logger.verbose("Setting host plugin as default channel")
                            HostPluginProtocol.set_default_channel(True)
                        break
                    else:
                        logger.warn("Host plugin download failed")

                # If the HostPlugin rejects the request,
                # let the error continue, but set to use the HostPlugin
                except ResourceGoneError:
                    HostPluginProtocol.set_default_channel(True)
                    raise

            else:
                logger.error("No download channels available")

        if not os.path.isfile(self.get_agent_pkg_path()):
            msg = u"Unable to download Agent {0} from any URI".format(self.name)
            add_event(
                AGENT_NAME,
                op=WALAEventOperation.Download,
                version=CURRENT_VERSION,
                is_success=False,
                message=msg)
            raise UpdateError(msg)
Esempio n. 9
0
class WireClient(object):
    def __init__(self, endpoint):
        logger.info("Wire server endpoint:{0}", endpoint)
        self.endpoint = endpoint
        self.goal_state = None
        self.updated = None
        self.hosting_env = None
        self.shared_conf = None
        self.certs = None
        self.ext_conf = None
        self.last_request = 0
        self.req_count = 0
        self.status_blob = StatusBlob(self)
        self.host_plugin = HostPluginProtocol(self.endpoint)

    def prevent_throttling(self):
        """
        Try to avoid throttling of wire server
        """
        now = time.time()
        if now - self.last_request < 1:
            logger.verbose("Last request issued less than 1 second ago")
            logger.verbose("Sleep {0} second to avoid throttling.",
                           SHORT_WAITING_INTERVAL)
            time.sleep(SHORT_WAITING_INTERVAL)
        self.last_request = now

        self.req_count += 1
        if self.req_count % 3 == 0:
            logger.verbose("Sleep {0} second to avoid throttling.",
                           SHORT_WAITING_INTERVAL)
            time.sleep(SHORT_WAITING_INTERVAL)
            self.req_count = 0

    def call_wireserver(self, http_req, *args, **kwargs):
        """
        Call wire server. Handle throttling(403) and Resource Gone(410)
        """
        self.prevent_throttling()
        for retry in range(0, 3):
            resp = http_req(*args, **kwargs)
            if resp.status == httpclient.FORBIDDEN:
                logger.warn("Sending too much request to wire server")
                logger.info("Sleep {0} second to avoid throttling.",
                            LONG_WAITING_INTERVAL)
                time.sleep(LONG_WAITING_INTERVAL)
            elif resp.status == httpclient.GONE:
                msg = args[0] if len(args) > 0 else ""
                raise WireProtocolResourceGone(msg)
            else:
                return resp
        raise ProtocolError(("Calling wire server failed: {0}"
                             "").format(resp.status))

    def decode_config(self, data):
        if data is None:
            return None
        data = remove_bom(data)
        xml_text = ustr(data, encoding='utf-8')
        return xml_text

    def fetch_config(self, uri, headers):
        try:
            resp = self.call_wireserver(restutil.http_get,
                                        uri,
                                        headers=headers)
        except HttpError as e:
            raise ProtocolError(ustr(e))

        if (resp.status != httpclient.OK):
            raise ProtocolError("{0} - {1}".format(resp.status, uri))

        return self.decode_config(resp.read())

    def fetch_cache(self, local_file):
        if not os.path.isfile(local_file):
            raise ProtocolError("{0} is missing.".format(local_file))
        try:
            return fileutil.read_file(local_file)
        except IOError as e:
            raise ProtocolError("Failed to read cache: {0}".format(e))

    def save_cache(self, local_file, data):
        try:
            fileutil.write_file(local_file, data)
        except IOError as e:
            raise ProtocolError("Failed to write cache: {0}".format(e))

    def call_storage_service(self, http_req, *args, **kwargs):
        """ 
        Call storage service, handle SERVICE_UNAVAILABLE(503)
        """
        for retry in range(0, 3):
            resp = http_req(*args, **kwargs)
            if resp.status == httpclient.SERVICE_UNAVAILABLE:
                logger.warn("Storage service is not avaible temporaryly")
                logger.info("Will retry later, in {0} seconds",
                            LONG_WAITING_INTERVAL)
                time.sleep(LONG_WAITING_INTERVAL)
            else:
                return resp
        raise ProtocolError(("Calling storage endpoint failed: {0}"
                             "").format(resp.status))

    def fetch_manifest(self, version_uris):
        for version_uri in version_uris:
            logger.verbose("Fetch ext handler manifest: {0}", version_uri.uri)
            try:
                resp = self.call_storage_service(restutil.http_get,
                                                 version_uri.uri,
                                                 None,
                                                 chk_proxy=True)
            except HttpError as e:
                raise ProtocolError(ustr(e))

            if resp.status == httpclient.OK:
                return self.decode_config(resp.read())
            logger.warn("Failed to fetch ExtensionManifest: {0}, {1}",
                        resp.status, version_uri.uri)
            logger.info("Will retry later, in {0} seconds",
                        LONG_WAITING_INTERVAL)
            time.sleep(LONG_WAITING_INTERVAL)
        raise ProtocolError(("Failed to fetch ExtensionManifest from "
                             "all sources"))

    def update_hosting_env(self, goal_state):
        if goal_state.hosting_env_uri is None:
            raise ProtocolError("HostingEnvironmentConfig uri is empty")
        local_file = os.path.join(conf.get_lib_dir(), HOSTING_ENV_FILE_NAME)
        xml_text = self.fetch_config(goal_state.hosting_env_uri,
                                     self.get_header())
        self.save_cache(local_file, xml_text)
        self.hosting_env = HostingEnv(xml_text)

    def update_shared_conf(self, goal_state):
        if goal_state.shared_conf_uri is None:
            raise ProtocolError("SharedConfig uri is empty")
        local_file = os.path.join(conf.get_lib_dir(), SHARED_CONF_FILE_NAME)
        xml_text = self.fetch_config(goal_state.shared_conf_uri,
                                     self.get_header())
        self.save_cache(local_file, xml_text)
        self.shared_conf = SharedConfig(xml_text)

    def update_certs(self, goal_state):
        if goal_state.certs_uri is None:
            return
        local_file = os.path.join(conf.get_lib_dir(), CERTS_FILE_NAME)
        xml_text = self.fetch_config(goal_state.certs_uri,
                                     self.get_header_for_cert())
        self.save_cache(local_file, xml_text)
        self.certs = Certificates(self, xml_text)

    def update_ext_conf(self, goal_state):
        if goal_state.ext_uri is None:
            logger.info("ExtensionsConfig.xml uri is empty")
            self.ext_conf = ExtensionsConfig(None)
            return
        incarnation = goal_state.incarnation
        local_file = os.path.join(conf.get_lib_dir(),
                                  EXT_CONF_FILE_NAME.format(incarnation))
        xml_text = self.fetch_config(goal_state.ext_uri, self.get_header())
        self.save_cache(local_file, xml_text)
        self.ext_conf = ExtensionsConfig(xml_text)

    def update_goal_state(self, forced=False, max_retry=3):
        uri = GOAL_STATE_URI.format(self.endpoint)
        xml_text = self.fetch_config(uri, self.get_header())
        goal_state = GoalState(xml_text)

        incarnation_file = os.path.join(conf.get_lib_dir(),
                                        INCARNATION_FILE_NAME)

        if not forced:
            last_incarnation = None
            if (os.path.isfile(incarnation_file)):
                last_incarnation = fileutil.read_file(incarnation_file)
            new_incarnation = goal_state.incarnation
            if last_incarnation is not None and \
                            last_incarnation == new_incarnation:
                # Goalstate is not updated.
                return

        # Start updating goalstate, retry on 410
        for retry in range(0, max_retry):
            try:
                self.goal_state = goal_state
                file_name = GOAL_STATE_FILE_NAME.format(goal_state.incarnation)
                goal_state_file = os.path.join(conf.get_lib_dir(), file_name)
                self.save_cache(goal_state_file, xml_text)
                self.save_cache(incarnation_file, goal_state.incarnation)
                self.update_hosting_env(goal_state)
                self.update_shared_conf(goal_state)
                self.update_certs(goal_state)
                self.update_ext_conf(goal_state)
                return
            except WireProtocolResourceGone:
                logger.info("Incarnation is out of date. Update goalstate.")
                xml_text = self.fetch_config(uri, self.get_header())
                goal_state = GoalState(xml_text)

        raise ProtocolError("Exceeded max retry updating goal state")

    def get_goal_state(self):
        if (self.goal_state is None):
            incarnation_file = os.path.join(conf.get_lib_dir(),
                                            INCARNATION_FILE_NAME)
            incarnation = self.fetch_cache(incarnation_file)

            file_name = GOAL_STATE_FILE_NAME.format(incarnation)
            goal_state_file = os.path.join(conf.get_lib_dir(), file_name)
            xml_text = self.fetch_cache(goal_state_file)
            self.goal_state = GoalState(xml_text)
        return self.goal_state

    def get_hosting_env(self):
        if (self.hosting_env is None):
            local_file = os.path.join(conf.get_lib_dir(),
                                      HOSTING_ENV_FILE_NAME)
            xml_text = self.fetch_cache(local_file)
            self.hosting_env = HostingEnv(xml_text)
        return self.hosting_env

    def get_shared_conf(self):
        if (self.shared_conf is None):
            local_file = os.path.join(conf.get_lib_dir(),
                                      SHARED_CONF_FILE_NAME)
            xml_text = self.fetch_cache(local_file)
            self.shared_conf = SharedConfig(xml_text)
        return self.shared_conf

    def get_certs(self):
        if (self.certs is None):
            local_file = os.path.join(conf.get_lib_dir(), CERTS_FILE_NAME)
            xml_text = self.fetch_cache(local_file)
            self.certs = Certificates(self, xml_text)
        if self.certs is None:
            return None
        return self.certs

    def get_ext_conf(self):
        if (self.ext_conf is None):
            goal_state = self.get_goal_state()
            if goal_state.ext_uri is None:
                self.ext_conf = ExtensionsConfig(None)
            else:
                local_file = EXT_CONF_FILE_NAME.format(goal_state.incarnation)
                local_file = os.path.join(conf.get_lib_dir(), local_file)
                xml_text = self.fetch_cache(local_file)
                self.ext_conf = ExtensionsConfig(xml_text)
        return self.ext_conf

    def get_ext_manifest(self, ext_handler, goal_state):
        local_file = MANIFEST_FILE_NAME.format(ext_handler.name,
                                               goal_state.incarnation)
        local_file = os.path.join(conf.get_lib_dir(), local_file)
        xml_text = self.fetch_manifest(ext_handler.versionUris)
        self.save_cache(local_file, xml_text)
        return ExtensionManifest(xml_text)

    def get_gafamily_manifest(self, vmagent_manifest, goal_state):
        local_file = MANIFEST_FILE_NAME.format(vmagent_manifest.family,
                                               goal_state.incarnation)
        local_file = os.path.join(conf.get_lib_dir(), local_file)
        xml_text = self.fetch_manifest(vmagent_manifest.versionsManifestUris)
        fileutil.write_file(local_file, xml_text)
        return ExtensionManifest(xml_text)

    def check_wire_protocol_version(self):
        uri = VERSION_INFO_URI.format(self.endpoint)
        version_info_xml = self.fetch_config(uri, None)
        version_info = VersionInfo(version_info_xml)

        preferred = version_info.get_preferred()
        if PROTOCOL_VERSION == preferred:
            logger.info("Wire protocol version:{0}", PROTOCOL_VERSION)
        elif PROTOCOL_VERSION in version_info.get_supported():
            logger.info("Wire protocol version:{0}", PROTOCOL_VERSION)
            logger.warn("Server prefered version:{0}", preferred)
        else:
            error = ("Agent supported wire protocol version: {0} was not "
                     "advised by Fabric.").format(PROTOCOL_VERSION)
            raise ProtocolNotFoundError(error)

    def upload_status_blob(self):
        ext_conf = self.get_ext_conf()
        if ext_conf.status_upload_blob is not None:
            if not self.status_blob.upload(ext_conf.status_upload_blob):
                self.host_plugin.put_vm_status(self.status_blob,
                                               ext_conf.status_upload_blob)

    def report_role_prop(self, thumbprint):
        goal_state = self.get_goal_state()
        role_prop = _build_role_properties(goal_state.container_id,
                                           goal_state.role_instance_id,
                                           thumbprint)
        role_prop = role_prop.encode("utf-8")
        role_prop_uri = ROLE_PROP_URI.format(self.endpoint)
        headers = self.get_header_for_xml_content()
        try:
            resp = self.call_wireserver(restutil.http_post,
                                        role_prop_uri,
                                        role_prop,
                                        headers=headers)
        except HttpError as e:
            raise ProtocolError((u"Failed to send role properties: {0}"
                                 u"").format(e))
        if resp.status != httpclient.ACCEPTED:
            raise ProtocolError((u"Failed to send role properties: {0}"
                                 u", {1}").format(resp.status, resp.read()))

    def report_health(self, status, substatus, description):
        goal_state = self.get_goal_state()
        health_report = _build_health_report(goal_state.incarnation,
                                             goal_state.container_id,
                                             goal_state.role_instance_id,
                                             status, substatus, description)
        health_report = health_report.encode("utf-8")
        health_report_uri = HEALTH_REPORT_URI.format(self.endpoint)
        headers = self.get_header_for_xml_content()
        try:
            resp = self.call_wireserver(restutil.http_post,
                                        health_report_uri,
                                        health_report,
                                        headers=headers,
                                        max_retry=8)
        except HttpError as e:
            raise ProtocolError((u"Failed to send provision status: {0}"
                                 u"").format(e))
        if resp.status != httpclient.OK:
            raise ProtocolError((u"Failed to send provision status: {0}"
                                 u", {1}").format(resp.status, resp.read()))

    def send_event(self, provider_id, event_str):
        uri = TELEMETRY_URI.format(self.endpoint)
        data_format = ('<?xml version="1.0"?>'
                       '<TelemetryData version="1.0">'
                       '<Provider id="{0}">{1}'
                       '</Provider>'
                       '</TelemetryData>')
        data = data_format.format(provider_id, event_str)
        try:
            header = self.get_header_for_xml_content()
            resp = self.call_wireserver(restutil.http_post, uri, data, header)
        except HttpError as e:
            raise ProtocolError("Failed to send events:{0}".format(e))

        if resp.status != httpclient.OK:
            logger.verbose(resp.read())
            raise ProtocolError("Failed to send events:{0}".format(
                resp.status))

    def report_event(self, event_list):
        buf = {}
        # Group events by providerId
        for event in event_list.events:
            if event.providerId not in buf:
                buf[event.providerId] = ""
            event_str = event_to_v1(event)
            if len(event_str) >= 63 * 1024:
                logger.warn("Single event too large: {0}", event_str[300:])
                continue
            if len(buf[event.providerId] + event_str) >= 63 * 1024:
                self.send_event(event.providerId, buf[event.providerId])
                buf[event.providerId] = ""
            buf[event.providerId] = buf[event.providerId] + event_str

        # Send out all events left in buffer.
        for provider_id in list(buf.keys()):
            if len(buf[provider_id]) > 0:
                self.send_event(provider_id, buf[provider_id])

    def get_header(self):
        return {
            "x-ms-agent-name": "WALinuxAgent",
            "x-ms-version": PROTOCOL_VERSION
        }

    def get_header_for_xml_content(self):
        return {
            "x-ms-agent-name": "WALinuxAgent",
            "x-ms-version": PROTOCOL_VERSION,
            "Content-Type": "text/xml;charset=utf-8"
        }

    def get_header_for_cert(self):
        trans_cert_file = os.path.join(conf.get_lib_dir(),
                                       TRANSPORT_CERT_FILE_NAME)
        content = self.fetch_cache(trans_cert_file)
        cert = get_bytes_from_pem(content)
        return {
            "x-ms-agent-name": "WALinuxAgent",
            "x-ms-version": PROTOCOL_VERSION,
            "x-ms-cipher-name": "DES_EDE3_CBC",
            "x-ms-guest-agent-public-x509-cert": cert
        }
Esempio n. 10
0
class WireClient(object):
    def __init__(self, endpoint):
        logger.info("Wire server endpoint:{0}", endpoint)
        self.endpoint = endpoint
        self.goal_state = None
        self.updated = None
        self.hosting_env = None
        self.shared_conf = None
        self.certs = None
        self.ext_conf = None
        self.last_request = 0
        self.req_count = 0
        self.status_blob = StatusBlob(self)
        self.host_plugin = HostPluginProtocol(self.endpoint)

    def prevent_throttling(self):
        """
        Try to avoid throttling of wire server
        """
        now = time.time()
        if now - self.last_request < 1:
            logger.verbose("Last request issued less than 1 second ago")
            logger.verbose("Sleep {0} second to avoid throttling.", 
                        SHORT_WAITING_INTERVAL)
            time.sleep(SHORT_WAITING_INTERVAL)
        self.last_request = now

        self.req_count += 1
        if self.req_count % 3 == 0:
            logger.verbose("Sleep {0} second to avoid throttling.", 
                        SHORT_WAITING_INTERVAL)
            time.sleep(SHORT_WAITING_INTERVAL)
            self.req_count = 0

    def call_wireserver(self, http_req, *args, **kwargs):
        """
        Call wire server. Handle throttling(403) and Resource Gone(410)
        """
        self.prevent_throttling()
        for retry in range(0, 3):
            resp = http_req(*args, **kwargs)
            if resp.status == httpclient.FORBIDDEN:
                logger.warn("Sending too much request to wire server")
                logger.info("Sleep {0} second to avoid throttling.",
                            LONG_WAITING_INTERVAL)
                time.sleep(LONG_WAITING_INTERVAL)
            elif resp.status == httpclient.GONE:
                msg = args[0] if len(args) > 0 else ""
                raise WireProtocolResourceGone(msg)
            else:
                return resp
        raise ProtocolError(("Calling wire server failed: {0}"
                             "").format(resp.status))

    def decode_config(self, data):
        if data is None:
            return None
        data = remove_bom(data)
        xml_text = ustr(data, encoding='utf-8')
        return xml_text

    def fetch_config(self, uri, headers):
        try:
            resp = self.call_wireserver(restutil.http_get, uri,
                                        headers=headers)
        except HttpError as e:
            raise ProtocolError(ustr(e))

        if (resp.status != httpclient.OK):
            raise ProtocolError("{0} - {1}".format(resp.status, uri))

        return self.decode_config(resp.read())

    def fetch_cache(self, local_file):
        if not os.path.isfile(local_file):
            raise ProtocolError("{0} is missing.".format(local_file))
        try:
            return fileutil.read_file(local_file)
        except IOError as e:
            raise ProtocolError("Failed to read cache: {0}".format(e))

    def save_cache(self, local_file, data):
        try:
            fileutil.write_file(local_file, data)
        except IOError as e:
            raise ProtocolError("Failed to write cache: {0}".format(e))

    def call_storage_service(self, http_req, *args, **kwargs):
        """ 
        Call storage service, handle SERVICE_UNAVAILABLE(503)
        """
        for retry in range(0, 3):
            resp = http_req(*args, **kwargs)
            if resp.status == httpclient.SERVICE_UNAVAILABLE:
                logger.warn("Storage service is not avaible temporaryly")
                logger.info("Will retry later, in {0} seconds",
                            LONG_WAITING_INTERVAL)
                time.sleep(LONG_WAITING_INTERVAL)
            else:
                return resp
        raise ProtocolError(("Calling storage endpoint failed: {0}"
                             "").format(resp.status))

    def fetch_manifest(self, version_uris):
        for version_uri in version_uris:
            logger.verbose("Fetch ext handler manifest: {0}", version_uri.uri)
            try:
                resp = self.call_storage_service(restutil.http_get,
                                                 version_uri.uri, None,
                                                 chk_proxy=True)
            except HttpError as e:
                raise ProtocolError(ustr(e))

            if resp.status == httpclient.OK:
                return self.decode_config(resp.read())
            logger.warn("Failed to fetch ExtensionManifest: {0}, {1}",
                        resp.status, version_uri.uri)
            logger.info("Will retry later, in {0} seconds",
                        LONG_WAITING_INTERVAL)
            time.sleep(LONG_WAITING_INTERVAL)
        raise ProtocolError(("Failed to fetch ExtensionManifest from "
                             "all sources"))

    def update_hosting_env(self, goal_state):
        if goal_state.hosting_env_uri is None:
            raise ProtocolError("HostingEnvironmentConfig uri is empty")
        local_file = os.path.join(conf.get_lib_dir(), HOSTING_ENV_FILE_NAME)
        xml_text = self.fetch_config(goal_state.hosting_env_uri,
                                     self.get_header())
        self.save_cache(local_file, xml_text)
        self.hosting_env = HostingEnv(xml_text)

    def update_shared_conf(self, goal_state):
        if goal_state.shared_conf_uri is None:
            raise ProtocolError("SharedConfig uri is empty")
        local_file = os.path.join(conf.get_lib_dir(), SHARED_CONF_FILE_NAME)
        xml_text = self.fetch_config(goal_state.shared_conf_uri,
                                     self.get_header())
        self.save_cache(local_file, xml_text)
        self.shared_conf = SharedConfig(xml_text)

    def update_certs(self, goal_state):
        if goal_state.certs_uri is None:
            return
        local_file = os.path.join(conf.get_lib_dir(), CERTS_FILE_NAME)
        xml_text = self.fetch_config(goal_state.certs_uri,
                                     self.get_header_for_cert())
        self.save_cache(local_file, xml_text)
        self.certs = Certificates(self, xml_text)

    def update_ext_conf(self, goal_state):
        if goal_state.ext_uri is None:
            logger.info("ExtensionsConfig.xml uri is empty")
            self.ext_conf = ExtensionsConfig(None)
            return
        incarnation = goal_state.incarnation
        local_file = os.path.join(conf.get_lib_dir(),
                                  EXT_CONF_FILE_NAME.format(incarnation))
        xml_text = self.fetch_config(goal_state.ext_uri, self.get_header())
        self.save_cache(local_file, xml_text)
        self.ext_conf = ExtensionsConfig(xml_text)

    def update_goal_state(self, forced=False, max_retry=3):
        uri = GOAL_STATE_URI.format(self.endpoint)
        xml_text = self.fetch_config(uri, self.get_header())
        goal_state = GoalState(xml_text)

        incarnation_file = os.path.join(conf.get_lib_dir(),
                                        INCARNATION_FILE_NAME)

        if not forced:
            last_incarnation = None
            if (os.path.isfile(incarnation_file)):
                last_incarnation = fileutil.read_file(incarnation_file)
            new_incarnation = goal_state.incarnation
            if last_incarnation is not None and \
                            last_incarnation == new_incarnation:
                # Goalstate is not updated.
                return

        # Start updating goalstate, retry on 410
        for retry in range(0, max_retry):
            try:
                self.goal_state = goal_state
                file_name = GOAL_STATE_FILE_NAME.format(goal_state.incarnation)
                goal_state_file = os.path.join(conf.get_lib_dir(), file_name)
                self.save_cache(goal_state_file, xml_text)
                self.save_cache(incarnation_file, goal_state.incarnation)
                self.update_hosting_env(goal_state)
                self.update_shared_conf(goal_state)
                self.update_certs(goal_state)
                self.update_ext_conf(goal_state)
                return
            except WireProtocolResourceGone:
                logger.info("Incarnation is out of date. Update goalstate.")
                xml_text = self.fetch_config(uri, self.get_header())
                goal_state = GoalState(xml_text)

        raise ProtocolError("Exceeded max retry updating goal state")

    def get_goal_state(self):
        if (self.goal_state is None):
            incarnation_file = os.path.join(conf.get_lib_dir(),
                                            INCARNATION_FILE_NAME)
            incarnation = self.fetch_cache(incarnation_file)

            file_name = GOAL_STATE_FILE_NAME.format(incarnation)
            goal_state_file = os.path.join(conf.get_lib_dir(), file_name)
            xml_text = self.fetch_cache(goal_state_file)
            self.goal_state = GoalState(xml_text)
        return self.goal_state

    def get_hosting_env(self):
        if (self.hosting_env is None):
            local_file = os.path.join(conf.get_lib_dir(), HOSTING_ENV_FILE_NAME)
            xml_text = self.fetch_cache(local_file)
            self.hosting_env = HostingEnv(xml_text)
        return self.hosting_env

    def get_shared_conf(self):
        if (self.shared_conf is None):
            local_file = os.path.join(conf.get_lib_dir(), SHARED_CONF_FILE_NAME)
            xml_text = self.fetch_cache(local_file)
            self.shared_conf = SharedConfig(xml_text)
        return self.shared_conf

    def get_certs(self):
        if (self.certs is None):
            local_file = os.path.join(conf.get_lib_dir(), CERTS_FILE_NAME)
            xml_text = self.fetch_cache(local_file)
            self.certs = Certificates(self, xml_text)
        if self.certs is None:
            return None
        return self.certs

    def get_ext_conf(self):
        if (self.ext_conf is None):
            goal_state = self.get_goal_state()
            if goal_state.ext_uri is None:
                self.ext_conf = ExtensionsConfig(None)
            else:
                local_file = EXT_CONF_FILE_NAME.format(goal_state.incarnation)
                local_file = os.path.join(conf.get_lib_dir(), local_file)
                xml_text = self.fetch_cache(local_file)
                self.ext_conf = ExtensionsConfig(xml_text)
        return self.ext_conf

    def get_ext_manifest(self, ext_handler, goal_state):
        local_file = MANIFEST_FILE_NAME.format(ext_handler.name,
                                               goal_state.incarnation)
        local_file = os.path.join(conf.get_lib_dir(), local_file)
        xml_text = self.fetch_manifest(ext_handler.versionUris)
        self.save_cache(local_file, xml_text)
        return ExtensionManifest(xml_text)

    def get_gafamily_manifest(self, vmagent_manifest, goal_state):
        local_file = MANIFEST_FILE_NAME.format(vmagent_manifest.family,
                                               goal_state.incarnation)
        local_file = os.path.join(conf.get_lib_dir(), local_file)
        xml_text = self.fetch_manifest(vmagent_manifest.versionsManifestUris)
        fileutil.write_file(local_file, xml_text)
        return ExtensionManifest(xml_text)

    def check_wire_protocol_version(self):
        uri = VERSION_INFO_URI.format(self.endpoint)
        version_info_xml = self.fetch_config(uri, None)
        version_info = VersionInfo(version_info_xml)

        preferred = version_info.get_preferred()
        if PROTOCOL_VERSION == preferred:
            logger.info("Wire protocol version:{0}", PROTOCOL_VERSION)
        elif PROTOCOL_VERSION in version_info.get_supported():
            logger.info("Wire protocol version:{0}", PROTOCOL_VERSION)
            logger.warn("Server prefered version:{0}", preferred)
        else:
            error = ("Agent supported wire protocol version: {0} was not "
                     "advised by Fabric.").format(PROTOCOL_VERSION)
            raise ProtocolNotFoundError(error)

    def upload_status_blob(self):
        ext_conf = self.get_ext_conf()
        if ext_conf.status_upload_blob is not None:
            if not self.status_blob.upload(ext_conf.status_upload_blob):
                self.host_plugin.put_vm_status(self.status_blob,
                                               ext_conf.status_upload_blob)

    def report_role_prop(self, thumbprint):
        goal_state = self.get_goal_state()
        role_prop = _build_role_properties(goal_state.container_id,
                                           goal_state.role_instance_id,
                                           thumbprint)
        role_prop = role_prop.encode("utf-8")
        role_prop_uri = ROLE_PROP_URI.format(self.endpoint)
        headers = self.get_header_for_xml_content()
        try:
            resp = self.call_wireserver(restutil.http_post, role_prop_uri,
                                        role_prop, headers=headers)
        except HttpError as e:
            raise ProtocolError((u"Failed to send role properties: {0}"
                                 u"").format(e))
        if resp.status != httpclient.ACCEPTED:
            raise ProtocolError((u"Failed to send role properties: {0}"
                                 u", {1}").format(resp.status, resp.read()))

    def report_health(self, status, substatus, description):
        goal_state = self.get_goal_state()
        health_report = _build_health_report(goal_state.incarnation,
                                             goal_state.container_id,
                                             goal_state.role_instance_id,
                                             status,
                                             substatus,
                                             description)
        health_report = health_report.encode("utf-8")
        health_report_uri = HEALTH_REPORT_URI.format(self.endpoint)
        headers = self.get_header_for_xml_content()
        try:
            resp = self.call_wireserver(restutil.http_post, health_report_uri,
                                        health_report, headers=headers, max_retry=8)
        except HttpError as e:
            raise ProtocolError((u"Failed to send provision status: {0}"
                                 u"").format(e))
        if resp.status != httpclient.OK:
            raise ProtocolError((u"Failed to send provision status: {0}"
                                 u", {1}").format(resp.status, resp.read()))

    def send_event(self, provider_id, event_str):
        uri = TELEMETRY_URI.format(self.endpoint)
        data_format = ('<?xml version="1.0"?>'
                       '<TelemetryData version="1.0">'
                       '<Provider id="{0}">{1}'
                       '</Provider>'
                       '</TelemetryData>')
        data = data_format.format(provider_id, event_str)
        try:
            header = self.get_header_for_xml_content()
            resp = self.call_wireserver(restutil.http_post, uri, data, header)
        except HttpError as e:
            raise ProtocolError("Failed to send events:{0}".format(e))

        if resp.status != httpclient.OK:
            logger.verbose(resp.read())
            raise ProtocolError("Failed to send events:{0}".format(resp.status))

    def report_event(self, event_list):
        buf = {}
        # Group events by providerId
        for event in event_list.events:
            if event.providerId not in buf:
                buf[event.providerId] = ""
            event_str = event_to_v1(event)
            if len(event_str) >= 63 * 1024:
                logger.warn("Single event too large: {0}", event_str[300:])
                continue
            if len(buf[event.providerId] + event_str) >= 63 * 1024:
                self.send_event(event.providerId, buf[event.providerId])
                buf[event.providerId] = ""
            buf[event.providerId] = buf[event.providerId] + event_str

        # Send out all events left in buffer.
        for provider_id in list(buf.keys()):
            if len(buf[provider_id]) > 0:
                self.send_event(provider_id, buf[provider_id])

    def get_header(self):
        return {
            "x-ms-agent-name": "WALinuxAgent",
            "x-ms-version": PROTOCOL_VERSION
        }

    def get_header_for_xml_content(self):
        return {
            "x-ms-agent-name": "WALinuxAgent",
            "x-ms-version": PROTOCOL_VERSION,
            "Content-Type": "text/xml;charset=utf-8"
        }

    def get_header_for_cert(self):
        trans_cert_file = os.path.join(conf.get_lib_dir(),
                                       TRANSPORT_CERT_FILE_NAME)
        content = self.fetch_cache(trans_cert_file)
        cert = get_bytes_from_pem(content)
        return {
            "x-ms-agent-name": "WALinuxAgent",
            "x-ms-version": PROTOCOL_VERSION,
            "x-ms-cipher-name": "DES_EDE3_CBC",
            "x-ms-guest-agent-public-x509-cert": cert
        }
Esempio n. 11
0
 def get_host_plugin(self):
     if self.host_plugin is None:
         self.host_plugin = HostPluginProtocol(self.endpoint, self.get_goal_state())
     return self.host_plugin