def setUp(self): self.target_state = 7 self.isTerminal = Terminal(self.target_state) self.step_penalty = -1 self.catch_reward = 1 self.reward = RewardFunction(self.step_penalty, self.catch_reward, self.isTerminal)
def testRollout(self, max_rollout_step, init_state, gt_sum_value): max_iteration = 1000 target_state = 6 isTerminal = Terminal(target_state) catch_reward = 1 step_penalty = 0 reward_func = RewardFunction(step_penalty, catch_reward, isTerminal) rollout_policy = lambda state: np.random.choice(self.action_space) leaf_node = Node(id={1: init_state}, num_visited=1, sum_value=0, action_prior=self.default_action_prior, is_expanded=True) rollout = RollOut(rollout_policy, max_rollout_step, self.transition, reward_func, isTerminal) stored_reward = [] for curr_iter in range(max_iteration): stored_reward.append(rollout(leaf_node)) calc_sum_value = np.mean(stored_reward) self.assertAlmostEqual(gt_sum_value, calc_sum_value, places=1)
def setUp(self): # Env param bound_low = 0 bound_high = 7 self.transition = TransitionFunction(bound_low, bound_high) self.action_space = [-1, 1] self.num_action_space = len(self.action_space) self.action_prior_func = GetActionPrior(self.action_space) step_penalty = -1 catch_reward = 1 self.target_state = bound_high self.isTerminal = Terminal(self.target_state) self.c_init = 0 self.c_base = 1 self.calculateScore = CalculateScore(self.c_init, self.c_base) self.selectChild = SelectChild(self.calculateScore) init_state = 3 level1_0_state = self.transition(init_state, action=0) level1_1_state = self.transition(init_state, action=1) self.default_action_prior = 0.5 self.root = Node(id={1: init_state}, num_visited=1, sum_value=0, action_prior=self.default_action_prior, is_expanded=True) self.level1_0 = Node(parent=self.root, id={0: level1_0_state}, num_visited=2, sum_value=5, action_prior=self.default_action_prior, is_expanded=False) self.level1_1 = Node(parent=self.root, id={1: level1_1_state}, num_visited=3, sum_value=10, action_prior=self.default_action_prior, is_expanded=False) self.initializeChildren = InitializeChildren(self.action_space, self.transition, self.action_prior_func) self.expand = Expand(self.isTerminal, self.initializeChildren)
def setUp(self): self.target_state = 7 self.isTerminal = Terminal(self.target_state)