Example #1
0
class DDPG(object):
    def __init__(self,
                 n_states,
                 n_actions,
                 model_name='',
                 alr=0.001,
                 clr=0.001,
                 gamma=0.9,
                 batch_size=32,
                 tau=0.002,
                 shift=0,
                 memory_size=100000,
                 a_hidden_sizes=[128, 128, 64],
                 c_hidden_sizes=[128, 256, 64],
                 use_default=False):
        self.n_states = n_states
        self.n_actions = n_actions
        self.alr = alr
        self.clr = clr
        self.model_name = model_name
        self.batch_size = batch_size
        self.gamma = gamma
        self.tau = tau
        self.a_hidden_sizes = a_hidden_sizes
        self.c_hidden_sizes = c_hidden_sizes
        self.shift = shift
        self.use_default = use_default

        self._build_network()

        self.replay_memory = PrioritizedReplayMemory(capacity=memory_size)
        self.noise = OUProcess(n_actions)

    @staticmethod
    def totensor(x):
        return Variable(torch.FloatTensor(x))

    def _build_network(self):
        self.actor = Actor(self.n_states, self.n_actions, self.a_hidden_sizes,
                           self.use_default)
        self.target_actor = Actor(self.n_states, self.n_actions,
                                  self.a_hidden_sizes, self.use_default)
        self.critic = Critic(self.n_states, self.n_actions,
                             self.c_hidden_sizes, self.use_default)
        self.target_critic = Critic(self.n_states, self.n_actions,
                                    self.c_hidden_sizes, self.use_default)

        # Copy actor's parameters
        self._update_target(self.target_actor, self.actor, tau=1.0)

        # Copy critic's parameters
        self._update_target(self.target_critic, self.critic, tau=1.0)

        self.loss_criterion = nn.MSELoss()
        self.actor_optimizer = optimizer.Adam(lr=self.alr,
                                              params=self.actor.parameters(),
                                              weight_decay=1e-5)
        self.critic_optimizer = optimizer.Adam(lr=self.clr,
                                               params=self.critic.parameters(),
                                               weight_decay=1e-5)

    @staticmethod
    def _update_target(target, source, tau):
        for (target_param, param) in zip(target.parameters(),
                                         source.parameters()):
            target_param.data.copy_(target_param.data * (1 - tau) +
                                    param.data * tau)

    def reset(self, sigma, theta):
        self.noise.reset(sigma, theta)

    def _sample_batch(self):
        batch, idx = self.replay_memory.sample(self.batch_size)
        states = list(map(lambda x: x[0].tolist(), batch))  # pylint: disable=W0141
        actions = list(map(lambda x: x[1].tolist(), batch))  # pylint: disable=W0141
        rewards = list(map(lambda x: x[2], batch))  # pylint: disable=W0141
        next_states = list(map(lambda x: x[3].tolist(), batch))  # pylint: disable=W0141

        return idx, states, next_states, actions, rewards

    def add_sample(self, state, action, reward, next_state):
        self.critic.eval()
        self.actor.eval()
        self.target_critic.eval()
        self.target_actor.eval()
        batch_state = self.totensor([state.tolist()])
        batch_next_state = self.totensor([next_state.tolist()])
        current_value = self.critic(batch_state,
                                    self.totensor([action.tolist()]))
        target_action = self.target_actor(batch_next_state)
        target_value = self.totensor([reward]) \
            + self.target_critic(batch_next_state, target_action) * self.gamma
        error = float(torch.abs(current_value - target_value).data.numpy()[0])

        self.target_actor.train()
        self.actor.train()
        self.critic.train()
        self.target_critic.train()
        self.replay_memory.add(error, (state, action, reward, next_state))

    def update(self):
        idxs, states, next_states, actions, rewards = self._sample_batch()
        batch_states = self.totensor(states)
        batch_next_states = self.totensor(next_states)
        batch_actions = self.totensor(actions)
        batch_rewards = self.totensor(rewards)

        target_next_actions = self.target_actor(batch_next_states).detach()
        target_next_value = self.target_critic(batch_next_states,
                                               target_next_actions).detach()
        current_value = self.critic(batch_states, batch_actions)
        next_value = batch_rewards + target_next_value * self.gamma + self.shift

        # update prioritized memory
        if isinstance(self.replay_memory, PrioritizedReplayMemory):
            error = torch.abs(current_value - next_value).data.numpy()
            for i in range(self.batch_size):
                idx = idxs[i]
                self.replay_memory.update(idx, error[i][0])

        # Update Critic
        loss = self.loss_criterion(current_value, next_value)
        self.critic_optimizer.zero_grad()
        loss.backward()
        self.critic_optimizer.step()

        # Update Actor
        self.critic.eval()
        policy_loss = -self.critic(batch_states, self.actor(batch_states))
        policy_loss = policy_loss.mean()
        self.actor_optimizer.zero_grad()
        policy_loss.backward()

        self.actor_optimizer.step()
        self.critic.train()

        self._update_target(self.target_critic, self.critic, tau=self.tau)
        self._update_target(self.target_actor, self.actor, tau=self.tau)

        return loss.data, policy_loss.data

    def choose_action(self, states):
        self.actor.eval()
        act = self.actor(self.totensor([states.tolist()])).squeeze(0)
        self.actor.train()
        action = act.data.numpy()
        action += self.noise.noise()
        return action.clip(0, 1)

    def set_model(self, actor_dict, critic_dict):
        self.actor.load_state_dict(pickle.loads(actor_dict))
        self.critic.load_state_dict(pickle.loads(critic_dict))

    def get_model(self):
        return pickle.dumps(self.actor.state_dict()), pickle.dumps(
            self.critic.state_dict())
