Example #1
0
    def __init__(self, config, localNet, env, globalNets, globalOptimizer, netLossFunc, nbAction, rank,
                 globalEpisodeCount, globalEpisodeReward, globalRunningAvgReward, resultQueue, logFolder,
                stateProcessor = None, lock = None):

        self.globalPolicyNet = globalNets[0]
        self.globalTargetNet = globalNets[1]
        self.rank = rank
        self.globalOptimizer = globalOptimizer
        self.localNet = localNet

        mp.Process.__init__(self)
        DQNAgent.__init__(self. config, localNet, None, None, netLossFunc, nbAction, stateProcessor, )


        self.totalStep = 0
        self.updateGlobalFrequency = 10
        if 'updateGlobalFrequency' in self.config:
            self.updateGlobalFrequency = self.config['updateGlobalFrequency']


        self.globalEpisodeCount = globalEpisodeCount
        self.globalEpisodeReward = globalEpisodeReward
        self.globalRunningAvgReward = globalRunningAvgReward
        self.resultQueue = resultQueue
        self.dirName = logFolder

        self.randomSeed = 1 + self.rank
        if 'randomSeed' in self.config:
            self.randomSeed = self.config['randomSeed'] + self.rank
        torch.manual_seed(self.randomSeed)

        self.nStepForward = 1
        if 'nStepForward' in self.config:
            self.nStepForward = self.config['nStepForward']
        self.targetNetUpdateEpisode = 10
        if 'targetNetUpdateEpisode' in self.config:
            self.targetNetUpdateEpisode = self.config['targetNetUpdateEpisode']

        self.nStepBuffer = []

        # only use vanilla replay memory
        self.memory = ReplayMemory(self.memoryCapacity)

        self.priorityMemoryOption = False

        # use synthetic lock or not
        self.synchLock = False
        if 'synchLock' in self.config:
            self.synchLock = self.config['synchLock']

        self.lock = lock

        self.device = 'cpu'
        if 'device' in self.config and torch.cuda.is_available():
            self.device = self.config['device']
            torch.cuda.manual_seed(self.randomSeed)
            self.localNet = self.localNet.cuda()
Example #2
0
    def init_memory(self):

        if self.priorityMemoryOption:
            self.memory = PrioritizedReplayMemory(self.memoryCapacity, self.config)
        else:
            # most commonly experience replay memory
            if self.memoryOption == 'natural':
                self.memory = ReplayMemory(self.memoryCapacity)
            elif self.memoryOption == 'reward':
                self.memory = ReplayMemoryReward(self.memoryCapacity, self.config['rewardMemoryBackupStep'],
                                                 self.gamma, self.config['rewardMemoryTerminalRatio'] )
Example #3
0
class DQNAsynERWorker(DQNAgent, mp.Process):
    def __init__(self, config, localNet, env, globalNets, globalOptimizer, netLossFunc, nbAction, rank,
                 globalEpisodeCount, globalEpisodeReward, globalRunningAvgReward, resultQueue, logFolder,
                stateProcessor = None, lock = None):

        self.globalPolicyNet = globalNets[0]
        self.globalTargetNet = globalNets[1]
        self.rank = rank
        self.globalOptimizer = globalOptimizer
        self.localNet = localNet

        mp.Process.__init__(self)
        DQNAgent.__init__(self. config, localNet, None, None, netLossFunc, nbAction, stateProcessor, )


        self.totalStep = 0
        self.updateGlobalFrequency = 10
        if 'updateGlobalFrequency' in self.config:
            self.updateGlobalFrequency = self.config['updateGlobalFrequency']


        self.globalEpisodeCount = globalEpisodeCount
        self.globalEpisodeReward = globalEpisodeReward
        self.globalRunningAvgReward = globalRunningAvgReward
        self.resultQueue = resultQueue
        self.dirName = logFolder

        self.randomSeed = 1 + self.rank
        if 'randomSeed' in self.config:
            self.randomSeed = self.config['randomSeed'] + self.rank
        torch.manual_seed(self.randomSeed)

        self.nStepForward = 1
        if 'nStepForward' in self.config:
            self.nStepForward = self.config['nStepForward']
        self.targetNetUpdateEpisode = 10
        if 'targetNetUpdateEpisode' in self.config:
            self.targetNetUpdateEpisode = self.config['targetNetUpdateEpisode']

        self.nStepBuffer = []

        # only use vanilla replay memory
        self.memory = ReplayMemory(self.memoryCapacity)

        self.priorityMemoryOption = False

        # use synthetic lock or not
        self.synchLock = False
        if 'synchLock' in self.config:
            self.synchLock = self.config['synchLock']

        self.lock = lock

        self.device = 'cpu'
        if 'device' in self.config and torch.cuda.is_available():
            self.device = self.config['device']
            torch.cuda.manual_seed(self.randomSeed)
            self.localNet = self.localNet.cuda()

    def epsilon_by_episode(self, step):
        return self.epsilon_final + (
                self.epsilon_start - self.epsilon_final) * math.exp(-1. * step / self.epsilon_decay)


    def run(self):
        torch.set_num_threads(1)
        bufferState, bufferAction, bufferReward, bufferNextState = [], [], [], []
        for self.epIdx in range(self.trainStep):

            print("episode index:" + str(self.epIdx) + " from" + current_process().name + "\n")
            state = self.env.reset()
            done = False
            rewardSum = 0


            # clear the nstep buffer
            self.nStepBuffer.clear()

            for stepCount in range(self.episodeLength):

                epsilon = self.epsilon_by_episode(self.globalEpisodeCount.value)
                action = self.select_action(self.localNet, state, epsilon)
                nextState, reward, done, info = self.env.step(action)

                if stepCount == 0:
                    print("at step 0: from " + current_process().name + "\n")
                    print(info)

                if done:
                    nextState = None

                self.update_net_and_sync(state, action, nextState, reward)

                state = nextState
                rewardSum += reward * pow(self.gamma, stepCount)

                self.totalStep += 1
                if done:
#                    print("done in step count: {}".format(stepCount))
#                    print("reward sum = " + str(rewardSum))
                # done and print information
                #    pass
                    break

            self.recordInfo(rewardSum, stepCount)


        self.resultQueue.put(None)

    def recordInfo(self, reward, stepCount):
        with self.globalEpisodeReward.get_lock():
            self.globalEpisodeReward.value = reward
        with self.globalRunningAvgReward.get_lock():
            self.globalRunningAvgReward.value = (self.globalRunningAvgReward.value * self.globalEpisodeCount.value + reward) / (
                        self.globalEpisodeCount.value + 1)
        with self.globalEpisodeCount.get_lock():
            self.globalEpisodeCount.value += 1
            if self.config['logFlag'] and self.globalEpisodeCount.value % self.config['logFrequency'] == 0:
                self.save_checkpoint()
                # sync global target to global policy net
            if self.globalEpisodeCount.value % self.targetNetUpdateEpisode == 0:
                self.globalTargetNet.load_state_dict(self.globalPolicyNet.state_dict())

        # resultQueue.put(globalEpisodeReward.value)
        self.resultQueue.put(
            [self.globalEpisodeCount.value, stepCount, self.globalEpisodeReward.value, self.globalRunningAvgReward.value])
        print(self.name)
        print("Episode: ", self.globalEpisodeCount.value)
        print("stepCount: ", stepCount)
        print("Episode Reward: ", self.globalEpisodeReward.value)
        print("Episode Running Average Reward: ", self.globalRunningAvgReward.value)

    def save_checkpoint(self):
        prefix = self.dirName + 'Epoch' + str(self.globalEpisodeCount.value)
        torch.save({
            'epoch': self.globalEpisodeCount.value + 1,
            'model_state_dict': self.globalPolicyNet.state_dict(),
            'optimizer_state_dict': self.globalOptimizer.state_dict(),
        }, prefix + '_checkpoint.pt')

    def update_net_and_sync(self, state, action, nextState, reward):

        self.store_experience(state, action, nextState, reward)

        if self.priorityMemoryOption:
            if len(self.memory) < self.config['memoryCapacity']:
                return
        else:
            if len(self.memory) < self.trainBatchSize:
                return

        if self.totalStep % self.updateGlobalFrequency == 0:
            transitions_raw = self.memory.sample(self.trainBatchSize)
            transitions = Transition(*zip(*transitions_raw))
            action = torch.tensor(transitions.action, device=self.device, dtype=torch.long).unsqueeze(
                -1)  # shape(batch, 1)
            reward = torch.tensor(transitions.reward, device=self.device, dtype=torch.float32).unsqueeze(
                -1)  # shape(batch, 1)
            batchSize = reward.shape[0]


            # for some env, the output state requires further processing before feeding to neural network
            if self.stateProcessor is not None:
                state, _ = self.stateProcessor(transitions.state, self.device)
                nonFinalNextState, nonFinalMask = self.stateProcessor(transitions.next_state, self.device)
            else:
                state = torch.tensor(transitions.state, device=self.device, dtype=torch.float32)
                nonFinalMask = torch.tensor([s is not None for s in transitions.next_state],
                                            device=self.device, dtype=torch.uint8)
                nonFinalNextState = torch.tensor([s for s in transitions.next_state if s is not None],
                                                 device=self.device, dtype=torch.float32)
            if self.synchLock:

                self.lock.acquire()
                QValues = self.globalPolicyNet(state).gather(1, action)

                if self.netUpdateOption == 'targetNet':
                    # Here we detach because we do not want gradient flow from target values to net parameters
                    QNext = torch.zeros(batchSize, device=self.device, dtype=torch.float32)
                    QNext[nonFinalMask] = self.globalTargetNet(nonFinalNextState).max(1)[0].detach()
                    targetValues = reward + self.gamma * QNext.unsqueeze(-1)
                if self.netUpdateOption == 'policyNet':
                    raise NotImplementedError
                    targetValues = reward + self.gamma * torch.max(self.globalPolicyNet(nextState).detach(), dim=1)[0].unsqueeze(-1)
                if self.netUpdateOption == 'doubleQ':
                     # select optimal action from policy net
                     with torch.no_grad():
                        batchAction = self.globalPolicyNet(nonFinalNextState).max(dim=1)[1].unsqueeze(-1)
                        QNext = torch.zeros(batchSize, device=self.device, dtype=torch.float32).unsqueeze(-1)
                        QNext[nonFinalMask] = self.globalTargetNet(nonFinalNextState).gather(1, batchAction)
                        targetValues = reward + self.gamma * QNext

                loss = self.netLossFunc(QValues, targetValues)

                self.globalOptimizer.zero_grad()

                loss.backward()

                if self.netGradClip is not None:
                    torch.nn.utils.clip_grad_norm_(self.globalPolicyNet.parameters(), self.netGradClip)

                # global net update
                self.globalOptimizer.step()
                #
                # # update local net
                self.localNet.load_state_dict(self.globalPolicyNet.state_dict())

                self.lock.release()
            else:

                # update local net
                self.localNet.load_state_dict(self.globalPolicyNet.state_dict())

                QValues = self.localNet(state).gather(1, action)

                if self.netUpdateOption == 'targetNet':
                    # Here we detach because we do not want gradient flow from target values to net parameters
                    QNext = torch.zeros(batchSize, device=self.device, dtype=torch.float32)
                    QNext[nonFinalMask] = self.globalTargetNet(nonFinalNextState).max(1)[0].detach()
                    targetValues = reward + self.gamma * QNext.unsqueeze(-1)
                if self.netUpdateOption == 'policyNet':
                    raise NotImplementedError
                    targetValues = reward + self.gamma * torch.max(self.globalPolicyNet(nextState).detach(), dim=1)[
                        0].unsqueeze(-1)
                if self.netUpdateOption == 'doubleQ':
                    # select optimal action from policy net
                    with torch.no_grad():
                        batchAction = self.localNet(nonFinalNextState).max(dim=1)[1].unsqueeze(-1)
                        QNext = torch.zeros(batchSize, device=self.device, dtype=torch.float32).unsqueeze(-1)
                        QNext[nonFinalMask] = self.globalTargetNet(nonFinalNextState).gather(1, batchAction)
                        targetValues = reward + self.gamma * QNext

                loss = self.netLossFunc(QValues, targetValues)

                loss.backward()

                self.lock.acquire()

                self.globalOptimizer.zero_grad()

                for lp, gp in zip(self.localNet.parameters(), self.globalPolicyNet.parameters()):
                    if self.device == 'cpu':
                        gp._grad = lp._grad
                    else:
                        gp._grad = lp._grad.cpu()

                if self.netGradClip is not None:
                    torch.nn.utils.clip_grad_norm_(self.globalPolicyNet.parameters(), self.netGradClip)

                # global net update
                self.globalOptimizer.step()

                self.lock.release()
                #
                # # update local net
                self.localNet.load_state_dict(self.globalPolicyNet.state_dict())

    def test_multiProcess(self):
        print("Hello, World! from " + current_process().name + "\n")
        print(self.globalPolicyNet.state_dict())
        for gp in self.globalPolicyNet.parameters():
            gp.grad = torch.ones_like(gp)
            #gp.grad.fill_(1)

        self.globalOptimizer.step()

        print('globalNetID:')
        print(id(self.globalPolicyNet))
        print('globalOptimizer:')
        print(id(self.globalOptimizer))
        print('localNetID:')
        print(id(self.localNet))
 def init_memory(self):
     self.memory = ReplayMemory(self.memoryCapacity)
