class TD3MultiAgent:
    def __init__(self):
      
        self.max_action = 1
        self.policy_freq = 2
        self.policy_freq_it = 0
        self.batch_size = 512
        self.discount = 0.99
        self.replay_buffer = int(1e5)
        
        
        self.device = 'cuda'
        
        self.state_dim = 24
        self.action_dim = 2
        self.max_action = 1
        self.policy_noise = 0.1
        self.agents = 1
        
        self.random_period = 1e4
        
        self.tau = 5e-3
        
        self.replay_buffer = ReplayBuffer(self.replay_buffer)
        
        self.actor = Actor(self.state_dim, self.action_dim, self.max_action).to(self.device)
        self.actor_target = Actor(self.state_dim, self.action_dim, self.max_action).to(self.device)
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=1e-4)
#         self.actor.load_state_dict(torch.load('actor2.pth'))
#         self.actor_target.load_state_dict(torch.load('actor2.pth'))

        self.noise = OUNoise(2, 32)
        
        
        self.critic = Critic(48, self.action_dim).to(self.device)
        self.critic_target = Critic(48, self.action_dim).to(self.device)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)

    
    def select_action_with_noise(self, state, i):
        import pdb
        ratio = len(self.replay_buffer)/self.random_period

        if len(self.replay_buffer)>self.random_period:
            
            state = torch.FloatTensor(state[i,:]).to(self.device)
            action = self.actor(state).cpu().data.numpy()

            if self.policy_noise != 0: 
                action = (action + self.noise.sample())
            return action.clip(-self.max_action,self.max_action)
        
        else:
            q= self.noise.sample()
            return q
   
    
    def step(self, i):
        if len(self.replay_buffer)>self.random_period/2:
            # Sample mini batch
#         if True:
            import pdb
            s, a, r, s_, d = self.replay_buffer.sample(self.batch_size)
            
            state = torch.FloatTensor(s[:,i,:]).to(self.device)
            action = torch.FloatTensor(a[:,i,:]).to(self.device)
            next_state = torch.FloatTensor(s_[:,i,:]).to(self.device)
            
            a_state = torch.FloatTensor(s).to(self.device).reshape(-1,48)
            a_action = torch.FloatTensor(a).to(self.device).reshape(-1,4)
            a_next_state = torch.FloatTensor(s_).to(self.device).reshape(-1,48)
            
            done = torch.FloatTensor(1 - d[:,i]).to(self.device)
            reward = torch.FloatTensor(r[:,i]).to(self.device)
#             pdb.set_trace()
            # Select action with the actor target and apply clipped noise
            noise = torch.FloatTensor(a[:,i,:]).data.normal_(0, self.policy_noise).to(self.device)
            noise = noise.clamp(-0.1,0.1) # NOISE CLIP WTF?
            next_action = (self.actor_target(next_state) + noise).clamp(-self.max_action, self.max_action)
            # Compute the target Q value

            target_Q1, target_Q2 = self.critic_target(a_next_state, next_action)
            target_Q = torch.min(target_Q1, target_Q2)
            target_Q = reward.reshape(-1,1) + (done.reshape(-1,1) * self.discount * target_Q).detach()
            
            # Get current Q estimates
            current_Q1, current_Q2 = self.critic(a_state, action)

            # Compute critic loss
            critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
            # Optimize the critic
            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            self.critic_optimizer.step()

            # Delayed policy updates
            if self.policy_freq_it % self.policy_freq == 0:
                # Compute actor loss
                actor_loss = -self.critic.Q1(a_state, self.actor(state)).mean()
                # Optimize the actor 
                self.actor_optimizer.zero_grad()
                actor_loss.backward()
                self.actor_optimizer.step()

                # Update the frozen target models
                for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                    target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
                for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)


            self.policy_freq_it += 1
        
        return True
        
    
    def reset(self):
        self.policy_freq_it = 0
        self.noise.reset()
