Exemplo n.º 1
0
    def __init__(self):
        self.logger = logging.getLogger('agent')
        self.__actionMgr = ActionAPIMgr()

        self.mainImitationLearning = MainImitationLearning()
        self.mainImitationLearning.Init()
        self.cfgData = self.mainImitationLearning.cfgData
        self.actionsContextDict = self.mainImitationLearning.cfgData['actionsContextDict']
        self.actionName = self.mainImitationLearning.actionName
        self.actionDefine = self.mainImitationLearning.cfgData['actionDefine']
        self.taskActionDict = self.mainImitationLearning.taskActionDict
        self.preAction = [None] * len(self.taskActionDict)

        self.resetWaitTime = 5
        self.timeInit = -1

        self.timeNow = -1
        self.centerx = -1
        self.centery = -1
        self.radius = -1
        self.contactJoyStick = -1

        self.downStateDict = dict()
        #最多3个触点[0,1,2]
        for contact in range(3):
            self.downStateDict[contact] = False
Exemplo n.º 2
0
    def __init__(self):
        GameEnv.__init__(self)
        self.actionCtrl = ImitationAction()
        self.__frameIndex = -1

        self.__agentAPI = AgentAPIMgr.AgentAPIMgr()

        self.mainImitationLearning = MainImitationLearning()
        self.mainImitationLearning.Init()

        self.__inputImgWidth = self.mainImitationLearning.inputWidth
        self.__inputImgHeight = self.mainImitationLearning.inputHeight

        self.__timeMs = self.mainImitationLearning.actionTimeMs

        self.__gameState = GAME_STATE_FINISH
Exemplo n.º 3
0
    def __init__(self):
        GameEnv.__init__(self)
        self.__actionCtrl = ImitationAction()
        self.__beginTaskID = list()
        self.__overTaskID = list()

        self.__frameIndex = -1
        self.__agentAPI = AgentAPIMgr.AgentAPIMgr()

        self.mainImitationLearning = MainImitationLearning()
        self.mainImitationLearning.Init()

        self.__inputImgWidth = self.mainImitationLearning.inputWidth
        self.__inputImgHeight = self.mainImitationLearning.inputHeight

        self.__timeMs = self.mainImitationLearning.actionTimeMs

        self.__gameState = GAME_STATE_OVER
        self.__isTerminal = False

        self.__imgHeight = 0
        self.__imgWidth = 0
Exemplo n.º 4
0
    def __init__(self):
        self.logger = logging.getLogger('agent')
        self.__actionMgr = ActionAPIMgr()

        self.mainImitationLearning = MainImitationLearning()
        self.mainImitationLearning.Init()
        self.cfgData = self.mainImitationLearning.cfgData
        self.actionsContextList = self.mainImitationLearning.cfgData['actionsContextList']
        self.actionName = self.mainImitationLearning.actionName
        self.actionDefine = self.mainImitationLearning.cfgData['actionDefine']
        self.taskActionDict = self.mainImitationLearning.taskActionDict


        self.preAction = [None] * len(self.taskActionDict)
        self.resetWaitTime = 5
        self.timeInit = -1

        self.timeNow = -1
        self.centerx = -1
        self.centery = -1
        self.radius = -1
        self.contactJoyStick = -1
