def __init__(self): self.logger = logging.getLogger('agent') self.__connect = BusConnect() if self.__connect.Connect() is not True: self.logger.error('Game env connect failed.') raise Exception('Game env connect failed.')
class ProgressReport(object): """ Report training progress to MC """ def __init__(self): self.__logger = logging.getLogger('agent') self.__connect = BusConnect() def Init(self): """ Init tbus connection for report progress """ if self.__connect.Connect() is not True: self.__logger.error('Agent connect failed.') return False return True def SendTrainProgress(self, progress): """ Report training progress to MC """ self.__logger.info('im train progress: %d', progress) stateMsg = common_pb2.tagMessage() stateMsg.eMsgID = common_pb2.MSG_IM_TRAIN_STATE stateMsg.stIMTrainState.nProgress = progress if self.__connect.SendMsg(stateMsg, BusConnect.PEER_NODE_MC) == 0: return True return False
class ActionMgr(object): """ ActionMgr implement for remote action """ def __init__(self): self.__initialized = False self.__connect = BusConnect() def Initialize(self): """ Initialize this module, init bus connection :return: """ self.__initialized = True return self.__connect.Connect() def Finish(self): """ Finish this module, tbus disconnect :return: """ if self.__initialized: LOG.info('Close connection...') self.__connect.Close() self.__initialized = False def SendAction(self, actionID, actionData, frameSeq=-1): """ Send action to remote(client) :param actionID: the self-defined action ID :param actionData: the context data of the action ID :param frameSeq: the frame sequence, default is -1 :return: """ if not self.__initialized: LOG.warning('Call Initialize first!') return False actionData['msg_id'] = MSG_ID_AI_ACTION actionData['action_id'] = actionID actionBuff = msgpack.packb(actionData, default=mn.encode, use_bin_type=True) msg = common_pb2.tagMessage() msg.eMsgID = common_pb2.MSG_AI_ACTION msg.stAIAction.nFrameSeq = frameSeq msg.stAIAction.byAIActionBuff = actionBuff #msgBuff = msg.SerializeToString() if LOG_REGACTION.level <= logging.DEBUG: actionStr = json.dumps(actionData) LOG_REGACTION.debug('{}||action||{}'.format(frameSeq, actionStr)) ret = self.__connect.SendMsg(msg) if ret != 0: LOG.warning('TBus Send To MC return code[{0}]'.format(ret)) return False return True
def __init__(self): self.__logger = logging.getLogger('agent') self.__aiModel = None self.__agentEnv = None self.__runAIFunc = None self.__aiPlugin = AIPlugin() self.__outputAIAction = True self.__connect = BusConnect()
def __init__(self): self.__usePluginEnv = False self.__usePluginAIModel = False self.__useDefaultRunFunc = True self.__logger = logging.getLogger('agent') self.__aiModel = None self.__agentEnv = None self.__RunAIFunc = None self.__outputAIAction = True self.__connect = BusConnect()
class AIFrameWork(object): """ Agent AI framework, run AI test """ def __init__(self): self.__usePluginEnv = False self.__usePluginAIModel = False self.__useDefaultRunFunc = True self.__logger = logging.getLogger('agent') self.__aiModel = None self.__agentEnv = None self.__RunAIFunc = None self.__outputAIAction = True self.__connect = BusConnect() def Init(self): """ Init tbus connect, register service, create ai & env object """ if self.__connect.Connect() is not True: self.__logger.error('Agent connect failed.') return False if self._RegisterService() is not True: self.__logger.error('Agent register service failed.') self._SendTaskReport(common_pb2.PB_TASK_INIT_FAILURE) return False if self._InitAITask() is True: self._SendTaskReport(common_pb2.PB_TASK_INIT_SUCCESS) else: self._SendTaskReport(common_pb2.PB_TASK_INIT_FAILURE) return False return True def _InitAITask(self): if self._GetPluginInConfig() is not True: return False self._CreateAgentEnvObj() self._CreateAIModelObj() if self.__aiModel is None or self.__agentEnv is None: self.__logger.error('Create agent env or aimodel object failed.') return False self._CreateRunFunc() if self.__useDefaultRunFunc is not True and self.__RunAIFunc is None: self.__logger.error('Create run function failed.') return False if self.__agentEnv.Init() is not True: self.__logger.error('Agent env init failed.') return False if self.__aiModel.Init(self.__agentEnv) is not True: self.__logger.error('AI model init failed.') return False return True def Finish(self): """ Disconnect tbus, unregister service, release ai & env object """ self._UnRegisterService() if self.__aiModel is not None: self.__aiModel.Finish() if self.__agentEnv is not None: self.__agentEnv.Finish() self.__connect.Close() def Run(self, isTestMode=True): """ Main framework, run AI test """ if self.__RunAIFunc: self.__RunAIFunc(self.__agentEnv, self.__aiModel, isTestMode) else: self._DefaultRun(isTestMode) def StopAIAction(self): """ Stop ai action when receive signal or msg """ self.__outputAIAction = False self.__agentEnv.StopAction() self.__logger.info('Stop ai action') self.__agentEnv.UpdateEnvState(ENV_STATE_PAUSE_PLAYING, 'Pause ai playing') def RestartAIAction(self): """ Restart ai action when receive signal or msg """ self.__outputAIAction = True self.__agentEnv.RestartAction() self.__logger.info('Restart ai action') self.__agentEnv.UpdateEnvState(ENV_STATE_RESTORE_PLAYING, 'Resume ai playing') def _DefaultRun(self, isTestMode): self.__agentEnv.UpdateEnvState(ENV_STATE_WAITING_START, 'Wait episode start') while True: #wait new episode start self._WaitEpisode() self.__logger.info('Episode start') self.__agentEnv.UpdateEnvState(ENV_STATE_PLAYING, 'Episode start, ai playing') self.__aiModel.OnEpisodeStart() #run episode accroding to AI, until episode over self._RunEpisode(isTestMode) self.__logger.info('Episode over') self.__agentEnv.UpdateEnvState(ENV_STATE_OVER, 'Episode over') self.__aiModel.OnEpisodeOver() self.__agentEnv.UpdateEnvState(ENV_STATE_WAITING_START, 'Wait episode start') return def _WaitEpisode(self): while True: self._HandleMsg() if self.__agentEnv.IsEpisodeStart() is True: break time.sleep(0.001) return def _RunEpisode(self, isTestMode): while True: if self.__outputAIAction is True: if isTestMode is True: self.__aiModel.TestOneStep() else: self.__aiModel.TrainOneStep() else: self.__agentEnv.GetState() time.sleep(0.01) msgID = self._HandleMsg() if msgID == common_pb2.MSG_UI_GAME_OVER: break if self.__agentEnv.IsEpisodeOver() is True: break return def _HandleMsg(self): msg = self.__connect.RecvMsg() if msg is None: return common_pb2.MSG_NONE msgID = msg.eMsgID if msgID == common_pb2.MSG_UI_GAME_START: self.__logger.info('Enter new episode...') self.__aiModel.OnEnterEpisode() elif msgID == common_pb2.MSG_UI_GAME_OVER: self.__logger.info('Leave episode') self.__aiModel.OnLeaveEpisode() else: self.__logger.info('Unknown msg id') return msgID def _GetPluginInConfig(self): pluginCfgPath = util.ConvertToSDKFilePath(PLUGIN_CFG_FILE) if os.path.exists(pluginCfgPath): return self._LoadPlugInParams(pluginCfgPath) oldPluginCfgPath = util.ConvertToSDKFilePath(OLD_PLUGIN_CFG_FILE) if os.path.exists(oldPluginCfgPath): return self._LoadOldPlugInParams(oldPluginCfgPath) self.__logger.error( 'agentai cfg file {0} not exist.'.format(pluginCfgPath)) return False def _LoadPlugInParams(self, pluginCfgPath): try: config = configparser.ConfigParser() config.read(pluginCfgPath) envSection = 'AGENT_ENV' aiSection = 'AI_MODEL' runSection = 'RUN_FUNCTION' if config.has_section('AgentEnv'): envSection = 'AgentEnv' if config.has_section('AIModel'): aiSection = 'AIModel' if config.has_section('RunFunc'): runSection = 'RunFunc' self.__usePluginEnv = config.getboolean(envSection, 'UsePluginEnv') self.__usePluginAIModel = config.getboolean( aiSection, 'UsePluginAIModel') self.__useDefaultRunFunc = config.getboolean( runSection, 'UseDefaultRunFunc') #if self.__usePluginEnv is True: self.__envPackage = config.get(envSection, 'EnvPackage') self.__envModule = config.get(envSection, 'EnvModule') self.__envClass = config.get(envSection, 'EnvClass') #if self.__usePluginAIModel is True: self.__aiModelPackage = config.get(aiSection, 'AIModelPackage') self.__aiModelModule = config.get(aiSection, 'AIModelModule') self.__aiModelClass = config.get(aiSection, 'AIModelClass') if self.__useDefaultRunFunc is not True: self.__runFuncPackage = config.get(runSection, 'RunFuncPackage') self.__runFuncModule = config.get(runSection, 'RunFuncModule') self.__runFuncName = config.get(runSection, 'RunFuncName') except Exception as e: self.__logger.error('Load file {} failed, error: {}.'.format( pluginCfgPath, e)) return False return True def _LoadOldPlugInParams(self, pluginCfgPath): try: config = configparser.ConfigParser() config.read(pluginCfgPath) self.__usePluginEnv = config.getboolean('AgentEnv', 'UsePluginEnv') self.__usePluginAIModel = config.getboolean( 'AIModel', 'UsePluginAIModel') self.__useDefaultRunFunc = config.getboolean( 'RunFunc', 'UseDefaultRunFunc') #if self.__usePluginEnv is True: self.__envPackage = config.get('AgentEnv', 'EnvPackage') self.__envModule = config.get('AgentEnv', 'EnvModule') self.__envClass = config.get('AgentEnv', 'EnvClass') #if self.__usePluginAIModel is True: self.__aiModelPackage = config.get('AIModel', 'AIModelPackage') self.__aiModelModule = config.get('AIModel', 'AIModelModule') self.__aiModelClass = config.get('AIModel', 'AIModelClass') if self.__useDefaultRunFunc is not True: self.__runFuncPackage = config.get('RunFunc', 'RunFuncPackage') self.__runFuncModule = config.get('RunFunc', 'RunFuncModule') self.__runFuncName = config.get('RunFunc', 'RunFuncName') except Exception as e: self.__logger.error('Load file {} failed, error: {}.'.format( pluginCfgPath, e)) return False return True def _CreateAgentEnvObj(self): if self.__usePluginEnv is True: sys.path.append('PlugIn/ai') modulename = '{0}.{1}'.format(self.__envPackage, self.__envModule) envPackage = __import__(modulename) envModule = getattr(envPackage, self.__envModule) envClass = getattr(envModule, self.__envClass) self.__logger.info('agent env class: {0}'.format(envClass)) self.__agentEnv = envClass() if self.__usePluginEnv is True: sys.path.pop() def _CreateAIModelObj(self): if self.__usePluginAIModel is True: sys.path.append('PlugIn/ai') else: sys.path.append('AgentAI/aimodel') modulename = '{0}.{1}'.format(self.__aiModelPackage, self.__aiModelModule) aiModelPackage = __import__(modulename) aiModelModule = getattr(aiModelPackage, self.__aiModelModule) aiModelClass = getattr(aiModelModule, self.__aiModelClass) self.__logger.info('aimodel class: {0}'.format(aiModelClass)) self.__aiModel = aiModelClass() sys.path.pop() def _CreateRunFunc(self): if self.__useDefaultRunFunc is not True: sys.path.append('PlugIn/ai') modulename = '{0}.{1}'.format(self.__runFuncPackage, self.__runFuncModule) runFuncPackage = __import__(modulename) runFuncModule = getattr(runFuncPackage, self.__runFuncModule) self.__RunAIFunc = getattr(runFuncModule, self.__runFuncName) self.__logger.info('run function: {0}'.format(self.__RunAIFunc)) sys.path.pop() def _RegisterService(self): regMsg = common_pb2.tagMessage() regMsg.eMsgID = common_pb2.MSG_SERVICE_REGISTER regMsg.stServiceRegister.eRegisterType = common_pb2.PB_SERVICE_REGISTER regMsg.stServiceRegister.eServiceType = common_pb2.PB_SERVICE_TYPE_AI if self.__connect.SendMsg(regMsg) == 0: return True return False def _SendTaskReport(self, reportCode): taskMsg = common_pb2.tagMessage() taskMsg.eMsgID = common_pb2.MSG_TASK_REPORT taskMsg.stTaskReport.eTaskStatus = reportCode if self.__connect.SendMsg(taskMsg) == 0: return True return False def _UnRegisterService(self): unRegMsg = common_pb2.tagMessage() unRegMsg.eMsgID = common_pb2.MSG_SERVICE_REGISTER unRegMsg.stServiceRegister.eRegisterType = common_pb2.PB_SERVICE_UNREGISTER unRegMsg.stServiceRegister.eServiceType = common_pb2.PB_SERVICE_TYPE_AI if self.__connect.SendMsg(unRegMsg) == 0: return True return False
def __init__(self): self.__logger = logging.getLogger('agent') self.__connect = BusConnect() self.__imTrain = MainImitationLearning()
class IMTrainFrameWork(object): """ Imitation learnning train framework """ def __init__(self): self.__logger = logging.getLogger('agent') self.__connect = BusConnect() self.__imTrain = MainImitationLearning() def Init(self): """ Init tbus connect, register service to manager center """ if self.__connect.Connect() is not True: self.__logger.error('Agent connect failed.') return False if self._RegisterService() is not True: self.__logger.error('Agent register service failed.') self._SendTaskReport(common_pb2.PB_TASK_INIT_FAILURE) return False if self._send_resource_info() is not True: self.__logger.error('send the source info failed.') self._SendTaskReport(common_pb2.PB_TASK_INIT_FAILURE) return False return True def Train(self): """ Load samples and train im model """ self.__imTrain.GenerateImageSamples() self.__imTrain.TrainNetwork() def Finish(self): """ Disconnect tbus, unregister service """ self._UnRegisterService() self.__connect.Close() def _SendTaskReport(self, reportCode): taskMsg = common_pb2.tagMessage() taskMsg.eMsgID = common_pb2.MSG_TASK_REPORT taskMsg.stTaskReport.eTaskStatus = reportCode if self.__connect.SendMsg(taskMsg, BusConnect.PEER_NODE_MC) == 0: return True return False def _RegisterService(self): regMsg = common_pb2.tagMessage() regMsg.eMsgID = common_pb2.MSG_SERVICE_REGISTER regMsg.stServiceRegister.eRegisterType = common_pb2.PB_SERVICE_REGISTER regMsg.stServiceRegister.eServiceType = common_pb2.PB_SERVICE_TYPE_AI if self.__connect.SendMsg(regMsg, BusConnect.PEER_NODE_MC) == 0: return True return False def _UnRegisterService(self): unRegMsg = common_pb2.tagMessage() unRegMsg.eMsgID = common_pb2.MSG_SERVICE_REGISTER unRegMsg.stServiceRegister.eRegisterType = common_pb2.PB_SERVICE_UNREGISTER unRegMsg.stServiceRegister.eServiceType = common_pb2.PB_SERVICE_TYPE_AI if self.__connect.SendMsg(unRegMsg, BusConnect.PEER_NODE_MC) == 0: return True return False def _send_resource_info(self): self.__logger.info('send source info to mc, project_path: %s', os.environ.get('AI_SDK_PROJECT_FILE_PATH')) project_config_path = os.environ.get('AI_SDK_PROJECT_FILE_PATH') if not project_config_path: raise Exception('environ var(AI_SDK_PROJECT_FILE_PATH) is invalid') content = get_configure(project_config_path) if content['source'] is None: self.__logger.info("invalid the source in the project config, content: %s", content) return False self.__logger.info("the project config is %s, project_config_path: %s", str(content), project_config_path) source = content['source'] source_res_message = create_source_response(source) if self.__connect.SendMsg(source_res_message, BusConnect.PEER_NODE_MC) == 0: self.__logger.info("send the source info to mc service success") return True self.__logger.warning("send the source info to mc service failed, please check") return False
class AIFrameWork(object): """ Agent AI framework, run AI test """ def __init__(self): self.__logger = logging.getLogger('agent') self.__aiModel = None self.__agentEnv = None self.__runAIFunc = None self.__aiPlugin = AIPlugin() self.__outputAIAction = True self.__connect = BusConnect() def Init(self): """ Init tbus connect, register service, create ai & env object """ if self.__connect.Connect() is not True: self.__logger.error('Agent connect failed.') return False if self._RegisterService() is not True: self.__logger.error('Agent register service failed.') self._SendTaskReport(common_pb2.PB_TASK_INIT_FAILURE) return False if self._send_resource_info() is not True: self.__logger.error('send the source info failed.') self._SendTaskReport(common_pb2.PB_TASK_INIT_FAILURE) return False if self._InitAIObject() is True: self.__logger.info( 'AIFrameWork.Init, _SendTaskReport PB_TASK_INIT_SUCCESS to MC.' ) self._SendTaskReport(common_pb2.PB_TASK_INIT_SUCCESS) else: self.__logger.error( 'AIFrameWork.Init, _SendTaskReport PB_TASK_INIT_FAILURE to MC.' ) self._SendTaskReport(common_pb2.PB_TASK_INIT_FAILURE) return False return True def _InitAIObject(self): if self.__aiPlugin.Init() is not True: return False self.__agentEnv = self.__aiPlugin.CreateAgentEnvObj() self.__aiModel = self.__aiPlugin.CreateAIModelObj() if self.__aiModel is None or self.__agentEnv is None: self.__logger.error('Create agent env or aimodel object failed.') return False self.__runAIFunc = self.__aiPlugin.CreateRunFunc() if self.__aiPlugin.UseDefaultRun( ) is not True and self.__runAIFunc is None: self.__logger.error('Create run function failed.') return False if self.__agentEnv.Init() is not True: self.__logger.error('Agent env init failed.') return False self.__logger.info('AIFrameWork.InitAIObject agentEnv init success') if self.__aiModel.Init(self.__agentEnv) is not True: self.__logger.error('AI model init failed.') return False self.__logger.info('AIFrameWork.InitAIObject aiModel init success') return True def Finish(self): """ Disconnect tbus, unregister service, release ai & env object """ self._UnRegisterService() if self.__aiModel is not None: self.__aiModel.Finish() if self.__agentEnv is not None: self.__agentEnv.Finish() self.__connect.Close() def Run(self, isTestMode=True): """ Main framework, run AI test """ if self.__runAIFunc: logging.debug("execute the run ai func") self.__runAIFunc(self.__agentEnv, self.__aiModel, isTestMode) else: self._DefaultRun(isTestMode) def StopAIAction(self): """ Stop ai action when receive signal or msg """ self.__outputAIAction = False self.__agentEnv.StopAction() self.__logger.info('Stop ai action') self.__agentEnv.UpdateEnvState(ENV_STATE_PAUSE_PLAYING, 'Pause ai playing') def RestartAIAction(self): """ Restart ai action when receive signal or msg """ self.__outputAIAction = True self.__agentEnv.RestartAction() self.__logger.info('Restart ai action') self.__agentEnv.UpdateEnvState(ENV_STATE_RESTORE_PLAYING, 'Resume ai playing') def _DefaultRun(self, isTestMode): self.__logger.debug("execute the default run") self.__agentEnv.UpdateEnvState(ENV_STATE_WAITING_START, 'Wait episode start') while True: #wait new episode start self.__logger.debug("begin to wait the start") self._WaitEpisode() self.__logger.info('Episode start') self.__agentEnv.UpdateEnvState(ENV_STATE_PLAYING, 'Episode start, ai playing') self.__aiModel.OnEpisodeStart() #run episode accroding to AI, until episode over self._RunEpisode(isTestMode) self.__logger.info('Episode over') self.__agentEnv.UpdateEnvState(ENV_STATE_OVER, 'Episode over') self.__aiModel.OnEpisodeOver() self.__agentEnv.UpdateEnvState(ENV_STATE_WAITING_START, 'Wait episode start') return def _WaitEpisode(self): while True: self._HandleMsg() if self.__agentEnv.IsEpisodeStart() is True: break time.sleep(0.1) return def _RunEpisode(self, isTestMode): while True: if self.__outputAIAction is True: if isTestMode is True: self.__aiModel.TestOneStep() else: self.__aiModel.TrainOneStep() else: self.__agentEnv.GetState() time.sleep(0.01) msgID = self._HandleMsg() if msgID == common_pb2.MSG_UI_GAME_OVER: break if self.__agentEnv.IsEpisodeOver() is True: break return def _HandleMsg(self): msg = self.__connect.RecvMsg(BusConnect.PEER_NODE_MC) if msg is None: return common_pb2.MSG_NONE msgID = msg.eMsgID self.__logger.info('the message from mc is %s, msgID: %d'.format( msg, msgID)) if msgID == common_pb2.MSG_UI_GAME_START: self.__logger.info('Enter new episode...') self.__aiModel.OnEnterEpisode() elif msgID == common_pb2.MSG_UI_GAME_OVER: self.__logger.info('Leave episode') self.__aiModel.OnLeaveEpisode() else: self.__logger.info('Unknown msg id') return msgID def _send_resource_info(self): self.__logger.info('send source info to mc, project_path:%s'.format( os.environ.get('AI_SDK_PROJECT_FILE_PATH'))) project_config_path = os.environ.get('AI_SDK_PROJECT_FILE_PATH') if not project_config_path: raise Exception('environ var(AI_SDK_PROJECT_FILE_PATH) is invalid') content = util.get_configure(project_config_path) if not content: self.__logger.warning( "failed to get project config content, file:%s".format( project_config_path)) return False if not content.get('source'): self.__logger.warning( "invalid the source in the project config, content:%s".format( content)) content['source'] = {} content['source']['device_type'] = "Android" content['source']['platform'] = "Local" content['source']['long_edge'] = 1280 # if content['source'] is None: # self.__logger.info("invalid the source in the project config, content:{}", content) # return False self.__logger.info( "the project config is {}, project_config_path:{}".format( content, project_config_path)) source = content['source'] source_res_message = util.create_source_response(source) self.__logger.info( "send the source message to mc, source_res_message:{}".format( source_res_message)) # 发送设备源信息到MC, 由MC把信息缓存起来 if self.__connect.SendMsg(source_res_message, BusConnect.PEER_NODE_MC) == 0: self.__logger.info("send the source info to mc service success") return True self.__logger.warning( "send the source info to mc service failed, please check") return False def _RegisterService(self): regMsg = common_pb2.tagMessage() regMsg.eMsgID = common_pb2.MSG_SERVICE_REGISTER regMsg.stServiceRegister.eRegisterType = common_pb2.PB_SERVICE_REGISTER regMsg.stServiceRegister.eServiceType = common_pb2.PB_SERVICE_TYPE_AI if self.__connect.SendMsg(regMsg, BusConnect.PEER_NODE_MC) == 0: return True return False def _SendTaskReport(self, reportCode): taskMsg = common_pb2.tagMessage() taskMsg.eMsgID = common_pb2.MSG_TASK_REPORT taskMsg.stTaskReport.eTaskStatus = reportCode if self.__connect.SendMsg(taskMsg, BusConnect.PEER_NODE_MC) == 0: return True return False def _UnRegisterService(self): unRegMsg = common_pb2.tagMessage() unRegMsg.eMsgID = common_pb2.MSG_SERVICE_REGISTER unRegMsg.stServiceRegister.eRegisterType = common_pb2.PB_SERVICE_UNREGISTER unRegMsg.stServiceRegister.eServiceType = common_pb2.PB_SERVICE_TYPE_AI if self.__connect.SendMsg(unRegMsg, BusConnect.PEER_NODE_MC) == 0: return True return False
def __init__(self): self.__logger = logging.getLogger('agent') self.__connect = BusConnect()
def __init__(self): self.__initialized = False self.__connect = BusConnect()
class GameEnv(object): """ Game envionment interface class, define abstract interface """ __metaclass__ = ABCMeta def __init__(self): self.logger = logging.getLogger('agent') self.__connect = BusConnect() if self.__connect.Connect() is not True: self.logger.error('Game env connect failed.') raise Exception('Game env connect failed.') def SendAction(self, actionMsg): """ Send action msg to MC to do action """ return self.__connect.SendMsg(actionMsg, BusConnect.PEER_NODE_MC) def UpdateEnvState(self, stateCode, stateDescription): """ Send agent state msg to MC when state change """ stateMsg = common_pb2.tagMessage() stateMsg.eMsgID = common_pb2.MSG_AGENT_STATE stateMsg.stAgentState.eAgentState = int(stateCode) stateMsg.stAgentState.strAgentState = stateDescription return self.__connect.SendMsg(stateMsg, BusConnect.PEER_NODE_MC) @abstractmethod def Init(self): """ Abstract interface, Init game env object """ raise NotImplementedError() @abstractmethod def Finish(self): """ Abstract interface, Exit game env object """ raise NotImplementedError() @abstractmethod def GetActionSpace(self): """ Abstract interface, return number of game action """ raise NotImplementedError() @abstractmethod def DoAction(self, action, *args, **kwargs): """ Abstract interface, do game action in game env """ raise NotImplementedError() @abstractmethod def StopAction(self): """ Abstract interface, stop game action when receive special msg or signal """ self.logger.info("execute the default stop action") @abstractmethod def RestartAction(self): """ Abstract interface, restart output game action when receive special msg or signal """ self.logger.info("execute the default restart action") @abstractmethod def GetState(self): """ Abstract interface, return game state usually means game image or game data """ raise NotImplementedError() @abstractmethod def Reset(self): """ Abstract interface, reset date or state in game env """ self.logger.info("execute the default reset action") @abstractmethod def IsEpisodeStart(self): """ Abstract interface, check whether episode start or not """ return False @abstractmethod def IsEpisodeOver(self): """ Abstract interface, check whether episode over or not """ return True def update_scene_task(self, result_dict, action_dict, agent_api): total_task = list(result_dict.keys()) disable_task = list() for action_id in action_dict.keys(): action_context = action_dict[action_id] self.logger.debug( "update the action action_id: %s, action_context: %s", str(action_id), str(action_context)) scene_task_id = action_context.get('sceneTask') if scene_task_id is None: self.logger.debug("the action has no scene task, action_id:%s", str(action_id)) continue if action_context['type'] == 'click': flag, px, py = self._get_position(result_dict, scene_task_id) self.logger.debug("get result of scene task id is %s, %s, %s", str(flag), str(px), str(py)) if flag is True: action_context['updateBtn'] = True action_context['updateBtnX'] = px action_context['updateBtnY'] = py disable_task.append(scene_task_id) # 发送消息给gameReg enable_task = [ total_task[n] for n in range(len(total_task)) if total_task[n] not in disable_task ] self.logger.debug("the enable_task is %s and disable_task is %s", str(enable_task), str(disable_task)) self.__send_update_task(disable_task, enable_task, agent_api) @staticmethod def _get_position(result_dict, task_id): state = False px = -1 py = -1 reg_results = result_dict.get(task_id) if reg_results is None: return state, px, py for result in reg_results: x = result['ROI']['x'] y = result['ROI']['y'] w = result['ROI']['w'] h = result['ROI']['h'] if x > 0 and y > 0: state = True px = int(x + w / 2) py = int(y + h / 2) break return state, px, py def __send_update_task(self, disable_task, enable_task, agent_api): task_flag_dict = dict() for taskID in disable_task: task_flag_dict[taskID] = False for taskID in enable_task: task_flag_dict[taskID] = True ret = agent_api.SendCmd(AgentAPIMgr.MSG_SEND_TASK_FLAG, task_flag_dict) if not ret: self.logger.error( 'AgentAPI MSG_SEND_TASK_FLAG failed, task_flag_dict:{}'.format( task_flag_dict)) return False self.logger.debug( "AgentAPI MSG_SEND_TASK_FLAG success, task_flag_dict:{}".format( task_flag_dict)) return True
class IMTrainFrameWork(object): """ Imitation learnning train framework """ def __init__(self): self.__logger = logging.getLogger('agent') self.__connect = BusConnect() self.__imTrain = MainImitationLearning() def Init(self): """ Init tbus connect, register service to manager center """ if self.__connect.Connect() is not True: self.__logger.error('Agent connect failed.') return False if self._RegisterService() is not True: self.__logger.error('Agent register service failed.') self._SendTaskReport(common_pb2.PB_TASK_INIT_FAILURE) return False return True def Train(self): """ Load samples and train im model """ self.__imTrain.GenerateImageSamples() self.__imTrain.TrainNetwork() def Finish(self): """ Disconnect tbus, unregister service """ self._UnRegisterService() self.__connect.Close() def _SendTaskReport(self, reportCode): taskMsg = common_pb2.tagMessage() taskMsg.eMsgID = common_pb2.MSG_TASK_REPORT taskMsg.stTaskReport.eTaskStatus = reportCode if self.__connect.SendMsg(taskMsg) == 0: return True return False def _RegisterService(self): regMsg = common_pb2.tagMessage() regMsg.eMsgID = common_pb2.MSG_SERVICE_REGISTER regMsg.stServiceRegister.eRegisterType = common_pb2.PB_SERVICE_REGISTER regMsg.stServiceRegister.eServiceType = common_pb2.PB_SERVICE_TYPE_AI if self.__connect.SendMsg(regMsg) == 0: return True return False def _UnRegisterService(self): unRegMsg = common_pb2.tagMessage() unRegMsg.eMsgID = common_pb2.MSG_SERVICE_REGISTER unRegMsg.stServiceRegister.eRegisterType = common_pb2.PB_SERVICE_UNREGISTER unRegMsg.stServiceRegister.eServiceType = common_pb2.PB_SERVICE_TYPE_AI if self.__connect.SendMsg(unRegMsg) == 0: return True return False
class GameEnv(object): """ Game envionment interface class, define abstract interface """ __metaclass__ = ABCMeta def __init__(self): self.logger = logging.getLogger('agent') self.__connect = BusConnect() if self.__connect.Connect() is not True: self.logger.error('Game env connect failed.') raise Exception('Game env connect failed.') def SendAction(self, actionMsg): """ Send action msg to MC to do action """ return self.__connect.SendMsg(actionMsg) def UpdateEnvState(self, stateCode, stateDescription): """ Send agent state msg to MC when state change """ stateMsg = common_pb2.tagMessage() stateMsg.eMsgID = common_pb2.MSG_AGENT_STATE stateMsg.stAgentState.eAgentState = int(stateCode) stateMsg.stAgentState.strAgentState = stateDescription return self.__connect.SendMsg(stateMsg) @abstractmethod def Init(self): """ Abstract interface, Init game env object """ raise NotImplementedError() @abstractmethod def Finish(self): """ Abstract interface, Exit game env object """ raise NotImplementedError() @abstractmethod def GetActionSpace(self): """ Abstract interface, return number of game action """ raise NotImplementedError() @abstractmethod def DoAction(self, action, *args, **kwargs): """ Abstract interface, do game action in game env """ raise NotImplementedError() @abstractmethod def StopAction(self): """ Abstract interface, stop game action when receive special msg or signal """ pass @abstractmethod def RestartAction(self): """ Abstract interface, restart output game action when receive special msg or signal """ pass @abstractmethod def GetState(self): """ Abstract interface, return game state usually means game image or game data """ raise NotImplementedError() @abstractmethod def Reset(self): """ Abstract interface, reset date or state in game env """ pass @abstractmethod def IsEpisodeStart(self): """ Abstract interface, check whether episode start or not """ return False @abstractmethod def IsEpisodeOver(self): """ Abstract interface, check whether episode over or not """ return True