Пример #1
0
if __name__ == '__main__':
    os.environ['OMP_NUM_THREADS'] = '1'

    args = parser.parse_args()
    env = gym.make("FetchPickAndPlace-v1")
    shared_model = Actor()
    if args.use_cuda:
        shared_model.cuda()
    torch.cuda.manual_seed_all(30)

    shared_model.share_memory()

    if os.path.isfile(args.save_path1):
        print('Loading A3C parametets ...')
        pretrained_dict = torch.load(args.save_path1)
        model_dict = shared_model.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model_dict.update(pretrained_dict)
        shared_model.load_state_dict(model_dict)

    optimizer = SharedAdam(shared_model.parameters(), lr=args.lr)
    optimizer.share_memory()

    for p in shared_model.fc1.parameters():
        p.requires_grad = False
    for p in shared_model.fc2.parameters():
        p.requires_grad = False
Пример #2
0
class DDPG:
    def __init__(self, env):
        self.env = env
        self.stateDim = obs2state(env.reset().observation).size()[1]
        self.actionDim = env.action_spec().shape[0]
        self.actor = Actor(self.env).cuda()
        self.critic = Critic(self.env).cuda()
        self.targetActor = deepcopy(Actor(self.env)).cuda()
        self.targetCritic = deepcopy(Critic(self.env)).cuda()
        self.actorOptim = optim.Adam(self.actor.parameters(), lr=ACTOR_LR)
        self.criticOptim = optim.Adam(self.critic.parameters(), lr=CRITIC_LR)
        self.criticLoss = nn.MSELoss()
        self.noise = OUNoise(mu=np.zeros(self.actionDim), sigma=SIGMA)
        self.replayBuffer = Buffer(BUFFER_SIZE)
        self.batchSize = MINIBATCH_SIZE
        self.checkpoint_dir = CHECKPOINT_DIR
        self.discount = DISCOUNT
        self.warmup = WARMUP
        self.epsilon = EPSILON
        self.epsilon_decay = EPSILON_DECAY
        self.rewardgraph = []
        self.start = 0
        self.end = NUM_EPISODES

    def getQTarget(self, nextStateBatch, rewardBatch, terminalBatch):
        """Inputs: Batch of next states, rewards and terminal flags of size self.batchSize
            Calculates the target Q-value from reward and bootstraped Q-value of next state
            using the target actor and target critic
           Outputs: Batch of Q-value targets"""

        targetBatch = torch.FloatTensor(rewardBatch).cuda()
        nonFinalMask = torch.ByteTensor(
            tuple(map(lambda s: s != True, terminalBatch)))
        nextStateBatch = torch.cat(nextStateBatch)
        nextActionBatch = self.targetActor(nextStateBatch)
        nextActionBatch.volatile = True
        qNext = self.targetCritic(nextStateBatch, nextActionBatch)

        nonFinalMask = self.discount * nonFinalMask.type(
            torch.cuda.FloatTensor)
        targetBatch += nonFinalMask * qNext.squeeze().data

        return Variable(targetBatch, volatile=False)

    def updateTargets(self, target, original):
        """Weighted average update of the target network and original network
            Inputs: target actor(critic) and original actor(critic)"""

        for targetParam, orgParam in zip(target.parameters(),
                                         original.parameters()):
            targetParam.data.copy_((1 - TAU)*targetParam.data + \
                                          TAU*orgParam.data)

    def getMaxAction(self, curState):
        """Inputs: Current state of the episode
            Returns the action which maximizes the Q-value of the current state-action pair"""

        spec = self.env.action_spec()
        minAct = Variable(torch.cuda.FloatTensor(spec.minimum),
                          requires_grad=False)
        maxAct = Variable(torch.cuda.FloatTensor(spec.maximum),
                          requires_grad=False)
        noise = self.epsilon * Variable(torch.FloatTensor(self.noise()),
                                        volatile=True).cuda()
        action = self.actor(curState)
        actionNoise = action + noise
        return actionNoise

    def train(self):
        if not os.path.exists(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)

        print('Training started...')

        for i in range(self.start, self.end):
            time_step = self.env.reset()
            ep_reward = 0

            while not time_step.last():

                #Visualize Training
                display.clear_output(wait=True)
                plt.imshow(self.env.physics.render())
                plt.show()

                # Get maximizing action
                curState = Variable(obs2state(time_step.observation),
                                    volatile=True).cuda()
                self.actor.eval()
                action = self.getMaxAction(curState)
                curState.volatile = False
                action.volatile = False
                self.actor.train()

                # Step episode
                time_step = self.env.step(action.data)
                nextState = Variable(obs2state(time_step.observation),
                                     volatile=True).cuda()
                reward = time_step.reward
                ep_reward += reward
                terminal = time_step.last()

                # Update replay bufer
                self.replayBuffer.append(
                    (curState, action, nextState, reward, terminal))

                # Training loop
                if len(self.replayBuffer) >= self.warmup:

                    curStateBatch, actionBatch, nextStateBatch, \
                    rewardBatch, terminalBatch = self.replayBuffer.sample_batch(self.batchSize)
                    curStateBatch = torch.cat(curStateBatch)
                    actionBatch = torch.cat(actionBatch)

                    qPredBatch = self.critic(curStateBatch, actionBatch)
                    qTargetBatch = self.getQTarget(nextStateBatch, rewardBatch,
                                                   terminalBatch)

                    # Critic update
                    self.criticOptim.zero_grad()
                    criticLoss = self.criticLoss(qPredBatch, qTargetBatch)
                    print('Critic Loss: {}'.format(criticLoss))
                    criticLoss.backward()
                    self.criticOptim.step()

                    # Actor update
                    self.actorOptim.zero_grad()
                    actorLoss = -torch.mean(
                        self.critic(curStateBatch, self.actor(curStateBatch)))
                    print('Actor Loss: {}'.format(actorLoss))
                    actorLoss.backward()
                    self.actorOptim.step()

                    # Update Targets
                    self.updateTargets(self.targetActor, self.actor)
                    self.updateTargets(self.targetCritic, self.critic)
                    self.epsilon -= self.epsilon_decay

            if i % 20 == 0:
                self.save_checkpoint(i)
            self.rewardgraph.append(ep_reward)

    def save_checkpoint(self, episode_num):
        checkpointName = self.checkpoint_dir + 'ep{}.pth.tar'.format(
            episode_num)
        checkpoint = {
            'episode': episode_num,
            'actor': self.actor.state_dict(),
            'critic': self.critic.state_dict(),
            'targetActor': self.targetActor.state_dict(),
            'targetCritic': self.targetCritic.state_dict(),
            'actorOpt': self.actorOptim.state_dict(),
            'criticOpt': self.criticOptim.state_dict(),
            'replayBuffer': self.replayBuffer,
            'rewardgraph': self.rewardgraph,
            'epsilon': self.epsilon
        }
        torch.save(checkpoint, checkpointName)

    def loadCheckpoint(self, checkpointName):
        if os.path.isfile(checkpointName):
            print("Loading checkpoint...")
            checkpoint = torch.load(checkpointName)
            self.start = checkpoint['episode'] + 1
            self.actor.load_state_dict(checkpoint['actor'])
            self.critic.load_state_dict(checkpoint['critic'])
            self.targetActor.load_state_dict(checkpoint['targetActor'])
            self.targetCritic.load_state_dict(checkpoint['targetCritic'])
            self.actorOptim.load_state_dict(checkpoint['actorOpt'])
            self.criticOptim.load_state_dict(checkpoint['criticOpt'])
            self.replayBuffer = checkpoint['replayBuffer']
            self.rewardgraph = checkpoint['rewardgraph']
            self.epsilon = checkpoint['epsilon']
            print('Checkpoint loaded')
        else:
            raise OSError('Checkpoint not found')