Exemplo n.º 5
0
class ImitationEnv(GameEnv):
    """
    Env class for imitation learning
    """
    def __init__(self):
        GameEnv.__init__(self)
        self.__actionCtrl = ImitationAction()
        self.__beginTaskID = list()
        self.__overTaskID = list()

        self.__frameIndex = -1
        self.__agentAPI = AgentAPIMgr.AgentAPIMgr()

        self.mainImitationLearning = MainImitationLearning()
        self.mainImitationLearning.Init()

        self.__inputImgWidth = self.mainImitationLearning.inputWidth
        self.__inputImgHeight = self.mainImitationLearning.inputHeight

        self.__timeMs = self.mainImitationLearning.actionTimeMs

        self.__gameState = GAME_STATE_OVER
        self.__isTerminal = False

        self.__imgHeight = 0
        self.__imgWidth = 0

    def Init(self):
        """
        Int function for env
        """
        taskCfgFile = util.ConvertToSDKFilePath(TASK_CFG_FILE)
        taskReferCfgFile = util.ConvertToSDKFilePath(TASK_REFER_CFG_FILE)
        if not self.__agentAPI.Initialize(taskCfgFile,
                                          referFile=taskReferCfgFile):
            self.logger.error('Agent API Init Failed')
            return False

        if not self.__agentAPI.SendCmd(AgentAPIMgr.MSG_SEND_GROUP_ID,
                                       REG_GROUP_ID):
            self.logger.error('send message failed')
            return False

        if not self._LoadGameState():
            return False

        return True

    def Finish(self):
        """
        Finish env
        """
        self.__agentAPI.Release()
        self.__actionCtrl.Finish()

    def GetActionSpace(self):
        """
        Get action number
        """
        self.logger.info(
            'execute the default get action space in the imitation env')

    def Reset(self):
        """
        Reset env
        """
        self.logger.info('execute the default reset in the imitation env')

    def RestartAction(self):
        """
        Restart action
        """
        self.logger.info(
            'execute the default restart action in the imitation env')

    def StopAction(self):
        """
        Stop action
        """
        self.logger.info(
            'execute the default stop action in the imitation env')

    def DoAction(self, action, *args, **kwargs):
        """
        Do specific action
        """
        self._OutPutAction(action)

    def _OutPutAction(self, actionIndex):
        self.__actionCtrl.DoAction(actionIndex, self.__imgHeight,
                                   self.__imgWidth, self.__timeMs,
                                   self.__frameIndex)

    def GetState(self):
        """
        Get game data , image and state
        """
        gameInfo = self._GetGameInfo()
        image = gameInfo['image']
        self.__frameIndex = gameInfo['frameSeq']
        state = self.__gameState
        img = image
        img = cv2.resize(img, (self.__inputImgWidth, self.__inputImgHeight))
        self.__isTerminal = True

        if state == GAME_STATE_RUN:
            self.__isTerminal = False

        return img, self.__isTerminal

    def IsTrainable(self):
        """
        Whether model is trainable
        """
        return True

    def IsEpisodeStart(self):
        """
        Whether game is begin
        """
        _ = self._GetGameInfo()
        if self.__gameState == GAME_STATE_RUN:
            self.__isTerminal = False
            return True

        return False

    def IsEpisodeOver(self):
        """
        Whether game is over
        """
        return self.__isTerminal

    def OnEpisodeStart(self):
        """
        Initital env when episode is begin
        """
        self.__actionCtrl.Initialize(self.__imgHeight, self.__imgWidth)
        self.logger.info('init:  height: {}  width: {}'.format(
            self.__imgHeight, self.__imgWidth))

    def OnEpisodeOver(self):
        """
        End env when episode is over
        """
        pass

    def _GetBtnPostion(self, resultDict, taskID):
        state = False
        px = -1
        py = -1

        regResults = resultDict.get(taskID)
        if regResults is None:
            return (state, px, py)

        for result in regResults:
            x = result['ROI']['x']
            y = result['ROI']['y']
            w = result['ROI']['w']
            h = result['ROI']['h']

            if x > 0 and y > 0:
                state = True
                px = int(x + w / 2)
                py = int(y + h / 2)
                break

        return (state, px, py)

    def _GetGameInfo(self):
        gameInfo = None

        while True:
            gameInfo = self.__agentAPI.GetInfo(AgentAPIMgr.GAME_RESULT_INFO)
            if gameInfo is None:
                time.sleep(0.002)
                continue

            result = gameInfo['result']
            if result is None:
                time.sleep(0.002)
                continue

            image = gameInfo['image']
            self.__imgHeight = image.shape[0]
            self.__imgWidth = image.shape[1]

            self.logger.debug(
                "the result of game reg is %s, beginTask: %s, endTask: %s",
                str(result), str(self.__beginTaskID), str(self.__overTaskID))

            self._ParseGameState(result)
            self._ParseBtnPostion(result)
            self._ParseSceneInfo(result)

            break

        return gameInfo

    def _ParseGameState(self, resultDict):
        for taskID in self.__beginTaskID:
            flag, _, _ = util.get_button_state(resultDict, taskID)
            if flag is True:
                self.__gameState = GAME_STATE_RUN
                self.logger.debug("the game state set game run state")

        for taskID in self.__overTaskID:
            flag, _, _ = util.get_button_state(resultDict, taskID)
            if flag is True:
                self.__gameState = GAME_STATE_OVER
                self.logger.debug("the game state set game over state")

    def _ParseBtnPostion(self, resultDict):
        totalTask = list(resultDict.keys())
        disableTask = list()

        for _, actionContext in self.__actionCtrl.actionsContextDict.items():
            sceneTaskID = actionContext.get('sceneTask')
            if sceneTaskID is None:
                continue

            if actionContext['type'] == 'click':
                flag, updateBtnX, updateBtnY = self._GetBtnPostion(
                    resultDict, sceneTaskID)
                if flag is True:
                    actionContext['updateBtn'] = True
                    actionContext['updateBtnX'] = updateBtnX
                    actionContext['updateBtnY'] = updateBtnY
                    disableTask.append(sceneTaskID)

        enableTask = [
            totalTask[n] for n in range(len(totalTask))
            if totalTask[n] not in disableTask
        ]
        self.logger.debug("the enable_task is %s and disable_task is %s",
                          str(enableTask), str(disableTask))
        self.SendUpdateTask(disableTask, enableTask)

    def _ParseSceneInfo(self, resultDict):
        pass

    def _LoadGameState(self):
        imEnvFile = util.ConvertToSDKFilePath(IM_ENV_CFG_FILE)
        try:
            with open(imEnvFile, 'r', encoding='utf-8') as file:
                jsonStr = file.read()
                gameStateCfg = json.loads(jsonStr)
                self.logger.info(
                    "the config of env is {}".format(gameStateCfg))
                self.__beginTaskID.extend(gameStateCfg['beginTaskID'])
                self.__overTaskID.extend(gameStateCfg['overTaskID'])
        except Exception as err:
            self.logger.error('Load game state file %s error! Error msg: %s',
                              imEnvFile, str(err))
            return False

        return True

    def SendUpdateTask(self, disableTask, enableTask):
        taskFlagDict = dict()
        for taskID in disableTask:
            taskFlagDict[taskID] = False

        for taskID in enableTask:
            taskFlagDict[taskID] = True

        ret = self.__agentAPI.SendCmd(AgentAPIMgr.MSG_SEND_TASK_FLAG,
                                      taskFlagDict)
        if not ret:
            self.logger.error('AgentAPI MSG_SEND_TASK_FLAG failed')
            return False
        return True
