class QTree(QAbstract): """ This class is used when the number of actions is unknown. state_list = [state_1, state_2, ...] # a list of all states each state has a value Note that Node.data is a state :param: states are *terminal* state of options :param: actions are children index of states """ def __init__(self, state): self.tree = Tree(state) self.current_node = self.tree.root self.number_options = 0 def __len__(self): return len(self.tree.nodes) def __str__(self): return self.tree.str_tree() def reset(self): self.current_node = self.tree.root def get_node_from_state(self, state): """ :param state: :return: the corresponding node with node.data == state :exception if the state does not exist """ for node in self.tree.root.depth_first(): if node.data == state: return node raise ValueError("state does not exist in the tree") def get_child_node_from_current_state(self, state): """ :param state: the node data we are looking for :return: a child of self.current_node with child.data == state """ for child in self.current_node.children: if child.data == state: return child raise ValueError("None of my children have this state") def add_state(self, next_state): """ Add a state at the current node. But be careful, do not to add twice the same state at the same position. :param next_state: the state you want to add :return: """ if self.no_return_update(next_state): # update only if the transition does not exist in the other way round # update the number of visits of the current node self.current_node.number_visits += 1 try: self.current_node = self.get_node_from_state(next_state) except ValueError: # add next_state only if it does not already exist next_current_node = self.tree.add_tree(self.current_node, Node(next_state)) # and update the number of options if len(self.current_node.children) > self.number_options: self.number_options += 1 self.current_node = next_current_node def get_random_action(self, state): """ could implement the following code: node = self.get_node_from_state(state) return np.random.randint(len(node.children)) but I'm not sure it is worth performing random actions at the high level. """ pass def get_number_visits(self): return self.current_node.number_visits def find_best_action(self, state=None): """ :return: best_option_index, terminal_state """ values = self.current_node.get_values() if not values: return 0, None # In case where there is no best solution: ask the Tree if all(val == values[0] for val in values): best_option_index = Tree.get_random_next_option_index(self.current_node) else: best_reward = max(values) best_option_index = values.index(best_reward) return best_option_index, self.current_node.children[best_option_index].data def update_q_value(self, action, reward, new_state, learning_rate): """ Performs the Q learning update : Q_{t+1}(current_position, action) = (1- learning_rate) * Q_t(current_position, action) += learning_rate * [reward + max_{actions} Q_(new_position, action)] """ node_activated = self.get_child_node_from_current_state(action) # node which value attribute is # Q_t(current_position, action) try: new_node = self.get_node_from_state(new_state) # maybe different than node_activated if new_node.children: # there are children, take the maximum value best_value = max(new_node.get_values()) else: # there are no children -> best_value is 0 best_value = 0 except ValueError: # this new_state does not exist for the moment best_value = 0 node_activated.value *= (1 - learning_rate) node_activated.value += learning_rate * (reward + best_value) def no_return_update(self, new_state): """ (no return option) does not add anything if for action in q[option.terminal_state]: action.terminal_state = option.initial_state """ try: new_node = self.get_node_from_state(new_state) for node in new_node.children: if node.data == self.current_node.data: return False return True except ValueError: return True
class TreeTest(unittest.TestCase): def setUp(self): """ We define here a Tree to test its functions """ self.tree = Tree(root_data=0) self.node_1 = Node(data=1) self.node_2 = Node(data=2) self.node_3 = Node(data=3) self.node_4 = Node(data=4) self.node_5 = Node(data=5) self.node_6 = Node(data=6) self.node_7 = Node(data=7) self.node_8 = Node(data=8) self.set_parents_children() self.set_values() def set_values(self): self.tree.root.value = 0 self.node_1.value = 1 self.node_2.value = 10 self.node_3.value = 11 self.node_4.value = 100 self.node_5.value = 101 self.node_6.value = 111 self.node_7.value = 1000 def set_parents_children(self): """ Defines a Tree with the nodes :return: """ self.tree.add_tree(self.tree.root, self.node_1) self.tree.add_tree(self.tree.root, self.node_2) self.tree.add_tree(self.tree.root, self.node_3) self.tree.add_tree(self.node_1, self.node_4) self.tree.add_tree(self.node_1, self.node_5) self.tree.add_tree(self.node_3, self.node_6) self.tree.add_tree(self.node_4, self.node_7) # ------------- The tests are defined here -------------- def test_print_tree(self): print(self.tree.str_tree()) def test_new_root(self): self.tree.new_root(self.node_3) tree = Tree(0) tree.root = self.node_3 tree.nodes = [self.node_3, self.node_6] tree.depth[0].append(self.node_3) tree.depth[1].append(self.node_6) tree.max_depth = 1 self.assertEqual(self.tree.root, tree.root) self.assertEqual(self.tree.nodes, tree.nodes) self.assertEqual(self.tree.depth, tree.depth) self.assertEqual(self.tree.max_depth, tree.max_depth) def test_update(self): self.node_8.depth = 3 self.tree.update(self.node_8) self.assertEqual(self.tree.depth[3], [self.node_7, self.node_8]) def test_add_tree(self): self.tree.add_tree(parent_node=self.node_6, node=self.node_8) self.assertEqual(self.tree.depth[3], [self.node_7, self.node_8]) def test_get_leaves(self): leaves = self.tree.get_leaves(node=self.tree.root) self.assertEqual(leaves, [self.node_7, self.node_5, self.node_2, self.node_6]) def test_get_next_option_index(self): next_node_index_1 = Tree.get_next_option_index(self.tree.root, self.node_4) next_node_index_4 = Tree.get_next_option_index(self.node_1, self.node_7) next_node_index_3 = Tree.get_next_option_index(self.tree.root, self.node_6) self.assertEqual(next_node_index_1, 0) self.assertEqual(next_node_index_4, 0) self.assertEqual(next_node_index_3, 2) def test_get_probability_leaves(self): leaves_0, _ = Tree.get_probability_leaves(self.tree.root) leaves_1, _ = Tree.get_probability_leaves(self.node_1) leaves_3, _ = Tree.get_probability_leaves(self.node_3) leaves_4, _ = Tree.get_probability_leaves(self.node_4) with self.assertRaises(Exception): leaves_2, _ = self.tree.get_probability_leaves(self.node_2) with self.assertRaises(Exception): leaves_5, _ = self.tree.get_probability_leaves(self.node_5) with self.assertRaises(Exception): leaves_7, _ = self.tree.get_probability_leaves(self.node_7) with self.assertRaises(Exception): leaves_6, _ = self.tree.get_probability_leaves(self.node_7) np.testing.assert_array_equal(leaves_0, np.array([3 / 8, 2 / 8, 1 / 8, 2 / 8])) np.testing.assert_array_equal(leaves_1, np.array([2 / 3, 1 / 3])) np.testing.assert_array_equal(leaves_3, np.array([1])) np.testing.assert_array_equal(leaves_4, np.array([1])) def test_get_random_next_option_index(self): """ :return: """ pass