예제 #1
0
    def __init__(self):
        self.logger = logging.getLogger('agent')
        self.__connect = BusConnect()

        if self.__connect.Connect() is not True:
            self.logger.error('Game env connect failed.')
            raise Exception('Game env connect failed.')
예제 #2
0
class ProgressReport(object):
    """
    Report training progress to MC
    """

    def __init__(self):
        self.__logger = logging.getLogger('agent')
        self.__connect = BusConnect()

    def Init(self):
        """
        Init tbus connection for report progress
        """
        if self.__connect.Connect() is not True:
            self.__logger.error('Agent connect failed.')
            return False

        return True

    def SendTrainProgress(self, progress):
        """
        Report training progress to MC
        """
        self.__logger.info('im train progress: %d', progress)

        stateMsg = common_pb2.tagMessage()
        stateMsg.eMsgID = common_pb2.MSG_IM_TRAIN_STATE
        stateMsg.stIMTrainState.nProgress = progress

        if self.__connect.SendMsg(stateMsg, BusConnect.PEER_NODE_MC) == 0:
            return True
        return False
예제 #3
0
class ActionMgr(object):
    """
    ActionMgr implement for remote action
    """
    def __init__(self):
        self.__initialized = False
        self.__connect = BusConnect()

    def Initialize(self):
        """
        Initialize this module, init bus connection
        :return:
        """
        self.__initialized = True
        return self.__connect.Connect()

    def Finish(self):
        """
        Finish this module, tbus disconnect
        :return:
        """
        if self.__initialized:
            LOG.info('Close connection...')
            self.__connect.Close()
            self.__initialized = False

    def SendAction(self, actionID, actionData, frameSeq=-1):
        """
        Send action to remote(client)
        :param actionID: the self-defined action ID
        :param actionData: the context data of the action ID
        :param frameSeq: the frame sequence, default is -1
        :return:
        """
        if not self.__initialized:
            LOG.warning('Call Initialize first!')
            return False

        actionData['msg_id'] = MSG_ID_AI_ACTION
        actionData['action_id'] = actionID
        actionBuff = msgpack.packb(actionData,
                                   default=mn.encode,
                                   use_bin_type=True)

        msg = common_pb2.tagMessage()
        msg.eMsgID = common_pb2.MSG_AI_ACTION
        msg.stAIAction.nFrameSeq = frameSeq
        msg.stAIAction.byAIActionBuff = actionBuff
        #msgBuff = msg.SerializeToString()

        if LOG_REGACTION.level <= logging.DEBUG:
            actionStr = json.dumps(actionData)
            LOG_REGACTION.debug('{}||action||{}'.format(frameSeq, actionStr))

        ret = self.__connect.SendMsg(msg)
        if ret != 0:
            LOG.warning('TBus Send To MC return code[{0}]'.format(ret))
            return False
        return True
예제 #4
0
 def __init__(self):
     self.__logger = logging.getLogger('agent')
     self.__aiModel = None
     self.__agentEnv = None
     self.__runAIFunc = None
     self.__aiPlugin = AIPlugin()
     self.__outputAIAction = True
     self.__connect = BusConnect()
예제 #5
0
 def __init__(self):
     self.__usePluginEnv = False
     self.__usePluginAIModel = False
     self.__useDefaultRunFunc = True
     self.__logger = logging.getLogger('agent')
     self.__aiModel = None
     self.__agentEnv = None
     self.__RunAIFunc = None
     self.__outputAIAction = True
     self.__connect = BusConnect()