Exemplo n.º 6
0
class ImitationEnv(GameEnv):
    """
    Env class for imitation learning
    """

    def __init__(self):
        GameEnv.__init__(self)
        self.actionCtrl = ImitationAction()
        self.__frameIndex = -1

        self.__agentAPI = AgentAPIMgr.AgentAPIMgr()

        self.mainImitationLearning = MainImitationLearning()
        self.mainImitationLearning.Init()

        self.__inputImgWidth = self.mainImitationLearning.inputWidth
        self.__inputImgHeight = self.mainImitationLearning.inputHeight

        self.__timeMs = self.mainImitationLearning.actionTimeMs

        self.__gameState = GAME_STATE_FINISH

    def Init(self):
        """
        Int function for env
        """
        taskCfgFile = util.ConvertToSDKFilePath(TASK_CFG_FILE)
        taskReferCfgFile = util.ConvertToSDKFilePath(TASK_REFER_CFG_FILE)
        ret = self.__agentAPI.Initialize(taskCfgFile, referFile=taskReferCfgFile)
        if not ret:
            self.logger.error('Agent API Init Failed')
            return False

        ret = self.__agentAPI.SendCmd(AgentAPIMgr.MSG_SEND_GROUP_ID, REG_GROUP_ID)
        if not ret:
            self.logger.error('send message failed')
            return False

        return True

    def Finish(self):
        """
        Finish env
        """
        self.__agentAPI.Release()
        self.actionCtrl.Finish()

    def GetActionSpace(self):
        """
        Get action number
        """
        pass

    def Reset(self):
        """
        Reset env
        """
        pass

    def RestartAction(self):
        """
        Restart action
        """
        pass

    def StopAction(self):
        """
        Stop action
        """
        pass

    def DoAction(self, action, *args, **kwargs):
        """
        Do specific action
        """
        self._OutPutAction(action)

    def _OutPutAction(self, actionIndex):
        self.actionCtrl.DoAction(actionIndex,
                                 self.__imgHeight,
                                 self.__imgWidth,
                                 self.__timeMs,
                                 self.__frameIndex)

    def GetState(self):
        """
        Get game data , image and state
        """
        gameInfo = self._GetGameInfo()
        image = gameInfo['image']
        self.__frameIndex = gameInfo['frameSeq']
        state = self.__gameState
        img = image
        img = cv2.resize(img, (self.__inputImgWidth, self.__inputImgHeight))
        self.__isTerminal = True

        if state == GAME_STATE_RUN:
            self.__isTerminal = False

        return img, self.__isTerminal

    def IsTrainable(self):
        """
        Whether model is trainable
        """
        return True

    def IsEpisodeStart(self):
        """
        Whether game is begin
        """
        _ = self._GetGameInfo()
        if self.__gameState == GAME_STATE_RUN:
            self.__isTerminal = False
            return True

        return False

    def IsEpisodeOver(self):
        """
        Whether game is over
        """
        return self.__isTerminal

    def OnEpisodeStart(self):
        """
        Initital env when episode is begin
        """
        self.actionCtrl.Initialize(self.__imgHeight, self.__imgWidth)
        self.logger.info('init:  height: {}  width: {}'.format(self.__imgHeight, self.__imgWidth))

    def OnEpisodeOver(self):
        """
        End env when episode is over
        """
        pass

    def _GetBtnState(self, resultDict, taskID):
        state = False
        px = -1
        py = -1

        regResults = resultDict.get(taskID)
        if regResults is None:
            return (state, px, py)

        for item in regResults:
            flag = item[0]
            x = item[1]
            y = item[2]
            w = item[3]
            h = item[4]

            if flag is True:
                state = True
                px = int(x + w/2)
                py = int(y + h/2)
                break

        return (state, px, py)

    def _GetGameInfo(self):
        gameInfo = None

        while True:
            gameInfo = self.__agentAPI.GetInfo(AgentAPIMgr.GAME_RESULT_INFO)
            if gameInfo is None:
                time.sleep(0.002)
                continue

            result = gameInfo['result']
            if result is None:
                time.sleep(0.002)
                continue

            image = gameInfo['image']
            self.__imgHeight = image.shape[0]
            self.__imgWidth = image.shape[1]

            self.__gameState = GAME_STATE_RUN
            break

        return gameInfo
