示例#1
0
        WRITER.add_scalar('training/generator/loss', average_generator_loss,
                          epoch)
        WRITER.add_scalar('training/critic/loss', average_critic_loss, epoch)
        WRITER.add_scalar('training/critic/real-performance',
                          average_critic_real_performance, epoch)
        WRITER.add_scalar('training/critic/generated-performance',
                          average_critic_generated_performance, epoch)
        WRITER.add_scalar('training/discernability-score',
                          discernability_score, epoch)
        WRITER.add_scalar('training/epoch-duration', time_elapsed, epoch)

        # Save the model parameters at a specified interval.
        if (epoch > 0 and (epoch % args.model_save_frequency == 0
                           or epoch == args.num_epochs - 1)):

            # Create a backup of tensorboard data each time model is saved.
            shutil.copytree(LIVE_TENSORBOARD_DIR,
                            f'{TENSORBOARD_DIR}/{epoch:03d}')

            save_critic_model_path = f'{args.save_model_dir}/critic_{EXPERIMENT_ID}-{epoch}.pth'
            print(f'\nSaving critic model as "{save_critic_model_path}"...')
            torch.save(critic_model.state_dict(), save_critic_model_path)

            save_generator_model_path = f'{args.save_model_dir}/generator_{EXPERIMENT_ID}-{epoch}.pth'
            print(
                f'Saving generator model as "{save_generator_model_path}"...\n'
            )
            torch.save(generator_model.state_dict(), save_generator_model_path)

