Exemple #1
0
 def __init__(self, params={}):
     BaseDynaAgent.__init__(self, params)
     MCTSAgent.__init__(self, params)
     self.episode_counter = -1
     self.td_average = 0
     self.average_rate = 0.1
     self.mcts_count = 0
Exemple #2
0
 def end(self, reward):
     if self.episode_counter < episodes_only_dqn:
         BaseDynaAgent.end(self, reward)
     elif self.episode_counter < episodes_only_dqn + episodes_only_mcts:
         MCTSAgent.end(self, reward)
     else:
         if self.episode_counter % 2 == 0:
             BaseDynaAgent.end(self, reward)
         else:
             MCTSAgent.end(self, reward)
Exemple #3
0
 def step(self, reward, observation):
     if self.episode_counter < episodes_only_dqn:
         action = BaseDynaAgent.step(self, reward, observation)
     elif self.episode_counter < episodes_only_dqn + episodes_only_mcts:
         action = MCTSAgent.step(self, reward, observation)
     else:
         if self.episode_counter % 2 == 0:
             action = BaseDynaAgent.step(self, reward, observation)
         else:
             action = MCTSAgent.step(self, reward, observation)
     return action
Exemple #4
0
 def start(self, observation):
     self.episode_counter += 1
     if self.episode_counter % 2 == 0:
         action = BaseDynaAgent.start(self, observation)
     else:
         action = MCTSAgent.start(self, observation)
     return action
Exemple #5
0
 def start(self, observation):
     self.episode_counter += 1
     if self.episode_counter % 2 == 0:
         action = BaseDynaAgent.start(self, observation)
     else:
         action = MCTSAgent.start(self, observation)
         self.mcts_prev_state = self.getStateRepresentation(observation)
         self.mcts_prev_action = action
     return action
Exemple #6
0
    def step(self, reward, observation):
        self.time_step += 1
        if self.episode_counter % 2 == 0:

            self.state = self.getStateRepresentation(observation)
            self.action = self.policy(self.state)

            reward = torch.tensor([reward], device=self.device)

            with torch.no_grad():
                real_prev_action = self.action_list[self.prev_action.item()]
                prev_state_value = self.getStateActionValue(
                    self.prev_state, real_prev_action).item()
                state_value = self._vf['q']['network'](
                    self.state).max(1)[1].view(1, 1).item()
                td_error = reward.item(
                ) + self.gamma * state_value - prev_state_value
                self.update_average_td_error(td_error)

            # store the new transition in buffer
            self.updateTransitionBuffer(
                utils.transition(self.prev_state, self.prev_action, reward,
                                 self.state, self.action, False,
                                 self.time_step, 0))

            # update target
            if self._target_vf['counter'] >= self._target_vf['update_rate']:
                self.setTargetValueFunction(self._vf['q'], 'q')
                # self.setTargetValueFunction(self._vf['s'], 's')

            # update value function with the buffer
            if self._vf['q']['training']:
                if len(self.transition_buffer) >= self._vf['q']['batch_size']:
                    transition_batch = self.getTransitionFromBuffer(
                        n=self._vf['q']['batch_size'])
                    self.updateValueFunction(transition_batch, 'q')
            if self._vf['s']['training']:
                if len(self.transition_buffer) >= self._vf['s']['batch_size']:
                    transition_batch = self.getTransitionFromBuffer(
                        n=self._vf['q']['batch_size'])
                    self.updateValueFunction(transition_batch, 's')

            # train/plan with model
            self.trainModel()
            self.plan()

            self.updateStateRepresentation()

            self.prev_state = self.getStateRepresentation(observation)
            self.prev_action = self.action  # another option:** we can again call self.policy function **

            action = self.action_list[self.prev_action.item()]
        else:
            action = MCTSAgent.step(self, reward, observation)

        return action
Exemple #7
0
    def step(self, reward, observation):
        if self.episode_counter % 2 == 0:
            self.time_step += 1

            self.state = self.getStateRepresentation(observation)
            self.action = self.policy(self.state)

            # update target
            if self._target_vf['counter'] >= self._target_vf['update_rate']:
                self.setTargetValueFunction(self._vf['q'], 'q')
                # self.setTargetValueFunction(self._vf['s'], 's')

            # update value function with the buffer
            if self._vf['q']['training']:
                if len(self.transition_buffer) >= self._vf['q']['batch_size']:
                    transition_batch = self.getTransitionFromBuffer(
                        n=self._vf['q']['batch_size'])
                    self.updateValueFunction(transition_batch, 'q')
            if self._vf['s']['training']:
                if len(self.transition_buffer) >= self._vf['s']['batch_size']:
                    transition_batch = self.getTransitionFromBuffer(
                        n=self._vf['q']['batch_size'])
                    self.updateValueFunction(transition_batch, 's')

            # train/plan with model
            self.trainModel()
            self.plan()

            self.updateStateRepresentation()

            self.prev_state = self.getStateRepresentation(observation)
            self.prev_action = self.action  # another option:** we can again call self.policy function **

            action = self.action_list[self.prev_action.item()]
        else:
            action = MCTSAgent.step(self, reward, observation)

        return action