Exemplo n.º 7
0
 def __init__(self):
     self.__logger = logging.getLogger('agent')
     self.__connect = BusConnect()
     self.__imTrain = MainImitationLearning()
Exemplo n.º 8
0
class IMTrainFrameWork(object):
    """
    Imitation learnning train framework
    """

    def __init__(self):
        self.__logger = logging.getLogger('agent')
        self.__connect = BusConnect()
        self.__imTrain = MainImitationLearning()

    def Init(self):
        """
        Init tbus connect, register service to manager center
        """
        if self.__connect.Connect() is not True:
            self.__logger.error('Agent connect failed.')
            return False

        if self._RegisterService() is not True:
            self.__logger.error('Agent register service failed.')
            self._SendTaskReport(common_pb2.PB_TASK_INIT_FAILURE)
            return False

        if self._send_resource_info() is not True:
            self.__logger.error('send the source info failed.')
            self._SendTaskReport(common_pb2.PB_TASK_INIT_FAILURE)
            return False

        return True

    def Train(self):
        """
        Load samples and train im model
        """
        self.__imTrain.GenerateImageSamples()
        self.__imTrain.TrainNetwork()

    def Finish(self):
        """
        Disconnect tbus, unregister service
        """
        self._UnRegisterService()
        self.__connect.Close()

    def _SendTaskReport(self, reportCode):
        taskMsg = common_pb2.tagMessage()
        taskMsg.eMsgID = common_pb2.MSG_TASK_REPORT
        taskMsg.stTaskReport.eTaskStatus = reportCode

        if self.__connect.SendMsg(taskMsg, BusConnect.PEER_NODE_MC) == 0:
            return True
        return False

    def _RegisterService(self):
        regMsg = common_pb2.tagMessage()
        regMsg.eMsgID = common_pb2.MSG_SERVICE_REGISTER
        regMsg.stServiceRegister.eRegisterType = common_pb2.PB_SERVICE_REGISTER
        regMsg.stServiceRegister.eServiceType = common_pb2.PB_SERVICE_TYPE_AI

        if self.__connect.SendMsg(regMsg, BusConnect.PEER_NODE_MC) == 0:
            return True
        return False

    def _UnRegisterService(self):
        unRegMsg = common_pb2.tagMessage()
        unRegMsg.eMsgID = common_pb2.MSG_SERVICE_REGISTER
        unRegMsg.stServiceRegister.eRegisterType = common_pb2.PB_SERVICE_UNREGISTER
        unRegMsg.stServiceRegister.eServiceType = common_pb2.PB_SERVICE_TYPE_AI

        if self.__connect.SendMsg(unRegMsg, BusConnect.PEER_NODE_MC) == 0:
            return True
        return False

    def _send_resource_info(self):
        self.__logger.info('send source info to mc, project_path: %s', os.environ.get('AI_SDK_PROJECT_FILE_PATH'))
        project_config_path = os.environ.get('AI_SDK_PROJECT_FILE_PATH')
        if not project_config_path:
            raise Exception('environ var(AI_SDK_PROJECT_FILE_PATH) is invalid')
        content = get_configure(project_config_path)

        if content['source'] is None:
            self.__logger.info("invalid the source in the project config, content: %s", content)
            return False
        self.__logger.info("the project config is %s, project_config_path: %s", str(content), project_config_path)
        source = content['source']
        source_res_message = create_source_response(source)

        if self.__connect.SendMsg(source_res_message, BusConnect.PEER_NODE_MC) == 0:
            self.__logger.info("send the source info to mc service success")
            return True
        self.__logger.warning("send the source info to mc service failed, please check")
        return False