print('Finished training!')
示例#2
0
class Agent():
    """Interacts with and learns from the environment."""
    def __init__(self, state_size: int, action_size: int, seed: int,
                 n_agent: int):
        """Initialize an Agent object.

        Params
        ======
            state_size (int): dimension of each state
            action_size (int): dimension of each action
            seed (int): random seed
        """
        self.state_size = state_size
        self.action_size = action_size
        self.n_agent = n_agent
        self.seed = random.seed(seed)
        self.global_step = 0
        self.update_step = 0

        # Initialize actor and critic local and target networks
        self.actor = Actor(state_size,
                           action_size,
                           seed,
                           ACTOR_NETWORK_LINEAR_SIZES,
                           batch_normalization=ACTOR_BATCH_NORM).to(device)
        self.actor_target = Actor(
            state_size,
            action_size,
            seed,
            ACTOR_NETWORK_LINEAR_SIZES,
            batch_normalization=ACTOR_BATCH_NORM).to(device)
        self.critic = Critic(state_size,
                             action_size,
                             seed,
                             CRITIC_NETWORK_LINEAR_SIZES,
                             batch_normalization=CRITIC_BATCH_NORM).to(device)
        self.critic_second = Critic(
            state_size,
            action_size,
            seed,
            CRITIC_SECOND_NETWORK_LINEAR_SIZES,
            batch_normalization=CRITIC_BATCH_NORM).to(device)
        self.critic_second_target = Critic(
            state_size,
            action_size,
            seed,
            CRITIC_SECOND_NETWORK_LINEAR_SIZES,
            batch_normalization=CRITIC_BATCH_NORM).to(device)
        self.critic_target = Critic(
            state_size,
            action_size,
            seed,
            CRITIC_NETWORK_LINEAR_SIZES,
            batch_normalization=CRITIC_BATCH_NORM).to(device)
        self.actor_optimizer = optim.Adam(self.actor.parameters(),
                                          lr=ACTOR_LEARNING_RATE)
        self.critic_optimizer = optim.Adam(self.critic.parameters(),
                                           lr=CRITIC_LEARNING_RATE)
        self.critic_second_optimizer = optim.Adam(
            self.critic_second.parameters(), lr=CRITIC_LEARNING_RATE)
        # Replay memory
        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
        # Initialize time step (for updating every UPDATE_EVERY steps)
        self.t_step = [0] * n_agent
        self.noise = OUNoise(action_size, seed, decay_period=50)

        # Copy parameters from local network to target network
        for target_param, param in zip(self.actor_target.parameters(),
                                       self.actor.parameters()):
            target_param.data.copy_(param.data)
        for target_param, param in zip(self.critic_target.parameters(),
                                       self.critic.parameters()):
            target_param.data.copy_(param.data)
        for target_param, param in zip(self.critic_second_target.parameters(),
                                       self.critic_second.parameters()):
            target_param.data.copy_(param.data)

    def step(self, state: np.array, action, reward, next_state, done, i_agent):
        # Save experience in replay memory
        self.memory.add(state, action, reward, next_state, done)

        # Learn every UPDATE_EVERY time steps.
        self.t_step[i_agent] = (self.t_step[i_agent] + 1) % UPDATE_EVERY
        # Learn, if enough samples are available in memory every UPDATE_EVERY
        if len(self.memory) > BATCH_SIZE and (not any(self.t_step)):
            for _ in range(LEARN_TIMES):
                experiences = self.memory.sample()
                self.learn(experiences, GAMMA)

    def noise_reset(self):
        self.noise.reset()

    def save_model(self, checkpoint_path: str = "./checkpoints/"):
        torch.save(self.actor.state_dict(), f"{checkpoint_path}/actor.pt")
        torch.save(self.critic.state_dict(), f"{checkpoint_path}/critic.pt")

    def load_model(self, checkpoint_path: str = "./checkpoints/checkpoint.pt"):
        self.actor.load_state_dict(torch.load(f"{checkpoint_path}/actor.pt"))
        self.critic.load_state_dict(torch.load(f"{checkpoint_path}/critic.pt"))

    def act(self, states: np.array, step: int):
        """Returns actions for given state as per current policy.

        Params
        ======
            state (array_like): current state
        """
        self.global_step += 1
        if self.global_step < WARM_UP_STEPS:
            action_values = np.random.rand(self.n_agent, self.action_size)
            return action_values
        states = torch.from_numpy(states).float().to(device)
        self.actor.eval()
        with torch.no_grad():
            action_values = self.actor(states).cpu().data.numpy()
        self.actor.train()
        action_values = [
            self.noise.get_action(action, t=step) for action in action_values
        ]
        return action_values

    def learn(self, experiences: tuple, gamma=GAMMA):
        """Update value parameters using given batch of experience tuples.

        Params
        ======
            experiences (Tuple[torch.Variable]): tuple of (s, a, r, s', done) tuples 
            gamma (float): discount factor
        """
        self.update_step += 1
        states, actions, rewards, next_states, dones = experiences
        # Critic loss
        mask = torch.tensor(1 - dones).detach().to(device)
        Q_values = self.critic(states, actions)
        Q_values_second = self.critic_second(states, actions)
        next_actions = self.actor_target(next_states)
        next_Q = torch.min(
            self.critic_target(next_states, next_actions.detach()),
            self.critic_second_target(next_states, next_actions.detach()))
        Q_prime = rewards + gamma * next_Q * mask

        if self.update_step % CRITIC_UPDATE_EVERY == 0:
            critic_loss = F.mse_loss(Q_values, Q_prime.detach())
            critic_second_loss = F.mse_loss(Q_values_second, Q_prime.detach())
            # Update first critic network
            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            if CRITIC_GRADIENT_CLIPPING_VALUE:
                torch.nn.utils.clip_grad_norm_(self.critic.parameters(),
                                               CRITIC_GRADIENT_CLIPPING_VALUE)
            self.critic_optimizer.step()

            # Update second critic network
            self.critic_second_optimizer.zero_grad()
            critic_second_loss.backward()
            if CRITIC_GRADIENT_CLIPPING_VALUE:
                torch.nn.utils.clip_grad_norm_(self.critic_second.parameters(),
                                               CRITIC_GRADIENT_CLIPPING_VALUE)
            self.critic_second_optimizer.step()

        # Actor loss

        if self.update_step % POLICY_UPDATE_EVERY == 0:
            policy_loss = -self.critic(states, self.actor(states)).mean()

            # Update actor network
            self.actor_optimizer.zero_grad()
            policy_loss.backward()
            if ACTOR_GRADIENT_CLIPPING_VALUE:
                torch.nn.utils.clip_grad_norm_(self.actor.parameters(),
                                               ACTOR_GRADIENT_CLIPPING_VALUE)
            self.actor_optimizer.step()

        self.actor_soft_update()
        self.critic_soft_update()
        self.critic_second_soft_update()

    def actor_soft_update(self, tau: float = TAU):
        """Soft update for actor target network

        Args:
            tau (float, optional). Defaults to TAU.
        """
        for target_param, param in zip(self.actor_target.parameters(),
                                       self.actor.parameters()):
            target_param.data.copy_(tau * param.data +
                                    (1.0 - tau) * target_param.data)

    def critic_soft_update(self, tau: float = TAU):
        """Soft update for critic target network

        Args:
            tau (float, optional). Defaults to TAU.
        """
        for target_param, param in zip(self.critic_target.parameters(),
                                       self.critic.parameters()):
            target_param.detach_()
            target_param.data.copy_(tau * param.data +
                                    (1.0 - tau) * target_param.data)

    def critic_second_soft_update(self, tau: float = TAU):
        """Soft update for critic target network

        Args:
            tau (float, optional). Defaults to TAU.
        """
        for target_param, param in zip(self.critic_second_target.parameters(),
                                       self.critic_second.parameters()):
            target_param.detach_()
            target_param.data.copy_(tau * param.data +
                                    (1.0 - tau) * target_param.data)
