Пример #1
0
 def __init__(self, params={}):
     BaseDynaAgent.__init__(self, params)
     MCTSAgent.__init__(self, params)
     self.episode_counter = -1
     self.td_average = 0
     self.average_rate = 0.1
     self.mcts_count = 0
Пример #2
0
 def end(self, reward):
     if self.episode_counter < episodes_only_dqn:
         BaseDynaAgent.end(self, reward)
     elif self.episode_counter < episodes_only_dqn + episodes_only_mcts:
         MCTSAgent.end(self, reward)
     else:
         if self.episode_counter % 2 == 0:
             BaseDynaAgent.end(self, reward)
         else:
             MCTSAgent.end(self, reward)
Пример #3
0
 def step(self, reward, observation):
     if self.episode_counter < episodes_only_dqn:
         action = BaseDynaAgent.step(self, reward, observation)
     elif self.episode_counter < episodes_only_dqn + episodes_only_mcts:
         action = MCTSAgent.step(self, reward, observation)
     else:
         if self.episode_counter % 2 == 0:
             action = BaseDynaAgent.step(self, reward, observation)
         else:
             action = MCTSAgent.step(self, reward, observation)
     return action
Пример #4
0
 def start(self, observation):
     self.episode_counter += 1
     if self.episode_counter % 2 == 0:
         action = BaseDynaAgent.start(self, observation)
     else:
         action = MCTSAgent.start(self, observation)
     return action
Пример #5
0
 def start(self, observation):
     self.episode_counter += 1
     if self.episode_counter % 2 == 0:
         action = BaseDynaAgent.start(self, observation)
     else:
         action = MCTSAgent.start(self, observation)
         self.mcts_prev_state = self.getStateRepresentation(observation)
         self.mcts_prev_action = action
     return action
Пример #6
0
    def rollout_policy(self, state):
        # random policy
        # action = random.choice(self.action_list)

        # DQNs policy
        state = self.getStateRepresentation(state)

        action_ind = BaseDynaAgent.policy(self, state)
        action = self.action_list[action_ind.item()]

        return action
Пример #7
0
 def policy(self, state):
     if self.episode_counter % 2 == 1:
         action, sub_tree = None, None
         for i in range(self.num_iterations):
             action, sub_tree = self.MCTS_iteration()
         # self.render_tree()
         self.subtree_node = sub_tree
         action = torch.from_numpy(np.array(
             [self.getActionIndex(action)])).unsqueeze(0).to(self.device)
     else:
         action = BaseDynaAgent.policy(self, state)
     return action
Пример #8
0
    def start(self, observation):
        self.episode_counter += 1
        if self.keep_tree and self.root is None:
            self.root = Node(None, observation)
            self.expansion(self.root)

        if self.keep_tree:
            self.subtree_node = self.root
        else:
            self.subtree_node = Node(None, observation)
            self.expansion(self.subtree_node)

        action = BaseDynaAgent.start(self, observation)
        return action
Пример #9
0
 def end(self, reward):
     BaseDynaAgent.end(self, reward)
Пример #10
0
 def __init__(self, params={}):
     BaseDynaAgent.__init__(self, params)
     MCTSAgent.__init__(self, params)
     self.episode_counter = -1
Пример #11
0
 def __init__(self, params={}):
     BaseDynaAgent.__init__(self, params)
     MCTSAgent.__init__(self, params)
     with open("dqn_vf_4by4.p", 'rb') as file:
         self._vf = pickle.load(file)
     self.episode_counter = -1