Exemple #8
0
    def rollout(self, node):
        if self.episode_counter > 200:
            sum_returns = 0
            for i in range(self.num_rollouts):
                depth = 0
                single_return = 0
                is_terminal = False
                state = node.get_state()
                while not is_terminal and depth < self.rollout_depth:
                    a = random.choice(self.action_list)
                    next_state, is_terminal, reward = self.true_model(state, a)
                    single_return += reward
                    depth += 1
                    state = next_state
                if not is_terminal:
                    state_representation = self.getStateRepresentation(state)
                    bootstrap_value = self.getStateActionValue(
                        state_representation)
                    single_return += bootstrap_value.item()
                sum_returns += single_return

            return sum_returns / self.num_rollouts
        else:
            return MCTSAgent.rollout(self, node)
Exemple #9
0
 def step(self, reward, observation):
     action = MCTSAgent.step(self, reward, observation)
     return action
Exemple #10
0
 def __init__(self, params={}):
     BaseDynaAgent.__init__(self, params)
     MCTSAgent.__init__(self, params)
     self.episode_counter = -1
Exemple #11
0
 def start(self, observation):
     self.episode_counter += 1
     if self._sr['network'] is None:
         self.init_s_representation_network(observation)
     action = MCTSAgent.start(self, observation)
     return action
Exemple #12
0
 def __init__(self, params={}):
     BaseDynaAgent.__init__(self, params)
     MCTSAgent.__init__(self, params)
     with open("dqn_vf_4by4.p", 'rb') as file:
         self._vf = pickle.load(file)
     self.episode_counter = -1
Exemple #13
0
    def step(self, reward, observation):
        if self.episode_counter % 2 == 0:
            self.time_step += 1

            self.state = self.getStateRepresentation(observation)
            self.action = self.policy(self.state)

            reward = torch.tensor([reward], device=self.device)

            # store the new transition in buffer
            self.updateTransitionBuffer(
                utils.transition(self.prev_state, self.prev_action, reward,
                                 self.state, self.action, False,
                                 self.time_step, 0))

            # update target
            if self._target_vf['counter'] >= self._target_vf['update_rate']:
                self.setTargetValueFunction(self._vf['q'], 'q')
                # self.setTargetValueFunction(self._vf['s'], 's')

            # update value function with the buffer
            if self._vf['q']['training']:
                if len(self.transition_buffer) >= self._vf['q']['batch_size']:
                    transition_batch = self.getTransitionFromBuffer(
                        n=self._vf['q']['batch_size'])
                    self.updateValueFunction(transition_batch, 'q')
            if self._vf['s']['training']:
                if len(self.transition_buffer) >= self._vf['s']['batch_size']:
                    transition_batch = self.getTransitionFromBuffer(
                        n=self._vf['q']['batch_size'])
                    self.updateValueFunction(transition_batch, 's')

            # train/plan with model
            self.trainModel()
            self.plan()

            self.updateStateRepresentation()

            self.prev_state = self.getStateRepresentation(observation)
            self.prev_action = self.action  # another option:** we can again call self.policy function **

            action = self.action_list[self.prev_action.item()]
        else:
            action = MCTSAgent.step(self, reward, observation)
            prev_action_index = self.getActionIndex(self.mcts_prev_action)
            prev_action_torch = torch.tensor([prev_action_index],
                                             device=self.device,
                                             dtype=int).view(1, 1)
            reward = torch.tensor([reward], device=self.device).float()
            state_torch = self.getStateRepresentation(observation)
            self.updateTransitionBuffer(
                utils.transition(self.mcts_prev_state, prev_action_torch,
                                 reward, state_torch, None, False,
                                 self.time_step, 0))
            self.mcts_prev_state = state_torch
            self.mcts_prev_action = action

            # update target
            if self._target_vf['counter'] >= self._target_vf['update_rate']:
                self.setTargetValueFunction(self._vf['q'], 'q')
                # self.setTargetValueFunction(self._vf['s'], 's')

            # update value function with the buffer
            if self._vf['q']['training']:
                if len(self.transition_buffer) >= self._vf['q']['batch_size']:
                    transition_batch = self.getTransitionFromBuffer(
                        n=self._vf['q']['batch_size'])
                    self.updateValueFunction(transition_batch, 'q')
            if self._vf['s']['training']:
                if len(self.transition_buffer) >= self._vf['s']['batch_size']:
                    transition_batch = self.getTransitionFromBuffer(
                        n=self._vf['q']['batch_size'])
                    self.updateValueFunction(transition_batch, 's')
        return action