示例#3
0
文件: ddpg.py 项目: juandd18/Tenis
class DDPGAgent(object):
    """
    General class for DDPG agents (policy, critic, target policy, target
    critic, exploration noise)
    """
    def __init__(self, num_in_pol, num_out_pol, num_in_critic, hidden_dim_actor=120,
    hidden_dim_critic=64,lr_actor=0.01,lr_critic=0.01,batch_size=64,
    max_episode_len=100,tau=0.02,gamma = 0.99,agent_name='one', discrete_action=False):
        """
        Inputs:
            num_in_pol (int): number of dimensions for policy input
            num_out_pol (int): number of dimensions for policy output
            num_in_critic (int): number of dimensions for critic input
        """
        self.policy = Actor(num_in_pol, num_out_pol,
                                 hidden_dim=hidden_dim_actor,
                                 discrete_action=discrete_action)
        self.critic = Critic(num_in_pol, 1,num_out_pol,
                                 hidden_dim=hidden_dim_critic)
        self.target_policy = Actor(num_in_pol, num_out_pol,
                                        hidden_dim=hidden_dim_actor,
                                        discrete_action=discrete_action)
        self.target_critic = Critic(num_in_pol, 1,num_out_pol,
                                        hidden_dim=hidden_dim_critic)
        hard_update(self.target_policy, self.policy)
        hard_update(self.target_critic, self.critic)
        self.policy_optimizer = Adam(self.policy.parameters(), lr=lr_actor)
        self.critic_optimizer = Adam(self.critic.parameters(), lr=lr_critic,weight_decay=0)
        
        self.policy = self.policy.float()
        self.critic = self.critic.float()
        self.target_policy = self.target_policy.float()
        self.target_critic = self.target_critic.float()

        self.agent_name = agent_name
        self.gamma = gamma
        self.tau = tau
        self.batch_size = batch_size
        #self.replay_buffer = ReplayBuffer(1e7)
        self.replay_buffer = ReplayBufferOption(500000,self.batch_size,12)
        self.max_replay_buffer_len = batch_size * max_episode_len
        self.replay_sample_index = None
        self.niter = 0
        self.eps = 5.0
        self.eps_decay = 1/(250*5)

        self.exploration = OUNoise(num_out_pol)
        self.discrete_action = discrete_action

        self.num_history = 2
        self.states = []
        self.actions = []
        self.rewards = []
        self.next_states = []
        self.dones = []

    def reset_noise(self):
        if not self.discrete_action:
            self.exploration.reset()

    def scale_noise(self, scale):
        if self.discrete_action:
            self.exploration = scale
        else:
            self.exploration.scale = scale

    def act(self, obs, explore=False):
        """
        Take a step forward in environment for a minibatch of observations
        Inputs:
            obs : Observations for this agent
            explore (boolean): Whether or not to add exploration noise
        Outputs:
            action (PyTorch Variable): Actions for this agent
        """
        #obs = obs.reshape(1,48)
        state = Variable(torch.Tensor(obs),requires_grad=False)

        self.policy.eval()
        with torch.no_grad():
            action = self.policy(state)
        self.policy.train()
        # continuous action
        if explore:
            action += Variable(Tensor(self.eps * self.exploration.sample()),requires_grad=False)
            action = torch.clamp(action, min=-1, max=1)
        return action

    def step(self, agent_id, state, action, reward, next_state, done,t_step):
        
        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)
        self.next_states.append(next_state)
        self.dones.append(done)

        #self.replay_buffer.add(state, action, reward, next_state, done)
        if t_step % self.num_history == 0:
            # Save experience / reward
            
            self.replay_buffer.add(self.states, self.actions, self.rewards, self.next_states, self.dones)
            self.states = []
            self.actions = []
            self.rewards = []
            self.next_states = []
            self.dones = []

        # Learn, if enough samples are available in memory
        if len(self.replay_buffer) > self.batch_size:
            
            obs, acs, rews, next_obs, don = self.replay_buffer.sample()     
            self.update(agent_id ,obs,  acs, rews, next_obs, don,t_step)
        


    def update(self, agent_id, obs, acs, rews, next_obs, dones ,t_step, logger=None):
    
        obs = torch.from_numpy(obs).float()
        acs = torch.from_numpy(acs).float()
        rews = torch.from_numpy(rews[:,agent_id]).float()
        next_obs = torch.from_numpy(next_obs).float()
        dones = torch.from_numpy(dones[:,agent_id]).float()

        acs = acs.view(-1,2)
                
        # --------- update critic ------------ #        
        self.critic_optimizer.zero_grad()
        
        all_trgt_acs = self.target_policy(next_obs) 
    
        target_value = (rews + self.gamma *
                        self.target_critic(next_obs,all_trgt_acs) *
                        (1 - dones)) 
        
        actual_value = self.critic(obs,acs)
        vf_loss = MSELoss(actual_value, target_value.detach())

        # Minimize the loss
        vf_loss.backward()
        #torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 1)
        self.critic_optimizer.step()

        # --------- update actor --------------- #
        self.policy_optimizer.zero_grad()

        if self.discrete_action:
            curr_pol_out = self.policy(obs)
            curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)
        else:
            curr_pol_out = self.policy(obs)
            curr_pol_vf_in = curr_pol_out


        pol_loss = -self.critic(obs,curr_pol_vf_in).mean()
        #pol_loss += (curr_pol_out**2).mean() * 1e-3
        pol_loss.backward()

        torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 1)
        self.policy_optimizer.step()

        self.update_all_targets()
        self.eps -= self.eps_decay
        self.eps = max(self.eps, 0)
        

        if logger is not None:
            logger.add_scalars('agent%i/losses' % self.agent_name,
                               {'vf_loss': vf_loss,
                                'pol_loss': pol_loss},
                               self.niter)

    def update_all_targets(self):
        """
        Update all target networks (called after normal updates have been
        performed for each agent)
        """
        
        soft_update(self.critic, self.target_critic, self.tau)
        soft_update(self.policy, self.target_policy, self.tau)
   
    def get_params(self):
        return {'policy': self.policy.state_dict(),
                'critic': self.critic.state_dict(),
                'target_policy': self.target_policy.state_dict(),
                'target_critic': self.target_critic.state_dict(),
                'policy_optimizer': self.policy_optimizer.state_dict(),
                'critic_optimizer': self.critic_optimizer.state_dict()}

    def load_params(self, params):
        self.policy.load_state_dict(params['policy'])
        self.critic.load_state_dict(params['critic'])
        self.target_policy.load_state_dict(params['target_policy'])
        self.target_critic.load_state_dict(params['target_critic'])
        self.policy_optimizer.load_state_dict(params['policy_optimizer'])
        self.critic_optimizer.load_state_dict(params['critic_optimizer'])