Example #5
0
class TDDDPGAgent(DDPGAgent):
    def __init__(self,
                 config,
                 actorNets,
                 criticNets,
                 env,
                 optimizers,
                 netLossFunc,
                 nbAction,
                 stateProcessor=None,
                 experienceProcessor=None):

        super(TDDDPGAgent, self).__init__(config, actorNets, criticNets, env,
                                          optimizers, netLossFunc, nbAction,
                                          stateProcessor, experienceProcessor)

    def initalizeNets(self, actorNets, criticNets, optimizers):
        self.actorNet = actorNets['actor']
        self.actorNet_target = actorNets[
            'target'] if 'target' in actorNets else None
        self.criticNetOne = criticNets['criticOne']
        self.criticNet_targetOne = criticNets[
            'targetOne'] if 'targetOne' in criticNets else None
        self.criticNetTwo = criticNets['criticTwo']
        self.criticNet_targetTwo = criticNets[
            'targetTwo'] if 'targetTwo' in criticNets else None

        self.actor_optimizer = optimizers['actor']
        self.criticOne_optimizer = optimizers['criticOne']
        self.criticTwo_optimizer = optimizers['criticTwo']

        self.net_to_device()

    def init_memory(self):
        self.memory = ReplayMemory(self.memoryCapacity)

    def read_config(self):
        super(TDDDPGAgent, self).read_config()

        self.policyUpdateFreq = 2
        if 'policyUpdateFreq' in self.config:
            self.policyUpdateFreq = self.config['policyUpdateFreq']
        self.policySmoothNoise = 0.01
        if 'policySmoothNoise' in self.config:
            self.policyUpdateFreq = self.config['policySmoothNoise']

    def net_to_device(self):
        # move model to correct device
        self.actorNet = self.actorNet.to(self.device)
        self.criticNetOne = self.criticNetOne.to(self.device)
        self.criticNetTwo = self.criticNetTwo.to(self.device)

        # in case targetNet is None
        if self.actorNet_target is not None:
            self.actorNet_target = self.actorNet_target.to(self.device)
        # in case targetNet is None
        if self.criticNet_targetOne is not None:
            self.criticNet_targetOne = self.criticNet_targetOne.to(self.device)
        if self.criticNet_targetTwo is not None:
            self.criticNet_targetTwo = self.criticNet_targetTwo.to(self.device)

    def prepare_minibatch(self, state, action, nextState, reward, info):
        # first store memory

        self.store_experience(state, action, nextState, reward, info)
        if len(self.memory) < self.trainBatchSize:
            return
        transitions_raw = self.memory.sample(self.trainBatchSize)
        transitions = Transition(*zip(*transitions_raw))
        action = torch.tensor(transitions.action,
                              device=self.device,
                              dtype=torch.float32)  # shape(batch, numActions)
        reward = torch.tensor(transitions.reward,
                              device=self.device,
                              dtype=torch.float32)  # shape(batch)

        # for some env, the output state requires further processing before feeding to neural network
        if self.stateProcessor is not None:
            state, _ = self.stateProcessor(transitions.state, self.device)
            nonFinalNextState, nonFinalMask = self.stateProcessor(
                transitions.next_state, self.device)
        else:
            state = torch.tensor(transitions.state,
                                 device=self.device,
                                 dtype=torch.float32)
            nonFinalMask = torch.tensor(tuple(
                map(lambda s: s is not None, transitions.next_state)),
                                        device=self.device,
                                        dtype=torch.uint8)
            nonFinalNextState = torch.tensor(
                [s for s in transitions.next_state if s is not None],
                device=self.device,
                dtype=torch.float32)

        return state, nonFinalMask, nonFinalNextState, action, reward

    def update_net(self, state, action, nextState, reward, info):

        # state, nonFinalMask, nonFinalNextState, action, reward = self.prepare_minibatch(state, action, nextState, reward, info)
        self.store_experience(state, action, nextState, reward, info)
        if len(self.memory) < self.trainBatchSize:
            return
        transitions_raw = self.memory.sample(self.trainBatchSize)
        transitions = Transition(*zip(*transitions_raw))
        action = torch.tensor(transitions.action,
                              device=self.device,
                              dtype=torch.float32)  # shape(batch, numActions)
        reward = torch.tensor(transitions.reward,
                              device=self.device,
                              dtype=torch.float32)  # shape(batch)

        # for some env, the output state requires further processing before feeding to neural network
        if self.stateProcessor is not None:
            state, _ = self.stateProcessor(transitions.state, self.device)
            nonFinalNextState, nonFinalMask = self.stateProcessor(
                transitions.next_state, self.device)
        else:
            state = torch.tensor(transitions.state,
                                 device=self.device,
                                 dtype=torch.float32)
            nonFinalMask = torch.tensor(tuple(
                map(lambda s: s is not None, transitions.next_state)),
                                        device=self.device,
                                        dtype=torch.uint8)
            nonFinalNextState = torch.tensor(
                [s for s in transitions.next_state if s is not None],
                device=self.device,
                dtype=torch.float32)

        batchSize = reward.shape[0]

        # Critic loss
        QValuesOne = self.criticNetOne.forward(state, action).squeeze()
        QValuesTwo = self.criticNetTwo.forward(state, action).squeeze()

        actionNoise = torch.randn((nonFinalNextState.shape[0], self.numAction),
                                  dtype=torch.float32,
                                  device=self.device)
        next_actions = self.actorNet_target.forward(
            nonFinalNextState) + actionNoise * self.policySmoothNoise

        # next_actions = self.actorNet_target.forward(nonFinalNextState)

        QNext = torch.zeros(batchSize, device=self.device, dtype=torch.float32)
        QNextCriticOne = self.criticNet_targetOne.forward(
            nonFinalNextState, next_actions.detach()).squeeze()
        QNextCriticTwo = self.criticNet_targetTwo.forward(
            nonFinalNextState, next_actions.detach()).squeeze()

        QNext[nonFinalMask] = torch.min(QNextCriticOne, QNextCriticTwo)

        targetValues = reward + self.gamma * QNext

        criticOne_loss = self.netLossFunc(QValuesOne, targetValues)
        criticTwo_loss = self.netLossFunc(QValuesTwo, targetValues)

        self.criticOne_optimizer.zero_grad()
        self.criticTwo_optimizer.zero_grad()

        # https://jdhao.github.io/2017/11/12/pytorch-computation-graph/
        criticOne_loss.backward(retain_graph=True)
        criticTwo_loss.backward()

        if self.netGradClip is not None:
            torch.nn.utils.clip_grad_norm_(self.criticNetOne.parameters(),
                                           self.netGradClip)
            torch.nn.utils.clip_grad_norm_(self.criticNetTwo.parameters(),
                                           self.netGradClip)

        self.criticOne_optimizer.step()
        self.criticTwo_optimizer.step()

        if self.learnStepCounter % self.policyUpdateFreq:
            # Actor loss
            # we try to maximize criticNet output(which is state value)
            policy_loss = -self.criticNetOne.forward(
                state, self.actorNet.forward(state)).mean()

            # update networks
            self.actor_optimizer.zero_grad()
            policy_loss.backward()
            if self.netGradClip is not None:
                torch.nn.utils.clip_grad_norm_(self.actorNet.parameters(),
                                               self.netGradClip)

            self.actor_optimizer.step()

            if self.globalStepCount % self.lossRecordStep == 0:
                self.losses.append([
                    self.globalStepCount, self.epIdx,
                    criticOne_loss.item(),
                    criticTwo_loss.item(),
                    policy_loss.item()
                ])

                # update target networks
                for target_param, param in zip(
                        self.actorNet_target.parameters(),
                        self.actorNet.parameters()):
                    target_param.data.copy_(param.data * self.tau +
                                            target_param.data *
                                            (1.0 - self.tau))

                for target_param, param in zip(
                        self.criticNet_targetOne.parameters(),
                        self.criticNetOne.parameters()):
                    target_param.data.copy_(param.data * self.tau +
                                            target_param.data *
                                            (1.0 - self.tau))

                for target_param, param in zip(
                        self.criticNet_targetTwo.parameters(),
                        self.criticNetTwo.parameters()):
                    target_param.data.copy_(param.data * self.tau +
                                            target_param.data *
                                            (1.0 - self.tau))

        self.learnStepCounter += 1

    def save_all(self):
        prefix = self.dirName + self.identifier + 'Finalepoch' + str(
            self.epIdx)
        self.saveLosses(prefix + '_loss.txt')
        self.saveRewards(prefix + '_reward.txt')
        with open(prefix + '_memory.pickle', 'wb') as file:
            pickle.dump(self.memory, file)

        torch.save(
            {
                'epoch':
                self.epIdx,
                'globalStep':
                self.globalStepCount,
                'actorNet_state_dict':
                self.actorNet.state_dict(),
                'criticNetOne_state_dict':
                self.criticNetOne.state_dict(),
                'criticNetTwo_state_dict':
                self.criticNetTwo.state_dict(),
                'actor_optimizer_state_dict':
                self.actor_optimizer.state_dict(),
                'criticOne_optimizer_state_dict':
                self.criticOne_optimizer.state_dict(),
                'criticTwo_optimizer_state_dict':
                self.criticOne_optimizer.state_dict()
            }, prefix + '_checkpoint.pt')

    def save_checkpoint(self):
        prefix = self.dirName + self.identifier + 'Epoch' + str(self.epIdx)
        self.saveLosses(prefix + '_loss.txt')
        self.saveRewards(prefix + '_reward.txt')
        with open(prefix + '_memory.pickle', 'wb') as file:
            pickle.dump(self.memory, file)

        torch.save(
            {
                'epoch':
                self.epIdx,
                'globalStep':
                self.globalStepCount,
                'actorNet_state_dict':
                self.actorNet.state_dict(),
                'criticNetOne_state_dict':
                self.criticNetOne.state_dict(),
                'criticNetTwo_state_dict':
                self.criticNetTwo.state_dict(),
                'actor_optimizer_state_dict':
                self.actor_optimizer.state_dict(),
                'criticOne_optimizer_state_dict':
                self.criticOne_optimizer.state_dict(),
                'criticTwo_optimizer_state_dict':
                self.criticTwo_optimizer.state_dict()
            }, prefix + '_checkpoint.pt')

    def load_checkpoint(self, prefix):
        self.loadLosses(prefix + '_loss.txt')
        self.loadRewards(prefix + '_reward.txt')
        with open(prefix + '_memory.pickle', 'rb') as file:
            self.memory = pickle.load(file)

        checkpoint = torch.load(prefix + '_checkpoint.pt')
        self.epIdx = checkpoint['epoch']
        self.globalStepCount = checkpoint['globalStep']
        self.actorNet.load_state_dict(checkpoint['actorNet_state_dict'])
        self.actorNet_target.load_state_dict(checkpoint['actorNet_state_dict'])
        self.criticNetOne.load_state_dict(
            checkpoint['criticNetOne_state_dict'])
        self.criticNet_targetOne.load_state_dict(
            checkpoint['criticNetOne_state_dict'])
        self.criticNetTwo.load_state_dict(
            checkpoint['criticNetTwo_state_dict'])
        self.criticNet_targetTwo.load_state_dict(
            checkpoint['criticNetTwo_state_dict'])

        self.actor_optimizer.load_state_dict(
            checkpoint['actor_optimizer_state_dict'])
        self.criticOne_optimizer.load_state_dict(
            checkpoint['criticOne_optimizer_state_dict'])
        self.criticTwo_optimizer.load_state_dict(
            checkpoint['criticTwo_optimizer_state_dict'])
 def init_memory(self):
     self.memories = [ReplayMemory(self.memoryCapacity) for _ in range(self.episodeLength)]
