def testRolloutWithHeuristic(self, max_rollout_step, init_state, gt_sumValue): max_iteration = 1000 target_state = 6 isTerminal = Terminal(target_state) catch_reward = 1 step_penalty = 0 reward_func = RewardFunction(step_penalty, catch_reward, isTerminal) rolloutHeuristic = lambda state: 2 rollout_policy = lambda state: np.random.choice(self.action_space) leaf_node = Node(id={1: init_state}, numVisited=1, sumValue=0, actionPrior=self.default_actionPrior, isExpanded=True) rollout = RollOut(rollout_policy, max_rollout_step, self.transition, reward_func, isTerminal, rolloutHeuristic) stored_reward = [] for curr_iter in range(max_iteration): stored_reward.append(rollout(leaf_node)) calc_sumValue = np.mean(stored_reward) self.assertAlmostEqual(gt_sumValue, calc_sumValue, places=1)
def setUp(self): self.target_state = 7 self.isTerminal = Terminal(self.target_state) self.step_penalty = -1 self.catch_reward = 1 self.reward = RewardFunction(self.step_penalty, self.catch_reward, self.isTerminal)
def setUp(self): # Env param bound_low = 0 bound_high = 7 self.transition = TransitionFunction(bound_low, bound_high) self.action_space = [-1, 1] self.num_action_space = len(self.action_space) self.uniformActionPrior = { action: 1 / self.num_action_space for action in self.action_space } step_penalty = -1 catch_reward = 1 self.target_state = bound_high self.isTerminal = Terminal(self.target_state) self.c_init = 0 self.c_base = 1 self.scoreChild = ScoreChild(self.c_init, self.c_base) self.selectAction = SelectAction(self.scoreChild) self.selectNextState = SelectNextState(self.selectAction) self.growNextState = GrowNextState(self.transition) init_state = 3 level1_0_state = self.transition(init_state, action=0) level1_1_state = self.transition(init_state, action=1) self.default_actionPrior = 1 / self.num_action_space self.root = Node(id={1: init_state}, numVisited=1, sumValue=0, actionPrior=self.default_actionPrior, isExpanded=True) self.level1_0 = Node(parent=self.root, id={0: level1_0_state}, numVisited=2, sumValue=5, actionPrior=self.default_actionPrior, isExpanded=False) self.level1_1 = Node(parent=self.root, id={1: level1_1_state}, numVisited=3, sumValue=10, actionPrior=self.default_actionPrior, isExpanded=False) self.getActionPrior = lambda state: self.uniformActionPrior self.initializeChildren = InitializeChildren(self.action_space, self.transition, self.getActionPrior) self.expand = Expand(self.isTerminal, self.initializeChildren)
def setUp(self): self.target_state = 7 self.isTerminal = Terminal(self.target_state)