示例#4
0
class DDPG(object):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    veer = 'straight'
    origin = 'west'

    # state_dim : 状态维度
    # action_dim: 动作维度
    # max_action:动作限制向量
    def __init__(self, state_dim, action_dim, max_action, origin_str, veer_str,
                 logger):

        # 存在于 GPU 的神经网络
        self.actor = Actor(state_dim, action_dim,
                           max_action).to(self.device)  # origin_network
        self.actor_target = Actor(state_dim, action_dim,
                                  max_action).to(self.device)  # target_network
        self.actor_target.load_state_dict(self.actor.state_dict(
        ))  # initiate actor_target with actor's parameters
        # pytorch 中的 tensor 默认requires_grad 属性为false,即不参与梯度传播运算,特别地,opimizer中模型参数是会参与梯度优化的
        self.actor_optimizer = optim.Adam(
            self.actor.parameters(),
            pdata.LEARNING_RATE)  # 以pdata.LEARNING_RATE指定学习率优化actor中的参数

        self.critic = Critic(state_dim, action_dim).to(self.device)
        self.critic_target = Critic(state_dim, action_dim).to(self.device)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_optimizer = optim.Adam(self.critic.parameters(),
                                           pdata.LEARNING_RATE)

        # self.replay_buffer = Replay_buffer()    # initiate replay-buffer
        self.replay_buffer = FilterReplayBuffer()
        self.writer = SummaryWriter(pdata.DIRECTORY + 'runs')
        self.num_critic_update_iteration = 0
        self.num_actor_update_iteration = 0
        self.num_training = 0

        self.veer = veer_str
        self.origin = origin_str

        self.logger = logger

    def select_action(self, state):
        state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
        # numpy.ndarray.flatten(): 返回一个 ndarray对象的copy,并且将该ndarray压缩成一维数组
        action = self.actor(state).cpu().data.numpy().flatten()
        return action

    # update the parameters in actor network and critic network
    # 只有 replay_buffer 中的 storage 超过了样本数量才会调用 update函数
    def update(self):
        critic_loss_list = []
        actor_performance_list = []
        for it in range(pdata.UPDATE_ITERATION):
            # Sample replay buffer
            x, y, u, r, d = self.replay_buffer.sample(
                pdata.BATCH_SIZE
            )  # 随机获取 batch_size 个五元组样本(sample random minibatch)
            state = torch.FloatTensor(x).to(self.device)
            action = torch.FloatTensor(u).to(self.device)
            next_state = torch.FloatTensor(y).to(self.device)
            done = torch.FloatTensor(d).to(self.device)
            reward = torch.FloatTensor(r).to(self.device)

            # Compute the target Q value —— Q(S', A') is an value evaluated with next_state and predicted action
            # 这里的 target_Q 是 sample 个 一维tensor
            target_Q = self.critic_target(next_state,
                                          self.actor_target(next_state))
            # detach(): Return a new tensor, detached from the current graph
            # evaluate Q: targetQ = R + γQ'(s', a')
            target_Q = reward + (
                (1 - done) * pdata.GAMMA * target_Q).detach()  # batch_size个维度

            # Get current Q estimate
            current_Q = self.critic(state, action)  # 1 维

            # Compute critic loss : a mean-square error
            # 由论文,critic_loss 其实计算的是每个样本估计值与每个critic网络输出的均值方差
            # torch.nn.functional.mse_loss 为计算tensor中各个元素的的均值方差
            critic_loss = F.mse_loss(current_Q, target_Q)
            self.writer.add_scalar(
                'critic_loss',
                critic_loss,
                global_step=self.num_critic_update_iteration)
            # self.logger.write_to_log('critic_loss:{loss}'.format(loss=critic_loss))
            # self.logger.add_to_critic_buffer(critic_loss.item())
            critic_loss_list.append(critic_loss.item())

            # Optimize the critic
            self.critic_optimizer.zero_grad()  # zeros the gradient buffer
            critic_loss.backward()  # back propagation on a dynamic graph
            self.critic_optimizer.step()

            # Compute actor loss
            # actor_loss:见论文中对公式 (6) 的理解
            # mean():对tensor对象求所有element的均值
            # backward() 以梯度下降的方式更新参数,则将 actor_loss 设置为反向梯度,这样参数便往梯度上升方向更新
            actor_loss = -self.critic(state, self.actor(state)).mean()
            self.writer.add_scalar('actor_performance',
                                   actor_loss,
                                   global_step=self.num_actor_update_iteration)
            # self.logger.write_to_log('actor_loss:{loss}'.format(loss=actor_loss))
            # self.logger.add_to_actor_buffer(actor_loss.item())
            actor_performance_list.append(actor_loss.item())

            # Optimize the actor
            self.actor_optimizer.zero_grad(
            )  # Clears the gradients of all optimized torch.Tensor
            actor_loss.backward()
            self.actor_optimizer.step()  # perform a single optimization step

            # 这里是两个 target网络的 soft update
            for param, target_param in zip(self.critic.parameters(),
                                           self.critic_target.parameters()):
                target_param.data.copy_(pdata.TAU * param.data +
                                        (1 - pdata.TAU) * target_param.data)

            for param, target_param in zip(self.actor.parameters(),
                                           self.actor_target.parameters()):
                target_param.data.copy_(pdata.TAU * param.data +
                                        (1 - pdata.TAU) * target_param.data)

            self.num_actor_update_iteration += 1
            self.num_critic_update_iteration += 1

        actor_performance = np.mean(np.array(actor_performance_list)).item()
        self.logger.add_to_actor_buffer(actor_performance)
        critic_loss = np.mean(critic_loss_list).item()
        self.logger.add_to_critic_buffer(critic_loss)

    def save(self, mark_str):
        torch.save(self.actor.state_dict(),
                   pdata.DIRECTORY + mark_str + '_actor.pth')
        torch.save(self.critic.state_dict(),
                   pdata.DIRECTORY + mark_str + '_critic.pth')

    def load(self, mark_str):
        file_actor = pdata.DIRECTORY + mark_str + '_actor.pth'
        file_critic = pdata.DIRECTORY + mark_str + '_critic.pth'
        if os.path.exists(file_actor) and os.path.exists(file_critic):
            self.actor.load_state_dict(torch.load(file_actor))
            self.critic.load_state_dict(torch.load(file_critic))