Example #7
0
from Agents.Core.ReplayMemory import ReplayMemory, Transition
#from ..Agents.Core.ReplayMemory import ReplayMemory, Transition
import torch
import numpy as np
import pickle

state1 = np.random.rand(5, 5)
state2 = np.random.rand(5, 5)
state3 = np.random.rand(5, 5)
state4 = np.random.rand(5, 5)

tran1 = Transition(state1, 1, state2, 1)
tran2 = Transition(state3, 2, state4, 2)
memory = ReplayMemory(10)
memory.push(tran1)
memory.push(tran2)
print(memory)

file = open('memory.pickle', 'wb')
pickle.dump(memory, file)
file.close()

with open('memory.pickle', 'rb') as file:
    memory2 = pickle.load(file)

print(memory2)
Example #8
0
from Agents.Core.ReplayMemory import ReplayMemory, Transition
#from ..Agents.Core.ReplayMemory import ReplayMemory, Transition
import torch

tran1 = Transition(1, 1, 1, 1)
tran2 = Transition(2, 2, 2, 2)
memory = ReplayMemory(10)
memory.push(tran1)
memory.push(tran2)
memory.push(3, 3, 3, 3)
print(memory)

memory.write_to_text('memoryOut.txt')

toTensor = memory.totensor()
toTensor2 = torch.tensor(memory.sample(2))
for i in range(5, 50):
    tran = Transition(i, i, i, i)
    memory.push(tran)

print(memory)
memory.clear()
print(memory)
print(toTensor)
print(toTensor2)
                 env,
                 optimizer,
                 torch.nn.MSELoss(reduction='none'),
                 N_A,
                 config=config)

xSet = np.linspace(-1, 1, 100)
policy = np.zeros_like(xSet)
for i, x in enumerate(xSet):
    policy[i] = agent.getPolicy(np.array([x]))

np.savetxt('StabilizerPolicyBeforeTrain.txt', policy, fmt='%d')

#agent.perform_random_exploration(10)
agent.train()
storeMemory = ReplayMemory(100000)
agent.testPolicyNet(100, storeMemory)
storeMemory.write_to_text('testPolicyMemory.txt')


def customPolicy(state):
    x = state[0]
    # move towards negative
    if x > 0.1:
        action = 2
    # move towards positive
    elif x < -0.1:
        action = 1
    # do not move
    else:
        action = 0
