コード例 #1
0
ファイル: common.py プロジェクト: fabiandevia/home
def async_fire_mqtt_message(hass, topic, payload, qos=0, retain=False):
    """Fire the MQTT message."""
    if isinstance(payload, str):
        payload = payload.encode('utf-8')
    msg = mqtt.Message(topic, payload, qos, retain)
    hass.data['mqtt']._mqtt_handle_message(msg)
コード例 #2
0
ファイル: __init__.py プロジェクト: golfr32/http2mqtt2hass
async def async_setup_entry(hass, entry):
    """Load a config entry."""
    conf = hass.data.get(DATA_HTTP2MQTT2HASS_CONFIG)

    # Config entry was created because user had configuration.yaml entry
    # They removed that, so remove entry.
    if conf is None and entry.source == config_entries.SOURCE_IMPORT:
        hass.async_create_task(hass.config_entries.async_remove(
            entry.entry_id))
        return False

    # If user didn't have configuration.yaml config, generate defaults
    if conf is None:
        conf = CONFIG_SCHEMA({
            DOMAIN: entry.data,
        })[DOMAIN]
    elif any(key in conf for key in entry.data):
        _LOGGER.warning(
            "Data in your config entry is going to override your "
            "configuration.yaml: %s", entry.data)

    conf.update(entry.data)

    broker = conf[CONF_BROKER]
    port = conf[CONF_PORT]
    client_id = conf.get(CONF_CLIENT_ID)
    keepalive = conf[CONF_KEEPALIVE]
    app_key = conf.get(CONF_APP_KEY)
    app_secret = conf.get(CONF_APP_SECRET)
    certificate = conf.get(CONF_CERTIFICATE)
    client_key = conf.get(CONF_CLIENT_KEY)
    client_cert = conf.get(CONF_CLIENT_CERT)
    tls_insecure = conf.get(CONF_TLS_INSECURE)
    protocol = conf[CONF_PROTOCOL]
    allowed_uri = conf.get(CONF_ALLOWED_URI)
    decrypt_key = bytes().fromhex(
        sha1(app_secret.encode("utf-8")).hexdigest())[0:16]

    # For cloudmqtt.com, secured connection, auto fill in certificate
    if (certificate is None and 19999 < conf[CONF_PORT] < 30000
            and broker.endswith('.cloudmqtt.com')):
        certificate = os.path.join(os.path.dirname(__file__),
                                   'addtrustexternalcaroot.crt')

    # When the certificate is set to auto, use bundled certs from requests
    elif certificate == 'auto':
        certificate = requests.certs.where()

    if CONF_WILL_MESSAGE in conf:
        will_message = mqtt.Message(**conf[CONF_WILL_MESSAGE])
    else:
        will_message = None

    if CONF_BIRTH_MESSAGE in conf:
        birth_message = mqtt.Message(**conf[CONF_BIRTH_MESSAGE])
    else:
        birth_message = None

    # Be able to override versions other than TLSv1.0 under Python3.6
    conf_tls_version = conf.get(CONF_TLS_VERSION)  # type: str
    if conf_tls_version == '1.2':
        tls_version = ssl.PROTOCOL_TLSv1_2
    elif conf_tls_version == '1.1':
        tls_version = ssl.PROTOCOL_TLSv1_1
    elif conf_tls_version == '1.0':
        tls_version = ssl.PROTOCOL_TLSv1
    else:
        import sys
        # Python3.6 supports automatic negotiation of highest TLS version
        if sys.hexversion >= 0x03060000:
            tls_version = ssl.PROTOCOL_TLS  # pylint: disable=no-member
        else:
            tls_version = ssl.PROTOCOL_TLSv1

    hass.data[DATA_HTTP2MQTT2HASS_MQTT] = mqtt.MQTT(
        hass,
        broker=broker,
        port=port,
        client_id=client_id,
        keepalive=keepalive,
        username=app_key,
        password=app_secret,
        certificate=certificate,
        client_key=client_key,
        client_cert=client_cert,
        tls_insecure=tls_insecure,
        protocol=protocol,
        will_message=will_message,
        birth_message=birth_message,
        tls_version=tls_version,
    )

    success = await hass.data[DATA_HTTP2MQTT2HASS_MQTT].async_connect(
    )  # type: bool

    if not success:
        return False

    async def async_stop_mqtt(event: Event):
        """Stop MQTT component."""
        await hass.data[DATA_HTTP2MQTT2HASS_MQTT].async_disconnect()

    hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, async_stop_mqtt)

    async def localHttp(resData, topic):
        url = 'https://localhost:8123' + resData['uri']
        if ('content' in resData):
            try:
                session = async_get_clientsession(hass, verify_ssl=False)
                with async_timeout.timeout(5, loop=hass.loop):
                    response = await session.post(
                        url,
                        data=resData['content'],
                        headers=resData.get('headers'))
            except (asyncio.TimeoutError, aiohttp.ClientError):
                _LOGGER.error("Error while accessing: %s", url)
                result = {"error": "time_out"}
        else:
            try:
                session = async_get_clientsession(hass, verify_ssl=False)
                with async_timeout.timeout(5, loop=hass.loop):
                    response = await session.get(
                        url, headers=resData.get('headers'))
            except (asyncio.TimeoutError, aiohttp.ClientError):
                _LOGGER.error("Error while accessing: %s", url)
                result = {"error": "time_out"}
            # _LOGGER.debug(response.history) #查看重定向信息
        if response.status != 200:
            _LOGGER.error("Error while accessing: %s, status=%d", url,
                          response.status)

        if ('image' in response.headers['Content-Type']
                or 'stream' in response.headers['Content-Type']):
            result = await response.read()
            result = b64encode(result).decode()
        else:
            result = await response.text()
        headers = {'Content-Type': response.headers['Content-Type']}
        res = {
            'headers': headers,
            'status': response.status,
            'content': result.encode('utf-8').decode('unicode_escape'),
            'msgId': resData.get('msgId')
        }
        _LOGGER.debug(
            "%s response[%s]: [%s]",
            resData['uri'].split('/')[-1].split('?')[0],
            resData.get('msgId'),
            response.headers['Content-Type'],
        )
        res = AESCipher(decrypt_key).encrypt(
            json.dumps(res, ensure_ascii=False).encode('utf8'))

        await hass.data[DATA_HTTP2MQTT2HASS_MQTT].async_publish(
            topic.replace('/request/', '/response/'), res, 2, False)

    @callback
    def message_received(topic, payload, qos):
        """Handle new MQTT state messages."""
        _LOGGER.debug('get encrypt message: \n {}'.format(payload))
        try:
            payload = AESCipher(decrypt_key).decrypt(payload)
            req = json.loads(payload)
            _LOGGER.debug("raw message: %s", req)
            if (allowed_uri
                    and req.get('uri').split('?')[0] not in allowed_uri):
                _LOGGER.debug('uri not allowed: %s', req.get('uri'))
                return
            hass.add_job(localHttp(req, topic))
        except (JSONDecodeError, UnicodeDecodeError, binascii.Error):
            import sys
            ex_type, ex_val, ex_stack = sys.exc_info()
            log = ''
            for stack in traceback.extract_tb(ex_stack):
                log += str(stack)
            _LOGGER.debug('decrypt failure, abandon:%s', log)

    await hass.data[DATA_HTTP2MQTT2HASS_MQTT].async_subscribe(
        "ai-home/http2mqtt2hass/" + app_key + "/request/#", message_received,
        2, 'utf-8')
    return True
