Esempio n. 1
0
 def init_memory(self):
     self.memory = ExtendedReplayMemory(self.memoryCapacity)
class DDPGMultiStageUnit(DDPGAgent):
    def __init__(self,
                 config,
                 actorNets,
                 criticNets,
                 env,
                 optimizers,
                 netLossFunc,
                 nbAction,
                 stateProcessor=None,
                 experienceProcessor=None):

        super(DDPGMultiStageUnit,
              self).__init__(config, actorNets, criticNets, env, optimizers,
                             netLossFunc, nbAction, stateProcessor,
                             experienceProcessor)

    def train(self):
        raise NotImplementedError

    def init_memory(self):
        self.memory = ExtendedReplayMemory(self.memoryCapacity)

    def store_experience(self, state, action, nextState, reward, done, info):
        # if it is one step
        transition = ExtendedTransition(state, action, nextState, reward, done)
        self.memory.push(transition)

    def prepare_minibatch(self, transitions_raw):
        # first store memory

        transitions = ExtendedTransition(*zip(*transitions_raw))
        action = torch.tensor(transitions.action,
                              device=self.device,
                              dtype=torch.float32)
        reward = torch.tensor(transitions.reward,
                              device=self.device,
                              dtype=torch.float32)

        # for some env, the output state requires further processing before feeding to neural network
        if self.stateProcessor is not None:
            state, _ = self.stateProcessor(transitions.state, self.device)
            nonFinalNextState, nonFinalMask, finalNextState, finalMask = self.stateProcessor(
                transitions.next_state, self.device)
        else:
            state = torch.tensor(transitions.state,
                                 device=self.device,
                                 dtype=torch.float32)
            nextState = torch.tensor(transitions.next_state,
                                     device=self.device,
                                     dtype=torch.float32)
            # final mask is one that have stage done
            finalMask = torch.tensor(transitions.done,
                                     device=self.device,
                                     dtype=torch.uint8)
            nonFinalMask = 1 - finalMask
            finalNextState = [
                nextState[i] for i in range(self.trainBatchSize)
                if finalMask[i]
            ]
            nonFinalNextState = [
                nextState[i] for i in range(self.trainBatchSize)
                if nonFinalMask[i]
            ]

        if len(nonFinalNextState):
            nonFinalNextState = torch.stack(nonFinalNextState)

        if len(finalNextState):
            finalNextState = torch.stack(finalNextState)

        return state, nonFinalMask, nonFinalNextState, finalMask, finalNextState, action, reward

    def update_net_on_memory_given_target(self, targetAgent=None):

        if len(self.memory) < self.trainBatchSize:
            return

        transitions_raw = self.memory.sample(self.trainBatchSize)

        state, nonFinalMask, nonFinalNextState, finalMask, finalNextState, action, reward = self.prepare_minibatch(
            transitions_raw)

        # now do net update

        # Critic loss
        QValues = self.criticNet.forward(state, action).squeeze()
        QNext = torch.zeros(self.trainBatchSize,
                            device=self.device,
                            dtype=torch.float32)

        numNonFinalNextState = sum(nonFinalMask)
        numFinalNextState = sum(finalMask)

        if numNonFinalNextState:

            next_actions = self.actorNet_target.forward(nonFinalNextState)

            # if we do not have stage done
            # we use our own target net to bootstrap
            QNext[nonFinalMask] = self.criticNet_target.forward(
                nonFinalNextState, next_actions.detach()).squeeze()

        if numFinalNextState:
            if targetAgent is not None:
                QNext[finalMask] = targetAgent.evaluate_state_value(
                    finalNextState)

        targetValues = reward + self.gamma * QNext
        critic_loss = self.netLossFunc(QValues, targetValues)

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        if self.netGradClip is not None:
            torch.nn.utils.clip_grad_norm_(self.criticNet.parameters(),
                                           self.netGradClip)

        self.critic_optimizer.step()

        # update networks
        if self.learnStepCounter % self.policyUpdateFreq == 0:

            # Actor loss
            # we try to maximize criticNet output(which is state value)
            policy_loss = -self.criticNet.forward(
                state, self.actorNet.forward(state)).mean()

            self.actor_optimizer.zero_grad()
            policy_loss.backward()
            if self.netGradClip is not None:
                torch.nn.utils.clip_grad_norm_(self.actorNet.parameters(),
                                               self.netGradClip)

            self.actor_optimizer.step()

            if self.globalStepCount % self.lossRecordStep == 0:
                self.losses.append([
                    self.globalStepCount, self.epIdx,
                    critic_loss.item(),
                    policy_loss.item()
                ])

        # update networks
        if self.learnStepCounter % self.policyUpdateFreq == 0:
            # update target networks
            for target_param, param in zip(self.actorNet_target.parameters(),
                                           self.actorNet.parameters()):
                target_param.data.copy_(param.data * self.tau +
                                        target_param.data * (1.0 - self.tau))

            for target_param, param in zip(self.criticNet_target.parameters(),
                                           self.criticNet.parameters()):
                target_param.data.copy_(param.data * self.tau +
                                        target_param.data * (1.0 - self.tau))

        self.learnStepCounter += 1

    def evaluate_state_value(self, state):

        optimal_actions = self.actorNet_target.forward(state)
        QNext = self.criticNet_target.forward(
            state, optimal_actions.detach()).squeeze()
        return QNext