Exemple #2
0
class DyNODESacAgent(object):
    """DyNODE-SAC."""
    def __init__(self,
                 obs_shape,
                 action_shape,
                 device,
                 model_kind,
                 kind='D',
                 step_MVE=5,
                 hidden_dim=256,
                 discount=0.99,
                 init_temperature=0.01,
                 alpha_lr=1e-3,
                 alpha_beta=0.9,
                 actor_lr=1e-3,
                 actor_beta=0.9,
                 actor_log_std_min=-10,
                 actor_log_std_max=2,
                 critic_lr=1e-3,
                 critic_beta=0.9,
                 critic_tau=0.005,
                 critic_target_update_freq=2,
                 model_lr=1e-3,
                 log_interval=100):

        self.device = device
        self.discount = discount
        self.critic_tau = critic_tau
        self.critic_target_update_freq = critic_target_update_freq
        self.log_interval = log_interval
        self.step_MVE = step_MVE
        self.model_kind = model_kind

        self.actor = Actor(obs_shape, action_shape, hidden_dim,
                           actor_log_std_min, actor_log_std_max).to(device)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=actor_lr,
                                                betas=(actor_beta, 0.999))

        self.critic = Critic(obs_shape, action_shape, hidden_dim).to(device)
        self.critic_target = Critic(obs_shape, action_shape,
                                    hidden_dim).to(device)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=critic_lr,
                                                 betas=(critic_beta, 0.999))

        self.log_alpha = torch.tensor(np.log(init_temperature)).to(device)
        self.log_alpha.requires_grad = True
        self.target_entropy = -np.prod(
            action_shape)  # set target entropy to -|A|
        self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha],
                                                    lr=alpha_lr,
                                                    betas=(alpha_beta, 0.999))

        if self.model_kind == 'dynode_model':
            self.model = DyNODE(obs_shape,
                                action_shape,
                                hidden_dim_p=200,
                                hidden_dim_r=200).to(device)
        elif self.model_kind == 'nn_model':
            self.model = NN_Model(obs_shape,
                                  action_shape,
                                  hidden_dim_p=200,
                                  hidden_dim_r=200,
                                  kind=kind).to(device)
        else:
            assert 'model is not supported'

        self.model_optimizer = torch.optim.Adam(self.model.parameters(),
                                                lr=model_lr)

        self.train()
        self.critic_target.train()

    def train(self, training=True):
        self.training = training
        self.actor.train(training)
        self.critic.train(training)
        self.model.train(training)

    @property
    def alpha(self):
        return self.log_alpha.exp()

    def select_action(self, obs):
        with torch.no_grad():
            obs = torch.FloatTensor(obs).to(self.device)
            obs = obs.unsqueeze(0)
            mu, _, _, _ = self.actor(obs,
                                     compute_pi=False,
                                     compute_log_pi=False)
            return mu.cpu().data.numpy().flatten()

    def sample_action(self, obs):
        with torch.no_grad():
            obs = torch.FloatTensor(obs).to(self.device)
            obs = obs.unsqueeze(0)
            mu, pi, _, _ = self.actor(obs, compute_log_pi=False)
            return pi.cpu().data.numpy().flatten()

    def update_model(self, replay_buffer, L, step):

        if self.model_kind == 'dynode_model':
            obs_m, action_m, reward_m, next_obs_m, _ = replay_buffer.sample_dynode(
            )
            transition_loss, reward_loss = self.model.loss(
                obs_m, action_m, reward_m, next_obs_m)
            model_loss = transition_loss + reward_loss
        elif self.model_kind == 'nn_model':
            obs, action, reward, next_obs, _ = replay_buffer.sample()
            transition_loss, reward_loss = self.model.loss(
                obs, action, reward, next_obs)
            model_loss = transition_loss + reward_loss
        else:
            assert 'model is not supported'

        # Optimize the Model
        self.model_optimizer.zero_grad()
        model_loss.backward()
        self.model_optimizer.step()

        if step % self.log_interval == 0:
            L.log('train/model_loss', model_loss, step)

    def MVE_prediction(self, replay_buffer, L, step):

        obs, action, reward, next_obs, not_done = replay_buffer.sample()

        trajectory = []
        next_ob = next_obs
        with torch.no_grad():
            while len(trajectory) < self.step_MVE:
                ob = next_ob
                _, act, _, _ = self.actor(ob)
                rew, next_ob = self.model(ob, act)
                trajectory.append([ob, act, rew, next_ob])

            _, next_action, log_pi, _ = self.actor(next_ob)
            target_Q1, target_Q2 = self.critic_target(next_ob, next_action)
            ret = torch.min(target_Q1,
                            target_Q2) - self.alpha.detach() * log_pi

        critic_loss = 0
        for ob, act, rew, _ in reversed(trajectory):
            current_Q1, current_Q2 = self.critic(ob, act)
            ret = rew + self.discount * ret
            # critic_loss = critic_loss + utils.huber(current_Q1 - ret).mean() + utils.huber(current_Q2 - ret).mean()
            critic_loss = critic_loss + F.mse_loss(
                current_Q1, ret) + F.mse_loss(current_Q2, ret)
        current_Q1, current_Q2 = self.critic(obs, action)
        ret = reward + self.discount * ret
        # critic_loss = critic_loss + utils.huber(current_Q1 - ret).mean() + utils.huber(current_Q2 - ret).mean()
        critic_loss = critic_loss + F.mse_loss(current_Q1, ret) + F.mse_loss(
            current_Q2, ret)
        critic_loss = critic_loss / (self.step_MVE + 1)

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # actor
        _, pi, log_pi, log_std = self.actor(obs)
        actor_Q1, actor_Q2 = self.critic(obs.detach(), pi)
        actor_Q = torch.min(actor_Q1, actor_Q2)
        actor_loss = (self.alpha.detach() * log_pi - actor_Q).mean()

        # optimize the actor
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        self.log_alpha_optimizer.zero_grad()
        alpha_loss = (self.alpha *
                      (-log_pi - self.target_entropy).detach()).mean()
        alpha_loss.backward()
        self.log_alpha_optimizer.step()

    def update_critic(self, obs, action, reward, next_obs, not_done, L, step):
        with torch.no_grad():
            _, policy_action, log_pi, _ = self.actor(next_obs)
            target_Q1, target_Q2 = self.critic_target(next_obs, policy_action)
            target_V = torch.min(target_Q1,
                                 target_Q2) - self.alpha.detach() * log_pi
            target_Q = reward + (not_done * self.discount * target_V)

        # get current Q estimates
        current_Q1, current_Q2 = self.critic(obs, action)
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(
            current_Q2, target_Q)
        if step % self.log_interval == 0:
            L.log('train_critic/loss', critic_loss, step)

        # Optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        self.critic.log(L, step)

    def update_actor_and_alpha(self, obs, L, step):
        _, pi, log_pi, log_std = self.actor(obs)
        actor_Q1, actor_Q2 = self.critic(obs, pi)

        actor_Q = torch.min(actor_Q1, actor_Q2)
        actor_loss = (self.alpha.detach() * log_pi - actor_Q).mean()

        if step % self.log_interval == 0:
            L.log('train_actor/loss', actor_loss, step)
            L.log('train_actor/target_entropy', self.target_entropy, step)
        entropy = 0.5 * log_std.shape[1] * (
            1.0 + np.log(2 * np.pi)) + log_std.sum(dim=-1)
        if step % self.log_interval == 0:
            L.log('train_actor/entropy', entropy.mean(), step)

        # optimize the actor
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        self.actor.log(L, step)

        self.log_alpha_optimizer.zero_grad()
        alpha_loss = (self.alpha *
                      (-log_pi - self.target_entropy).detach()).mean()
        if step % self.log_interval == 0:
            L.log('train_alpha/loss', alpha_loss, step)
            L.log('train_alpha/value', self.alpha, step)
        alpha_loss.backward()
        self.log_alpha_optimizer.step()

    def update(self, replay_buffer, L, step):

        if step < 2000:
            for _ in range(2):
                obs, action, reward, next_obs, not_done = replay_buffer.sample(
                )
                self.update_critic(obs, action, reward, next_obs, not_done, L,
                                   step)
                self.update_actor_and_alpha(obs, L, step)

            if step % self.log_interval == 0:
                L.log('train/batch_reward', reward.mean(), step)

        else:
            obs, action, reward, next_obs, not_done = replay_buffer.sample()

            if step % self.log_interval == 0:
                L.log('train/batch_reward', reward.mean(), step)

            self.MVE_prediction(replay_buffer, L, step)
            self.update_critic(obs, action, reward, next_obs, not_done, L,
                               step)
            self.update_actor_and_alpha(obs, L, step)

        if step % self.critic_target_update_freq == 0:
            utils.soft_update_params(self.critic.Q1, self.critic_target.Q1,
                                     self.critic_tau)
            utils.soft_update_params(self.critic.Q2, self.critic_target.Q2,
                                     self.critic_tau)

    def save(self, model_dir, step):
        torch.save(self.actor.state_dict(),
                   '%s/actor_%s.pt' % (model_dir, step))
        torch.save(self.critic.state_dict(),
                   '%s/critic_%s.pt' % (model_dir, step))

    def save_model(self, model_dir, step):
        torch.save(self.model.state_dict(),
                   '%s/model_%s.pt' % (model_dir, step))

    def load(self, model_dir, step):
        self.actor.load_state_dict(
            torch.load('%s/actor_%s.pt' % (model_dir, step)))
        self.critic.load_state_dict(
            torch.load('%s/critic_%s.pt' % (model_dir, step)))