コード例 #3
0
ファイル: __init__.py プロジェクト: xuejike/havcs
async def async_setup_entry(hass, entry):
    """Load a config entry."""
    conf = hass.data.get(DATA_HAVCS_CONFIG)

    # Config entry was created because user had configuration.yaml entry
    # They removed that, so remove entry.
    if conf is None and entry.source == config_entries.SOURCE_IMPORT:
        hass.async_create_task(
            hass.config_entries.async_remove(entry.entry_id))
        return False

    # If user didn't have configuration.yaml config, generate defaults
    if conf is None:
        conf = CONFIG_SCHEMA({
            DOMAIN: entry.data,
        })[DOMAIN]
    elif any(key in conf for key in entry.data):
        _LOGGER.warning(
            "[init] Data in your config entry is going to override your "
            "configuration.yaml: %s", entry.data)

    conf.update(entry.data)

    havcd_config_path = conf.get(CONF_DEVICE_CONFIG, os.path.join(hass.config.config_dir, 'havcs.yaml'))
    if not os.path.isfile(havcd_config_path):
        with open(havcd_config_path, "wt") as havcd_config_file:
            havcd_config_file.write('')

    platform = conf[CONF_PLATFORM]
    
    if CONF_HTTP_PROXY not in conf and CONF_SKILL not in conf:
        _LOGGER.debug('[init] havcs only run in http mode, skip mqtt initialization')
        _LOGGER.info('[init] initialization finished.')
        async def start_havcs_without_mqtt(event: Event):
            async def async_load_config():
                try:
                    hass.data[DOMAIN]['devices'] = await hass.async_add_executor_job(
                        conf_util.load_yaml_config_file, havcd_config_path
                    )
                except HomeAssistantError as err:
                    _LOGGER.error("Error loading %s: %s", havcd_config_path, err)
                    return None
                for p in platform:
                    HANDLER[p].vcdm.all(hass, True)
                _LOGGER.info('[init] load config after startup')
            await async_load_config()
            async def async_handler_service(service):
                
                if service.service == SERVICE_RELOAD:
                    try:
                        hass.data[DOMAIN]['devices'] = await hass.async_add_executor_job(
                            conf_util.load_yaml_config_file, havcd_config_path
                        )
                    except HomeAssistantError as err:
                        _LOGGER.error("Error loading %s: %s", havcd_config_path, err)
                        return None
                    for p in platform:
                        devices = HANDLER[p].vcdm.all(hass, True)
                        _LOGGER.info('[service] ------------%s 平台加载设备信息------------\n%s', p, devices)
                        mind_devices = [device for device in devices if None in device.values() or [] in device.values()]
                        if mind_devices:
                            _LOGGER.debug('!!!!!!!! 以下设备信息不完整,检查值为None的属性 !!!!!!!!')
                            for mind_device in mind_devices:
                                _LOGGER.debug('%s', mind_device)
                        _LOGGER.info('[service] ------------%s 平台加载设备信息------------\n', p)
                else:
                    pass
            hass.services.async_register(DOMAIN, SERVICE_RELOAD, async_handler_service, schema=HAVCS_SERVICE_SCHEMA)
        hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, start_havcs_without_mqtt)
        return True

    setting_conf = conf.get(CONF_SETTING)
    app_key = setting_conf.get(CONF_APP_KEY)
    app_secret = setting_conf.get(CONF_APP_SECRET)
    decrypt_key =bytes().fromhex(sha1(app_secret.encode("utf-8")).hexdigest())[0:16]

    allowed_uri = conf.get(CONF_HTTP_PROXY, {}).get(CONF_ALLOWED_URI)
    ha_url = conf.get(CONF_HTTP_PROXY, {}).get(CONF_HA_URL, hass.config.api.base_url)

    sync_device = conf.get(CONF_SKILL, {}).get(CONF_SYNC_DEVICE)
    bind_device = conf.get(CONF_SKILL, {}).get(CONF_BIND_DEVICE)

    broker = setting_conf[CONF_BROKER]
    port = setting_conf[CONF_PORT]
    client_id = setting_conf.get(CONF_CLIENT_ID)
    keepalive = setting_conf[CONF_KEEPALIVE]
    certificate = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'ca.crt')
    if os.path.exists(certificate):
        _LOGGER.debug('[init] sucess to autoload ca.crt from %s', certificate)
    else:
        certificate = setting_conf.get(CONF_CERTIFICATE)
    client_key = setting_conf.get(CONF_CLIENT_KEY)
    client_cert = setting_conf.get(CONF_CLIENT_CERT)
    tls_insecure = setting_conf.get(CONF_TLS_INSECURE)
    protocol = setting_conf[CONF_PROTOCOL]

    # For cloudmqtt.com, secured connection, auto fill in certificate
    if (certificate is None and 19999 < conf[CONF_PORT] < 30000 and
            broker.endswith('.cloudmqtt.com')):
        certificate = os.path.join(
            os.path.dirname(__file__), 'addtrustexternalcaroot.crt')

    # When the certificate is set to auto, use bundled certs from requests
    elif certificate == 'auto':
        certificate = requests.certs.where()

    if CONF_WILL_MESSAGE in setting_conf:
        will_message = mqtt.Message(**conf[CONF_WILL_MESSAGE])
    else:
        will_message = None

    if CONF_BIRTH_MESSAGE in setting_conf:
        birth_message = mqtt.Message(**conf[CONF_BIRTH_MESSAGE])
    else:
        birth_message = None

    # Be able to override versions other than TLSv1.0 under Python3.6
    conf_tls_version = setting_conf.get(CONF_TLS_VERSION)  # type: str
    if conf_tls_version == '1.2':
        tls_version = ssl.PROTOCOL_TLSv1_2
    elif conf_tls_version == '1.1':
        tls_version = ssl.PROTOCOL_TLSv1_1
    elif conf_tls_version == '1.0':
        tls_version = ssl.PROTOCOL_TLSv1
    else:
        import sys
        # Python3.6 supports automatic negotiation of highest TLS version
        if sys.hexversion >= 0x03060000:
            tls_version = ssl.PROTOCOL_TLS  # pylint: disable=no-member
        else:
            tls_version = ssl.PROTOCOL_TLSv1

    hass.data[DATA_HAVCS_MQTT] = mqtt.MQTT(
        hass,
        broker=broker,
        port=port,
        client_id=client_id,
        keepalive=keepalive,
        username=app_key,
        password=app_secret,
        certificate=certificate,
        client_key=client_key,
        client_cert=client_cert,
        tls_insecure=tls_insecure,
        protocol=protocol,
        will_message=will_message,
        birth_message=birth_message,
        tls_version=tls_version,
    )

    success = await hass.data[DATA_HAVCS_MQTT].async_connect()  # type: bool

    if success is True or success == 'connection_success':
        pass
    else:
        _LOGGER.error('[init] can not connect to mqtt server (code = %s), check mqtt server\'s address and port.', success)
        return False

    async def start_havcs(event: Event):
        async def async_load_config():
            try:
                hass.data[DOMAIN]['devices'] = await hass.async_add_executor_job(
                    conf_util.load_yaml_config_file, havcd_config_path
                )
            except HomeAssistantError as err:
                _LOGGER.error("Error loading %s: %s", havcd_config_path, err)
                return None
            for p in platform:
                HANDLER[p].vcdm.all(hass, True)
            _LOGGER.info('[init] load config after startup')
        await async_load_config()

        async def async_bind_device():
            for uuid in hass.data['havcs_bind_manager'].discovery:
                p_user_id = uuid.split('@')[0]
                platform = uuid.split('@')[1]
                if platform in HANDLER and getattr(HANDLER.get(platform), 'should_report_when_starup', False) and hasattr(HANDLER.get(platform), 'bind_device'):
                    err_result, devices, entity_ids = HANDLER[platform].process_discovery_command()
                    if err_result:
                        return
                    bind_entity_ids, unbind_entity_ids = await hass.data['havcs_bind_manager'].async_save_changed_devices(entity_ids,platform, p_user_id,True)
                    payload = await HANDLER[platform].bind_device(p_user_id, entity_ids , unbind_entity_ids, devices)
                    _LOGGER.debug('[skill] bind device to %s:\nbind_entity_ids = %s, unbind_entity_ids = %s', platform, bind_entity_ids, unbind_entity_ids)

                    if payload:
                        url = 'https://ai-home.ljr.im/skill/smarthome.php?v=update&AppKey='+app_key
                        data = havcs_util.AESCipher(decrypt_key).encrypt(json.dumps(payload, ensure_ascii = False).encode('utf8'))
                        try:
                            session = async_get_clientsession(hass, verify_ssl=False)
                            with async_timeout.timeout(5, loop=hass.loop):
                                response = await session.post(url, data=data)
                                _LOGGER.debug('[skill] get bind device result from %s: %s', platform, await response.text())
                        except(asyncio.TimeoutError, aiohttp.ClientError):
                            _LOGGER.error("[skill] fail to access %s, bind device fail: timeout", url)
        if bind_device:
            await async_bind_device()

        @callback
        def report_device(event):
            # _LOGGER.debug('[skill] %s changed, try to report', event.data[ATTR_ENTITY_ID])
            hass.add_job(async_report_device(event))

        async def async_report_device(event):
            """report device state when changed. """
            entity = hass.states.get(event.data[ATTR_ENTITY_ID])
            if entity is None or not entity.attributes.get('havcs_device', False):
                return
            for platform, handler in HANDLER.items():
                if hasattr(handler, 'report_device'):
                    payload = HANDLER[platform].report_device(entity.entity_id)
                    _LOGGER.debug('[skill] report device to %s: platform = %s, entity_id = %s, data = %s', platform, event.data[ATTR_ENTITY_ID], platform, payload)
                    if payload:
                        url = 'https://ai-home.ljr.im/skill/'+platform+'.php?v=report&AppKey='+app_key
                        data = havcs_util.AESCipher(decrypt_key).encrypt(json.dumps(payload, ensure_ascii = False).encode('utf8'))
                        try:
                            session = async_get_clientsession(hass, verify_ssl=False)
                            with async_timeout.timeout(5, loop=hass.loop):
                                response = await session.post(url, data=data)
                                _LOGGER.debug('[skill] get report device result from %s: %s', platform, await response.text())
                        except(asyncio.TimeoutError, aiohttp.ClientError):
                            _LOGGER.error("[skill] fail to access %s, report device fail: timeout", url)

        if sync_device:
            hass.bus.async_listen(EVENT_STATE_CHANGED, report_device)

        await hass.data[DATA_HAVCS_MQTT].async_publish("ai-home/http2mqtt2hass/"+app_key+"/response/test", 'init', 2, False)

        async def async_handler_service(service):
            
            if service.service == SERVICE_RELOAD:
                try:
                    hass.data[DOMAIN]['devices'] = await hass.async_add_executor_job(
                        conf_util.load_yaml_config_file, havcd_config_path
                    )
                except HomeAssistantError as err:
                    _LOGGER.error("Error loading %s: %s", havcd_config_path, err)
                    return None
                for p in platform:
                    devices = HANDLER[p].vcdm.all(hass, True)
                    _LOGGER.info('[service] ------------%s 平台加载设备信息------------\n%s', p, devices)
                    mind_devices = [device for device in devices if None in device.values() or [] in device.values()]
                    if mind_devices:
                        _LOGGER.debug('!!!!!!!! 以下设备信息不完整,检查值为None的属性 !!!!!!!!')
                        for mind_device in mind_devices:
                            _LOGGER.debug('%s', mind_device)
                    _LOGGER.info('[service] ------------%s 平台加载设备信息------------\n', p)
                if bind_device:
                    await async_bind_device()
            else:
                pass
        hass.services.async_register(DOMAIN, SERVICE_RELOAD, async_handler_service, schema=HAVCS_SERVICE_SCHEMA)

    hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, start_havcs)

    async def async_stop_mqtt(event: Event):
        """Stop MQTT component."""
        await hass.data[DATA_HAVCS_MQTT].async_disconnect()

    hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, async_stop_mqtt)

    async def async_http_proxy_handler(resData, topic, start_time = None):
        url = ha_url + resData['uri']
        _LOGGER.debug('[http_proxy] request: url = %s', url)
        if('content' in resData):
            _LOGGER.debug('[http_proxy] use POST method')
            platform = resData.get('platform', havcs_util.get_platform_from_command(resData['content']))
            auth_type, auth_value = resData.get('headers', {}).get('Authorization',' ').split(' ', 1)
            _LOGGER.debug('[http_proxy] platform = %s, auth_type = %s, access_token = %s', platform, auth_type, auth_value)

            try:
                session = async_get_clientsession(hass, verify_ssl=False)
                with async_timeout.timeout(5, loop=hass.loop):
                    response = await session.post(url, data=resData['content'], headers = resData.get('headers'))
            except(asyncio.TimeoutError, aiohttp.ClientError):
                _LOGGER.error("[http_proxy] fail to access %s in local network: timeout", url)

        else:
            _LOGGER.debug('[http_proxy] use GET method')
            try:
                session = async_get_clientsession(hass, verify_ssl=False)
                with async_timeout.timeout(5, loop=hass.loop):
                    response = await session.get(url, headers = resData.get('headers'))
            except(asyncio.TimeoutError, aiohttp.ClientError):
                _LOGGER.error("[http_proxy] fail to access %s in local network: timeout", url)
            # _LOGGER.debug("[http_proxy] %s", response.history) #查看重定向信息
        if response is not None:
            if response.status != 200:
                _LOGGER.error("[http_proxy] fail to access %s in local network: status=%d",url,response.status)
            if('image' in response.headers['Content-Type'] or 'stream' in response.headers['Content-Type']):
                result = await response.read()
                result = b64encode(result).decode()
            else:
                result = await response.text()
            headers = {
                'Content-Type': response.headers['Content-Type'] + ';charset=utf-8'
            }
            res = {
                'headers': headers,
                'status': response.status,
                'content': result.encode('utf-8').decode('unicode_escape'),
                'msgId': resData.get('msgId')
            }
        else:
            res = {
                'status': 500,
                'content': '{"error":"time_out"}',
                'msgId': resData.get('msgId')
            }
        _LOGGER.debug("[http_proxy] response: uri = %s, msgid = %s, type = %s", resData['uri'].split('?')[0], resData.get('msgId'), response.headers['Content-Type'])
        res = havcs_util.AESCipher(decrypt_key).encrypt(json.dumps(res, ensure_ascii = False).encode('utf8'))

        await hass.data[DATA_HAVCS_MQTT].async_publish(topic.replace('/request/','/response/'), res, 2, False)
        end_time = datetime.now()
        _LOGGER.debug('[mqtt] -------- mqtt task finish at %s, Running time: %ss --------', end_time.strftime('%Y-%m-%d %H:%M:%S'), (end_time - start_time).total_seconds())

    async def async_module_handler(resData, topic, start_time = None):
        platform = resData.get('platform', havcs_util.get_platform_from_command(resData['content']))
        if platform == 'unknown':
            _LOGGER.error('[skill] receive command from unsupport platform "%s".', platform)
            return
        if platform not in HANDLER:
            _LOGGER.error('[skill] receive command from uninitialized platform "%s" , check up your configuration.yaml.', platform)
            return
        try:
            response = await HANDLER[platform].handleRequest(json.loads(resData['content']), auth = True)
        except:
            response = '{"error":"service error"}'
            import traceback
            _LOGGER.error('[skill] %s', traceback.format_exc())
        res = {
                'headers': {'Content-Type': 'application/json;charset=utf-8'},
                'status': 200,
                'content': json.dumps(response).encode('utf-8').decode('unicode_escape'),
                'msgId': resData.get('msgId')
            }
        res = havcs_util.AESCipher(decrypt_key).encrypt(json.dumps(res, ensure_ascii = False).encode('utf8'))

        await hass.data[DATA_HAVCS_MQTT].async_publish(topic.replace('/request/','/response/'), res, 2, False)
        end_time = datetime.now()
        _LOGGER.debug('[mqtt] -------- mqtt task finish at %s, Running time: %ss --------', end_time.strftime('%Y-%m-%d %H:%M:%S'), (end_time - start_time).total_seconds())
        
    async def async_publish_error(resData,topic):
        res = {
                'headers': {'Content-Type': 'application/json;charset=utf-8'},
                'status': 404,
                'content': '',
                'msgId': resData.get('msgId')
            }                    
        res = havcs_util.AESCipher(decrypt_key).encrypt(json.dumps(res, ensure_ascii = False).encode('utf8'))
        await hass.data[DATA_HAVCS_MQTT].async_publish(topic.replace('/request/','/response/'), res, 2, False)

    @callback
    def message_received(*args): # 0.90 传参变化
        if isinstance(args[0], str):
            topic = args[0]
            payload = args[1]
            # qos = args[2]
        else:
            topic = args[0].topic
            payload = args[0].payload
            # qos = args[0].qos
        """Handle new MQTT state messages."""
        # _LOGGER.debug('get encrypt message: \n {}'.format(payload))
        try:
            start_time = datetime.now()
            end_time = None
            _LOGGER.debug('[mqtt] -------- start handle task from mqtt at %s --------', start_time.strftime('%Y-%m-%d %H:%M:%S'))

            payload = havcs_util.AESCipher(decrypt_key).decrypt(payload)
            req = json.loads(payload)
            if req.get('msgType') == 'hello':
                _LOGGER.info('[mqtt] get hello message: %s', req.get('content'))
                end_time = datetime.now() 
                _LOGGER.debug('[mqtt] -------- mqtt task finish at %s, Running time: %ss --------', end_time.strftime('%Y-%m-%d %H:%M:%S'), (end_time - start_time).total_seconds())
                return
            
            _LOGGER.debug("[mqtt] raw message: %s", req)
            if req.get('platform') == 'h2m2h':
                if('http_proxy' not in MODE):
                    _LOGGER.info('[http_proxy] havcs not run in http_proxy mode, ignore request: %s', req)
                    raise RuntimeError
                if(allowed_uri and req.get('uri','/').split('?')[0] not in allowed_uri):
                    _LOGGER.info('[http_proxy] uri not allowed: %s', req.get('uri','/'))
                    hass.add_job(async_publish_error(req, topic))
                    raise RuntimeError
                hass.add_job(async_http_proxy_handler(req, topic, start_time))
            else:
                if('skill' not in MODE):
                    _LOGGER.info('[skill] havcs not run in skill mode, ignore request: %s', req)
                    raise RuntimeError 
                hass.add_job(async_module_handler(req, topic, start_time))

        except (JSONDecodeError, UnicodeDecodeError, binascii.Error):
            import sys
            ex_type, ex_val, ex_stack = sys.exc_info()
            log = ''
            for stack in traceback.extract_tb(ex_stack):
                log += str(stack)
            _LOGGER.debug('[mqtt] fail to decrypt message, abandon[%s][%s]: %s', ex_type, ex_val, log)
            end_time = datetime.now()
        except:
            _LOGGER.error('[mqtt] fail to handle %s', traceback.format_exc())
            end_time = datetime.now()
        if end_time:    
            _LOGGER.debug('[mqtt] -------- mqtt task finish at %s, Running time: %ss --------', end_time.strftime('%Y-%m-%d %H:%M:%S'), (end_time - start_time).total_seconds())

    await hass.data[DATA_HAVCS_MQTT].async_subscribe("ai-home/http2mqtt2hass/"+app_key+"/request/#", message_received, 2, 'utf-8')
    _LOGGER.info('[init] initialization finished, waiting for welcome message of mqtt server.')

    return True