Example #10
0
class DDPGAgent:
    """class for DDPG agents.
        This class contains implementation of DDPG learning. It contains enhancement of experience augmentation, hindsight experience replay.
        # Arguments
            config: a dictionary for training parameters
            actors: actor net and its target net
            criticNets: critic net and its target net
            env: environment for the agent to interact. env should implement same interface of a gym env
            optimizers: network optimizers for both actor net and critic
            netLossFunc: loss function of the network, e.g., mse
            nbAction: number of actions
            stateProcessor: a function to process output from env, processed state will be used as input to the networks
            experienceProcessor: additional steps to process an experience
        """
    def __init__(self,
                 config,
                 actorNets,
                 criticNets,
                 env,
                 optimizers,
                 netLossFunc,
                 nbAction,
                 stateProcessor=None,
                 experienceProcessor=None):

        self.config = config
        self.read_config()
        self.env = env
        self.numAction = nbAction
        self.stateProcessor = stateProcessor
        self.netLossFunc = netLossFunc
        self.experienceProcessor = experienceProcessor

        self.initialization()
        self.init_memory()
        self.initalizeNets(actorNets, criticNets, optimizers)

    def initalizeNets(self, actorNets, criticNets, optimizers):
        '''
        initialize networks and their optimizers; move them to specified device (i.e., cpu or cuda)
        '''
        self.actorNet = actorNets['actor']
        self.actorNet_target = actorNets[
            'target'] if 'target' in actorNets else None
        self.criticNet = criticNets['critic']
        self.criticNet_target = criticNets[
            'target'] if 'target' in criticNets else None
        self.actor_optimizer = optimizers['actor']
        self.critic_optimizer = optimizers['critic']

        self.net_to_device()

    def init_memory(self):
        '''
        initialize replay memory
        '''
        self.memory = ReplayMemory(self.memoryCapacity)

    def read_config(self):
        '''
        read parameters from self.config object
        initialize various flags and parameters
        trainStep: number of episodes to train
        targetNetUpdateStep: frequency in terms of training steps/episodes to reset target net
        trainBatchSize: mini batch size for gradient decent.
        gamma: discount factor
        tau: soft update parameter
        memoryCapacity: memory capacity for experience storage
        netGradClip: gradient clipping parameter
        netUpdateOption: allowed strings are targetNet, policyNet, doubleQ
        verbose: bool, default false.
        nStepForward: multiple-step forward Q learning, default 1
        lossRecordStep: frequency to store loss.
        episodeLength: maximum steps in an episode
        netUpdateFrequency: frequency to perform gradient decent
        netUpdateStep: number of steps for gradient decent
        device: cpu or cuda
        randomSeed
        hindSightER: bool variable for hindsight experience replay
        hindSightERFreq: frequency to perform hindsight experience replay
        experienceAugmentation: additional experience augmentation function
        return: None
        '''
        self.trainStep = self.config['trainStep']
        self.targetNetUpdateStep = 10000
        if 'targetNetUpdateStep' in self.config:
            self.targetNetUpdateStep = self.config['targetNetUpdateStep']

        self.trainBatchSize = self.config['trainBatchSize']
        self.gamma = self.config['gamma']
        self.tau = self.config['tau']

        self.netGradClip = None
        if 'netGradClip' in self.config:
            self.netGradClip = self.config['netGradClip']
        self.netUpdateOption = 'targetNet'
        if 'netUpdateOption' in self.config:
            self.netUpdateOption = self.config['netUpdateOption']
        self.verbose = False
        if 'verbose' in self.config:
            self.verbose = self.config['verbose']
        self.netUpdateFrequency = 1
        if 'netUpdateFrequency' in self.config:
            self.netUpdateFrequency = self.config['netUpdateFrequency']
        self.nStepForward = 1
        if 'nStepForward' in self.config:
            self.nStepForward = self.config['nStepForward']
        self.lossRecordStep = 10
        if 'lossRecordStep' in self.config:
            self.lossRecordStep = self.config['lossRecordStep']
        self.episodeLength = 500
        if 'episodeLength' in self.config:
            self.episodeLength = self.config['episodeLength']

        self.verbose = False
        if 'verbose' in self.config:
            self.verbose = self.config['verbose']

        self.device = 'cpu'
        if 'device' in self.config and torch.cuda.is_available():
            self.device = self.config['device']

        self.randomSeed = 1
        if 'randomSeed' in self.config:
            self.randomSeed = self.config['randomSeed']

        self.memoryCapacity = self.config['memoryCapacity']

        self.hindSightER = False
        if 'hindSightER' in self.config:
            self.hindSightER = self.config['hindSightER']
            self.hindSightERFreq = self.config['hindSightERFreq']

        self.experienceAugmentation = False
        if 'experienceAugmentation' in self.config:
            self.experienceAugmentation = self.config['experienceAugmentation']
            self.experienceAugmentationFreq = self.config[
                'experienceAugmentationFreq']

        self.policyUpdateFreq = 1
        if 'policyUpdateFreq' in self.config:
            self.policyUpdateFreq = self.config['policyUpdateFreq']

    def net_to_device(self):
        '''
         move model to the specified devices
        '''

        self.actorNet = self.actorNet.to(self.device)
        self.criticNet = self.criticNet.to(self.device)

        # in case targetNet is None
        if self.actorNet_target is not None:
            self.actorNet_target = self.actorNet_target.to(self.device)
        # in case targetNet is None
        if self.criticNet_target is not None:
            self.criticNet_target = self.criticNet_target.to(self.device)

    def initialization(self):

        self.dirName = 'Log/'
        if 'dataLogFolder' in self.config:
            self.dirName = self.config['dataLogFolder']
        if not os.path.exists(self.dirName):
            os.makedirs(self.dirName)

        self.identifier = ''
        self.epIdx = 0
        self.learnStepCounter = 0  # for target net update
        self.globalStepCount = 0
        self.losses = []
        self.rewards = []

        self.runningAvgEpisodeReward = 0.0

    def select_action(self, net=None, state=None, noiseFlag=False):
        '''
        select action from net. The action selection is delegated to network to implement the method of 'select_action'
        # Arguments
        net: which net used for action selection. default is actorNet
        state: observation or state as the input to the net
        noiseFlag: if set False, will perform greedy selection. if True, will add noise from OU processes.
        return: numpy array of actions
        '''

        if net is None:
            net = self.actorNet

        with torch.no_grad():
            # self.policyNet(torch.from_numpy(state.astype(np.float32)).unsqueeze(0))
            # here state[np.newaxis,:] is to add a batch dimension
            if self.stateProcessor is not None:
                state, _ = self.stateProcessor([state], self.device)
                action = net.select_action(state, noiseFlag)
            else:
                stateTorch = torch.from_numpy(
                    np.array(state[np.newaxis, :], dtype=np.float32))
                action = net.select_action(stateTorch.to(self.device),
                                           noiseFlag)

        return action.cpu().data.numpy()[0]

    def process_hindSightExperience(self, state, action, nextState, reward,
                                    info):
        if nextState is not None and self.globalStepCount % self.hindSightERFreq == 0:
            stateNew, actionNew, nextStateNew, rewardNew = self.env.getHindSightExperience(
                state, action, nextState, info)
            if stateNew is not None:
                transition = Transition(stateNew, actionNew, nextStateNew,
                                        rewardNew)
                self.memory.push(transition)
                if self.experienceAugmentation:
                    self.process_experienceAugmentation(
                        state, action, nextState, reward, info)

    def process_experienceAugmentation(self, state, action, nextState, reward,
                                       info):
        if self.globalStepCount % self.experienceAugmentationFreq == 0:
            state_Augs, action_Augs, nextState_Augs, reward_Augs = self.env.getExperienceAugmentation(
                state, action, nextState, reward, info)
            for i in range(len(state_Augs)):
                transition = Transition(state_Augs[i], action_Augs[i],
                                        nextState_Augs[i], reward_Augs[i])
                self.memory.push(transition)

    def store_experience(self, state, action, nextState, reward, info):
        if self.experienceProcessor is not None:
            state, action, nextState, reward = self.experienceProcessor(
                state, action, nextState, reward, info)

        transition = Transition(state, action, nextState, reward)
        self.memory.push(transition)

        if self.experienceAugmentation:
            self.process_experienceAugmentation(state, action, nextState,
                                                reward, info)

        if self.hindSightER:
            self.process_hindSightExperience(state, action, nextState, reward,
                                             info)

    def prepare_minibatch(self, transitions_raw):
        '''
        do some proprocessing work for transitions_raw
        order the data
        convert transition list to torch tensors
        use trick from https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html
        https://stackoverflow.com/questions/19339/transpose-unzip-function-inverse-of-zip/19343#19343
        '''

        transitions = Transition(*zip(*transitions_raw))
        action = torch.tensor(transitions.action,
                              device=self.device,
                              dtype=torch.float32)  # shape(batch, numActions)
        reward = torch.tensor(transitions.reward,
                              device=self.device,
                              dtype=torch.float32)  # shape(batch)

        # for some env, the output state requires further processing before feeding to neural network
        if self.stateProcessor is not None:
            state, _ = self.stateProcessor(transitions.state, self.device)
            nonFinalNextState, nonFinalMask = self.stateProcessor(
                transitions.next_state, self.device)
        else:
            state = torch.tensor(transitions.state,
                                 device=self.device,
                                 dtype=torch.float32)
            nonFinalMask = torch.tensor(tuple(
                map(lambda s: s is not None, transitions.next_state)),
                                        device=self.device,
                                        dtype=torch.bool)
            nonFinalNextState = torch.tensor(
                [s for s in transitions.next_state if s is not None],
                device=self.device,
                dtype=torch.float32)

        return state, nonFinalMask, nonFinalNextState, action, reward

    def update_net(self, state, action, nextState, reward, info):
        '''
        This routine will store, transform, augment experiences and sample experiences for gradient descent.
        '''

        self.store_experience(state, action, nextState, reward, info)

        # prepare mini-batch
        if len(self.memory) < self.trainBatchSize:
            return

        transitions_raw = self.memory.sample(self.trainBatchSize)

        self.update_net_on_transitions(transitions_raw)

        self.copy_nets()

        self.learnStepCounter += 1

    def copy_nets(self):
        '''
        soft update target networks
        '''
        # update networks
        if self.learnStepCounter % self.policyUpdateFreq == 0:
            # update target networks
            for target_param, param in zip(self.actorNet_target.parameters(),
                                           self.actorNet.parameters()):
                target_param.data.copy_(param.data * self.tau +
                                        target_param.data * (1.0 - self.tau))

            for target_param, param in zip(self.criticNet_target.parameters(),
                                           self.criticNet.parameters()):
                target_param.data.copy_(param.data * self.tau +
                                        target_param.data * (1.0 - self.tau))

    def update_net_on_transitions(self, transitions_raw):
        '''
        This function performs gradient gradient on the network
        '''
        state, nonFinalMask, nonFinalNextState, action, reward = self.prepare_minibatch(
            transitions_raw)

        # Critic loss
        QValues = self.criticNet.forward(state, action).squeeze()
        QNext = torch.zeros(self.trainBatchSize,
                            device=self.device,
                            dtype=torch.float32)

        if len(nonFinalNextState):
            # next action is calculated using target actor network
            next_actions = self.actorNet_target.forward(nonFinalNextState)
            QNext[nonFinalMask] = self.criticNet_target.forward(
                nonFinalNextState, next_actions.detach()).squeeze()

        targetValues = reward + self.gamma * QNext
        critic_loss = self.netLossFunc(QValues, targetValues)

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        if self.netGradClip is not None:
            torch.nn.utils.clip_grad_norm_(self.criticNet.parameters(),
                                           self.netGradClip)

        self.critic_optimizer.step()

        # update networks
        if self.learnStepCounter % self.policyUpdateFreq == 0:

            # Actor loss
            # we try to maximize criticNet output(which is state value)
            policy_loss = -self.criticNet.forward(
                state, self.actorNet.forward(state)).mean()

            self.actor_optimizer.zero_grad()
            policy_loss.backward()
            if self.netGradClip is not None:
                torch.nn.utils.clip_grad_norm_(self.actorNet.parameters(),
                                               self.netGradClip)

            self.actor_optimizer.step()

            if self.globalStepCount % self.lossRecordStep == 0:
                self.losses.append([
                    self.globalStepCount, self.epIdx,
                    critic_loss.item(),
                    policy_loss.item()
                ])

    def work_before_step(self, state=None):
        pass

    def train_one_episode(self):

        print("episode index:" + str(self.epIdx))
        state = self.env.reset()
        done = False
        rewardSum = 0

        for stepCount in range(self.episodeLength):

            # any work to be done before select actions
            self.work_before_step(state)

            action = self.select_action(self.actorNet, state, noiseFlag=True)

            nextState, reward, done, info = self.env.step(action)

            if stepCount == 0:
                print("at step 0:")
                print(info)

            if done:
                nextState = None

            # learn the transition
            self.update_net(state, action, nextState, reward, info)

            state = nextState
            rewardSum += reward * pow(self.gamma, stepCount)
            self.globalStepCount += 1

            if self.verbose:
                print('action: ' + str(action))
                print('state:')
                print(nextState)
                print('reward:')
                print(reward)
                print('info')
                print(info)

            if done:
                break

        self.runningAvgEpisodeReward = (
            self.runningAvgEpisodeReward * self.epIdx +
            rewardSum) / (self.epIdx + 1)
        print("done in step count: {}".format(stepCount))
        print("reward sum = " + str(rewardSum))
        print("running average episode reward sum: {}".format(
            self.runningAvgEpisodeReward))
        print(info)

        self.rewards.append([
            self.epIdx, stepCount, self.globalStepCount, rewardSum,
            self.runningAvgEpisodeReward
        ])
        if self.config[
                'logFlag'] and self.epIdx % self.config['logFrequency'] == 0:
            self.save_checkpoint()

        self.epIdx += 1

    def train(self):

        # continue on historical training
        if len(self.rewards) > 0:
            self.runningAvgEpisodeReward = self.rewards[-1][-1]

        for trainStepCount in range(self.trainStep):
            self.train_one_episode()

        self.save_all()

    def saveLosses(self, fileName):
        np.savetxt(fileName, np.array(self.losses), fmt='%.5f', delimiter='\t')

    def saveRewards(self, fileName):
        np.savetxt(fileName,
                   np.array(self.rewards),
                   fmt='%.5f',
                   delimiter='\t')

    def loadLosses(self, fileName):
        self.losses = np.genfromtxt(fileName).tolist()

    def loadRewards(self, fileName):
        self.rewards = np.genfromtxt(fileName).tolist()

    def save_all(self, identifier=None):
        if identifier is None:
            identifier = self.identifier
        prefix = self.dirName + identifier + 'Finalepoch' + str(self.epIdx)
        self.saveLosses(prefix + '_loss.txt')
        self.saveRewards(prefix + '_reward.txt')
        with open(prefix + '_memory.pickle', 'wb') as file:
            pickle.dump(self.memory, file)

        torch.save(
            {
                'epoch': self.epIdx,
                'globalStep': self.globalStepCount,
                'actorNet_state_dict': self.actorNet.state_dict(),
                'criticNet_state_dict': self.criticNet.state_dict(),
                'actor_optimizer_state_dict':
                self.actor_optimizer.state_dict(),
                'critic_optimizer_state_dict':
                self.critic_optimizer.state_dict()
            }, prefix + '_checkpoint.pt')

    def save_checkpoint(self, identifier=None):
        if identifier is None:
            identifier = self.identifier

        prefix = self.dirName + identifier + 'Epoch' + str(self.epIdx)
        self.saveLosses(prefix + '_loss.txt')
        self.saveRewards(prefix + '_reward.txt')
        with open(prefix + '_memory.pickle', 'wb') as file:
            pickle.dump(self.memory, file)

        torch.save(
            {
                'epoch': self.epIdx,
                'globalStep': self.globalStepCount,
                'actorNet_state_dict': self.actorNet.state_dict(),
                'criticNet_state_dict': self.criticNet.state_dict(),
                'actor_optimizer_state_dict':
                self.actor_optimizer.state_dict(),
                'critic_optimizer_state_dict':
                self.critic_optimizer.state_dict()
            }, prefix + '_checkpoint.pt')

    def load_checkpoint(self, prefix):
        self.loadLosses(prefix + '_loss.txt')
        self.loadRewards(prefix + '_reward.txt')
        with open(prefix + '_memory.pickle', 'rb') as file:
            self.memory = pickle.load(file)

        checkpoint = torch.load(prefix + '_checkpoint.pt')
        self.epIdx = checkpoint['epoch']
        self.globalStepCount = checkpoint['globalStep']
        self.actorNet.load_state_dict(checkpoint['actorNet_state_dict'])
        self.actorNet_target.load_state_dict(checkpoint['actorNet_state_dict'])
        self.criticNet.load_state_dict(checkpoint['criticNet_state_dict'])
        self.criticNet_target.load_state_dict(
            checkpoint['criticNet_state_dict'])

        self.actor_optimizer.load_state_dict(
            checkpoint['actor_optimizer_state_dict'])
        self.critic_optimizer.load_state_dict(
            checkpoint['critic_optimizer_state_dict'])