class DDPGAgent:
    def __init__(self, state_size=24, action_size=2, seed=1, num_agents=2):
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(seed)
        self.num_agents = num_agents

        # DDPG specific configuration
        hidden_size = 512
        self.CHECKPOINT_FOLDER = './'

        # Defining networks
        self.actor = Actor(state_size, hidden_size, action_size).to(device)
        self.actor_target = Actor(state_size, hidden_size, action_size).to(device)

        self.critic = Critic(state_size, self.action_size, hidden_size, 1).to(device)
        self.critic_target = Critic(state_size, self.action_size, hidden_size, 1).to(device)

        self.optimizer_actor = optim.Adam(self.actor.parameters(), lr=ACTOR_LR)
        self.optimizer_critic = optim.Adam(self.critic.parameters(), lr=CRITIC_LR)

        # Noise
        self.noises = OUNoise((num_agents, action_size), seed)

        # Initialize replay buffer
        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)

    def act(self, state, add_noise=True):
        '''
        Returns action to be taken based on state provided as the input
        '''
        state = torch.from_numpy(state).float().to(device)
        actions = np.zeros((self.num_agents, self.action_size))

        self.actor.eval()

        with torch.no_grad():
            for agent_num, state in enumerate(state):
                action = self.actor(state).cpu().data.numpy()
                actions[agent_num, :] = action

        self.actor.train()

        if add_noise:
            actions += self.noises.sample()

        return np.clip(actions, -1, 1)

    def reset(self):
        self.noises.reset()

    def learn(self, experiences):
        '''
        Trains the actor critic network using experiences
        '''
        states, actions, rewards, next_states, dones = experiences

        # Update Critic
        actions_next = self.actor_target(next_states)
        # print(next_states.shape, actions_next.shape)
        Q_targets_next = self.critic_target(next_states, actions_next)
        Q_targets = rewards + GAMMA*Q_targets_next*(1-dones)

        Q_expected = self.critic(states, actions)
        critic_loss = F.mse_loss(Q_expected, Q_targets)

        self.optimizer_critic.zero_grad()
        critic_loss.backward()
        self.optimizer_critic.step()

        # Update Actor
        actions_pred = self.actor(states)
        actor_loss = -self.critic(states, actions_pred).mean()

        self.optimizer_actor.zero_grad()
        actor_loss.backward()
        self.optimizer_actor.step()

        # Updating the local networks
        self.soft_update(self.critic, self.critic_target)
        self.soft_update(self.actor, self.actor_target)


    def soft_update(self, model, model_target):
        tau = TAU
        for target_param, local_param in zip(model_target.parameters(), model.parameters()):
            target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data)

    def step(self, state, action, reward, next_state, done):
        '''
        Adds experience to memory and learns if the memory contains sufficient samples
        '''
        for i in range(self.num_agents):
            self.memory.add(state[i, :], action[i, :], reward[i], next_state[i, :], done[i])

        if len(self.memory) > BATCH_SIZE:
            # print("Now Learning")
            experiences = self.memory.sample()
            self.learn(experiences)

    def checkpoint(self):
        '''
        Saves the actor critic network on disk
        '''
        torch.save(self.actor.state_dict(), self.CHECKPOINT_FOLDER + 'checkpoint_actor.pth')
        torch.save(self.critic.state_dict(), self.CHECKPOINT_FOLDER + 'checkpoint_critic.pth')

    def load(self):
        '''
        Loads the actor critic network from disk
        '''
        self.actor.load_state_dict(torch.load(self.CHECKPOINT_FOLDER + 'checkpoint_actor.pth'))
        self.actor_target.load_state_dict(torch.load(self.CHECKPOINT_FOLDER + 'checkpoint_actor.pth'))
        self.critic.load_state_dict(torch.load(self.CHECKPOINT_FOLDER + 'checkpoint_critic.pth'))
        self.critic_target.load_state_dict(torch.load(self.CHECKPOINT_FOLDER + 'checkpoint_critic.pth'))
