コード例 #1
0
    def testRolloutWithHeuristic(self, max_rollout_step, init_state,
                                 gt_sumValue):
        max_iteration = 1000

        target_state = 6
        isTerminal = Terminal(target_state)

        catch_reward = 1
        step_penalty = 0
        reward_func = RewardFunction(step_penalty, catch_reward, isTerminal)
        rolloutHeuristic = lambda state: 2

        rollout_policy = lambda state: np.random.choice(self.action_space)
        leaf_node = Node(id={1: init_state},
                         numVisited=1,
                         sumValue=0,
                         actionPrior=self.default_actionPrior,
                         isExpanded=True)
        rollout = RollOut(rollout_policy, max_rollout_step, self.transition,
                          reward_func, isTerminal, rolloutHeuristic)
        stored_reward = []
        for curr_iter in range(max_iteration):
            stored_reward.append(rollout(leaf_node))

        calc_sumValue = np.mean(stored_reward)

        self.assertAlmostEqual(gt_sumValue, calc_sumValue, places=1)
コード例 #2
0
ファイル: test1DEnv.py プロジェクト: Chenfei1129/PWAndMCTS
    def setUp(self):
        self.target_state = 7
        self.isTerminal = Terminal(self.target_state)
        self.step_penalty = -1
        self.catch_reward = 1

        self.reward = RewardFunction(self.step_penalty, self.catch_reward,
                                     self.isTerminal)
コード例 #3
0
ファイル: testMCTSNew.py プロジェクト: Chenfei1129/PWAndMCTS
    def setUp(self):
        # Env param
        bound_low = 0
        bound_high = 7
        self.transition = TransitionFunction(bound_low, bound_high)

        self.action_space = [-1, 1]
        self.num_action_space = len(self.action_space)
        self.uniformActionPrior = {
            action: 1 / self.num_action_space
            for action in self.action_space
        }

        step_penalty = -1
        catch_reward = 1
        self.target_state = bound_high
        self.isTerminal = Terminal(self.target_state)

        self.c_init = 0
        self.c_base = 1
        self.scoreChild = ScoreChild(self.c_init, self.c_base)

        self.selectAction = SelectAction(self.scoreChild)
        self.selectNextState = SelectNextState(self.selectAction)
        self.growNextState = GrowNextState(self.transition)

        init_state = 3
        level1_0_state = self.transition(init_state, action=0)
        level1_1_state = self.transition(init_state, action=1)
        self.default_actionPrior = 1 / self.num_action_space

        self.root = Node(id={1: init_state},
                         numVisited=1,
                         sumValue=0,
                         actionPrior=self.default_actionPrior,
                         isExpanded=True)
        self.level1_0 = Node(parent=self.root,
                             id={0: level1_0_state},
                             numVisited=2,
                             sumValue=5,
                             actionPrior=self.default_actionPrior,
                             isExpanded=False)
        self.level1_1 = Node(parent=self.root,
                             id={1: level1_1_state},
                             numVisited=3,
                             sumValue=10,
                             actionPrior=self.default_actionPrior,
                             isExpanded=False)

        self.getActionPrior = lambda state: self.uniformActionPrior
        self.initializeChildren = InitializeChildren(self.action_space,
                                                     self.transition,
                                                     self.getActionPrior)
        self.expand = Expand(self.isTerminal, self.initializeChildren)
コード例 #4
0
ファイル: test1DEnv.py プロジェクト: Chenfei1129/PWAndMCTS
 def setUp(self):
     self.target_state = 7
     self.isTerminal = Terminal(self.target_state)