Example #11
0
 def init_memory(self):
     '''
     initialize replay memory
     '''
     self.memory = ReplayMemory(self.memoryCapacity)
Example #12
0
class DDPGAgent:
    def __init__(self,
                 config,
                 actorNets,
                 criticNets,
                 env,
                 optimizers,
                 netLossFunc,
                 nbAction,
                 stateProcessor=None,
                 experienceProcessor=None):

        self.config = config
        self.read_config()
        self.env = env
        self.numAction = nbAction
        self.stateProcessor = stateProcessor
        self.netLossFunc = netLossFunc
        self.experienceProcessor = experienceProcessor

        self.initialization()
        self.init_memory()
        self.initalizeNets(actorNets, criticNets, optimizers)

    def initalizeNets(self, actorNets, criticNets, optimizers):
        self.actorNet = actorNets['actor']
        self.actorNet_target = actorNets[
            'target'] if 'target' in actorNets else None
        self.criticNet = criticNets['critic']
        self.criticNet_target = criticNets[
            'target'] if 'target' in criticNets else None
        self.actor_optimizer = optimizers['actor']
        self.critic_optimizer = optimizers['critic']

        self.net_to_device()

    def init_memory(self):
        self.memory = ReplayMemory(self.memoryCapacity)

    def read_config(self):
        self.trainStep = self.config['trainStep']
        self.targetNetUpdateStep = 10000
        if 'targetNetUpdateStep' in self.config:
            self.targetNetUpdateStep = self.config['targetNetUpdateStep']

        self.trainBatchSize = self.config['trainBatchSize']
        self.gamma = self.config['gamma']
        self.tau = self.config['tau']

        self.netGradClip = None
        if 'netGradClip' in self.config:
            self.netGradClip = self.config['netGradClip']
        self.netUpdateOption = 'targetNet'
        if 'netUpdateOption' in self.config:
            self.netUpdateOption = self.config['netUpdateOption']
        self.verbose = False
        if 'verbose' in self.config:
            self.verbose = self.config['verbose']
        self.netUpdateFrequency = 1
        if 'netUpdateFrequency' in self.config:
            self.netUpdateFrequency = self.config['netUpdateFrequency']
        self.nStepForward = 1
        if 'nStepForward' in self.config:
            self.nStepForward = self.config['nStepForward']
        self.lossRecordStep = 10
        if 'lossRecordStep' in self.config:
            self.lossRecordStep = self.config['lossRecordStep']
        self.episodeLength = 500
        if 'episodeLength' in self.config:
            self.episodeLength = self.config['episodeLength']

        self.verbose = False
        if 'verbose' in self.config:
            self.verbose = self.config['verbose']

        self.device = 'cpu'
        if 'device' in self.config and torch.cuda.is_available():
            self.device = self.config['device']

        self.randomSeed = 1
        if 'randomSeed' in self.config:
            self.randomSeed = self.config['randomSeed']

        self.memoryCapacity = self.config['memoryCapacity']

        self.hindSightER = False
        if 'hindSightER' in self.config:
            self.hindSightER = self.config['hindSightER']
            self.hindSightERFreq = self.config['hindSightERFreq']

        self.experienceAugmentation = False
        if 'experienceAugmentation' in self.config:
            self.experienceAugmentation = self.config['experienceAugmentation']
            self.experienceAugmentationFreq = self.config[
                'experienceAugmentationFreq']

        self.policyUpdateFreq = 1
        if 'policyUpdateFreq' in self.config:
            self.policyUpdateFreq = self.config['policyUpdateFreq']

    def net_to_device(self):
        # move model to correct device
        self.actorNet = self.actorNet.to(self.device)
        self.criticNet = self.criticNet.to(self.device)

        # in case targetNet is None
        if self.actorNet_target is not None:
            self.actorNet_target = self.actorNet_target.to(self.device)
        # in case targetNet is None
        if self.criticNet_target is not None:
            self.criticNet_target = self.criticNet_target.to(self.device)

    def initialization(self):

        self.dirName = 'Log/'
        if 'dataLogFolder' in self.config:
            self.dirName = self.config['dataLogFolder']
        if not os.path.exists(self.dirName):
            os.makedirs(self.dirName)

        self.identifier = ''
        self.epIdx = 0
        self.learnStepCounter = 0  # for target net update
        self.globalStepCount = 0
        self.losses = []
        self.rewards = []

    def select_action(self, net, state, noiseFlag=False):

        with torch.no_grad():
            # self.policyNet(torch.from_numpy(state.astype(np.float32)).unsqueeze(0))
            # here state[np.newaxis,:] is to add a batch dimension
            if self.stateProcessor is not None:
                state, _ = self.stateProcessor([state], self.device)
                action = net.select_action(state, noiseFlag)
            else:
                stateTorch = torch.from_numpy(
                    np.array(state[np.newaxis, :], dtype=np.float32))
                action = net.select_action(stateTorch.to(self.device),
                                           noiseFlag)

        return action.cpu().data.numpy()[0]

    def process_hindSightExperience(self, state, action, nextState, reward,
                                    info):
        if nextState is not None and self.globalStepCount % self.hindSightERFreq == 0:
            stateNew, actionNew, nextStateNew, rewardNew = self.env.getHindSightExperience(
                state, action, nextState, info)
            if stateNew is not None:
                transition = Transition(stateNew, actionNew, nextStateNew,
                                        rewardNew)
                self.memory.push(transition)
                if self.experienceAugmentation:
                    self.process_experienceAugmentation(
                        state, action, nextState, reward, info)

    def process_experienceAugmentation(self, state, action, nextState, reward,
                                       info):
        if self.globalStepCount % self.experienceAugmentationFreq == 0:
            state_Augs, action_Augs, nextState_Augs, reward_Augs = self.env.getExperienceAugmentation(
                state, action, nextState, reward, info)
            for i in range(len(state_Augs)):
                transition = Transition(state_Augs[i], action_Augs[i],
                                        nextState_Augs[i], reward_Augs[i])
                self.memory.push(transition)

    def store_experience(self, state, action, nextState, reward, info):
        if self.experienceProcessor is not None:
            state, action, nextState, reward = self.experienceProcessor(
                state, action, nextState, reward, info)

        transition = Transition(state, action, nextState, reward)
        self.memory.push(transition)

        if self.experienceAugmentation:
            self.process_experienceAugmentation(state, action, nextState,
                                                reward, info)

        if self.hindSightER:
            self.process_hindSightExperience(state, action, nextState, reward,
                                             info)

    def prepare_minibatch(self, transitions_raw):
        # first store memory

        transitions = Transition(*zip(*transitions_raw))
        action = torch.tensor(transitions.action,
                              device=self.device,
                              dtype=torch.float32)  # shape(batch, numActions)
        reward = torch.tensor(transitions.reward,
                              device=self.device,
                              dtype=torch.float32)  # shape(batch)

        # for some env, the output state requires further processing before feeding to neural network
        if self.stateProcessor is not None:
            state, _ = self.stateProcessor(transitions.state, self.device)
            nonFinalNextState, nonFinalMask = self.stateProcessor(
                transitions.next_state, self.device)
        else:
            state = torch.tensor(transitions.state,
                                 device=self.device,
                                 dtype=torch.float32)
            nonFinalMask = torch.tensor(tuple(
                map(lambda s: s is not None, transitions.next_state)),
                                        device=self.device,
                                        dtype=torch.uint8)
            nonFinalNextState = torch.tensor(
                [s for s in transitions.next_state if s is not None],
                device=self.device,
                dtype=torch.float32)

        return state, nonFinalMask, nonFinalNextState, action, reward

    def update_net(self, state, action, nextState, reward, info):

        # first store memory

        self.store_experience(state, action, nextState, reward, info)

        # prepare mini-batch
        if len(self.memory) < self.trainBatchSize:
            return

        transitions_raw = self.memory.sample(self.trainBatchSize)

        self.update_net_on_transitions(transitions_raw)

        self.copy_nets()

        self.learnStepCounter += 1

    def copy_nets(self):
        # update networks
        if self.learnStepCounter % self.policyUpdateFreq == 0:
            # update target networks
            for target_param, param in zip(self.actorNet_target.parameters(),
                                           self.actorNet.parameters()):
                target_param.data.copy_(param.data * self.tau +
                                        target_param.data * (1.0 - self.tau))

            for target_param, param in zip(self.criticNet_target.parameters(),
                                           self.criticNet.parameters()):
                target_param.data.copy_(param.data * self.tau +
                                        target_param.data * (1.0 - self.tau))

    def update_net_on_transitions(self, transitions_raw):

        state, nonFinalMask, nonFinalNextState, action, reward = self.prepare_minibatch(
            transitions_raw)

        # Critic loss
        QValues = self.criticNet.forward(state, action).squeeze()
        QNext = torch.zeros(self.trainBatchSize,
                            device=self.device,
                            dtype=torch.float32)

        if len(nonFinalNextState):
            # next action is calculated using target actor network
            next_actions = self.actorNet_target.forward(nonFinalNextState)
            QNext[nonFinalMask] = self.criticNet_target.forward(
                nonFinalNextState, next_actions.detach()).squeeze()

        targetValues = reward + self.gamma * QNext
        critic_loss = self.netLossFunc(QValues, targetValues)

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        if self.netGradClip is not None:
            torch.nn.utils.clip_grad_norm_(self.criticNet.parameters(),
                                           self.netGradClip)

        self.critic_optimizer.step()

        # Actor loss
        # we try to maximize criticNet output(which is state value)

        # update networks
        if self.learnStepCounter % self.policyUpdateFreq == 0:
            policy_loss = -self.criticNet.forward(
                state, self.actorNet.forward(state)).mean()

            self.actor_optimizer.zero_grad()
            policy_loss.backward()
            if self.netGradClip is not None:
                torch.nn.utils.clip_grad_norm_(self.actorNet.parameters(),
                                               self.netGradClip)

            self.actor_optimizer.step()

            if self.globalStepCount % self.lossRecordStep == 0:
                self.losses.append([
                    self.globalStepCount, self.epIdx,
                    critic_loss.item(),
                    policy_loss.item()
                ])

    def work_before_step(self, state):
        pass

    def train(self):

        runningAvgEpisodeReward = 0.0
        if len(self.rewards) > 0:
            runningAvgEpisodeReward = self.rewards[-1][-1]

        for trainStepCount in range(self.trainStep):

            print("episode index:" + str(self.epIdx))
            state = self.env.reset()
            done = False
            rewardSum = 0

            for stepCount in range(self.episodeLength):

                # any work to be done before select actions
                self.work_before_step(state)

                action = self.select_action(self.actorNet,
                                            state,
                                            noiseFlag=True)

                nextState, reward, done, info = self.env.step(action)

                if stepCount == 0:
                    print("at step 0:")
                    print(info)

                if done:
                    nextState = None

                # learn the transition
                self.update_net(state, action, nextState, reward, info)

                state = nextState
                rewardSum += reward * pow(self.gamma, stepCount)
                self.globalStepCount += 1

                if self.verbose:
                    print('action: ' + str(action))
                    print('state:')
                    print(nextState)
                    print('reward:')
                    print(reward)
                    print('info')
                    print(info)

                if done:
                    break

            runningAvgEpisodeReward = (runningAvgEpisodeReward * self.epIdx +
                                       rewardSum) / (self.epIdx + 1)
            print("done in step count: {}".format(stepCount))
            print("reward sum = " + str(rewardSum))
            print("running average episode reward sum: {}".format(
                runningAvgEpisodeReward))
            print(info)

            self.rewards.append([
                self.epIdx, stepCount, self.globalStepCount, rewardSum,
                runningAvgEpisodeReward
            ])
            if self.config['logFlag'] and self.epIdx % self.config[
                    'logFrequency'] == 0:
                self.save_checkpoint()

            self.epIdx += 1
        self.save_all()

    def saveLosses(self, fileName):
        np.savetxt(fileName, np.array(self.losses), fmt='%.5f', delimiter='\t')

    def saveRewards(self, fileName):
        np.savetxt(fileName,
                   np.array(self.rewards),
                   fmt='%.5f',
                   delimiter='\t')

    def loadLosses(self, fileName):
        self.losses = np.genfromtxt(fileName).tolist()

    def loadRewards(self, fileName):
        self.rewards = np.genfromtxt(fileName).tolist()

    def save_all(self):
        prefix = self.dirName + self.identifier + 'Finalepoch' + str(
            self.epIdx)
        self.saveLosses(prefix + '_loss.txt')
        self.saveRewards(prefix + '_reward.txt')
        with open(prefix + '_memory.pickle', 'wb') as file:
            pickle.dump(self.memory, file)

        torch.save(
            {
                'epoch': self.epIdx,
                'globalStep': self.globalStepCount,
                'actorNet_state_dict': self.actorNet.state_dict(),
                'criticNet_state_dict': self.criticNet.state_dict(),
                'actor_optimizer_state_dict':
                self.actor_optimizer.state_dict(),
                'critic_optimizer_state_dict':
                self.critic_optimizer.state_dict()
            }, prefix + '_checkpoint.pt')

    def save_checkpoint(self):
        prefix = self.dirName + self.identifier + 'Epoch' + str(self.epIdx)
        self.saveLosses(prefix + '_loss.txt')
        self.saveRewards(prefix + '_reward.txt')
        with open(prefix + '_memory.pickle', 'wb') as file:
            pickle.dump(self.memory, file)

        torch.save(
            {
                'epoch': self.epIdx,
                'globalStep': self.globalStepCount,
                'actorNet_state_dict': self.actorNet.state_dict(),
                'criticNet_state_dict': self.criticNet.state_dict(),
                'actor_optimizer_state_dict':
                self.actor_optimizer.state_dict(),
                'critic_optimizer_state_dict':
                self.critic_optimizer.state_dict()
            }, prefix + '_checkpoint.pt')

    def load_checkpoint(self, prefix):
        self.loadLosses(prefix + '_loss.txt')
        self.loadRewards(prefix + '_reward.txt')
        with open(prefix + '_memory.pickle', 'rb') as file:
            self.memory = pickle.load(file)

        checkpoint = torch.load(prefix + '_checkpoint.pt')
        self.epIdx = checkpoint['epoch']
        self.globalStepCount = checkpoint['globalStep']
        self.actorNet.load_state_dict(checkpoint['actorNet_state_dict'])
        self.actorNet_target.load_state_dict(checkpoint['actorNet_state_dict'])
        self.criticNet.load_state_dict(checkpoint['criticNet_state_dict'])
        self.criticNet_target.load_state_dict(
            checkpoint['criticNet_state_dict'])

        self.actor_optimizer.load_state_dict(
            checkpoint['actor_optimizer_state_dict'])
        self.critic_optimizer.load_state_dict(
            checkpoint['critic_optimizer_state_dict'])
    def init_memory(self):

        if self.memoryOption != 'natural':
            raise NotImplementedError

        self.memory = ReplayMemory(self.memoryCapacity)
