def Start(self): filePath = self.Config.FilePath.BestModel cnf = "BestLog/BestLog180126082522.cnf" wgt = "BestLog/BestLog180126082522.wgt" timeLimit = 0.9 net = Network() net.Load(cnf, wgt) net.TimeLimit = timeLimit model = Model() taskName = "TaskEval/EvalTask114.task" task = MujocoTask(model, taskName) #task = MujocoTask.LoadRandom(model, self.Config.Task.EvalDir) env = MujocoEnv(model) agentConfig = self.Config.ViewerAgent agent = Agent(agentConfig, net, model, task) bestAction = agent.SearchBestAction() while True: env.SetSimState(task.StartState) for action in bestAction: env.Step(action) #print(env.GetObservation(task)) env.Render()
def CalcScore(self, net, filePath): bestModel = MujocoModelHumanoid() bestTask = MujocoTask(bestModel, filePath) bestEnv = MujocoEnv(bestModel) bestAgent = Agent(self.Config.CheckerAgent, net, bestModel, bestTask) bestAction = bestAgent.SearchBestAction() bestScore = self.GetScore(bestEnv, bestTask, bestAction) return bestScore
def CalcScores(self, best, next, filePath): bestModel = MujocoModelHumanoid() bestTask = MujocoTask(bestModel, filePath) bestEnv = MujocoEnv(bestModel) nextModel = MujocoModelHumanoid() nextTask = MujocoTask(nextModel, filePath) nextEnv = MujocoEnv(nextModel) bestAgent = Agent(self.Config.EvaluateAgent, best, bestModel, bestTask) nextAgent = Agent(self.Config.EvaluateAgent, next, nextModel, nextTask) bestAction = bestAgent.SearchBestAction() nextAction = nextAgent.SearchBestAction() bestScore = self.GetScore(bestEnv, bestTask, bestAction) nextScore = self.GetScore(nextEnv, nextTask, nextAction) #nextAgent.SaveTrainData(self.Config.GetTrainPath("next")) return bestScore, nextScore
def Start(self): filePath = self.Config.FilePath.NextGeneration net = Network() net.Load(filePath.Config, filePath.Weight) model = Model() task = MujocoTask(model, self.GetRandomFile()) env = MujocoEnv(model) agentConfig = self.Config.SelfPlayAgent agent = Agent(agentConfig, net, model, task) bestAction = agent.SearchBestAction() print(bestAction) agent.SaveTrainData(self.Config.GetTrainPath())