예제 #6
0
class AIFrameWork(object):
    """
    Agent AI framework, run AI test
    """
    def __init__(self):
        self.__usePluginEnv = False
        self.__usePluginAIModel = False
        self.__useDefaultRunFunc = True
        self.__logger = logging.getLogger('agent')
        self.__aiModel = None
        self.__agentEnv = None
        self.__RunAIFunc = None
        self.__outputAIAction = True
        self.__connect = BusConnect()

    def Init(self):
        """
        Init tbus connect, register service, create ai & env object
        """
        if self.__connect.Connect() is not True:
            self.__logger.error('Agent connect failed.')
            return False

        if self._RegisterService() is not True:
            self.__logger.error('Agent register service failed.')
            self._SendTaskReport(common_pb2.PB_TASK_INIT_FAILURE)
            return False

        if self._InitAITask() is True:
            self._SendTaskReport(common_pb2.PB_TASK_INIT_SUCCESS)
        else:
            self._SendTaskReport(common_pb2.PB_TASK_INIT_FAILURE)
            return False

        return True

    def _InitAITask(self):
        if self._GetPluginInConfig() is not True:
            return False

        self._CreateAgentEnvObj()
        self._CreateAIModelObj()

        if self.__aiModel is None or self.__agentEnv is None:
            self.__logger.error('Create agent env or aimodel object failed.')
            return False

        self._CreateRunFunc()
        if self.__useDefaultRunFunc is not True and self.__RunAIFunc is None:
            self.__logger.error('Create run function failed.')
            return False

        if self.__agentEnv.Init() is not True:
            self.__logger.error('Agent env init failed.')
            return False

        if self.__aiModel.Init(self.__agentEnv) is not True:
            self.__logger.error('AI model init failed.')
            return False

        return True

    def Finish(self):
        """
        Disconnect tbus, unregister service, release ai & env object
        """
        self._UnRegisterService()
        if self.__aiModel is not None:
            self.__aiModel.Finish()
        if self.__agentEnv is not None:
            self.__agentEnv.Finish()
        self.__connect.Close()

    def Run(self, isTestMode=True):
        """
        Main framework, run AI test
        """
        if self.__RunAIFunc:
            self.__RunAIFunc(self.__agentEnv, self.__aiModel, isTestMode)
        else:
            self._DefaultRun(isTestMode)

    def StopAIAction(self):
        """
        Stop ai action when receive signal or msg
        """
        self.__outputAIAction = False
        self.__agentEnv.StopAction()
        self.__logger.info('Stop ai action')
        self.__agentEnv.UpdateEnvState(ENV_STATE_PAUSE_PLAYING,
                                       'Pause ai playing')

    def RestartAIAction(self):
        """
        Restart ai action when receive signal or msg
        """
        self.__outputAIAction = True
        self.__agentEnv.RestartAction()
        self.__logger.info('Restart ai action')
        self.__agentEnv.UpdateEnvState(ENV_STATE_RESTORE_PLAYING,
                                       'Resume ai playing')

    def _DefaultRun(self, isTestMode):
        self.__agentEnv.UpdateEnvState(ENV_STATE_WAITING_START,
                                       'Wait episode start')
        while True:
            #wait new episode start
            self._WaitEpisode()
            self.__logger.info('Episode start')
            self.__agentEnv.UpdateEnvState(ENV_STATE_PLAYING,
                                           'Episode start, ai playing')
            self.__aiModel.OnEpisodeStart()

            #run episode accroding to AI, until episode over
            self._RunEpisode(isTestMode)
            self.__logger.info('Episode over')
            self.__agentEnv.UpdateEnvState(ENV_STATE_OVER, 'Episode over')
            self.__aiModel.OnEpisodeOver()

            self.__agentEnv.UpdateEnvState(ENV_STATE_WAITING_START,
                                           'Wait episode start')

        return

    def _WaitEpisode(self):
        while True:
            self._HandleMsg()

            if self.__agentEnv.IsEpisodeStart() is True:
                break
            time.sleep(0.001)

        return

    def _RunEpisode(self, isTestMode):
        while True:
            if self.__outputAIAction is True:
                if isTestMode is True:
                    self.__aiModel.TestOneStep()
                else:
                    self.__aiModel.TrainOneStep()
            else:
                self.__agentEnv.GetState()
                time.sleep(0.01)

            msgID = self._HandleMsg()
            if msgID == common_pb2.MSG_UI_GAME_OVER:
                break

            if self.__agentEnv.IsEpisodeOver() is True:
                break

        return

    def _HandleMsg(self):
        msg = self.__connect.RecvMsg()
        if msg is None:
            return common_pb2.MSG_NONE

        msgID = msg.eMsgID
        if msgID == common_pb2.MSG_UI_GAME_START:
            self.__logger.info('Enter new episode...')
            self.__aiModel.OnEnterEpisode()
        elif msgID == common_pb2.MSG_UI_GAME_OVER:
            self.__logger.info('Leave episode')
            self.__aiModel.OnLeaveEpisode()
        else:
            self.__logger.info('Unknown msg id')

        return msgID

    def _GetPluginInConfig(self):
        pluginCfgPath = util.ConvertToSDKFilePath(PLUGIN_CFG_FILE)
        if os.path.exists(pluginCfgPath):
            return self._LoadPlugInParams(pluginCfgPath)

        oldPluginCfgPath = util.ConvertToSDKFilePath(OLD_PLUGIN_CFG_FILE)
        if os.path.exists(oldPluginCfgPath):
            return self._LoadOldPlugInParams(oldPluginCfgPath)

        self.__logger.error(
            'agentai cfg file {0} not exist.'.format(pluginCfgPath))
        return False

    def _LoadPlugInParams(self, pluginCfgPath):
        try:
            config = configparser.ConfigParser()
            config.read(pluginCfgPath)

            envSection = 'AGENT_ENV'
            aiSection = 'AI_MODEL'
            runSection = 'RUN_FUNCTION'

            if config.has_section('AgentEnv'):
                envSection = 'AgentEnv'

            if config.has_section('AIModel'):
                aiSection = 'AIModel'

            if config.has_section('RunFunc'):
                runSection = 'RunFunc'

            self.__usePluginEnv = config.getboolean(envSection, 'UsePluginEnv')
            self.__usePluginAIModel = config.getboolean(
                aiSection, 'UsePluginAIModel')
            self.__useDefaultRunFunc = config.getboolean(
                runSection, 'UseDefaultRunFunc')

            #if self.__usePluginEnv is True:
            self.__envPackage = config.get(envSection, 'EnvPackage')
            self.__envModule = config.get(envSection, 'EnvModule')
            self.__envClass = config.get(envSection, 'EnvClass')

            #if self.__usePluginAIModel is True:
            self.__aiModelPackage = config.get(aiSection, 'AIModelPackage')
            self.__aiModelModule = config.get(aiSection, 'AIModelModule')
            self.__aiModelClass = config.get(aiSection, 'AIModelClass')

            if self.__useDefaultRunFunc is not True:
                self.__runFuncPackage = config.get(runSection,
                                                   'RunFuncPackage')
                self.__runFuncModule = config.get(runSection, 'RunFuncModule')
                self.__runFuncName = config.get(runSection, 'RunFuncName')
        except Exception as e:
            self.__logger.error('Load file {} failed, error: {}.'.format(
                pluginCfgPath, e))
            return False

        return True

    def _LoadOldPlugInParams(self, pluginCfgPath):
        try:
            config = configparser.ConfigParser()
            config.read(pluginCfgPath)

            self.__usePluginEnv = config.getboolean('AgentEnv', 'UsePluginEnv')
            self.__usePluginAIModel = config.getboolean(
                'AIModel', 'UsePluginAIModel')
            self.__useDefaultRunFunc = config.getboolean(
                'RunFunc', 'UseDefaultRunFunc')

            #if self.__usePluginEnv is True:
            self.__envPackage = config.get('AgentEnv', 'EnvPackage')
            self.__envModule = config.get('AgentEnv', 'EnvModule')
            self.__envClass = config.get('AgentEnv', 'EnvClass')

            #if self.__usePluginAIModel is True:
            self.__aiModelPackage = config.get('AIModel', 'AIModelPackage')
            self.__aiModelModule = config.get('AIModel', 'AIModelModule')
            self.__aiModelClass = config.get('AIModel', 'AIModelClass')

            if self.__useDefaultRunFunc is not True:
                self.__runFuncPackage = config.get('RunFunc', 'RunFuncPackage')
                self.__runFuncModule = config.get('RunFunc', 'RunFuncModule')
                self.__runFuncName = config.get('RunFunc', 'RunFuncName')
        except Exception as e:
            self.__logger.error('Load file {} failed, error: {}.'.format(
                pluginCfgPath, e))
            return False

        return True

    def _CreateAgentEnvObj(self):
        if self.__usePluginEnv is True:
            sys.path.append('PlugIn/ai')

        modulename = '{0}.{1}'.format(self.__envPackage, self.__envModule)
        envPackage = __import__(modulename)
        envModule = getattr(envPackage, self.__envModule)
        envClass = getattr(envModule, self.__envClass)
        self.__logger.info('agent env class: {0}'.format(envClass))
        self.__agentEnv = envClass()

        if self.__usePluginEnv is True:
            sys.path.pop()

    def _CreateAIModelObj(self):
        if self.__usePluginAIModel is True:
            sys.path.append('PlugIn/ai')
        else:
            sys.path.append('AgentAI/aimodel')

        modulename = '{0}.{1}'.format(self.__aiModelPackage,
                                      self.__aiModelModule)
        aiModelPackage = __import__(modulename)
        aiModelModule = getattr(aiModelPackage, self.__aiModelModule)
        aiModelClass = getattr(aiModelModule, self.__aiModelClass)
        self.__logger.info('aimodel class: {0}'.format(aiModelClass))
        self.__aiModel = aiModelClass()

        sys.path.pop()

    def _CreateRunFunc(self):
        if self.__useDefaultRunFunc is not True:
            sys.path.append('PlugIn/ai')

            modulename = '{0}.{1}'.format(self.__runFuncPackage,
                                          self.__runFuncModule)
            runFuncPackage = __import__(modulename)
            runFuncModule = getattr(runFuncPackage, self.__runFuncModule)
            self.__RunAIFunc = getattr(runFuncModule, self.__runFuncName)
            self.__logger.info('run function: {0}'.format(self.__RunAIFunc))

            sys.path.pop()

    def _RegisterService(self):
        regMsg = common_pb2.tagMessage()
        regMsg.eMsgID = common_pb2.MSG_SERVICE_REGISTER
        regMsg.stServiceRegister.eRegisterType = common_pb2.PB_SERVICE_REGISTER
        regMsg.stServiceRegister.eServiceType = common_pb2.PB_SERVICE_TYPE_AI

        if self.__connect.SendMsg(regMsg) == 0:
            return True
        return False

    def _SendTaskReport(self, reportCode):
        taskMsg = common_pb2.tagMessage()
        taskMsg.eMsgID = common_pb2.MSG_TASK_REPORT
        taskMsg.stTaskReport.eTaskStatus = reportCode

        if self.__connect.SendMsg(taskMsg) == 0:
            return True
        return False

    def _UnRegisterService(self):
        unRegMsg = common_pb2.tagMessage()
        unRegMsg.eMsgID = common_pb2.MSG_SERVICE_REGISTER
        unRegMsg.stServiceRegister.eRegisterType = common_pb2.PB_SERVICE_UNREGISTER
        unRegMsg.stServiceRegister.eServiceType = common_pb2.PB_SERVICE_TYPE_AI

        if self.__connect.SendMsg(unRegMsg) == 0:
            return True
        return False