Example #2
0
class DDPG(object):
    def __init__(self,
                 n_states,
                 n_actions,
                 model_name='',
                 alr=0.001,
                 clr=0.001,
                 gamma=0.9,
                 batch_size=32,
                 tau=0.002,
                 memory_size=100000,
                 ouprocess=True,
                 mean_var_path=None,
                 supervised=False):
        self.n_states = n_states
        self.n_actions = n_actions
        self.alr = alr
        self.clr = clr
        self.model_name = model_name
        self.batch_size = batch_size
        self.gamma = gamma
        self.tau = tau
        self.ouprocess = ouprocess

        if mean_var_path is None:
            mean = np.zeros(n_states)
            var = np.zeros(n_states)
        elif not os.path.exists(mean_var_path):
            mean = np.zeros(n_states)
            var = np.zeros(n_states)
        else:
            with open(mean_var_path, 'rb') as f:
                mean, var = pickle.load(f)

        self.normalizer = Normalizer(mean, var)

        if supervised:
            self._build_actor()
            LOG.info("Supervised Learning Initialized")
        else:
            # Build Network
            self._build_network()

        self.replay_memory = PrioritizedReplayMemory(capacity=memory_size)
        self.noise = OUProcess(n_actions)

    @staticmethod
    def totensor(x):
        return Variable(torch.FloatTensor(x))

    def _build_actor(self):
        if self.ouprocess:
            noisy = False
        else:
            noisy = True
        self.actor = Actor(self.n_states, self.n_actions, noisy=noisy)
        self.actor_criterion = nn.MSELoss()
        self.actor_optimizer = optimizer.Adam(lr=self.alr,
                                              params=self.actor.parameters())

    def _build_network(self):
        if self.ouprocess:
            noisy = False
        else:
            noisy = True
        self.actor = Actor(self.n_states, self.n_actions, noisy=noisy)
        self.target_actor = Actor(self.n_states, self.n_actions)
        self.critic = Critic(self.n_states, self.n_actions)
        self.target_critic = Critic(self.n_states, self.n_actions)

        # if model params are provided, load them
        if len(self.model_name):
            self.load_model(model_name=self.model_name)
            LOG.info("Loading model from file: %s", self.model_name)

        # Copy actor's parameters
        self._update_target(self.target_actor, self.actor, tau=1.0)

        # Copy critic's parameters
        self._update_target(self.target_critic, self.critic, tau=1.0)

        self.loss_criterion = nn.MSELoss()
        self.actor_optimizer = optimizer.Adam(lr=self.alr,
                                              params=self.actor.parameters(),
                                              weight_decay=1e-5)
        self.critic_optimizer = optimizer.Adam(lr=self.clr,
                                               params=self.critic.parameters(),
                                               weight_decay=1e-5)

    @staticmethod
    def _update_target(target, source, tau):
        for (target_param, param) in zip(target.parameters(),
                                         source.parameters()):
            target_param.data.copy_(target_param.data * (1 - tau) +
                                    param.data * tau)

    def reset(self, sigma):
        self.noise.reset(sigma)

    def _sample_batch(self):
        batch, idx = self.replay_memory.sample(self.batch_size)
        # batch = self.replay_memory.sample(self.batch_size)
        states = list(map(lambda x: x[0].tolist(), batch))  # pylint: disable=W0141
        next_states = list(map(lambda x: x[3].tolist(), batch))  # pylint: disable=W0141
        actions = list(map(lambda x: x[1].tolist(), batch))  # pylint: disable=W0141
        rewards = list(map(lambda x: x[2], batch))  # pylint: disable=W0141
        terminates = list(map(lambda x: x[4], batch))  # pylint: disable=W0141

        return idx, states, next_states, actions, rewards, terminates

    def add_sample(self, state, action, reward, next_state, terminate):
        self.critic.eval()
        self.actor.eval()
        self.target_critic.eval()
        self.target_actor.eval()
        batch_state = self.normalizer([state.tolist()])
        batch_next_state = self.normalizer([next_state.tolist()])
        current_value = self.critic(batch_state,
                                    self.totensor([action.tolist()]))
        target_action = self.target_actor(batch_next_state)
        target_value = self.totensor([reward]) \
            + self.totensor([0 if x else 1 for x in [terminate]]) \
            * self.target_critic(batch_next_state, target_action) * self.gamma
        error = float(torch.abs(current_value - target_value).data.numpy()[0])

        self.target_actor.train()
        self.actor.train()
        self.critic.train()
        self.target_critic.train()
        self.replay_memory.add(error,
                               (state, action, reward, next_state, terminate))

    def update(self):
        idxs, states, next_states, actions, rewards, terminates = self._sample_batch(
        )
        batch_states = self.normalizer(states)
        batch_next_states = self.normalizer(next_states)
        batch_actions = self.totensor(actions)
        batch_rewards = self.totensor(rewards)
        mask = [0 if x else 1 for x in terminates]
        mask = self.totensor(mask)

        target_next_actions = self.target_actor(batch_next_states).detach()
        target_next_value = self.target_critic(batch_next_states,
                                               target_next_actions).detach()
        current_value = self.critic(batch_states, batch_actions)
        # TODO (dongshen): This clause is the original clause, but it has some mistakes
        # next_value = batch_rewards +  mask * target_next_value * self.gamma
        # Since terminate is always false, I remove the mask here.
        next_value = batch_rewards + target_next_value * self.gamma
        # Update Critic

        # update prioritized memory
        error = torch.abs(current_value - next_value).data.numpy()
        for i in range(self.batch_size):
            idx = idxs[i]
            self.replay_memory.update(idx, error[i][0])

        loss = self.loss_criterion(current_value, next_value)
        self.critic_optimizer.zero_grad()
        loss.backward()
        self.critic_optimizer.step()

        # Update Actor
        self.critic.eval()
        policy_loss = -self.critic(batch_states, self.actor(batch_states))
        policy_loss = policy_loss.mean()
        self.actor_optimizer.zero_grad()
        policy_loss.backward()

        self.actor_optimizer.step()
        self.critic.train()

        self._update_target(self.target_critic, self.critic, tau=self.tau)
        self._update_target(self.target_actor, self.actor, tau=self.tau)

        return loss.data, policy_loss.data

    def choose_action(self, x):
        """ Select Action according to the current state
        Args:
            x: np.array, current state
        """
        self.actor.eval()
        act = self.actor(self.normalizer([x.tolist()])).squeeze(0)
        self.actor.train()
        action = act.data.numpy()
        if self.ouprocess:
            action += self.noise.noise()
        return action.clip(0, 1)

    def sample_noise(self):
        self.actor.sample_noise()

    def load_model(self, model_name):
        """ Load Torch Model from files
        Args:
            model_name: str, model path
        """
        self.actor.load_state_dict(
            torch.load('{}_actor.pth'.format(model_name)))
        self.critic.load_state_dict(
            torch.load('{}_critic.pth'.format(model_name)))

    def save_model(self, model_name):
        """ Save Torch Model from files
        Args:
            model_dir: str, model dir
            title: str, model name
        """
        torch.save(self.actor.state_dict(), '{}_actor.pth'.format(model_name))

        torch.save(self.critic.state_dict(),
                   '{}_critic.pth'.format(model_name))

    def set_model(self, actor_dict, critic_dict):
        self.actor.load_state_dict(pickle.loads(actor_dict))
        self.critic.load_state_dict(pickle.loads(critic_dict))

    def get_model(self):
        return pickle.dumps(self.actor.state_dict()), pickle.dumps(
            self.critic.state_dict())

    def save_actor(self, path):
        """ save actor network
        Args:
             path, str, path to save
        """
        torch.save(self.actor.state_dict(), path)

    def load_actor(self, path):
        """ load actor network
        Args:
             path, str, path to load
        """
        self.actor.load_state_dict(torch.load(path))

    def train_actor(self, batch_data, is_train=True):
        """ Train the actor separately with data
        Args:
            batch_data: tuple, (states, actions)
            is_train: bool
        Return:
            _loss: float, training loss
        """
        states, action = batch_data

        if is_train:
            self.actor.train()
            pred = self.actor(self.normalizer(states))
            action = self.totensor(action)

            _loss = self.actor_criterion(pred, action)

            self.actor_optimizer.zero_grad()
            _loss.backward()
            self.actor_optimizer.step()

        else:
            self.actor.eval()
            pred = self.actor(self.normalizer(states))
            action = self.totensor(action)
            _loss = self.actor_criterion(pred, action)

        return _loss.data[0]