예제 #1
0
파일: garnet_test.py 프로젝트: hbhzwj/librl
    def setUp(self):
        scipy.random.seed(0)
        self.testDir = tempfile.mkdtemp()

        self.numStates = 3
        self.numActions = 2
        self.branching = 1
        self.feaDim = 5
        self.feaSum = 2

        # For action 0
        # state 0 -> 1
        # state 1 -> 0
        # state 2 -> 2
        ts0 = [[1], [0], [2]]
        tp0 = [[1], [1], [1]]
        # For action 1
        # state 0 -> 2
        # state 1 -> 1
        # state 2 -> 0
        ts1 = [[2], [1], [0]]
        tp1 = [[1], [1], [1]]
        self.transitionStates = scipy.array([ts0, ts1])
        self.transitionProb = scipy.array([tp0, tp1],
                                          dtype=float)
        self.stateObs = [(1, 1, 0, 0, 0),
                         (1, 0, 1, 0, 0),
                         (0, 1, 0, 1, 0)]

        message = {
            'transitionStates': self.transitionStates,
            'transitionProb': self.transitionProb,
            'stateObs': self.stateObs,
        }
        self.loadPath = self.testDir + 'transition_model.pkz'
        zdump(message, self.loadPath)

        self.env = GarnetEnvironment(numStates=self.numStates,
                                     numActions=self.numActions,
                                     branching=self.branching,
                                     feaDim=self.feaDim,
                                     feaSum=self.feaSum,
                                     loadPath=self.loadPath)
예제 #2
0
파일: garnet.py 프로젝트: hbhzwj/librl
 def _save(self, savePath):
     message = dict()
     message['transitionStates'] = self.transitionStates
     message['transitionProb'] = self.transitionProb
     message['stateObs'] = self.stateObs
     zdump(message, savePath)