Exemplo n.º 9
0
class ImitationAction(object):
    """
    Action class for imitation learning: define action of class
    """

    def __init__(self):
        self.logger = logging.getLogger('agent')
        self.__actionMgr = ActionAPIMgr()

        self.mainImitationLearning = MainImitationLearning()
        self.mainImitationLearning.Init()
        self.cfgData = self.mainImitationLearning.cfgData
        self.actionsContextDict = self.mainImitationLearning.cfgData['actionsContextDict']
        self.actionName = self.mainImitationLearning.actionName
        self.actionDefine = self.mainImitationLearning.cfgData['actionDefine']
        self.taskActionDict = self.mainImitationLearning.taskActionDict
        self.preAction = [None] * len(self.taskActionDict)

        self.resetWaitTime = 5
        self.timeInit = -1

        self.timeNow = -1
        self.centerx = -1
        self.centery = -1
        self.radius = -1
        self.contactJoyStick = -1

        self.downStateDict = dict()
        #最多3个触点[0,1,2]
        for contact in range(3):
            self.downStateDict[contact] = False

    def Initialize(self, height, width):
        """
        Action initialization
        """
        self.logger.info('the resolution of action, height:%d, width:%d',  height, width)
        return self.__actionMgr.Initialize()

    def Finish(self):
        """
        Finish Action
        """
        self.__actionMgr.Finish()

    def ActionInit(self):
        """
        Moving initialization
        """
        self.__actionMgr.MovingInit(self.centerx, self.centery,
                                    self.radius, contact=self.contactJoyStick,
                                    frameSeq=-1, waitTime=100)

    def ActionFinish(self, frameIndex):
        """
        Moving Finish
        """
        self.__actionMgr.MovingFinish(frameSeq=frameIndex)

    def ActionResetContact(self, frameIndex):
        """
        reset contanct
        """
        for n in range(3):
            if n != self.contactJoyStick:
                if self.downStateDict[n]:
                    self.__actionMgr.Up(contact=n, frameSeq=frameIndex)
                    self.downStateDict[n] = False
            else:
                self.__actionMgr.Moving(-1, frameSeq=frameIndex)

    def DoAction(self, actionIdListInput, imgHeight, imgWidth, timeMs, frameIndex):
        """
        Do action of "actionId" for timeMs milliseconds
        """
        ratioX = imgWidth * 1. / self.cfgData['inputWidth']
        ratioY = imgHeight * 1. / self.cfgData['inputHeight']
        actionIdList = list()
        if len(self.taskActionDict) == 1:
            actionIdList.append(actionIdListInput)
        else:
            actionIdList = actionIdListInput

        self.ActionResetContact(frameIndex)

        for ind in range(len(actionIdList)):
            actionId = actionIdList[ind]

            if self.actionDefine is not None:
                actionIdOriList = self.taskActionDict[ind][actionId]["actionIDGroup"]

                # if self.preAction[ind] == actionId:
                #     continue

                self.preAction[ind] = actionId
                for actionIdOri in actionIdOriList:
                    if self.actionsContextDict[actionIdOri]['type'] == 'none':
                        # self.ActionResetContact(frameIndex)
                        continue

                    contact = self.actionsContextDict[actionIdOri]['contact']
                    if self.preAction[ind] != actionId:
                        if contact != self.contactJoyStick:
                            if self.downStateDict[contact]:
                                self.__actionMgr.Up(contact=contact, frameSeq=frameIndex)
                                self.downStateDict[contact] = False

                    actionType = self.actionsContextDict[actionIdOri]['type']
                    self.DoSpecificAction(actionIdOri, actionType, ratioX, ratioY, frameIndex)
            else:
                self.logger.error('Should define actionDefine in imitationLearning.json')

    def DoSpecificAction(self, actionId, actionType, ratioX, ratioY, frameIndex):
        """
        Do specific action
        actionType == none: no action
        actionType == click: click
        actionType == swipe: swipe
        actionType == joystick: use joystick
        """
        if actionType == 'none':
            return
        if actionType == 'click':
            if self.actionsContextDict[actionId]['updateBtn'] is True:
                contact = self.actionsContextDict[actionId]['contact']
                self.downStateDict[contact] = True
                self.__actionMgr.Down(self.actionsContextDict[actionId]['updateBtnX'],
                                      self.actionsContextDict[actionId]['updateBtnY'],
                                      contact=contact,
                                      frameSeq=frameIndex)
                self.logger.info('Use the updated button position based on task for actionId %d', actionId)
            else:
                actionX = int(self.actionsContextDict[actionId]['buttonX'] * ratioX)
                actionY = int(self.actionsContextDict[actionId]['buttonY'] * ratioY)
                contact = self.actionsContextDict[actionId]['contact']
                self.downStateDict[contact] = True
                self.logger.info('execute the click action for actionId %d', actionId)
                self.__actionMgr.Down(actionX,
                                      actionY,
                                      contact=contact,
                                      frameSeq=frameIndex)
        if actionType == 'key':
            actionX = int(self.actionsContextDict[actionId]['buttonX'] * ratioX)
            actionY = int(self.actionsContextDict[actionId]['buttonY'] * ratioY)
            alphabet = self.actionsContextDict[actionId]['alphabet']
            action_type = self.actionsContextDict[actionId]['action_type']
            action_text = self.actionsContextDict[actionId]['action_text']
            contact = self.actionsContextDict[actionId]['contact']

            self.logger.info('execute the key action for actionId %d', actionId)
            self.logger.info('key action, actionId:%d, actionX:%d, actionY:%d, contact:%d',
                             actionId, actionX, actionY, contact)
            self.logger.info('key action, actionId: %d, alphabet: %s, type: %s, text: %s',
                             actionId, alphabet, str(action_type), action_text)

            self.__actionMgr.SimulatorKeyAction(actionX, actionY, contact=contact, frameSeq=frameIndex,
                                                alphabet=alphabet, action_type=action_type, action_text=action_text)

        if actionType == 'swipe':
            swipeStartX = int(self.actionsContextDict[actionId]['swipeStartX'] * ratioX)
            swipeStartY = int(self.actionsContextDict[actionId]['swipeStartY'] * ratioY)
            swipeEndX = int(self.actionsContextDict[actionId]['swipeEndX'] * ratioX)
            swipeEndY = int(self.actionsContextDict[actionId]['swipeEndY'] * ratioY)

            self.__actionMgr.Swipe(swipeStartX, swipeStartY, swipeEndX, swipeEndY,
                                   contact=self.actionsContextDict[actionId]['contact'],
                                   frameSeq=frameIndex, durationMS=80, needUp=False)

        if actionType == 'joystick':
            self.timeNow = time.time()
            if self.timeNow - self.timeInit > self.resetWaitTime:
                self.centerx = int(self.actionsContextDict[actionId]['centerx'] * ratioX)
                self.centery = int(self.actionsContextDict[actionId]['centery'] * ratioY)
                self.radius = int(0.5 * (self.actionsContextDict[actionId]['rangeInner'] +
                                         self.actionsContextDict[actionId]['rangeOuter']) * ratioX)

                self.contactJoyStick = self.actionsContextDict[actionId]['contact']

                self.ActionInit()
                self.timeInit = self.timeNow

            self.__actionMgr.Moving(self.actionsContextDict[actionId]['angle'], frameSeq=frameIndex)