예제 #7
0
 def __init__(self):
     self.__logger = logging.getLogger('agent')
     self.__connect = BusConnect()
     self.__imTrain = MainImitationLearning()
예제 #8
0
class IMTrainFrameWork(object):
    """
    Imitation learnning train framework
    """

    def __init__(self):
        self.__logger = logging.getLogger('agent')
        self.__connect = BusConnect()
        self.__imTrain = MainImitationLearning()

    def Init(self):
        """
        Init tbus connect, register service to manager center
        """
        if self.__connect.Connect() is not True:
            self.__logger.error('Agent connect failed.')
            return False

        if self._RegisterService() is not True:
            self.__logger.error('Agent register service failed.')
            self._SendTaskReport(common_pb2.PB_TASK_INIT_FAILURE)
            return False

        if self._send_resource_info() is not True:
            self.__logger.error('send the source info failed.')
            self._SendTaskReport(common_pb2.PB_TASK_INIT_FAILURE)
            return False

        return True

    def Train(self):
        """
        Load samples and train im model
        """
        self.__imTrain.GenerateImageSamples()
        self.__imTrain.TrainNetwork()

    def Finish(self):
        """
        Disconnect tbus, unregister service
        """
        self._UnRegisterService()
        self.__connect.Close()

    def _SendTaskReport(self, reportCode):
        taskMsg = common_pb2.tagMessage()
        taskMsg.eMsgID = common_pb2.MSG_TASK_REPORT
        taskMsg.stTaskReport.eTaskStatus = reportCode

        if self.__connect.SendMsg(taskMsg, BusConnect.PEER_NODE_MC) == 0:
            return True
        return False

    def _RegisterService(self):
        regMsg = common_pb2.tagMessage()
        regMsg.eMsgID = common_pb2.MSG_SERVICE_REGISTER
        regMsg.stServiceRegister.eRegisterType = common_pb2.PB_SERVICE_REGISTER
        regMsg.stServiceRegister.eServiceType = common_pb2.PB_SERVICE_TYPE_AI

        if self.__connect.SendMsg(regMsg, BusConnect.PEER_NODE_MC) == 0:
            return True
        return False

    def _UnRegisterService(self):
        unRegMsg = common_pb2.tagMessage()
        unRegMsg.eMsgID = common_pb2.MSG_SERVICE_REGISTER
        unRegMsg.stServiceRegister.eRegisterType = common_pb2.PB_SERVICE_UNREGISTER
        unRegMsg.stServiceRegister.eServiceType = common_pb2.PB_SERVICE_TYPE_AI

        if self.__connect.SendMsg(unRegMsg, BusConnect.PEER_NODE_MC) == 0:
            return True
        return False

    def _send_resource_info(self):
        self.__logger.info('send source info to mc, project_path: %s', os.environ.get('AI_SDK_PROJECT_FILE_PATH'))
        project_config_path = os.environ.get('AI_SDK_PROJECT_FILE_PATH')
        if not project_config_path:
            raise Exception('environ var(AI_SDK_PROJECT_FILE_PATH) is invalid')
        content = get_configure(project_config_path)

        if content['source'] is None:
            self.__logger.info("invalid the source in the project config, content: %s", content)
            return False
        self.__logger.info("the project config is %s, project_config_path: %s", str(content), project_config_path)
        source = content['source']
        source_res_message = create_source_response(source)

        if self.__connect.SendMsg(source_res_message, BusConnect.PEER_NODE_MC) == 0:
            self.__logger.info("send the source info to mc service success")
            return True
        self.__logger.warning("send the source info to mc service failed, please check")
        return False
