コード例 #1
0
    def plan(self):
        return 0
        with torch.no_grad():
            current_state = torch.tensor(self.prev_state.data.clone())
            true_current_state = self.prev_state.cpu().numpy()[0]
            current_action = np.copy(self.prev_action)
            for h in range(self.model['forward']['plan_horizon']):
                next_state, is_terminal = self.rolloutWithModel(
                    current_state, current_action, self.model['forward'], 1)
                is_terminal = np.random.binomial(
                    n=1, p=float(is_terminal.data.cpu().numpy()), size=1)
                true_next_state, _, reward = self.true_model(
                    true_current_state, current_action)
                reward = torch.tensor(reward).unsqueeze(0).to(self.device)
                next_action = self.forwardRolloutPolicy(next_state)

                if is_terminal:
                    self.updateTransitionBuffer(
                        utils.transition(current_state, current_action, reward,
                                         None, None, True, self.time_step))
                else:
                    self.updateTransitionBuffer(
                        utils.transition(current_state, current_action, reward,
                                         next_state, next_action, False,
                                         self.time_step))
                current_state = next_state
                current_action = next_action
                true_current_state = true_next_state
コード例 #2
0
    def step(self, reward, observation):
        self.time_step += 1

        self.state = self.getStateRepresentation(observation)

        reward = torch.tensor(reward).unsqueeze(0).to(self.device)
        self.action = self.policy(self.state)

        # store the new transition in buffer
        self.updateTransitionBuffer(utils.transition(self.prev_state, self.prev_action, reward,
                                                     self.state, self.action, False, self.time_step, 0))
        # update target
        if self._target_vf['counter'] >= self._target_vf['update_rate']:
            self.setTargetValueFunction(self._vf['q'], 'q')
            # self.setTargetValueFunction(self._vf['s'], 's')
            self._target_vf['counter'] = 0

        # update value function with the buffer
        if self._vf['q']['training']:
            if len(self.transition_buffer) >= self._vf['q']['batch_size']:
                transition_batch = self.getTransitionFromBuffer(n=self._vf['q']['batch_size'])
                self.updateValueFunction(transition_batch, 'q')
        if self._vf['s']['training']:
            if len(self.transition_buffer) >= self._vf['s']['batch_size']:
                transition_batch = self.getTransitionFromBuffer(n=self._vf['q']['batch_size'])
                self.updateValueFunction(transition_batch, 's')

        self.prev_state = self.getStateRepresentation(observation)
        self.prev_action = self.action  # another option:** we can again call self.policy function **

        return self.prev_action
コード例 #3
0
    def plan(self):
        return 0
        if self._vf['q']['training']:
            if len(self.planning_transition_buffer) >= self._vf['q']['batch_size']:
                transition_batch = self.getTransitionFromPlanningBuffer(n=self._vf['q']['batch_size'])
                self.updateValueFunction(transition_batch, 'q')

        with torch.no_grad():
            self.updatePlanningBuffer(self.model['backward'], self.state)
            for state in self.getStateFromPlanningBuffer(self.model['backward']):
                action = self.policy(state)
                for j in range(self.model['backward']['plan_horizon']):
                    prev_action = self.backwardRolloutPolicy(state)
                    prev_state = self.rolloutWithModel(state, prev_action, self.model['backward'])
                    reward = -1
                    terminal = self.isTerminal(state)
                    if terminal:
                        reward = 10
                    reward = torch.tensor(reward).unsqueeze(0).to(self.device)
                    x_old = prev_state.float().to(self.device)
                    x_new = state.float().to(self.device) if not terminal else None

                    error = 0
                    if self.is_using_error :
                        error = self.calculateTrueError(state, prev_action)
                    self.update_planning_transition_buffer(utils.transition(x_old, prev_action, reward,
                                                         x_new, action, terminal, self.time_step, error))
                    action = prev_action
                    state = prev_state
コード例 #4
0
    def updateValueFunction(self, transition_batch, vf_type):
        batch = utils.transition(*zip(*transition_batch))

        non_final_mask = torch.tensor(
            tuple(map(lambda s: s is not None,
                      batch.state)), device=self.device, dtype=torch.bool)
        non_final_next_states = torch.cat([s for s in batch.state
                                           if s is not None])
        prev_state_batch = torch.cat(batch.prev_state)
        prev_action_batch = torch.cat(batch.prev_action)
        reward_batch = torch.cat(batch.reward)

        state_action_values = self._vf['q']['network'](prev_state_batch).gather(1, prev_action_batch)
        next_state_values = torch.zeros(self._vf['q']['batch_size'], device=self.device)
        next_state_values[non_final_mask] = self._target_vf['network'](non_final_next_states).max(1)[0].detach()

        expected_state_action_values = (next_state_values * self.gamma) + reward_batch
        loss = F.mse_loss(state_action_values,
                          expected_state_action_values.unsqueeze(1))
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()


        self._target_vf['counter'] += 1
コード例 #5
0
 def learn_from_transition(self, state, a, reward, next_state, is_terminal):
     torch_state = self.getStateRepresentation(state)
     torch_next_state = self.getStateRepresentation(next_state)
     torch_reward = torch.tensor([reward], device=self.device)
     torch_action = torch.tensor([self.getActionIndex(a)],
                                 device=self.device).view(1, 1)
     transition = utils.transition(torch_state, torch_action, torch_reward,
                                   torch_next_state, None, is_terminal,
                                   self.time_step, 0)
     self.updateTransitionBuffer(transition)
コード例 #6
0
    def end(self, reward):
        reward = torch.tensor(reward).unsqueeze(0).to(self.device)

        self.updateTransitionBuffer(utils.transition(self.prev_state, self.prev_action, reward,
                                                     None, None, True, self.time_step, 0))

        if self._vf['q']['training']:
            if len(self.transition_buffer) >= self._vf['q']['batch_size']:
                transition_batch = self.getTransitionFromBuffer(n=self._vf['q']['batch_size'])
                self.updateValueFunction(transition_batch, 'q')
        if self._vf['s']['training']:
            if len(self.transition_buffer) >= self._vf['s']['batch_size']:
                transition_batch = self.getTransitionFromBuffer(n=self._vf['q']['batch_size'])
                self.updateValueFunction(transition_batch, 's')

        self.updateStateRepresentation()