コード例 #4
0
ファイル: __init__.py プロジェクト: zhangziran/aihome
async def async_setup_entry(hass, entry):
    """Load a config entry."""
    conf = hass.data.get(DATA_AIHOME_CONFIG).get(CONF_MQTT)

    # Config entry was created because user had configuration.yaml entry
    # They removed that, so remove entry.
    if conf is None and entry.source == config_entries.SOURCE_IMPORT:
        hass.async_create_task(hass.config_entries.async_remove(
            entry.entry_id))
        return False

    # If user didn't have configuration.yaml config, generate defaults
    if conf is None:
        conf = CONFIG_SCHEMA({
            DOMAIN: entry.data,
        })[DOMAIN].get(CONF_MQTT)
    elif any(key in conf for key in entry.data):
        _LOGGER.warning(
            "Data in your config entry is going to override your "
            "configuration.yaml: %s", entry.data)

    conf.update(entry.data)

    broker = conf[CONF_BROKER]
    port = conf[CONF_PORT]
    client_id = conf.get(CONF_CLIENT_ID)
    keepalive = conf[CONF_KEEPALIVE]
    app_key = conf.get(CONF_APP_KEY)
    app_secret = conf.get(CONF_APP_SECRET)
    certificate = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                               'ca.crt')
    if os.path.exists(certificate):
        _LOGGER.info('auto load ca.crt from %s', certificate)
    else:
        certificate = conf.get(CONF_CERTIFICATE)
    client_key = conf.get(CONF_CLIENT_KEY)
    client_cert = conf.get(CONF_CLIENT_CERT)
    tls_insecure = conf.get(CONF_TLS_INSECURE)
    protocol = conf[CONF_PROTOCOL]
    allowed_uri = conf.get(CONF_ALLOWED_URI)
    _LOGGER.info('allowed_uri: %s', allowed_uri)
    ha_url = conf.get(CONF_HA_URL)
    sync = conf.get(CONF_SYNC)
    decrypt_key = bytes().fromhex(
        sha1(app_secret.encode("utf-8")).hexdigest())[0:16]

    # For cloudmqtt.com, secured connection, auto fill in certificate
    if (certificate is None and 19999 < conf[CONF_PORT] < 30000
            and broker.endswith('.cloudmqtt.com')):
        certificate = os.path.join(os.path.dirname(__file__),
                                   'addtrustexternalcaroot.crt')

    # When the certificate is set to auto, use bundled certs from requests
    elif certificate == 'auto':
        certificate = requests.certs.where()

    if CONF_WILL_MESSAGE in conf:
        will_message = mqtt.Message(**conf[CONF_WILL_MESSAGE])
    else:
        will_message = None

    if CONF_BIRTH_MESSAGE in conf:
        birth_message = mqtt.Message(**conf[CONF_BIRTH_MESSAGE])
    else:
        birth_message = None

    # Be able to override versions other than TLSv1.0 under Python3.6
    conf_tls_version = conf.get(CONF_TLS_VERSION)  # type: str
    if conf_tls_version == '1.2':
        tls_version = ssl.PROTOCOL_TLSv1_2
    elif conf_tls_version == '1.1':
        tls_version = ssl.PROTOCOL_TLSv1_1
    elif conf_tls_version == '1.0':
        tls_version = ssl.PROTOCOL_TLSv1
    else:
        import sys
        # Python3.6 supports automatic negotiation of highest TLS version
        if sys.hexversion >= 0x03060000:
            tls_version = ssl.PROTOCOL_TLS  # pylint: disable=no-member
        else:
            tls_version = ssl.PROTOCOL_TLSv1

    hass.data[DATA_AIHOME_MQTT] = mqtt.MQTT(
        hass,
        broker=broker,
        port=port,
        client_id=client_id,
        keepalive=keepalive,
        username=app_key,
        password=app_secret,
        certificate=certificate,
        client_key=client_key,
        client_cert=client_cert,
        tls_insecure=tls_insecure,
        protocol=protocol,
        will_message=will_message,
        birth_message=birth_message,
        tls_version=tls_version,
    )

    success = await hass.data[DATA_AIHOME_MQTT].async_connect()  # type: bool

    if not success:
        _LOGGER.error(
            'can not connect to mqtt server, check mqtt server\'s address and port.'
        )
        return False

    async def start_aihome(event: Event):
        async def async_bind_device():
            for uuid in hass.data['aihome_bind_manager'].discovery:
                p_user_id = uuid.split('@')[0]
                platform = uuid.split('@')[1]
                if platform in HANDLER and getattr(
                        HANDLER.get(platform),
                        'should_report_when_starup', False) and hasattr(
                            HANDLER.get(platform), 'bind_device'):
                    devices, entity_ids = HANDLER[platform]._discoveryDevice()
                    bind_entity_ids, unbind_entity_ids = await hass.data[
                        'aihome_bind_manager'].async_save_changed_devices(
                            entity_ids, platform, p_user_id, True)
                    payload = await HANDLER[platform].bind_device(
                        p_user_id, entity_ids, unbind_entity_ids, devices)
                    _LOGGER.debug(
                        '[%s] report request: bind_entity_ids:%s, unbind_entity_ids:%s',
                        platform, bind_entity_ids, unbind_entity_ids)

                    if payload:
                        url = 'https://ai-home.ljr.im/skill/smarthome.php?v=update&AppKey=' + app_key
                        data = aihome_util.AESCipher(decrypt_key).encrypt(
                            json.dumps(payload,
                                       ensure_ascii=False).encode('utf8'))
                        try:
                            session = async_get_clientsession(hass,
                                                              verify_ssl=False)
                            with async_timeout.timeout(5, loop=hass.loop):
                                response = await session.post(url, data=data)
                                _LOGGER.debug('[%s] report response:%s',
                                              platform, await response.text())
                        except (asyncio.TimeoutError, aiohttp.ClientError):
                            _LOGGER.error("Error while accessing: %s", url)

        await async_bind_device()

        @callback
        def report_device(event):
            _LOGGER.debug('%s changed', event.data[ATTR_ENTITY_ID])
            hass.add_job(async_report_device(event))

        async def async_report_device(event):
            """report device state when changed. """
            entity = hass.states.get(event.data[ATTR_ENTITY_ID])
            if not entity.attributes.get('aihome_device', False):
                return
            for platform, handler in HANDLER.items():
                if hasattr(handler, 'report_device'):
                    payload = HANDLER[platform].report_device(entity.entity_id)
                    _LOGGER.debug('[%s] report device: %s, %s, %s', platform,
                                  event.data[ATTR_ENTITY_ID], platform,
                                  payload)
                    if payload:
                        url = 'https://ai-home.ljr.im/skill/' + platform + '.php?v=report&AppKey=' + app_key
                        data = aihome_util.AESCipher(decrypt_key).encrypt(
                            json.dumps(payload,
                                       ensure_ascii=False).encode('utf8'))
                        try:
                            session = async_get_clientsession(hass,
                                                              verify_ssl=False)
                            with async_timeout.timeout(5, loop=hass.loop):
                                response = await session.post(url, data=data)
                                _LOGGER.debug('[%s] report response:%s',
                                              platform, await response.text())
                        except (asyncio.TimeoutError, aiohttp.ClientError):
                            _LOGGER.error("Error while accessing: %s", url)

        if sync:
            hass.bus.async_listen(EVENT_STATE_CHANGED, report_device)

        await hass.data[DATA_AIHOME_MQTT].async_publish(
            "ai-home/http2mqtt2hass/" + app_key + "/response/test", 'init', 2,
            False)

    hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, start_aihome)

    async def async_stop_mqtt(event: Event):
        """Stop MQTT component."""
        await hass.data[DATA_AIHOME_MQTT].async_disconnect()

    hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, async_stop_mqtt)

    async def async_http_handler(resData, topic):
        url = ha_url + resData['uri']
        if ('content' in resData):
            _LOGGER.debug('---POST---')
            if 'AliGenie' in resData['content']:
                platform = 'aligenie'
            elif 'DuerOS' in resData['content']:
                platform = 'dueros'
            elif 'Alpha' in resData['content']:
                platform = 'jdwhale'
            else:
                platform = 'unknown'
            if platform in EXPIRATION:
                auth_type, auth_value = resData['headers'].get(
                    'Authorization').split(' ', 1)
                try:
                    unverif_claims = jwt.decode(auth_value, verify=False)
                    refresh_token = await hass.auth.async_get_refresh_token(
                        cast(str, unverif_claims.get('iss')))
                    if refresh_token is not None:
                        refresh_token.access_token_expiration = EXPIRATION[
                            platform]
                        for user in hass.auth._store._users.values():
                            if refresh_token.id in user.refresh_tokens:
                                user.refresh_tokens[
                                    refresh_token.id] = refresh_token
                                hass.auth._store._async_schedule_save()
                                EXPIRATION.pop(platform)
                                break
                except jwt.InvalidTokenError:
                    pass

            try:
                session = async_get_clientsession(hass, verify_ssl=False)
                with async_timeout.timeout(5, loop=hass.loop):
                    response = await session.post(
                        url,
                        data=resData['content'],
                        headers=resData.get('headers'))
            except (asyncio.TimeoutError, aiohttp.ClientError):
                _LOGGER.error("Error while accessing: %s", url)

        else:
            _LOGGER.debug('---GET---')
            try:
                session = async_get_clientsession(hass, verify_ssl=False)
                with async_timeout.timeout(5, loop=hass.loop):
                    response = await session.get(
                        url, headers=resData.get('headers'))
            except (asyncio.TimeoutError, aiohttp.ClientError):
                _LOGGER.error("Error while accessing: %s", url)
            # _LOGGER.debug(response.history) #查看重定向信息
        if response is not None:
            if response.status != 200:
                _LOGGER.error("Error while accessing: %s, status=%d", url,
                              response.status)
            if ('image' in response.headers['Content-Type']
                    or 'stream' in response.headers['Content-Type']):
                result = await response.read()
                result = b64encode(result).decode()
            else:
                result = await response.text()
            headers = {
                'Content-Type':
                response.headers['Content-Type'] + ';charset=utf-8'
            }
            res = {
                'headers': headers,
                'status': response.status,
                'content': result.encode('utf-8').decode('unicode_escape'),
                'msgId': resData.get('msgId')
            }
        else:
            res = {
                'status': 500,
                'content': '{"error":"time_out"}',
                'msgId': resData.get('msgId')
            }
        _LOGGER.debug(
            "%s response[%s]: [%s]",
            resData['uri'].split('/')[-1].split('?')[0],
            resData.get('msgId'),
            response.headers['Content-Type'],
        )
        res = aihome_util.AESCipher(decrypt_key).encrypt(
            json.dumps(res, ensure_ascii=False).encode('utf8'))

        await hass.data[DATA_AIHOME_MQTT].async_publish(
            topic.replace('/request/', '/response/'), res, 2, False)

    async def async_module_handler(resData, topic):
        if 'platform' in resData:
            platform = resData['platform']
        elif 'AliGenie' in resData['content']:
            platform = 'aligenie'
        elif 'DuerOS' in resData['content']:
            platform = 'dueros'
        elif 'Alpha' in resData['content']:
            platform = 'jdwhale'
        else:
            platform = 'unknown'
            _LOGGER.error('receive command from unsupport platform "%s".',
                          platform)
            return
        if platform not in HANDLER:
            _LOGGER.error(
                'receive command from uninitialized platform "%s" , check up your configuration.yaml.',
                platform)
            return
        try:
            response = await HANDLER[platform].handleRequest(json.loads(
                resData['content']),
                                                             ignoreToken=True)
        except:
            response = '{"error":"service error"}'
            import traceback
            _LOGGER.error(traceback.format_exc())
        res = {
            'headers': {
                'Content-Type': 'application/json;charset=utf-8'
            },
            'status':
            200,
            'content':
            json.dumps(response).encode('utf-8').decode('unicode_escape'),
            'msgId':
            resData.get('msgId')
        }
        res = aihome_util.AESCipher(decrypt_key).encrypt(
            json.dumps(res, ensure_ascii=False).encode('utf8'))

        await hass.data[DATA_AIHOME_MQTT].async_publish(
            topic.replace('/request/', '/response/'), res, 2, False)

    async def async_publish_error(resData, topic):
        res = {
            'headers': {
                'Content-Type': 'application/json;charset=utf-8'
            },
            'status': 404,
            'content': '',
            'msgId': resData.get('msgId')
        }
        res = aihome_util.AESCipher(decrypt_key).encrypt(
            json.dumps(res, ensure_ascii=False).encode('utf8'))
        await hass.data[DATA_AIHOME_MQTT].async_publish(
            topic.replace('/request/', '/response/'), res, 2, False)

    @callback
    def message_received(*args):  # 0.90 传参变化
        if isinstance(args[0], str):
            topic = args[0]
            payload = args[1]
            qos = args[2]
        else:
            topic = args[0].topic
            payload = args[0].payload
            qos = args[0].qos
        """Handle new MQTT state messages."""
        # _LOGGER.debug('get encrypt message: \n {}'.format(payload))
        try:
            payload = aihome_util.AESCipher(decrypt_key).decrypt(payload)
            req = json.loads(payload)
            if req.get('msgType') == 'hello':
                _LOGGER.info(req.get('content'))
                return
            _LOGGER.debug("[%s] raw message: %s", req.get('platform'), req)
            if req.get('platform') == 'h2m2h':
                if (allowed_uri and req.get('uri', '/').split('?')[0]
                        not in allowed_uri):
                    _LOGGER.debug('uri not allowed: %s', req.get('uri', '/'))
                    hass.add_job(async_publish_error(req, topic))
                    return
                hass.add_job(async_http_handler(req, topic))
            else:
                hass.add_job(async_module_handler(req, topic))

        except (JSONDecodeError, UnicodeDecodeError, binascii.Error):
            import sys
            ex_type, ex_val, ex_stack = sys.exc_info()
            log = ''
            for stack in traceback.extract_tb(ex_stack):
                log += str(stack)
            _LOGGER.debug('decrypt failure, abandon:%s', log)

    await hass.data[DATA_AIHOME_MQTT].async_subscribe(
        "ai-home/http2mqtt2hass/" + app_key + "/request/#", message_received,
        2, 'utf-8')
    return True