예제 #9
0
class AIFrameWork(object):
    """
    Agent AI framework, run AI test
    """
    def __init__(self):
        self.__logger = logging.getLogger('agent')
        self.__aiModel = None
        self.__agentEnv = None
        self.__runAIFunc = None
        self.__aiPlugin = AIPlugin()
        self.__outputAIAction = True
        self.__connect = BusConnect()

    def Init(self):
        """
        Init tbus connect, register service, create ai & env object
        """
        if self.__connect.Connect() is not True:
            self.__logger.error('Agent connect failed.')
            return False

        if self._RegisterService() is not True:
            self.__logger.error('Agent register service failed.')
            self._SendTaskReport(common_pb2.PB_TASK_INIT_FAILURE)
            return False

        if self._send_resource_info() is not True:
            self.__logger.error('send the source info failed.')
            self._SendTaskReport(common_pb2.PB_TASK_INIT_FAILURE)
            return False

        if self._InitAIObject() is True:
            self.__logger.info(
                'AIFrameWork.Init, _SendTaskReport PB_TASK_INIT_SUCCESS to MC.'
            )
            self._SendTaskReport(common_pb2.PB_TASK_INIT_SUCCESS)
        else:
            self.__logger.error(
                'AIFrameWork.Init, _SendTaskReport PB_TASK_INIT_FAILURE to MC.'
            )
            self._SendTaskReport(common_pb2.PB_TASK_INIT_FAILURE)
            return False

        return True

    def _InitAIObject(self):
        if self.__aiPlugin.Init() is not True:
            return False

        self.__agentEnv = self.__aiPlugin.CreateAgentEnvObj()
        self.__aiModel = self.__aiPlugin.CreateAIModelObj()

        if self.__aiModel is None or self.__agentEnv is None:
            self.__logger.error('Create agent env or aimodel object failed.')
            return False

        self.__runAIFunc = self.__aiPlugin.CreateRunFunc()
        if self.__aiPlugin.UseDefaultRun(
        ) is not True and self.__runAIFunc is None:
            self.__logger.error('Create run function failed.')
            return False

        if self.__agentEnv.Init() is not True:
            self.__logger.error('Agent env init failed.')
            return False
        self.__logger.info('AIFrameWork.InitAIObject agentEnv init success')

        if self.__aiModel.Init(self.__agentEnv) is not True:
            self.__logger.error('AI model init failed.')
            return False

        self.__logger.info('AIFrameWork.InitAIObject aiModel init success')
        return True

    def Finish(self):
        """
        Disconnect tbus, unregister service, release ai & env object
        """
        self._UnRegisterService()
        if self.__aiModel is not None:
            self.__aiModel.Finish()
        if self.__agentEnv is not None:
            self.__agentEnv.Finish()
        self.__connect.Close()

    def Run(self, isTestMode=True):
        """
        Main framework, run AI test
        """
        if self.__runAIFunc:
            logging.debug("execute the run ai func")
            self.__runAIFunc(self.__agentEnv, self.__aiModel, isTestMode)
        else:
            self._DefaultRun(isTestMode)

    def StopAIAction(self):
        """
        Stop ai action when receive signal or msg
        """
        self.__outputAIAction = False
        self.__agentEnv.StopAction()
        self.__logger.info('Stop ai action')
        self.__agentEnv.UpdateEnvState(ENV_STATE_PAUSE_PLAYING,
                                       'Pause ai playing')

    def RestartAIAction(self):
        """
        Restart ai action when receive signal or msg
        """
        self.__outputAIAction = True
        self.__agentEnv.RestartAction()
        self.__logger.info('Restart ai action')
        self.__agentEnv.UpdateEnvState(ENV_STATE_RESTORE_PLAYING,
                                       'Resume ai playing')

    def _DefaultRun(self, isTestMode):
        self.__logger.debug("execute the default run")
        self.__agentEnv.UpdateEnvState(ENV_STATE_WAITING_START,
                                       'Wait episode start')
        while True:
            #wait new episode start
            self.__logger.debug("begin to wait the start")
            self._WaitEpisode()
            self.__logger.info('Episode start')
            self.__agentEnv.UpdateEnvState(ENV_STATE_PLAYING,
                                           'Episode start, ai playing')
            self.__aiModel.OnEpisodeStart()

            #run episode accroding to AI, until episode over
            self._RunEpisode(isTestMode)
            self.__logger.info('Episode over')
            self.__agentEnv.UpdateEnvState(ENV_STATE_OVER, 'Episode over')
            self.__aiModel.OnEpisodeOver()

            self.__agentEnv.UpdateEnvState(ENV_STATE_WAITING_START,
                                           'Wait episode start')

        return

    def _WaitEpisode(self):
        while True:
            self._HandleMsg()

            if self.__agentEnv.IsEpisodeStart() is True:
                break
            time.sleep(0.1)

        return

    def _RunEpisode(self, isTestMode):
        while True:
            if self.__outputAIAction is True:
                if isTestMode is True:
                    self.__aiModel.TestOneStep()
                else:
                    self.__aiModel.TrainOneStep()
            else:
                self.__agentEnv.GetState()
                time.sleep(0.01)

            msgID = self._HandleMsg()
            if msgID == common_pb2.MSG_UI_GAME_OVER:
                break

            if self.__agentEnv.IsEpisodeOver() is True:
                break

        return

    def _HandleMsg(self):
        msg = self.__connect.RecvMsg(BusConnect.PEER_NODE_MC)
        if msg is None:
            return common_pb2.MSG_NONE

        msgID = msg.eMsgID
        self.__logger.info('the message from mc is %s, msgID: %d'.format(
            msg, msgID))

        if msgID == common_pb2.MSG_UI_GAME_START:
            self.__logger.info('Enter new episode...')
            self.__aiModel.OnEnterEpisode()
        elif msgID == common_pb2.MSG_UI_GAME_OVER:
            self.__logger.info('Leave episode')
            self.__aiModel.OnLeaveEpisode()
        else:
            self.__logger.info('Unknown msg id')

        return msgID

    def _send_resource_info(self):
        self.__logger.info('send source info to mc, project_path:%s'.format(
            os.environ.get('AI_SDK_PROJECT_FILE_PATH')))
        project_config_path = os.environ.get('AI_SDK_PROJECT_FILE_PATH')
        if not project_config_path:
            raise Exception('environ var(AI_SDK_PROJECT_FILE_PATH) is invalid')

        content = util.get_configure(project_config_path)
        if not content:
            self.__logger.warning(
                "failed to get project config content, file:%s".format(
                    project_config_path))
            return False

        if not content.get('source'):
            self.__logger.warning(
                "invalid the source in the project config, content:%s".format(
                    content))
            content['source'] = {}
            content['source']['device_type'] = "Android"
            content['source']['platform'] = "Local"
            content['source']['long_edge'] = 1280

        # if content['source'] is None:
        #     self.__logger.info("invalid the source in the project config, content:{}", content)
        #     return False
        self.__logger.info(
            "the project config is {}, project_config_path:{}".format(
                content, project_config_path))
        source = content['source']
        source_res_message = util.create_source_response(source)
        self.__logger.info(
            "send the source message to mc, source_res_message:{}".format(
                source_res_message))

        # 发送设备源信息到MC, 由MC把信息缓存起来
        if self.__connect.SendMsg(source_res_message,
                                  BusConnect.PEER_NODE_MC) == 0:
            self.__logger.info("send the source info to mc service success")
            return True

        self.__logger.warning(
            "send the source info to mc service failed, please check")
        return False

    def _RegisterService(self):
        regMsg = common_pb2.tagMessage()
        regMsg.eMsgID = common_pb2.MSG_SERVICE_REGISTER
        regMsg.stServiceRegister.eRegisterType = common_pb2.PB_SERVICE_REGISTER
        regMsg.stServiceRegister.eServiceType = common_pb2.PB_SERVICE_TYPE_AI

        if self.__connect.SendMsg(regMsg, BusConnect.PEER_NODE_MC) == 0:
            return True
        return False

    def _SendTaskReport(self, reportCode):
        taskMsg = common_pb2.tagMessage()
        taskMsg.eMsgID = common_pb2.MSG_TASK_REPORT
        taskMsg.stTaskReport.eTaskStatus = reportCode

        if self.__connect.SendMsg(taskMsg, BusConnect.PEER_NODE_MC) == 0:
            return True
        return False

    def _UnRegisterService(self):
        unRegMsg = common_pb2.tagMessage()
        unRegMsg.eMsgID = common_pb2.MSG_SERVICE_REGISTER
        unRegMsg.stServiceRegister.eRegisterType = common_pb2.PB_SERVICE_UNREGISTER
        unRegMsg.stServiceRegister.eServiceType = common_pb2.PB_SERVICE_TYPE_AI

        if self.__connect.SendMsg(unRegMsg, BusConnect.PEER_NODE_MC) == 0:
            return True
        return False