示例#6
0
class SAC:
    def __init__(self,
                 env,
                 lr=3e-4,
                 gamma=0.99,
                 polyak=5e-3,
                 alpha=0.2,
                 reward_scale=1.0,
                 cuda=True,
                 writer=None):
        state_size = env.observation_space.shape[0]
        action_size = env.action_space.shape[0]
        self.actor = Actor(state_size, action_size)
        self.critic = Critic(state_size, action_size)
        self.target_critic = Critic(state_size, action_size).eval()
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr)
        self.q1_optimizer = optim.Adam(self.critic.q1.parameters(), lr=lr)
        self.q2_optimizer = optim.Adam(self.critic.q2.parameters(), lr=lr)

        self.target_critic.load_state_dict(self.critic.state_dict())
        for param in self.target_critic.parameters():
            param.requires_grad = False

        self.memory = ReplayMemory()

        self.gamma = gamma
        self.alpha = alpha
        self.polyak = polyak  # Always between 0 and 1, usually close to 1
        self.reward_scale = reward_scale

        self.writer = writer

        self.cuda = cuda
        if cuda:
            self.actor = self.actor.to('cuda')
            self.critic = self.critic.to('cuda')
            self.target_critic = self.target_critic.to('cuda')

    def explore(self, state):
        if self.cuda:
            state = torch.tensor(state).unsqueeze(0).to('cuda', torch.float)
        action, _, _ = self.actor.sample(state)
        # action, _ = self.actor(state)
        return action.cpu().detach().numpy().reshape(-1)

    def exploit(self, state):
        if self.cuda:
            state = torch.tensor(state).unsqueeze(0).to('cuda', torch.float)
        _, _, action = self.actor.sample(state)
        return action.cpu().detach().numpy().reshape(-1)

    def store_step(self, state, action, next_state, reward, terminal):
        state = to_tensor_unsqueeze(state)
        if action.dtype == np.float32:
            action = torch.from_numpy(action)
        next_state = to_tensor_unsqueeze(next_state)
        reward = torch.from_numpy(np.array([reward]).astype(np.float))
        terminal = torch.from_numpy(np.array([terminal]).astype(np.uint8))
        self.memory.push(state, action, next_state, reward, terminal)

    def target_update(self, target_net, net):
        for t, s in zip(target_net.parameters(), net.parameters()):
            # t.data.copy_(t.data * (1.0 - self.polyak) + s.data * self.polyak)
            t.data.mul_(1.0 - self.polyak)
            t.data.add_(self.polyak * s.data)

    def calc_target_q(self, next_states, rewards, terminals):
        with torch.no_grad():
            next_action, entropy, _ = self.actor.sample(
                next_states)  # penalty term
            next_q1, next_q2 = self.target_critic(next_states, next_action)
            next_q = torch.min(next_q1, next_q2) - self.alpha * entropy
        target_q = rewards * self.reward_scale + (
            1. - terminals) * self.gamma * next_q
        return target_q

    def calc_critic_loss(self, states, actions, next_states, rewards,
                         terminals):
        q1, q2 = self.critic(states, actions)
        target_q = self.calc_target_q(next_states, rewards, terminals)

        q1_loss = torch.mean((q1 - target_q).pow(2))
        q2_loss = torch.mean((q2 - target_q).pow(2))
        return q1_loss, q2_loss

    def calc_actor_loss(self, states):
        action, entropy, _ = self.actor.sample(states)

        q1, q2 = self.critic(states, action)
        q = torch.min(q1, q2)

        # actor_loss = torch.mean(-q - self.alpha * entropy)
        actor_loss = (self.alpha * entropy - q).mean()
        return actor_loss, entropy

    def train(self, timestep, batch_size=256):
        if len(self.memory) < batch_size:
            return

        transitions = self.memory.sample(batch_size)
        transitions = Transition(*zip(*transitions))
        if self.cuda:
            states = torch.cat(transitions.state).to('cuda')
            actions = torch.stack(transitions.action).to('cuda')
            next_states = torch.cat(transitions.next_state).to('cuda')
            rewards = torch.stack(transitions.reward).to('cuda')
            terminals = torch.stack(transitions.terminal).to('cuda')
        else:
            states = torch.cat(transitions.state)
            actions = torch.stack(transitions.action)
            next_states = torch.cat(transitions.next_state)
            rewards = torch.stack(transitions.reward)
            terminals = torch.stack(transitions.terminal)
        # Compute target Q func
        q1_loss, q2_loss = self.calc_critic_loss(states, actions, next_states,
                                                 rewards, terminals)
        # Compute actor loss
        actor_loss, mean = self.calc_actor_loss(states)
        update_params(self.q1_optimizer, self.critic.q1, q1_loss)
        update_params(self.q2_optimizer, self.critic.q2, q2_loss)
        update_params(self.actor_optimizer, self.actor, actor_loss)
        # target update
        self.target_update(self.target_critic, self.critic)

        if timestep % 100 and self.writer:
            self.writer.add_scalar('Loss/Actor', actor_loss.item(), timestep)
            self.writer.add_scalar('Loss/Critic', q1_loss.item(), timestep)

    def save_weights(self, path):
        self.actor.save(os.path.join(path, 'actor'))
        self.critic.save(os.path.join(path, 'critic'))