コード例 #5
0
ファイル: __init__.py プロジェクト: pc3806135/hass
async def async_setup_entry(hass, config_entry):
    """Load a config entry."""

    hass.data.setdefault(DOMAIN, {})
    hass.data[DOMAIN].setdefault(DATA_HAVCS_HANDLER, {})
    hass.data[DOMAIN].setdefault(DATA_HAVCS_CONFIG, {})
    hass.data[DOMAIN].setdefault(DATA_HAVCS_ITEMS, {})
    conf = hass.data[DOMAIN].get(DATA_HAVCS_CONFIG)

    # Config entry was created because user had configuration.yaml entry
    # They removed that, so remove entry.
    if config_entry.source == config_entries.SOURCE_IMPORT:
        if conf is None:
            hass.async_create_task(
                hass.config_entries.async_remove(config_entry.entry_id))
            return False

    elif config_entry.source == SOURCE_PLATFORM:
        if conf is None:
            if [
                    entry
                    for entry in hass.config_entries.async_entries(DOMAIN)
                    if entry.source == config_entries.SOURCE_USER
            ]:
                return True
            else:
                hass.async_create_task(
                    hass.config_entries.async_remove(config_entry.entry_id))
                return False
        else:
            return True

    # If user didn't have configuration.yaml config, generate defaults
    elif config_entry.source == config_entries.SOURCE_USER:
        if not conf:
            conf = CONFIG_SCHEMA({DOMAIN: dict(config_entry.data)})[DOMAIN]
        elif any(key in conf for key in config_entry.data):
            _LOGGER.warning(
                "[init] Data in your config entry is going to override your "
                "configuration.yaml: %s", config_entry.data)

        # conf.update(config_entry.data)

        for key in config_entry.data:
            if key in conf:
                if isinstance(conf[key], dict):
                    conf[key].update(config_entry.data[key])
                else:
                    conf[key] = config_entry.data[key]
            else:
                conf[key] = config_entry.data[key]
        if CONF_HTTP not in config_entry.data and CONF_HTTP in conf:
            conf.pop(CONF_HTTP)
        if CONF_HTTP_PROXY not in config_entry.data and CONF_HTTP_PROXY in conf:
            conf.pop(CONF_HTTP_PROXY)
        if CONF_SKILL not in config_entry.data and CONF_SKILL in conf:
            conf.pop(CONF_SKILL)

    http_manager = hass.data[DOMAIN][
        DATA_HAVCS_HTTP_MANAGER] = HavcsHttpManager(
            hass,
            conf.get(CONF_HTTP, {}).get(CONF_HA_URL, hass.config.api.base_url),
            DEVICE_CONFIG_SCHEMA)
    if CONF_HTTP in conf:
        if conf.get(CONF_HTTP) is None:
            conf[CONF_HTTP] = HTTP_SCHEMA({})
        http_manager.set_expiration(
            timedelta(hours=conf.get(CONF_HTTP).get(CONF_EXPIRE_IN_HOURS,
                                                    DEFAULT_EXPIRE_IN_HOURS)))
        http_manager.register_auth_authorize()
        http_manager.register_auth_token()
        http_manager.register_service()
        _LOGGER.info("[init] havcs enable \"http mode\"")

        MODE.append('http')
    if CONF_HTTP_PROXY in conf:
        if conf.get(CONF_HTTP_PROXY) is None:
            conf[CONF_HTTP_PROXY] = HTTP_PROXY({})
        _LOGGER.info("[init] havcs enable \"http_proxy mode\"")
        if CONF_SETTING not in conf:
            _LOGGER.error(
                "[init] fail to start havcs: http_proxy mode require mqtt congfiguration"
            )
            return False
        MODE.append('http_proxy')
    if CONF_SKILL in conf:
        if conf.get(CONF_SKILL) is None:
            conf[CONF_SKILL] = SKILL_SCHEMA({})
        _LOGGER.info("[init] havcs enable \"skill mode\"")
        if CONF_SETTING not in conf:
            _LOGGER.error(
                "[init] fail to start havcs: skill mode require mqtt congfiguration"
            )
            return False
        MODE.append('skill')

    havcs_util.ENTITY_KEY = conf.get(CONF_SETTING, {}).get(CONF_ENTITY_KEY)
    havcs_util.CONTEXT_HAVCS = Context(
        conf.get(CONF_SETTING, {}).get(CONF_USER_ID))

    platforms = conf.get(CONF_PLATFORM)

    device_config = conf.get(CONF_DEVICE_CONFIG)
    if device_config == 'text':
        havcd_config_path = os.path.join(hass.config.config_dir, 'havcs.yaml')
        if not os.path.isfile(havcd_config_path):
            with open(havcd_config_path, "wt") as havcd_config_file:
                havcd_config_file.write('')
        hass.components.frontend.async_remove_panel(DOMAIN)
    else:
        havcd_config_path = os.path.join(hass.config.config_dir,
                                         'havcs-ui.yaml')
        if not os.path.isfile(havcd_config_path):
            if os.path.isfile(
                    os.path.join(hass.config.config_dir, 'havcs.yaml')):
                shutil.copyfile(
                    os.path.join(hass.config.config_dir, 'havcs.yaml'),
                    havcd_config_path)
            else:
                with open(havcd_config_path, "wt") as havcd_config_file:
                    havcd_config_file.write('')
        http_manager.register_deivce_manager()
    hass.data[DOMAIN][CONF_DEVICE_CONFIG_PATH] = havcd_config_path

    sync_device = conf.get(CONF_SKILL, {}).get(CONF_SYNC_DEVICE)
    bind_device = conf.get(CONF_SKILL, {}).get(CONF_BIND_DEVICE)

    if CONF_HTTP_PROXY not in conf and CONF_SKILL not in conf:
        _LOGGER.debug(
            "[init] havcs only run in http mode, skip mqtt initialization")

        ha_url = conf.get(CONF_HTTP, {}).get(CONF_HA_URL,
                                             hass.config.api.base_url)

        try:
            session = async_get_clientsession(hass, verify_ssl=False)
            with async_timeout.timeout(5, loop=hass.loop):
                response = await session.post(ha_url + '/havcs_auth')
            console.log(response.status)
        except (asyncio.TimeoutError, aiohttp.ClientError):
            _LOGGER.error(
                "[auth] fail to get token, access %s in local network: timeout",
                ha_url)
        except:
            _LOGGER.error(
                "[auth] fail to get token, access %s in local network: %s",
                ha_url, traceback.format_exc())
    else:
        setting_conf = conf.get(CONF_SETTING)
        app_key = setting_conf.get(CONF_APP_KEY)
        app_secret = setting_conf.get(CONF_APP_SECRET)
        decrypt_key = bytes().fromhex(
            sha1(app_secret.encode("utf-8")).hexdigest())[0:16]

        if platforms:
            bind_manager = hass.data[DOMAIN][
                DATA_HAVCS_BIND_MANAGER] = HavcsBindManager(
                    hass, platforms, bind_device, sync_device, app_key,
                    decrypt_key)
            await bind_manager.async_init()

        allowed_uri = conf.get(CONF_HTTP_PROXY, {}).get(CONF_ALLOWED_URI)
        ha_url = conf.get(CONF_HTTP_PROXY, {}).get(CONF_HA_URL,
                                                   hass.config.api.base_url)

        broker = setting_conf[CONF_BROKER]
        port = setting_conf[CONF_PORT]
        client_id = setting_conf.get(CONF_CLIENT_ID)
        keepalive = setting_conf[CONF_KEEPALIVE]
        certificate = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                   'ca.crt')
        if os.path.exists(certificate):
            _LOGGER.debug("[init] sucess to autoload ca.crt from %s",
                          certificate)
        else:
            certificate = setting_conf.get(CONF_CERTIFICATE)
        client_key = setting_conf.get(CONF_CLIENT_KEY)
        client_cert = setting_conf.get(CONF_CLIENT_CERT)
        tls_insecure = setting_conf.get(CONF_TLS_INSECURE)
        protocol = setting_conf[CONF_PROTOCOL]

        # For cloudmqtt.com, secured connection, auto fill in certificate
        if (certificate is None and 19999 < conf[CONF_PORT] < 30000
                and broker.endswith('.cloudmqtt.com')):
            certificate = os.path.join(os.path.dirname(__file__),
                                       'addtrustexternalcaroot.crt')

        # When the certificate is set to auto, use bundled certs from requests
        elif certificate == 'auto':
            certificate = requests.certs.where()

        if CONF_WILL_MESSAGE in setting_conf:
            will_message = mqtt.Message(**conf[CONF_WILL_MESSAGE])
        else:
            will_message = None

        if CONF_BIRTH_MESSAGE in setting_conf:
            birth_message = mqtt.Message(**conf[CONF_BIRTH_MESSAGE])
        else:
            birth_message = None

        # Be able to override versions other than TLSv1.0 under Python3.6
        conf_tls_version = setting_conf.get(CONF_TLS_VERSION)  # type: str
        if conf_tls_version == '1.2':
            tls_version = ssl.PROTOCOL_TLSv1_2
        elif conf_tls_version == '1.1':
            tls_version = ssl.PROTOCOL_TLSv1_1
        elif conf_tls_version == '1.0':
            tls_version = ssl.PROTOCOL_TLSv1
        else:
            import sys
            # Python3.6 supports automatic negotiation of highest TLS version
            if sys.hexversion >= 0x03060000:
                tls_version = ssl.PROTOCOL_TLS  # pylint: disable=no-member
            else:
                tls_version = ssl.PROTOCOL_TLSv1

        hass.data[DOMAIN][DATA_HAVCS_MQTT] = mqtt.MQTT(
            hass,
            broker=broker,
            port=port,
            client_id=client_id,
            keepalive=keepalive,
            username=app_key,
            password=app_secret,
            certificate=certificate,
            client_key=client_key,
            client_cert=client_cert,
            tls_insecure=tls_insecure,
            protocol=protocol,
            will_message=will_message,
            birth_message=birth_message,
            tls_version=tls_version,
        )
        _LOGGER.debug("[init] connecting to mqtt server")
        success = await hass.data[DOMAIN][DATA_HAVCS_MQTT].async_connect(
        )  # type: bool

        if success is True or success == 'connection_success':
            pass
        else:
            import hashlib
            md5_l = hashlib.md5()
            with open(certificate, mode="rb") as f:
                by = f.read()
            md5_l.update(by)
            local_ca_md5 = md5_l.hexdigest()
            _LOGGER.debug("[init] local ca.crt md5 %s", local_ca_md5)
            from urllib.request import urlopen
            try:
                response = urlopen(
                    'https://raw.githubusercontent.com/cnk700i/havcs/master/custom_components/havcs/ca.crt',
                    timeout=5)
                ca_bytes = response.read()
                md5_l = hashlib.md5()
                md5_l.update(ca_bytes)
                latest_ca_md5 = md5_l.hexdigest()
                if local_ca_md5 != latest_ca_md5:
                    _LOGGER.error(
                        "[init] can not connect to mqtt server(host = %s, port = %s, error_code = %s), try update ca.crt file ",
                        broker, port, success)
                else:
                    _LOGGER.error(
                        "[init] can not connect to mqtt server(host = %s, port = %s, error_code = %s), check mqtt server's address and port ",
                        broker, port, success)
            except:
                _LOGGER.error("[init] fail to  check whether ca.crt is latest")
            await async_unload_entry(hass, config_entry)
            return False

        async def async_stop_mqtt(event: Event):
            """Stop MQTT component."""
            await hass.data[DOMAIN][DATA_HAVCS_MQTT].async_disconnect()

        hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, async_stop_mqtt)

        async def async_http_proxy_handler(resData, topic, start_time=None):
            response = None
            url = ha_url + resData['uri']
            _LOGGER.debug("[http_proxy] request: url = %s", url)
            if ('content' in resData):
                _LOGGER.debug("[http_proxy] use POST method")
                platform = resData.get(
                    'platform',
                    havcs_util.get_platform_from_command(resData['content']))
                auth_type, auth_value = resData.get('headers', {}).get(
                    'Authorization', ' ').split(' ', 1)
                _LOGGER.debug(
                    "[http_proxy] platform = %s, auth_type = %s, access_token = %s",
                    platform, auth_type, auth_value)

                try:
                    session = async_get_clientsession(hass, verify_ssl=False)
                    with async_timeout.timeout(5, loop=hass.loop):
                        response = await session.post(
                            url,
                            data=resData['content'],
                            headers=resData.get('headers'))
                except (asyncio.TimeoutError, aiohttp.ClientError):
                    _LOGGER.error(
                        "[http_proxy] fail to access %s in local network: timeout",
                        url)
                except:
                    _LOGGER.error(
                        "[http_proxy] fail to access %s in local network: %s",
                        url, traceback.format_exc())
            else:
                _LOGGER.debug("[http_proxy] use GET method")
                try:
                    session = async_get_clientsession(hass, verify_ssl=False)
                    with async_timeout.timeout(5, loop=hass.loop):
                        response = await session.get(
                            url, headers=resData.get('headers'))
                except (asyncio.TimeoutError, aiohttp.ClientError):
                    _LOGGER.error(
                        "[http_proxy] fail to access %s in local network: timeout",
                        url)
                except:
                    _LOGGER.error(
                        "[http_proxy] fail to access %s in local network: %s",
                        url, traceback.format_exc())
                # _LOGGER.debug("[http_proxy] %s", response.history) #查看重定向信息
            if response is not None:
                if response.status != 200:
                    _LOGGER.error(
                        "[http_proxy] fail to access %s in local network: status=%d",
                        url, response.status)
                if ('image' in response.headers['Content-Type']
                        or 'stream' in response.headers['Content-Type']):
                    result = await response.read()
                    result = b64encode(result).decode()
                else:
                    result = await response.text()
                headers = {
                    'Content-Type':
                    response.headers['Content-Type'] + ';charset=utf-8'
                }
                res = {
                    'headers': headers,
                    'status': response.status,
                    'content': result.encode('utf-8').decode('unicode_escape'),
                    'msgId': resData.get('msgId')
                }
                _LOGGER.debug(
                    "[http_proxy] response: uri = %s, msgid = %s, type = %s",
                    resData['uri'].split('?')[0], resData.get('msgId'),
                    response.headers['Content-Type'])
            else:
                res = {
                    'status': 500,
                    'content': '{"error":"time_out"}',
                    'msgId': resData.get('msgId')
                }
                _LOGGER.debug("[http_proxy] response: uri = %s, msgid = %s",
                              resData['uri'].split('?')[0],
                              resData.get('msgId'))
            res = havcs_util.AESCipher(decrypt_key).encrypt(
                json.dumps(res, ensure_ascii=False).encode('utf8'))

            await hass.data[DOMAIN][DATA_HAVCS_MQTT].async_publish(
                topic.replace('/request/', '/response/'), res, 2, False)
            end_time = datetime.now()
            _LOGGER.debug(
                "[mqtt] -------- mqtt task finish at %s, Running time: %ss --------",
                end_time.strftime('%Y-%m-%d %H:%M:%S'),
                (end_time - start_time).total_seconds())

        async def async_module_handler(resData, topic, start_time=None):
            platform = resData.get(
                'platform',
                havcs_util.get_platform_from_command(resData['content']))
            if platform == 'unknown':
                _LOGGER.error(
                    "[skill] receive command from unsupport platform \"%s\"",
                    platform)
                return
            if platform not in hass.data[DOMAIN][DATA_HAVCS_HANDLER]:
                _LOGGER.error(
                    "[skill] receive command from uninitialized platform \"%s\" , check up your configuration.yaml",
                    platform)
                return
            try:
                response = await hass.data[DOMAIN][DATA_HAVCS_HANDLER][
                    platform].handleRequest(json.loads(resData['content']),
                                            auth=True)
            except:
                response = '{"error":"service error"}'
                _LOGGER.error("[skill] %s", traceback.format_exc())
            res = {
                'headers': {
                    'Content-Type': 'application/json;charset=utf-8'
                },
                'status':
                200,
                'content':
                json.dumps(response).encode('utf-8').decode('unicode_escape'),
                'msgId':
                resData.get('msgId')
            }
            res = havcs_util.AESCipher(decrypt_key).encrypt(
                json.dumps(res, ensure_ascii=False).encode('utf8'))

            await hass.data[DOMAIN][DATA_HAVCS_MQTT].async_publish(
                topic.replace('/request/', '/response/'), res, 2, False)
            end_time = datetime.now()
            _LOGGER.debug(
                "[mqtt] -------- mqtt task finish at %s, Running time: %ss --------",
                end_time.strftime('%Y-%m-%d %H:%M:%S'),
                (end_time - start_time).total_seconds())

        async def async_publish_error(resData, topic):
            res = {
                'headers': {
                    'Content-Type': 'application/json;charset=utf-8'
                },
                'status': 404,
                'content': '',
                'msgId': resData.get('msgId')
            }
            res = havcs_util.AESCipher(decrypt_key).encrypt(
                json.dumps(res, ensure_ascii=False).encode('utf8'))
            await hass.data[DOMAIN][DATA_HAVCS_MQTT].async_publish(
                topic.replace('/request/', '/response/'), res, 2, False)

        @callback
        def message_received(*args):  # 0.90 传参变化
            if isinstance(args[0], str):
                topic = args[0]
                payload = args[1]
                # qos = args[2]
            else:
                topic = args[0].topic
                payload = args[0].payload
                # qos = args[0].qos
            """Handle new MQTT state messages."""
            # _LOGGER.debug("[mqtt] get encrypt message: \n {}".format(payload))
            try:
                start_time = datetime.now()
                end_time = None
                _LOGGER.debug(
                    "[mqtt] -------- start handle task from mqtt at %s --------",
                    start_time.strftime('%Y-%m-%d %H:%M:%S'))

                payload = havcs_util.AESCipher(decrypt_key).decrypt(payload)
                # _LOGGER.debug("[mqtt] get raw message: \n {}".format(payload))
                req = json.loads(payload)
                if req.get('msgType') == 'hello':
                    _LOGGER.info("[mqtt] get hello message: %s",
                                 req.get('content'))
                    end_time = datetime.now()
                    _LOGGER.debug(
                        "[mqtt] -------- mqtt task finish at %s, Running time: %ss --------",
                        end_time.strftime('%Y-%m-%d %H:%M:%S'),
                        (end_time - start_time).total_seconds())
                    return

                _LOGGER.debug("[mqtt] raw message: %s", req)
                if req.get('platform') == 'h2m2h':
                    if ('http_proxy' not in MODE):
                        _LOGGER.info(
                            "[http_proxy] havcs not run in http_proxy mode, ignore request: %s",
                            req)
                        raise RuntimeError
                    if (allowed_uri and req.get('uri', '/').split('?')[0]
                            not in allowed_uri):
                        _LOGGER.info("[http_proxy] uri not allowed: %s",
                                     req.get('uri', '/'))
                        hass.add_job(async_publish_error(req, topic))
                        raise RuntimeError
                    hass.add_job(
                        async_http_proxy_handler(req, topic, start_time))
                else:
                    if ('skill' not in MODE):
                        _LOGGER.info(
                            "[skill] havcs not run in skill mode, ignore request: %s",
                            req)
                        raise RuntimeError
                    hass.add_job(async_module_handler(req, topic, start_time))

            except (json.decoder.JSONDecodeError, UnicodeDecodeError,
                    binascii.Error):
                import sys
                ex_type, ex_val, ex_stack = sys.exc_info()
                log = ''
                for stack in traceback.extract_tb(ex_stack):
                    log += str(stack)
                _LOGGER.debug(
                    "[mqtt] fail to decrypt message, abandon[%s][%s]: %s",
                    ex_type, ex_val, log)
                end_time = datetime.now()
            except:
                _LOGGER.error("[mqtt] fail to handle %s",
                              traceback.format_exc())
                end_time = datetime.now()
            if end_time:
                _LOGGER.debug(
                    "[mqtt] -------- mqtt task finish at %s, Running time: %ss --------",
                    end_time.strftime('%Y-%m-%d %H:%M:%S'),
                    (end_time - start_time).total_seconds())

        await hass.data[DOMAIN][DATA_HAVCS_MQTT].async_subscribe(
            "ai-home/http2mqtt2hass/" + app_key + "/request/#",
            message_received, 2, 'utf-8')
        _LOGGER.info(
            "[init] mqtt initialization finished, waiting for welcome message of mqtt server."
        )

    async def start_havcs(event: Event):
        async def async_load_device_info():
            _LOGGER.info("load device info from file")
            try:
                device_config = await hass.async_add_executor_job(
                    conf_util.load_yaml_config_file, havcd_config_path)
                hass.data[DOMAIN][DATA_HAVCS_ITEMS] = DEVICE_CONFIG_SCHEMA(
                    device_config)
            except HomeAssistantError as err:
                _LOGGER.error("Error loading %s: %s", havcd_config_path, err)
                return None
            except vol.error.Error as exception:
                _LOGGER.warning(
                    "failed to load all devices from file, find invalid data: %s",
                    exception)
            except:
                _LOGGER.error("Error loading %s: %s", havcd_config_path,
                              traceback.format_exc())
                return None

        async def async_init_sub_entry():
            # create when config change
            mode = []
            if CONF_HTTP in conf:
                mode.append(CONF_HTTP)
                if CONF_HTTP_PROXY in conf:
                    mode.append(CONF_HTTP_PROXY)
            if CONF_SKILL in conf:
                mode.append(CONF_SKILL)

            havcs_entries = hass.config_entries.async_entries(DOMAIN)
            # sub entry for every platform
            entry_platforms = set([
                entry.data.get('platform') for entry in havcs_entries
                if entry.source == SOURCE_PLATFORM
            ])
            conf_platforms = set(conf.get(CONF_PLATFORM))
            new_platforms = conf_platforms - entry_platforms
            _LOGGER.debug("[post-task] load new platform entry %s",
                          new_platforms)
            for platform in new_platforms:
                # 如果在async_setup_entry中执行无法await,async_init所触发的component_setup会不断等待之前的component_setup任务
                await hass.async_create_task(
                    hass.config_entries.flow.async_init(
                        DOMAIN,
                        context={'source': SOURCE_PLATFORM},
                        data={
                            'platform': platform,
                            'mode': mode
                        }))
            remove_platforms = entry_platforms - conf_platforms
            _LOGGER.debug("[post-task] remove old platform entry %s",
                          remove_platforms)
            for entry in [
                    entry for entry in havcs_entries
                    if entry.source == SOURCE_PLATFORM
            ]:
                if entry.data.get('platform') in remove_platforms:
                    await hass.async_create_task(
                        hass.config_entries.async_remove(entry.entry_id))
                else:
                    entry.title = f"接入平台[{entry.data.get('platform')}-{DEVICE_PLATFORM_DICT[entry.data.get('platform')]['cn_name']}],接入方式{mode}"
                    hass.config_entries.async_update_entry(entry)

            # await async_load_device_info()

            for platform in platforms:
                for ent in hass.config_entries.async_entries(DOMAIN):
                    if ent.source == SOURCE_PLATFORM and ent.data.get(
                            'platform') == platform:
                        try:
                            module = importlib.import_module(
                                'custom_components.{}.{}'.format(
                                    DOMAIN, platform))
                            _LOGGER.info("[post-task] import %s.%s", DOMAIN,
                                         platform)
                            hass.data[DOMAIN][DATA_HAVCS_HANDLER][
                                platform] = module.createHandler(hass, ent)
                            # hass.data[DOMAIN][DATA_HAVCS_HANDLER][platform].vcdm.all(hass, True)
                        except ImportError as err:
                            _LOGGER.error(
                                "[post-task] Unable to import %s.%s, %s",
                                DOMAIN, platform, err)
                            return False
                        except:
                            _LOGGER.error(
                                "[post-task] fail to create %s handler: %s",
                                platform, traceback.format_exc())
                            return False
                        break

        await async_init_sub_entry()

        if DATA_HAVCS_MQTT in hass.data[DOMAIN]:
            await hass.data[DOMAIN][DATA_HAVCS_MQTT].async_publish(
                "ai-home/http2mqtt2hass/" + app_key + "/response/test", 'init',
                2, False)

        async def async_handler_service(service):

            if service.service == SERVICE_RELOAD:
                await async_load_device_info()
                for platform in hass.data[DOMAIN][DATA_HAVCS_HANDLER]:
                    devices = hass.data[DOMAIN][DATA_HAVCS_HANDLER][
                        platform].vcdm.all(hass, True)
                    await hass.data[DOMAIN][DATA_HAVCS_HANDLER][
                        platform].vcdm.async_reregister_devices(hass)
                    _LOGGER.info(
                        "[service] ------------%s 平台加载设备信息------------\n%s",
                        platform, [device.attributes for device in devices])
                    mind_devices = [
                        device.attributes for device in devices
                        if None in device.attributes.values()
                        or [] in device.attributes.values()
                    ]
                    if mind_devices:
                        _LOGGER.debug(
                            "!!!!!!!! 以下设备信息不完整,检查值为None或[]的属性并进行设置 !!!!!!!!")
                        for mind_device in mind_devices:
                            _LOGGER.debug("%s", mind_device)
                    _LOGGER.info(
                        "[service] ------------%s 平台加载设备信息------------\n",
                        platform)
                if bind_device:
                    await hass.data[DOMAIN][DATA_HAVCS_BIND_MANAGER
                                            ].async_bind_device()

            elif service.service == SERVICE_DEBUG_DISCOVERY:
                for platform, handler in hass.data[DOMAIN][
                        DATA_HAVCS_HANDLER].items():
                    err_result, discovery_devices, entity_ids = handler.process_discovery_command(
                    )
                    _LOGGER.info(
                        "[service][%s] trigger discovery command, response: %s",
                        platform, discovery_devices)
            else:
                pass

        hass.services.async_register(DOMAIN,
                                     SERVICE_RELOAD,
                                     async_handler_service,
                                     schema=HAVCS_SERVICE_SCHEMA)
        hass.services.async_register(DOMAIN,
                                     SERVICE_DEBUG_DISCOVERY,
                                     async_handler_service,
                                     schema=HAVCS_SERVICE_SCHEMA)

        await hass.services.async_call(DOMAIN, SERVICE_RELOAD)

        if CONF_HTTP in conf or CONF_HTTP_PROXY in conf:
            hass.async_create_task(http_manager.async_check_http_oauth())

    if config_entry.source == config_entries.SOURCE_IMPORT:
        hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, start_havcs)
    elif config_entry.source == config_entries.SOURCE_USER:
        hass.async_create_task(start_havcs(None))

    _LOGGER.info("[init] havcs initialization finished.")
    return True
コード例 #6
0
ファイル: common.py プロジェクト: sander76/custom_components
def async_fire_mqtt_message(hass, topic, payload, qos=0, retain=False):
    """Fire the MQTT message."""
    if isinstance(payload, str):
        payload = payload.encode("utf-8")
    msg = mqtt.Message(topic, payload, qos, retain)
    hass.async_run_job(hass.data["mqtt"]._mqtt_on_message, None, None, msg)