예제 #10
0
 def __init__(self):
     self.__logger = logging.getLogger('agent')
     self.__connect = BusConnect()
예제 #11
0
 def __init__(self):
     self.__initialized = False
     self.__connect = BusConnect()
예제 #12
0
class GameEnv(object):
    """
    Game envionment interface class, define abstract interface
    """

    __metaclass__ = ABCMeta

    def __init__(self):
        self.logger = logging.getLogger('agent')
        self.__connect = BusConnect()

        if self.__connect.Connect() is not True:
            self.logger.error('Game env connect failed.')
            raise Exception('Game env connect failed.')

    def SendAction(self, actionMsg):
        """
        Send action msg to MC to do action
        """
        return self.__connect.SendMsg(actionMsg, BusConnect.PEER_NODE_MC)

    def UpdateEnvState(self, stateCode, stateDescription):
        """
        Send agent state msg to MC when state change
        """
        stateMsg = common_pb2.tagMessage()
        stateMsg.eMsgID = common_pb2.MSG_AGENT_STATE
        stateMsg.stAgentState.eAgentState = int(stateCode)
        stateMsg.stAgentState.strAgentState = stateDescription
        return self.__connect.SendMsg(stateMsg, BusConnect.PEER_NODE_MC)

    @abstractmethod
    def Init(self):
        """
        Abstract interface, Init game env object
        """
        raise NotImplementedError()

    @abstractmethod
    def Finish(self):
        """
        Abstract interface, Exit game env object
        """
        raise NotImplementedError()

    @abstractmethod
    def GetActionSpace(self):
        """
        Abstract interface, return number of game action
        """
        raise NotImplementedError()

    @abstractmethod
    def DoAction(self, action, *args, **kwargs):
        """
        Abstract interface, do game action in game env
        """
        raise NotImplementedError()

    @abstractmethod
    def StopAction(self):
        """
        Abstract interface, stop game action when receive special msg or signal
        """
        self.logger.info("execute the default stop action")

    @abstractmethod
    def RestartAction(self):
        """
        Abstract interface, restart output game action when receive special msg or signal
        """
        self.logger.info("execute the default restart action")

    @abstractmethod
    def GetState(self):
        """
        Abstract interface, return game state usually means game image or game data
        """
        raise NotImplementedError()

    @abstractmethod
    def Reset(self):
        """
        Abstract interface, reset date or state in game env
        """
        self.logger.info("execute the default reset action")

    @abstractmethod
    def IsEpisodeStart(self):
        """
        Abstract interface, check whether episode start or not
        """
        return False

    @abstractmethod
    def IsEpisodeOver(self):
        """
        Abstract interface, check whether episode over or not
        """
        return True

    def update_scene_task(self, result_dict, action_dict, agent_api):
        total_task = list(result_dict.keys())
        disable_task = list()

        for action_id in action_dict.keys():
            action_context = action_dict[action_id]
            self.logger.debug(
                "update the action action_id: %s, action_context: %s",
                str(action_id), str(action_context))
            scene_task_id = action_context.get('sceneTask')

            if scene_task_id is None:
                self.logger.debug("the action has no scene task, action_id:%s",
                                  str(action_id))
                continue

            if action_context['type'] == 'click':
                flag, px, py = self._get_position(result_dict, scene_task_id)
                self.logger.debug("get result of scene task id is %s, %s, %s",
                                  str(flag), str(px), str(py))
                if flag is True:
                    action_context['updateBtn'] = True
                    action_context['updateBtnX'] = px
                    action_context['updateBtnY'] = py
                    disable_task.append(scene_task_id)

        # 发送消息给gameReg
        enable_task = [
            total_task[n] for n in range(len(total_task))
            if total_task[n] not in disable_task
        ]
        self.logger.debug("the enable_task is %s and disable_task is %s",
                          str(enable_task), str(disable_task))
        self.__send_update_task(disable_task, enable_task, agent_api)

    @staticmethod
    def _get_position(result_dict, task_id):
        state = False
        px = -1
        py = -1

        reg_results = result_dict.get(task_id)
        if reg_results is None:
            return state, px, py

        for result in reg_results:
            x = result['ROI']['x']
            y = result['ROI']['y']
            w = result['ROI']['w']
            h = result['ROI']['h']

            if x > 0 and y > 0:
                state = True
                px = int(x + w / 2)
                py = int(y + h / 2)
                break

        return state, px, py

    def __send_update_task(self, disable_task, enable_task, agent_api):
        task_flag_dict = dict()
        for taskID in disable_task:
            task_flag_dict[taskID] = False

        for taskID in enable_task:
            task_flag_dict[taskID] = True

        ret = agent_api.SendCmd(AgentAPIMgr.MSG_SEND_TASK_FLAG, task_flag_dict)

        if not ret:
            self.logger.error(
                'AgentAPI MSG_SEND_TASK_FLAG failed, task_flag_dict:{}'.format(
                    task_flag_dict))
            return False

        self.logger.debug(
            "AgentAPI MSG_SEND_TASK_FLAG success, task_flag_dict:{}".format(
                task_flag_dict))
        return True