Exemple #3
0
class Agent():
    """ Interacts with and learns from the environment. """
    def __init__(self, state_size, action_size, fc1_units, fc2_units):
        """Initialize an Agent object.

        Params
        ======
        state_size (int): dimension of each state
        action_size (int): dimension of each action
        random_seed (int): random seed
        """
        self.state_size = state_size
        self.action_size = action_size
        self.seed = torch.manual_seed(SEED)

        # Actor Network (w/ Target Network)
        self.actor_local = Actor(state_size, action_size, fc1_units,
                                 fc2_units).to(device)
        self.actor_target = Actor(state_size, action_size, fc1_units,
                                  fc2_units).to(device)
        self.actor_optimizer = optim.Adam(self.actor_local.parameters(),
                                          lr=LR_ACTOR)

        # Critic Network (w/ Target Network)
        self.critic_local = Critic(state_size, action_size, fc1_units,
                                   fc2_units).to(device)
        self.critic_target = Critic(state_size, action_size, fc1_units,
                                    fc2_units).to(device)
        self.critic_optimizer = optim.Adam(self.critic_local.parameters(),
                                           lr=LR_CRITIC,
                                           weight_decay=WEIGHT_DECAY)

        # Noise process
        self.noise = OrnsteinUhlenbeck(action_size, SEED)

        # Replay memory
        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, SEED,
                                   device)

    def step(self, time_step, state, action, reward, next_state, done):
        """Save experience in replay memory, and use random sample from buffer to learn."""
        self.memory.add(state, action, reward, next_state, done)

        # Learn only every N_TIME_STEPS
        if time_step % N_TIME_STEPS != 0:
            return

        # Learn if enough samples are available in replay buffer
        if len(self.memory) > BATCH_SIZE:
            for i in range(N_LEARN_UPDATES):
                experiences = self.memory.sample()
                self.learn(experiences, GAMMA)

    def act(self, state, add_noise=True):
        """ Returns actions for given state as per current policy. """
        state = torch.from_numpy(state).float().to(device)
        self.actor_local.eval()
        with torch.no_grad():
            action = self.actor_local(state).cpu().data.numpy()
        self.actor_local.train()
        if add_noise:
            action += self.noise.sample()
        return np.clip(action, -1, 1)

    def reset(self):
        self.noise.reset()

    def learn(self, experiences, gamma):
        """Update policy and value parameters using given batch of experience tuples.
        Q_targets = r + γ * critic_target(next_state, actor_target(next_state))
        where:
            actor_target(state) -> action
            critic_target(state, action) -> Q-value
        Params
        ======
            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples
            gamma (float): discount factor
        """
        states, actions, rewards, next_states, dones = experiences

        # ---------------------------- update critic ---------------------------- #
        # Get predicted next-state actions and Q values from target models
        actions_next = self.actor_target(next_states)
        Q_targets_next = self.critic_target(next_states, actions_next)
        # Compute Q targets from current states (y_i)
        Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))
        # Compute critic loss
        Q_expected = self.critic_local(states, actions)
        critic_loss = F.mse_loss(Q_expected, Q_targets)
        # Minimize the loss
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.critic_local.parameters(), 1)
        self.critic_optimizer.step()

        # ---------------------------- update actor ---------------------------- #
        # Compute actor loss
        actions_pred = self.actor_local(states)
        actor_loss = -self.critic_local(states, actions_pred).mean()
        # Minimize the loss
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # ----------------------- update target networks ----------------------- #
        self.soft_update(self.critic_local, self.critic_target, TAU)
        self.soft_update(self.actor_local, self.actor_target, TAU)

    def soft_update(self, local_model, target_model, tau):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target
        Params
        ======
        local_model: PyTorch model (weights will be copied from)
        target_model: PyTorch model (weights will be copied to)
        tau (float): interpolation parameter
        """
        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)

    def store(self):
        torch.save(self.actor_local.state_dict(), 'checkpoint_actor.pth')
        torch.save(self.critic_local.state_dict(), 'checkpoint_critic.pth')

    def load(self):
        if os.path.isfile('checkpoint_actor.pth') and os.path.isfile(
                'checkpoint_critic.pth'):
            print("=> loading checkpoints for Actor and Critic... ")
            self.actor_local.load_state_dict('checkpoint_actor')
            self.critic_local.load_state_dict('checkpoint_critic')
            print("done !")
        else:
            print("no checkpoints found for Actor and Critic...")
Exemple #4
0
class A2C():
    def __init__(self, state_dim, action_dim, action_lim, update_type='soft',
                lr_actor=1e-4, lr_critic=1e-3, tau=1e-3,
                mem_size=1e6, batch_size=256, gamma=0.99,
                other_cars=False, ego_dim=None):
        self.device = torch.device("cuda:0" if torch.cuda.is_available()
                                        else "cpu")

        self.joint_model = False
        if len(state_dim) == 3:
            self.model = ActorCriticCNN(state_dim, action_dim, action_lim)
            self.model_optim = optim.Adam(self.model.parameters(), lr=lr_actor)

            self.target_model = ActorCriticCNN(state_dim, action_dim, action_lim)
            self.target_model.load_state_dict(self.model.state_dict())

            self.model.to(self.device)
            self.target_model.to(self.device)

            self.joint_model = True
        else:
            self.actor = Actor(state_dim, action_dim, action_lim, other_cars=other_cars, ego_dim=ego_dim)
            self.actor_optim = optim.Adam(self.actor.parameters(), lr=lr_actor)
            self.target_actor = Actor(state_dim, action_dim, action_lim, other_cars=other_cars, ego_dim=ego_dim)
            self.target_actor.load_state_dict(self.actor.state_dict())
            self.target_actor.eval()

            self.critic = Critic(state_dim, action_dim, other_cars=other_cars, ego_dim=ego_dim)
            self.critic_optim = optim.Adam(self.critic.parameters(), lr=lr_critic, weight_decay=1e-2)
            self.target_critic = Critic(state_dim, action_dim, other_cars=other_cars, ego_dim=ego_dim)
            self.target_critic.load_state_dict(self.critic.state_dict())
            self.target_critic.eval()

            self.actor.to(self.device)
            self.target_actor.to(self.device)
            self.critic.to(self.device)
            self.target_critic.to(self.device)

        self.action_lim = action_lim
        self.tau = tau # hard update if tau is None
        self.update_type = update_type
        self.batch_size = batch_size
        self.gamma = gamma

        if self.joint_model:
            mem_size = mem_size//100
        self.memory = Memory(int(mem_size), action_dim, state_dim)

        mu = np.zeros(action_dim)
        sigma = np.array([0.5, 0.05])
        self.noise = OrnsteinUhlenbeckActionNoise(mu, sigma)
        self.target_noise = OrnsteinUhlenbeckActionNoise(mu, sigma)

        self.initialised = True
        self.training = False

    def select_action(self, obs):
        with torch.no_grad():
            obs = torch.FloatTensor(np.expand_dims(obs, axis=0)).to(self.device)
            if self.joint_model:
                action, _ = self.model(obs)
                action = action.data.cpu().numpy().flatten()
            else:
                action = self.actor(obs).data.cpu().numpy().flatten()

        if self.training:
            action += self.noise()
            return action
        else:
            return action

    def append(self, obs0, action, reward, obs1, terminal1):
        self.memory.append(obs0, action, reward, obs1, terminal1)

    def reset_noise(self):
        self.noise.reset()
        self.target_noise.reset()

    def train(self):
        if self.joint_model:
            self.model.train()
            self.target_model.train()
        else:
            self.actor.train()
            self.target_actor.train()
            self.critic.train()
            self.target_critic.train()

        self.training = True

    def eval(self):
        if self.joint_model:
            self.model.eval()
            self.target_model.eval()
        else:
            self.actor.eval()
            self.target_actor.eval()
            self.critic.eval()
            self.target_critic.eval()

        self.training = False

    def save(self, folder, episode, previous=None, solved=False):
        filename = lambda type, ep : folder + '%s' % type + \
                                    (not solved) * ('_ep%d' % (ep)) + \
                                    (solved * '_solved') + '.pth'

        if self.joint_model:
            torch.save(self.model.state_dict(), filename('model', episode))
            torch.save(self.target_model.state_dict(), filename('target_model', episode))
        else:
            torch.save(self.actor.state_dict(), filename('actor', episode))
            torch.save(self.target_actor.state_dict(), filename('target_actor', episode))

            torch.save(self.critic.state_dict(), filename('critic', episode))
            torch.save(self.target_critic.state_dict(), filename('target_critic', episode))

        if previous is not None and previous > 0:
            if self.joint_model:
                os.remove(filename('model', previous))
                os.remove(filename('target_model', previous))
            else:
                os.remove(filename('actor', previous))
                os.remove(filename('target_actor', previous))
                os.remove(filename('critic', previous))
                os.remove(filename('target_critic', previous))

    def load_actor(self, actor_filepath):
        qualifier = '_' + actor_filepath.split("_")[-1]
        folder = actor_filepath[:actor_filepath.rfind("/")+1]
        filename = lambda type : folder + '%s' % type + qualifier

        if self.joint_model:
            self.model.load_state_dict(torch.load(filename('model'),
                                                    map_location=self.device))
            self.target_model.load_state_dict(torch.load(filename('target_model'),
                                                    map_location=self.device))
        else:
            self.actor.load_state_dict(torch.load(filename('actor'),
                                                    map_location=self.device))
            self.target_actor.load_state_dict(torch.load(filename('target_actor'),
                                                    map_location=self.device))

    def load_all(self, actor_filepath):
        self.load_actor(actor_filepath)
        qualifier = '_' + actor_filepath.split("_")[-1]
        folder = actor_filepath[:actor_filepath.rfind("/")+1]
        filename = lambda type : folder + '%s' % type + qualifier

        if not self.joint_model:
            self.critic.load_state_dict(torch.load(filename('critic'),
                                                    map_location=self.device))
            self.target_critic.load_state_dict(torch.load(filename('target_critic'),
                                                    map_location=self.device))

    def update(self, target_noise=True):
        try:
            minibatch = self.memory.sample(self.batch_size) # dict of ndarrays
        except ValueError as e:
            print('Replay memory not big enough. Continue.')
            return None, None

        states = Variable(torch.FloatTensor(minibatch['obs0'])).to(self.device)
        actions = Variable(torch.FloatTensor(minibatch['actions'])).to(self.device)
        rewards = Variable(torch.FloatTensor(minibatch['rewards'])).to(self.device)
        next_states = Variable(torch.FloatTensor(minibatch['obs1'])).to(self.device)
        terminals = Variable(torch.FloatTensor(minibatch['terminals1'])).to(self.device)

        if self.joint_model:
            target_actions, _ = self.target_model(next_states)
            if target_noise:
                for sample in range(target_actions.shape[0]):
                    target_actions[sample] += self.target_noise()
                    target_actions[sample].clamp(-self.action_lim, self.action_lim)
            _, target_qvals = self.target_model(next_states, target_actions=target_actions)
            y = rewards + self.gamma * (1 - terminals) * target_qvals

            _, model_qvals = self.model(states, target_actions=actions)
            value_loss = F.mse_loss(y, model_qvals)
            model_actions, _ = self.model(states)
            _, model_qvals = self.model(states, target_actions=model_actions)
            action_loss = -model_qvals.mean()

            self.model_optim.zero_grad()
            (value_loss + action_loss).backward()
            self.model_optim.step()
        else:
            target_actions = self.target_actor(next_states)
            if target_noise:
                for sample in range(target_actions.shape[0]):
                    target_actions[sample] += self.target_noise()
                    target_actions[sample].clamp(-self.action_lim, self.action_lim)
            target_critic_qvals = self.target_critic(next_states, target_actions)
            y = rewards + self.gamma * (1 - terminals) * target_critic_qvals

            # optimise critic
            critic_qvals = self.critic(states, actions)
            value_loss = F.mse_loss(y, critic_qvals)
            self.critic_optim.zero_grad()
            value_loss.backward()
            self.critic_optim.step()

            # optimise actor
            action_loss = -self.critic(states, self.actor(states)).mean()
            self.actor_optim.zero_grad()
            action_loss.backward()
            self.actor_optim.step()

        # optimise target networks
        if self.update_type == 'soft':
            if self.joint_model:
                soft_update(self.target_model, self.model, self.tau)
            else:
                soft_update(self.target_actor, self.actor, self.tau)
                soft_update(self.target_critic, self.critic, self.tau)
        else:
            if self.joint_model:
                hard_update(self.target_model, self.model)
            else:
                hard_update(self.target_actor, self.actor)
                hard_update(self.target_critic, self.critic)

        return action_loss.item(), value_loss.item()
class DDPG():
    """ Deep Deterministic Policy Gradients Agent used to interaction with and learn from an environment """
    def __init__(self, state_size: int, action_size: int, num_agents: int,
                 epsilon, random_seed: int):
        """ Initialize a DDPG Agent Object

        :param state_size: dimension of state (input)
        :param action_size: dimension of action (output)
        :param num_agents: number of concurrent agents in the environment
        :param epsilon: initial value of epsilon for exploration
        :param random_seed: random seed
        """
        self.state_size = state_size
        self.action_size = action_size
        self.num_agents = num_agents
        self.seed = random.seed(random_seed)
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        self.t_step = 0

        # Hyperparameters
        self.buffer_size = 1000000
        self.batch_size = 128
        self.update_every = 10
        self.num_updates = 10
        self.gamma = 0.99
        self.tau = 0.001
        self.lr_actor = 0.0001
        self.lr_critic = 0.001
        self.weight_decay = 0
        self.epsilon = epsilon
        self.epsilon_decay = 0.97
        self.epsilon_min = 0.005

        # Networks (Actor: State -> Action, Critic: (State,Action) -> Value)
        self.actor_local = Actor(self.state_size, self.action_size,
                                 random_seed).to(self.device)
        self.actor_target = Actor(self.state_size, self.action_size,
                                  random_seed).to(self.device)
        self.actor_optimizer = optim.Adam(self.actor_local.parameters(),
                                          lr=self.lr_actor)
        self.critic_local = Critic(self.state_size, self.action_size,
                                   random_seed).to(self.device)
        self.critic_target = Critic(self.state_size, self.action_size,
                                    random_seed).to(self.device)
        self.critic_optimizer = optim.Adam(self.critic_local.parameters(),
                                           lr=self.lr_critic,
                                           weight_decay=self.weight_decay)
        # Initialize actor and critic networks to start with same parameters
        self.soft_update(self.actor_local, self.actor_target, tau=1)
        self.soft_update(self.critic_local, self.critic_target, tau=1)

        # Noise Setup
        self.noise = OUNoise(self.action_size, random_seed)

        # Replay Buffer Setup
        self.memory = ReplayBuffer(self.buffer_size, self.batch_size)

    def __str__(self):
        return "DDPG_Agent"

    def train(self,
              env,
              brain_name,
              num_episodes=200,
              max_time=1000,
              print_every=10):
        """ Interacts with and learns from a given Unity Environment

        :param env: Unity Environment the agents is trying to learn
        :param brain_name: Brain for Environment
        :param num_episodes: Number of episodes to train
        :param max_time: How long each episode runs for
        :param print_every: How often in episodes to print a running average
        :return: Returns episodes scores and 100 episode averages as lists
        """
        # --------- Set Everything up --------#
        scores = []
        avg_scores = []
        scores_deque = deque(maxlen=print_every)

        # -------- Simulation Loop --------#
        for episode_num in range(1, num_episodes + 1):
            # Reset everything
            env_info = env.reset(train_mode=True)[brain_name]
            states = env_info.vector_observations
            episode_scores = np.zeros(self.num_agents)
            self.reset_noise()
            # Run the episode
            for t in range(max_time):
                actions = self.act(states, self.epsilon)
                env_info = env.step(actions)[brain_name]
                next_states, rewards, dones = env_info.vector_observations, env_info.rewards, env_info.local_done
                self.step(states, actions, rewards, next_states, dones)
                episode_scores += rewards
                states = next_states
                if np.any(dones):
                    break

            # -------- Episode Finished ---------#
            self.epsilon *= self.epsilon_decay
            self.epsilon = max(self.epsilon, self.epsilon_min)
            scores.append(np.mean(episode_scores))
            scores_deque.append(np.mean(episode_scores))
            avg_scores.append(np.mean(scores_deque))
            if episode_num % print_every == 0:
                print(
                    f'Episode: {episode_num} \tAverage Score: {round(np.mean(scores_deque), 2)}'
                )
                torch.save(
                    self.actor_local.state_dict(),
                    f'{PATH}\checkpoints\{self.__str__()}_Actor_Multiple.pth')
                torch.save(
                    self.critic_local.state_dict(),
                    f'{PATH}\checkpoints\{self.__str__()}_Critic_Multiple.pth')

        # -------- All Episodes finished Save parameters and scores --------#
        # Save Model Parameters
        torch.save(self.actor_local.state_dict(),
                   f'{PATH}\checkpoints\{self.__str__()}_Actor_Multiple.pth')
        torch.save(self.critic_local.state_dict(),
                   f'{PATH}\checkpoints\{self.__str__()}_Critic_Multiple.pth')
        # Save mean score per episode (of the 20 agents)
        f = open(f'{PATH}\scores\{self.__str__()}_Multiple_Scores.txt', 'w')
        scores_string = "\n".join([str(score) for score in scores])
        f.write(scores_string)
        f.close()
        # Save average scores for 100 window average
        f = open(f'{PATH}\scores\{self.__str__()}_Multiple_AvgScores.txt', 'w')
        avgScores_string = "\n".join([str(score) for score in avg_scores])
        f.write(avgScores_string)
        f.close()
        return scores, avg_scores

    def step(self, states, actions, rewards, next_states, dones):
        """ what the agent needs to do for every time step that occurs in the environment. Takes
        in a (s,a,r,s',d) tuple and saves it to memeory and learns from experiences. Note: this is not
        the same as a step in the environment. Step is only called once per environment time step.

        :param states: array of states agent used to select actions
        :param actions: array of actions taken by agents
        :param rewards: array of rewards for last action taken in environment
        :param next_states: array of next states after actions were taken
        :param dones: array of bools representing if environment is finished or not
        """
        # Save experienced in replay memory
        for agent_num in range(self.num_agents):
            self.memory.add(states[agent_num], actions[agent_num],
                            rewards[agent_num], next_states[agent_num],
                            dones[agent_num])

        # Learn "num_updates" times every "update_every" time step
        self.t_step += 1
        if len(self.memory
               ) > self.batch_size and self.t_step % self.update_every == 0:
            self.t_step = 0
            for _ in range(self.num_updates):
                experiences = self.memory.sample()
                self.learn(experiences)

    def act(self, states, epsilon, add_noise=True):
        """ Returns actions for given states as per current policy. Policy comes from the actor network.

        :param states: array of states from the environment
        :param epsilon: probability of exploration
        :param add_noise: bool on whether or not to potentially have exploration for action
        :return: clipped actions
        """
        states = torch.from_numpy(states).float().to(self.device)
        self.actor_local.eval()  # Sets to eval mode (no gradients)
        with torch.no_grad():
            actions = self.actor_local(states).cpu().data.numpy()
        self.actor_local.train()  # Sets to train mode (gradients back on)
        if add_noise and epsilon > np.random.random():
            actions += [self.noise.sample() for _ in range(self.num_agents)]
        return np.clip(actions, -1, 1)

    def reset_noise(self):
        """ resets to noise parameters """
        self.noise.reset()

    def learn(self, experiences):
        """ Update actor and critic networks using a given batch of experiences
        Q_targets = r + γ * critic_target(next_state, actor_target(next_state))
        where:
            actor_target(states) -> actions
            critic_target(states, actions) -> Q-value
        :param experiences: tuple of arrays (states, actions, rewards, next_states, dones)  sampled from the replay buffer
        """

        states, actions, rewards, next_states, dones = experiences
        # -------------------- Update Critic -------------------- #
        # Use target networks for getting next actions and q values and calculate q_targets
        next_actions = self.actor_target(next_states)
        next_q_targets = self.critic_target(next_states, next_actions)
        q_targets = rewards + (self.gamma * next_q_targets * (1 - dones))
        # Compute critic loss (Same as DQN Loss)
        q_expected = self.critic_local(states, actions)
        critic_loss = F.mse_loss(q_expected, q_targets)
        # Minimize loss
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # -------------------- Update Actor --------------------- #
        # Computer actor loss (maximize mean of Q(states,actions))
        action_preds = self.actor_local(states)
        # Optimizer minimizes and we want to maximize so multiply by -1
        actor_loss = -1 * self.critic_local(states, action_preds).mean()
        # Minimize the loss
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        #---------------- Update Target Networks ---------------- #
        self.soft_update(self.critic_local, self.critic_target, self.tau)
        self.soft_update(self.actor_local, self.actor_target, self.tau)

    def soft_update(self, local_network, target_network, tau):
        """ soft update newtwork parametes
        θ_target = τ*θ_local + (1 - τ)*θ_target

        :param local_network: PyTorch Network that is always up to date
        :param target_network: PyTorch Network that is not up to date
        :param tau: update (interpolation) parameter
        """
        for target_param, local_param in zip(target_network.parameters(),
                                             local_network.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)
Exemple #6
0
class WGAN(object):
    def __init__(self, image_size, input_channels, hidden_channels, output_channels, latent_dimension, lr, device, clamp=0.01, gp_weight=10):
        self.image_size = image_size
        self.input_channels = input_channels
        self.hidden_chanels = hidden_channels
        self.output_channels = output_channels
        self.latent_dimension = latent_dimension
        self.device = device
        self.clamp = clamp
        self.gp_weight = gp_weight

        self.critic = Critic(image_size, hidden_channels,
                             input_channels).to(device)
        self.generator = Generator(
            image_size, latent_dimension, hidden_channels, output_channels).to(device)

        self.critic.apply(self.weights_init)
        self.generator.apply(self.weights_init)

        self.optimizer_critic = torch.optim.RMSprop(
            self.critic.parameters(), lr)
        self.optimizer_gen = torch.optim.RMSprop(
            self.generator.parameters(), lr)

        self.optimizer_critic = torch.optim.Adam(
            self.critic.parameters(), lr, betas=(0, 0.9))
        self.optimizer_gen = torch.optim.Adam(
            self.generator.parameters(), lr, betas=(0, 0.9))

        self.critic_losses = []
        self.gen_losses = []

        self.losses = []

    def critique(self, x):
        return self.critic(x)

    def generate(self, z):
        return self.generator(z)

    def load_model(self, load_name):
        self.critic.load_state_dict(torch.load(
            load_name + '_critic', map_location='cpu'))
        self.generator.load_state_dict(torch.load(
            load_name + '_generator', map_location='cpu'))

    def _gradient_penalty(self, real_data, generated_data):
        batch_size = real_data.size()[0]

        # Calculate interpolation
        alpha = torch.rand(batch_size, 1, 1, 1)
        alpha = alpha.expand_as(real_data)

        if self.device != 'cpu':
            alpha = alpha.cuda()
        interpolated = alpha * real_data.data + \
            (1 - alpha) * generated_data.data
        interpolated = Variable(interpolated, requires_grad=True)
        if self.device != 'cpu':
            interpolated = interpolated.cuda()

        # Calculate probability of interpolated examples
        prob_interpolated = self.critique(interpolated).to(self.device)

        # Calculate gradients of probabilities with respect to examples
        gradients = torch_grad(outputs=prob_interpolated, inputs=interpolated,
                               grad_outputs=torch.ones(prob_interpolated.size()).to(self.device), create_graph=True, retain_graph=True)[0]

        # Gradients have shape (batch_size, num_channels, img_width, img_height),
        # so flatten to easily take norm per example in batch
        gradients = gradients.view(batch_size, -1)

        # Derivatives of the gradient close to 0 can cause problems because of
        # the square root, so manually calculate norm and add epsilon
        gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)

        # Return gradient penalty
        return self.gp_weight * ((gradients_norm - 1) ** 2).mean()

    def weights_init(self, m):
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

    def normalise_pixels(self, images):
        normalised_images = (images - 0.5)/0.5
        return normalised_images

    def denormalise_pixels(self, images):
        denormalised_images = (images * 0.5) + 0.5
        #denormalised_images = np.clip(denormalised_images, 0, 1)
        return denormalised_images

    def train(self, data_loader, num_epochs, n_iter=5, log_iter=1, test_iter=1, test_latents=None, save_name=None, load_name=None):
        if load_name is not None:
            self.load_model(load_name)

        d_err = 0
        g_err = 0
        for epoch in range(num_epochs):
            if epoch % test_iter == 0 and test_latents is not None:
                images = self.generate(test_latents)
                self.display_images(images)

            if epoch < 0:
                critic_iters = 100

            else:
                critic_iters = n_iter

            j = 1
            for i, batch in enumerate(data_loader):
                real_images, _ = batch
                real_images = self.normalise_pixels(real_images)

                real_images = real_images.to(self.device)
                d_real = self.critique(real_images)

                self.optimizer_critic.zero_grad()
                # Generate a batch of latents
                latent_zs = generate_latent(
                    self.latent_dimension, len(real_images)).to(self.device)

                # Transform latents into images using the generator
                fake_images = self.generate(latent_zs).to(self.device)

                # Classify fakes with discriminator
                d_fake = self.critique(fake_images.detach())

                gradient_penalty = self._gradient_penalty(
                    real_images, fake_images)

                d_err = -(torch.mean(d_real) - torch.mean(d_fake)) + \
                    gradient_penalty
                d_err.backward()

                # Gradient descent step for discriminator
                self.optimizer_critic.step()

#        for p in self.critic.parameters():
#          p.data.clamp_(-self.clamp, self.clamp)

                j += 1

                if j % critic_iters == 0:
                    ### Generator ###
                    self.optimizer_gen.zero_grad()

                    # Generate a batch of latents
                    latent_zs = generate_latent(
                        self.latent_dimension, len(real_images)).to(self.device)
                    # Transform latents into images using the generator
                    fake_images = self.generate(latent_zs).to(self.device)
                    d_fake = self.critique(fake_images)
                    g_err = -torch.mean(d_fake)
                    g_err.backward()

                    self.optimizer_gen.step()

                    j = 1

            if epoch % log_iter == 0:
                self.critic_losses.append(d_err)
                self.gen_losses.append(g_err)
                print('Epoch %d: Wasserstein Loss: %.4f\t Generator Loss: %.4f' % (
                    epoch, -d_err, g_err))

                if save_name is not None:
                    torch.save(self.critic.state_dict(), save_name + '_critic')
                    torch.save(self.generator.state_dict(),
                               save_name + '_generator')

                    with open(save_name + '_critic_losses.pkl', 'wb') as f:
                        pickle.dump(self.critic_losses, f)

                    with open(save_name + '_gen_losses.pkl', 'wb') as f:
                        pickle.dump(self.gen_losses, f)

    def display_images(self, images):
        k = 0

        fig, ax = plt.subplots(1, 5)

        for i in range(5):
            image = self.denormalise_pixels(
                images[k]).cpu().detach().permute(1, 2, 0)

            ax[i].imshow(image)

            ax[i].spines["top"].set_visible(False)
            ax[i].spines["right"].set_visible(False)
            ax[i].spines["bottom"].set_visible(False)
            ax[i].spines["left"].set_visible(False)
            ax[i].set_xticks([])
            ax[i].set_yticks([])

            k += 1

        fig.set_figheight(20)
        fig.set_figwidth(20)

        plt.show()