Esempio n. 3
0
class DQNMultiStageUnit(DQNAgent):
    def __init__(self,
                 config,
                 policyNet,
                 targetNet,
                 env,
                 optimizer,
                 netLossFunc,
                 nbAction,
                 stateProcessor=None,
                 experienceProcessor=None):
        super(DQNMultiStageUnit,
              self).__init__(config, policyNet, targetNet, env, optimizer,
                             netLossFunc, nbAction, stateProcessor,
                             experienceProcessor)

    def train(self):
        raise NotImplementedError

    def init_memory(self):
        self.memory = ExtendedReplayMemory(self.memoryCapacity)

    def store_experience(self, state, action, nextState, reward, done, info):
        # if it is one step
        transition = ExtendedTransition(state, action, nextState, reward, done)
        self.memory.push(transition)

    def prepare_minibatch(self, transitions_raw):
        # first store memory

        transitions = ExtendedTransition(*zip(*transitions_raw))
        action = torch.tensor(transitions.action,
                              device=self.device,
                              dtype=torch.long).unsqueeze(
                                  -1)  # shape(batch, 1)
        reward = torch.tensor(transitions.reward,
                              device=self.device,
                              dtype=torch.float32).unsqueeze(
                                  -1)  # shape(batch, 1)

        # for some env, the output state requires further processing before feeding to neural network
        if self.stateProcessor is not None:
            state, _ = self.stateProcessor(transitions.state, self.device)
            nonFinalNextState, nonFinalMask, finalNextState, finalMask = self.stateProcessor(
                transitions.next_state, self.device, transitions.done)
        else:
            state = torch.tensor(transitions.state,
                                 device=self.device,
                                 dtype=torch.float32)
            nextState = torch.tensor(transitions.next_state,
                                     device=self.device,
                                     dtype=torch.float32)
            # final mask is one that have stage done
            finalMask = torch.tensor(transitions.done,
                                     device=self.device,
                                     dtype=torch.uint8)
            nonFinalMask = 1 - finalMask
            finalNextState = [
                nextState[i] for i in range(self.trainBatchSize)
                if finalMask[i]
            ]
            nonFinalNextState = [
                nextState[i] for i in range(self.trainBatchSize)
                if nonFinalMask[i]
            ]

            if len(nonFinalNextState):
                nonFinalNextState = torch.stack(nonFinalNextState)

            if len(finalNextState):
                finalNextState = torch.stack(finalNextState)

        return state, nonFinalMask, nonFinalNextState, finalMask, finalNextState, action, reward

    def update_net_on_memory_given_target(self, targetAgent=None):

        if len(self.memory) < self.trainBatchSize:
            return

        transitions_raw = self.memory.sample(self.trainBatchSize)

        state, nonFinalMask, nonFinalNextState, finalMask, finalNextState, action, reward = self.prepare_minibatch(
            transitions_raw)

        # calculate Qvalues based on selected action batch
        QValues = self.policyNet(state).gather(1, action)

        # Here we detach because we do not want gradient flow from target values to net parameters
        QNext = torch.zeros(self.trainBatchSize,
                            device=self.device,
                            dtype=torch.float32)

        numNonFinalNextState = sum(nonFinalMask)
        numFinalNextState = sum(finalMask)

        if numNonFinalNextState:
            # if we do not have stage done
            # we use our own target net to bootstrap
            QNext[nonFinalMask] = self.targetNet(nonFinalNextState).max(
                1)[0].detach()

        if numFinalNextState:
            # if we have stage done,
            # 1) if it is not the last stage, we use external target agent to bootstrap
            # 2) if it is the last stage, we do not bootstrap
            if targetAgent is not None:
                QNext[finalMask] = targetAgent.evaluate_state_value(
                    finalNextState)

        targetValues = reward + (self.gamma) * QNext.unsqueeze(-1)

        # Compute loss
        loss_single = self.netLossFunc(QValues, targetValues)
        loss = torch.mean(loss_single)

        # Optimize the model
        # zero gradient
        self.optimizer.zero_grad()

        loss.backward()
        if self.netGradClip is not None:
            torch.nn.utils.clip_grad_norm_(self.policyNet.parameters(),
                                           self.netGradClip)
        self.optimizer.step()

        if self.globalStepCount % self.lossRecordStep == 0:
            self.losses.append([self.globalStepCount, self.epIdx, loss])

        if self.learnStepCounter % self.targetNetUpdateStep == 0:
            self.targetNet.load_state_dict(self.policyNet.state_dict())

        self.learnStepCounter += 1

    def evaluate_state_value(self, state):
        return self.targetNet(state).max(1)[0].detach()
Esempio n. 4
0
    def init_memory(self):

        self.memories = [
            ExtendedReplayMemory(self.memoryCapacity)
            for _ in range(len(self.policyNets))
        ]