예제 #13
0
class IMTrainFrameWork(object):
    """
    Imitation learnning train framework
    """

    def __init__(self):
        self.__logger = logging.getLogger('agent')
        self.__connect = BusConnect()
        self.__imTrain = MainImitationLearning()

    def Init(self):
        """
        Init tbus connect, register service to manager center
        """
        if self.__connect.Connect() is not True:
            self.__logger.error('Agent connect failed.')
            return False

        if self._RegisterService() is not True:
            self.__logger.error('Agent register service failed.')
            self._SendTaskReport(common_pb2.PB_TASK_INIT_FAILURE)
            return False

        return True

    def Train(self):
        """
        Load samples and train im model
        """
        self.__imTrain.GenerateImageSamples()
        self.__imTrain.TrainNetwork()

    def Finish(self):
        """
        Disconnect tbus, unregister service
        """
        self._UnRegisterService()
        self.__connect.Close()

    def _SendTaskReport(self, reportCode):
        taskMsg = common_pb2.tagMessage()
        taskMsg.eMsgID = common_pb2.MSG_TASK_REPORT
        taskMsg.stTaskReport.eTaskStatus = reportCode

        if self.__connect.SendMsg(taskMsg) == 0:
            return True
        return False

    def _RegisterService(self):
        regMsg = common_pb2.tagMessage()
        regMsg.eMsgID = common_pb2.MSG_SERVICE_REGISTER
        regMsg.stServiceRegister.eRegisterType = common_pb2.PB_SERVICE_REGISTER
        regMsg.stServiceRegister.eServiceType = common_pb2.PB_SERVICE_TYPE_AI

        if self.__connect.SendMsg(regMsg) == 0:
            return True
        return False

    def _UnRegisterService(self):
        unRegMsg = common_pb2.tagMessage()
        unRegMsg.eMsgID = common_pb2.MSG_SERVICE_REGISTER
        unRegMsg.stServiceRegister.eRegisterType = common_pb2.PB_SERVICE_UNREGISTER
        unRegMsg.stServiceRegister.eServiceType = common_pb2.PB_SERVICE_TYPE_AI

        if self.__connect.SendMsg(unRegMsg) == 0:
            return True
        return False