Exemplo n.º 10
0
class ImitationAction(object):
    """
    Action class for imitation learning: define action of class
    """

    def __init__(self):
        self.logger = logging.getLogger('agent')
        self.__actionMgr = ActionAPIMgr()

        self.mainImitationLearning = MainImitationLearning()
        self.mainImitationLearning.Init()
        self.cfgData = self.mainImitationLearning.cfgData
        self.actionsContextList = self.mainImitationLearning.cfgData['actionsContextList']
        self.actionName = self.mainImitationLearning.actionName
        self.actionDefine = self.mainImitationLearning.cfgData['actionDefine']
        self.taskActionDict = self.mainImitationLearning.taskActionDict


        self.preAction = [None] * len(self.taskActionDict)
        self.resetWaitTime = 5
        self.timeInit = -1

        self.timeNow = -1
        self.centerx = -1
        self.centery = -1
        self.radius = -1
        self.contactJoyStick = -1

    def Initialize(self, height, width):
        """
        Action initialization
        """
        return self.__actionMgr.Initialize()

    def Finish(self):
        """
        Finish Action
        """
        self.__actionMgr.Finish()

    def ActionInit(self):
        """
        Moving initialization
        """
        self.__actionMgr.MovingInit(self.centerx, self.centery,
                                    self.radius, contact=self.contactJoyStick,
                                    frameSeq=-1, waitTime=100)

    def ActionFinish(self, frameIndex):
        """
        Moving Finish
        """
        self.__actionMgr.MovingFinish(frameSeq=frameIndex)

    def ActionResetContact(self, frameIndex):
        """
        reset contanct
        """
        for n in range(3):
            if n != self.contactJoyStick:
                self.__actionMgr.Up(contact=n, frameSeq=frameIndex)
            else:
                self.__actionMgr.Moving(-1, frameSeq=frameIndex)

    def DoAction(self, actionIdListInput, imgHeight, imgWidth, timeMs, frameIndex):
        """
        Do action of "actionId" for timeMs milliseconds
        """
        ratioX = imgWidth * 1. / self.cfgData['inputWidth']
        ratioY = imgHeight * 1. / self.cfgData['inputHeight']
        actionIdList = list()
        if len(self.taskActionDict) == 1:
            actionIdList.append(actionIdListInput)
        else:
            actionIdList = actionIdListInput

        self.ActionResetContact(frameIndex)

        for ind, actionId in enumerate(actionIdList):
            if self.actionDefine is not None:
                actionIdOriList = self.taskActionDict[ind][actionId]["actionIDGroup"]

                # if self.preAction[ind] == actionId:
                #     continue

                self.preAction[ind] = actionId
                for actionIdOri in actionIdOriList:
                    if self.actionsContextList[actionIdOri]['type'] == 0:
                        # self.ActionResetContact(frameIndex)
                        continue

                    contact = self.actionsContextList[actionIdOri]['contact']
                    if self.preAction[ind] != actionId:
                        if contact != self.contactJoyStick:
                            self.__actionMgr.Up(contact=contact, frameSeq=frameIndex)

                    actionType = self.actionsContextList[actionIdOri]['type']
                    self.DoSpecificAction(actionIdOri, actionType, ratioX, ratioY, frameIndex)

            else:
                self.logger.error('Should define actionDefine in imitationLearning.json')

    def DoSpecificAction(self, actionId, actionType, ratioX, ratioY, frameIndex):
        """
        Do specific action
        actionType == 0: no action
        actionType == 3: click
        actionType == 4: swipe
        actionType == 5: use joystick
        """
        if actionType == 0:
            return
        if actionType == 3:
            actionX = int((self.actionsContextList[actionId]['regionX1'] +
                           self.actionsContextList[actionId]['regionX2']) / 2 * ratioX)
            actionY = int((self.actionsContextList[actionId]['regionY1'] +
                           self.actionsContextList[actionId]['regionY2']) / 2 * ratioY)
            self.__actionMgr.Down(actionX, actionY,
                                  contact=self.actionsContextList[actionId]['contact'],
                                  frameSeq=frameIndex)

        if actionType == 4:
            swipeStartX = int(self.actionsContextList[actionId]['swipeStartX'] * ratioX)
            swipeStartY = int(self.actionsContextList[actionId]['swipeStartY'] * ratioY)
            swipeEndX = int(self.actionsContextList[actionId]['swipeEndX'] * ratioX)
            swipeEndY = int(self.actionsContextList[actionId]['swipeEndY'] * ratioY)

            self.__actionMgr.Swipe(swipeStartX, swipeStartY, swipeEndX, swipeEndY,
                                   contact=self.actionsContextList[actionId]['contact'],
                                   frameSeq=frameIndex, durationMS=80, needUp=False)

        if actionType == 5:
            self.timeNow = time.time()
            if self.timeNow - self.timeInit > self.resetWaitTime:
                self.centerx = int(self.actionsContextList[actionId]['centerx'] * ratioX)
                self.centery = int(self.actionsContextList[actionId]['centery'] * ratioY)
                self.radius = int(0.5 * (self.actionsContextList[actionId]['rangeInner'] +
                                         self.actionsContextList[actionId]['rangeOuter']) * ratioX)

                self.contactJoyStick = self.actionsContextList[actionId]['contact']

                self.ActionInit()
                self.timeInit = self.timeNow

            self.__actionMgr.Moving(self.actionsContextList[actionId]['angle'], frameSeq=frameIndex)