class NAFAgent(DQNAgent):

    def __init__(self, config, policyNet, targetNet, env, optimizer, netLossFunc, nbAction, stateProcessor = None, experienceProcessor=None):
        super(NAFAgent, self).__init__(config, policyNet, targetNet, env, optimizer, netLossFunc, nbAction, stateProcessor, experienceProcessor)

        self.init_memory()


    def init_memory(self):

        if self.memoryOption != 'natural':
            raise NotImplementedError

        self.memory = ReplayMemory(self.memoryCapacity)

    def read_config(self):
        super(NAFAgent, self).read_config()
        # read additional parameters
        self.tau = self.config['tau']

    def work_At_Episode_Begin(self):
        pass

    def work_before_step(self, state=None):
        self.epsThreshold = self.epsilon_by_step(self.globalStepCount)

    def train_one_episode(self):

        print("episode index:" + str(self.epIdx))
        state = self.env.reset()
        done = False
        rewardSum = 0

        # any work to be done at episode begin
        self.work_At_Episode_Begin()

        for stepCount in range(self.episodeLength):

            # any work to be done before select actions
            self.work_before_step(state)

            action = self.select_action(self.policyNet, state, noiseFlag=True)

            nextState, reward, done, info = self.env.step(action)

            if stepCount == 0:
                print("at step 0:")
                print(info)

            if done:
                nextState = None

            # learn the transition
            self.update_net(state, action, nextState, reward, info)

            state = nextState
            rewardSum += reward * pow(self.gamma, stepCount)
            self.globalStepCount += 1

            if self.verbose:
                print('action: ' + str(action))
                print('state:')
                print(nextState)
                print('reward:')
                print(reward)
                print('info')
                print(info)

            if done:
                break

        self.runningAvgEpisodeReward = (self.runningAvgEpisodeReward * self.epIdx + rewardSum) / (self.epIdx + 1)
        print("done in step count: {}".format(stepCount))
        print("reward sum = " + str(rewardSum))
        print("running average episode reward sum: {}".format(self.runningAvgEpisodeReward))
        print(info)

        self.rewards.append([self.epIdx, stepCount, self.globalStepCount, rewardSum, self.runningAvgEpisodeReward])
        if self.config['logFlag'] and self.epIdx % self.config['logFrequency'] == 0:
            self.save_checkpoint()

        self.epIdx += 1

        return stepCount, rewardSum

    def train(self):

        if len(self.rewards) > 0:
            self.runningAvgEpisodeReward = self.rewards[-1][-1]

        for trainStepCount in range(self.trainStep):
            self.train_one_episode()
        self.save_all()

    def store_experience(self, state, action, nextState, reward, info):

        if self.experienceProcessor is not None:
            state, action, nextState, reward = self.experienceProcessor(state, action, nextState, reward, info)
        transition = Transition(state, action, nextState, reward)
        self.memory.push(transition)

    def update_net(self, state, action, nextState, reward, info):

        # first store memory

        self.store_experience(state, action, nextState, reward, info)

        if self.hindSightER and nextState is not None and self.globalStepCount % self.hindSightERFreq == 0:
            stateNew, actionNew, nextStateNew, rewardNew = self.env.getHindSightExperience(state, action, nextState, info)
            if stateNew is not None:
                self.store_experience(stateNew, actionNew, nextStateNew, rewardNew, info)


        if self.priorityMemoryOption:
            if len(self.memory) < self.config['memoryCapacity']:
                return
        else:
            if len(self.memory) < self.trainBatchSize:
                return


        # update net with specified frequency
        if self.globalStepCount % self.netUpdateFrequency == 0:
            # sample experience
            for nStep in range(self.netUpdateStep):
                info = {}
                if self.priorityMemoryOption:
                    transitions_raw, b_idx, ISWeights = self.memory.sample(self.trainBatchSize)
                    info['batchIdx'] = b_idx
                    info['ISWeights'] = torch.from_numpy(ISWeights.astype(np.float32)).to(self.device)
                else:
                    transitions_raw= self.memory.sample(self.trainBatchSize)

                loss = self.update_net_on_transitions(transitions_raw, self.netLossFunc, 1, updateOption=self.netUpdateOption, netGradClip=self.netGradClip, info=info)

                if self.globalStepCount % self.lossRecordStep == 0:
                    self.losses.append([self.globalStepCount, self.epIdx, loss])

                if self.learnStepCounter % self.targetNetUpdateStep == 0:
                    self.targetNet.load_state_dict(self.policyNet.state_dict())

                self.learnStepCounter += 1

    def prepare_minibatch(self, transitions_raw):
        # first store memory

        transitions = Transition(*zip(*transitions_raw))
        action = torch.tensor(transitions.action, device=self.device, dtype=torch.float32)  # shape(batch, numActions)
        reward = torch.tensor(transitions.reward, device=self.device, dtype=torch.float32)  # shape(batch)

        # for some env, the output state requires further processing before feeding to neural network
        if self.stateProcessor is not None:
            state, _ = self.stateProcessor(transitions.state, self.device)
            nonFinalNextState, nonFinalMask = self.stateProcessor(transitions.next_state, self.device)
        else:
            state = torch.tensor(transitions.state, device=self.device, dtype=torch.float32)
            nonFinalMask = torch.tensor(tuple(map(lambda s: s is not None, transitions.next_state)), device=self.device,
                                        dtype=torch.uint8)
            nonFinalNextState = torch.tensor([s for s in transitions.next_state if s is not None], device=self.device,
                                             dtype=torch.float32)

        return state, nonFinalMask, nonFinalNextState, action, reward


    def select_action(self, net=None, state=None, noiseFlag = True):

        if net is None:
            net = self.policyNet

        with torch.no_grad():
            # self.policyNet(torch.from_numpy(state.astype(np.float32)).unsqueeze(0))
            # here state[np.newaxis,:] is to add a batch dimension
            if self.stateProcessor is not None:
                state, _ = self.stateProcessor([state], self.device)
                action = net.select_action(state, noiseFlag)
            else:
                stateTorch = torch.from_numpy(np.array(state[np.newaxis, :], dtype=np.float32))
                action = net.select_action(stateTorch.to(self.device), noiseFlag)

        return action.cpu().data.numpy()[0]

    def update_net_on_transitions(self, transitions_raw, loss_fun, gradientStep = 1, updateOption='policyNet', netGradClip=None, info=None):

        # order the data
        # convert transition list to torch tensors
        # use trick from https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html
        # https://stackoverflow.com/questions/19339/transpose-unzip-function-inverse-of-zip/19343#19343

        state, nonFinalMask, nonFinalNextState, action, reward = self.prepare_minibatch(transitions_raw)

        for step in range(gradientStep):
            # calculate Qvalues based on selected action batch
            QValues = self.policyNet.eval_Q_value(state, action).squeeze()

            # Here we detach because we do not want gradient flow from target values to net parameters
            QNext = torch.zeros(self.trainBatchSize, device=self.device, dtype=torch.float32)
            QNext[nonFinalMask] = self.targetNet.eval_state_value(nonFinalNextState).squeeze().detach()
            targetValues = reward + self.gamma * QNext

            # Compute loss
            loss_single = loss_fun(QValues, targetValues)
            loss = torch.mean(loss_single)

            # Optimize the model
            # zero gradient
            self.optimizer.zero_grad()

            loss.backward()
            if netGradClip is not None:
                torch.nn.utils.clip_grad_norm_(self.policyNet.parameters(), netGradClip)
            self.optimizer.step()

            return loss.item()