예제 #14
0
class GameEnv(object):
    """
    Game envionment interface class, define abstract interface
    """

    __metaclass__ = ABCMeta

    def __init__(self):
        self.logger = logging.getLogger('agent')
        self.__connect = BusConnect()

        if self.__connect.Connect() is not True:
            self.logger.error('Game env connect failed.')
            raise Exception('Game env connect failed.')

    def SendAction(self, actionMsg):
        """
        Send action msg to MC to do action
        """
        return self.__connect.SendMsg(actionMsg)

    def UpdateEnvState(self, stateCode, stateDescription):
        """
        Send agent state msg to MC when state change
        """
        stateMsg = common_pb2.tagMessage()
        stateMsg.eMsgID = common_pb2.MSG_AGENT_STATE
        stateMsg.stAgentState.eAgentState = int(stateCode)
        stateMsg.stAgentState.strAgentState = stateDescription
        return self.__connect.SendMsg(stateMsg)

    @abstractmethod
    def Init(self):
        """
        Abstract interface, Init game env object
        """
        raise NotImplementedError()

    @abstractmethod
    def Finish(self):
        """
        Abstract interface, Exit game env object
        """
        raise NotImplementedError()

    @abstractmethod
    def GetActionSpace(self):
        """
        Abstract interface, return number of game action
        """
        raise NotImplementedError()

    @abstractmethod
    def DoAction(self, action, *args, **kwargs):
        """
        Abstract interface, do game action in game env
        """
        raise NotImplementedError()

    @abstractmethod
    def StopAction(self):
        """
        Abstract interface, stop game action when receive special msg or signal
        """
        pass

    @abstractmethod
    def RestartAction(self):
        """
        Abstract interface, restart output game action when receive special msg or signal
        """
        pass

    @abstractmethod
    def GetState(self):
        """
        Abstract interface, return game state usually means game image or game data
        """
        raise NotImplementedError()

    @abstractmethod
    def Reset(self):
        """
        Abstract interface, reset date or state in game env
        """
        pass

    @abstractmethod
    def IsEpisodeStart(self):
        """
        Abstract interface, check whether episode start or not
        """
        return False

    @abstractmethod
    def IsEpisodeOver(self):
        """
        Abstract interface, check whether episode over or not
        """
        return True