Exemplo n.º 11
0
class IMTrainFrameWork(object):
    """
    Imitation learnning train framework
    """

    def __init__(self):
        self.__logger = logging.getLogger('agent')
        self.__connect = BusConnect()
        self.__imTrain = MainImitationLearning()

    def Init(self):
        """
        Init tbus connect, register service to manager center
        """
        if self.__connect.Connect() is not True:
            self.__logger.error('Agent connect failed.')
            return False

        if self._RegisterService() is not True:
            self.__logger.error('Agent register service failed.')
            self._SendTaskReport(common_pb2.PB_TASK_INIT_FAILURE)
            return False

        return True

    def Train(self):
        """
        Load samples and train im model
        """
        self.__imTrain.GenerateImageSamples()
        self.__imTrain.TrainNetwork()

    def Finish(self):
        """
        Disconnect tbus, unregister service
        """
        self._UnRegisterService()
        self.__connect.Close()

    def _SendTaskReport(self, reportCode):
        taskMsg = common_pb2.tagMessage()
        taskMsg.eMsgID = common_pb2.MSG_TASK_REPORT
        taskMsg.stTaskReport.eTaskStatus = reportCode

        if self.__connect.SendMsg(taskMsg) == 0:
            return True
        return False

    def _RegisterService(self):
        regMsg = common_pb2.tagMessage()
        regMsg.eMsgID = common_pb2.MSG_SERVICE_REGISTER
        regMsg.stServiceRegister.eRegisterType = common_pb2.PB_SERVICE_REGISTER
        regMsg.stServiceRegister.eServiceType = common_pb2.PB_SERVICE_TYPE_AI

        if self.__connect.SendMsg(regMsg) == 0:
            return True
        return False

    def _UnRegisterService(self):
        unRegMsg = common_pb2.tagMessage()
        unRegMsg.eMsgID = common_pb2.MSG_SERVICE_REGISTER
        unRegMsg.stServiceRegister.eRegisterType = common_pb2.PB_SERVICE_UNREGISTER
        unRegMsg.stServiceRegister.eServiceType = common_pb2.PB_SERVICE_TYPE_AI

        if self.__connect.SendMsg(unRegMsg) == 0:
            return True
        return False