예제 #1
0
파일: ddpg.py 프로젝트: ypxie/chainerrl
    def compute_actor_loss(self, batch):
        """Compute loss for actor.

        Preconditions:
          q_function must have seen up to s_{t-1} and s_{t-1}.
          policy must have seen up to s_{t-1}.
        Preconditions:
          q_function must have seen up to s_t and s_t.
          policy must have seen up to s_t.
        """

        batch_state = batch['state']
        batch_action = batch['action']
        batch_size = len(batch_action)

        # Estimated policy observes s_t
        onpolicy_actions = self.policy(batch_state, test=False).sample()

        # Q(s_t, mu(s_t)) is evaluated.
        # This should not affect the internal state of Q.
        with state_kept(self.q_function):
            q = self.q_function(batch_state, onpolicy_actions, test=False)

        # Estimated Q-function observes s_t and a_t
        if isinstance(self.q_function, Recurrent):
            self.q_function.update_state(batch_state, batch_action, test=False)

        # Since we want to maximize Q, loss is negation of Q
        loss = -F.sum(q) / batch_size

        # Update stats
        self.average_actor_loss *= self.average_loss_decay
        self.average_actor_loss += ((1 - self.average_loss_decay) *
                                    float(loss.data))
        return loss
예제 #2
0
파일: al.py 프로젝트: pfnet-research/piekd
    def _compute_y_and_t(self, exp_batch):

        batch_state = exp_batch['state']
        batch_size = len(exp_batch['reward'])

        qout = self.q_function(batch_state)

        batch_actions = exp_batch['action']

        batch_q = qout.evaluate_actions(batch_actions)

        # Compute target values

        with chainer.no_backprop_mode():
            target_qout = self.target_q_function(batch_state)

            batch_next_state = exp_batch['next_state']

            with state_kept(self.target_q_function):
                target_next_qout = self.target_q_function(batch_next_state)
            next_q_max = F.reshape(target_next_qout.max, (batch_size, ))

            batch_rewards = exp_batch['reward']
            batch_terminal = exp_batch['is_state_terminal']

            # T Q: Bellman operator
            t_q = batch_rewards + exp_batch['discount'] * \
                (1.0 - batch_terminal) * next_q_max

            # T_AL Q: advantage learning operator
            cur_advantage = F.reshape(
                target_qout.compute_advantage(batch_actions), (batch_size, ))
            tal_q = t_q + self.alpha * cur_advantage

        return batch_q, tal_q
예제 #3
0
파일: pcl.py 프로젝트: takuma-ynd/chainerrl
    def update_on_policy(self, statevar):
        assert self.t_start < self.t

        if not self.disable_online_update:
            next_values = {}
            for t in range(self.t_start + 1, self.t):
                next_values[t - 1] = self.past_values[t]
            if statevar is None:
                next_values[self.t - 1] = chainer.Variable(
                    self.xp.zeros_like(self.past_values[self.t - 1].array))
            else:
                with state_kept(self.model):
                    _, v = self.model(statevar)
                next_values[self.t - 1] = v
            log_probs = {
                t: self.past_action_distrib[t].log_prob(
                    self.xp.asarray(self.xp.expand_dims(a, 0)))
                for t, a in self.past_actions.items()
            }
            self.online_batch_losses.append(
                self.compute_loss(t_start=self.t_start,
                                  t_stop=self.t,
                                  rewards=self.past_rewards,
                                  values=self.past_values,
                                  next_values=next_values,
                                  log_probs=log_probs))
            if len(self.online_batch_losses) == self.batchsize:
                loss = chainerrl.functions.sum_arrays(
                    self.online_batch_losses) / self.batchsize
                self.update(loss)
                self.online_batch_losses = []

        self.init_history_data_for_online_update()
예제 #4
0
파일: acer.py 프로젝트: rhythm92/chainerrl
    def update_on_policy(self, statevar):
        assert self.t_start < self.t

        if not self.disable_online_update:
            if statevar is None:
                R = 0
            else:
                with chainer.no_backprop_mode():
                    with state_kept(self.model):
                        action_distrib, action_value = self.model(statevar)
                        v = compute_state_value_as_expected_action_value(
                            action_value, action_distrib)
                R = v
            self.update(t_start=self.t_start,
                        t_stop=self.t,
                        R=R,
                        states=self.past_states,
                        actions=self.past_actions,
                        rewards=self.past_rewards,
                        values=self.past_values,
                        action_values=self.past_action_values,
                        action_log_probs=self.past_action_log_prob,
                        action_distribs=self.past_action_distrib,
                        avg_action_distribs=self.past_avg_action_distrib)

        self.init_history_data_for_online_update()
    def _compute_target_values(self, exp_batch):
        """Compute a batch of target return distributions."""

        batch_next_state = exp_batch['next_state']
        batch_rewards = exp_batch['reward']
        batch_terminal = exp_batch['is_state_terminal']

        with chainer.using_config('train', False), state_kept(self.q_function):
            next_qout = self.q_function(batch_next_state)

        target_next_qout = self.target_q_function(batch_next_state)

        next_q_max = target_next_qout.evaluate_actions(
            next_qout.greedy_actions)

        batch_size = batch_rewards.shape[0]
        z_values = target_next_qout.z_values
        n_atoms = z_values.size

        # next_q_max: (batch_size, n_atoms)
        next_q_max = target_next_qout.max_as_distribution.array
        assert next_q_max.shape == (batch_size, n_atoms), next_q_max.shape

        # Tz: (batch_size, n_atoms)
        Tz = (batch_rewards[..., None] + (1.0 - batch_terminal[..., None]) *
              self.xp.expand_dims(exp_batch['discount'], 1) * z_values[None])
        return _apply_categorical_projection(Tz, next_q_max, z_values)
예제 #6
0
    def _compute_y_and_t(self, exp_batch, gamma):

        batch_state = exp_batch['state']
        batch_size = len(exp_batch['reward'])

        qout = self.q_function(batch_state, test=False)

        batch_actions = exp_batch['action']
        batch_q = qout.evaluate_actions(batch_actions)

        # Compute target values

        with chainer.no_backprop_mode():
            target_qout = self.target_q_function(batch_state, test=True)

            batch_next_state = exp_batch['next_state']

            with state_kept(self.q_function):
                next_qout = self.q_function(batch_next_state, test=False)

            with state_kept(self.target_q_function):
                target_next_qout = self.target_q_function(batch_next_state,
                                                          test=True)
            next_q_max = F.reshape(
                target_next_qout.evaluate_actions(next_qout.greedy_actions),
                (batch_size, ))

            batch_rewards = exp_batch['reward']
            batch_terminal = exp_batch['is_state_terminal']

            # T Q: Bellman operator
            t_q = batch_rewards + self.gamma * \
                (1.0 - batch_terminal) * next_q_max

            # T_PAL Q: persistent advantage learning operator
            cur_advantage = F.reshape(
                target_qout.compute_advantage(batch_actions), (batch_size, ))
            next_advantage = F.reshape(
                target_next_qout.compute_advantage(batch_actions),
                (batch_size, ))
            tpal_q = t_q + self.alpha * \
                F.maximum(cur_advantage, next_advantage)

        return batch_q, tpal_q
예제 #7
0
파일: ddpg.py 프로젝트: ypxie/chainerrl
    def compute_critic_loss(self, batch):
        """Compute loss for critic.

        Preconditions:
          target_q_function must have seen up to s_t and a_t.
          target_policy must have seen up to s_t.
          q_function must have seen up to s_{t-1}.
        Postconditions:
          target_q_function must have seen up to s_{t+1} and a_{t+1}.
          target_policy must have seen up to s_{t+1}.
          q_function must have seen up to s_t.
        """

        batch_next_state = batch['next_state']
        batch_rewards = batch['reward']
        batch_terminal = batch['is_state_terminal']
        batch_state = batch['state']
        batch_actions = batch['action']
        batch_next_actions = batch['next_action']
        batchsize = len(batch_rewards)

        with chainer.no_backprop_mode():
            # Target policy observes s_{t+1}
            next_actions = self.target_policy(batch_next_state,
                                              test=True).sample()

            # Q(s_{t+1}, mu(a_{t+1})) is evaluated.
            # This should not affect the internal state of Q.
            with state_kept(self.target_q_function):
                next_q = self.target_q_function(batch_next_state,
                                                next_actions,
                                                test=True)

            # Target Q-function observes s_{t+1} and a_{t+1}
            if isinstance(self.target_q_function, Recurrent):
                self.target_q_function.update_state(batch_next_state,
                                                    batch_next_actions,
                                                    test=True)

            target_q = batch_rewards + self.gamma * \
                (1.0 - batch_terminal) * F.reshape(next_q, (batchsize,))

        # Estimated Q-function observes s_t and a_t
        predict_q = F.reshape(
            self.q_function(batch_state, batch_actions, test=False),
            (batchsize, ))

        loss = F.mean_squared_error(target_q, predict_q)

        # Update stats
        self.average_critic_loss *= self.average_loss_decay
        self.average_critic_loss += ((1 - self.average_loss_decay) *
                                     float(loss.data))

        return loss
예제 #8
0
    def _compute_target_values(self, exp_batch, gamma):

        batch_next_state = exp_batch['next_state']

        with state_kept(self.q_function):
            next_qout = self.q_function(batch_next_state, test=True)

        target_next_qout = self.target_q_function(batch_next_state, test=True)

        next_q_max = target_next_qout.evaluate_actions(
            next_qout.greedy_actions)

        batch_rewards = exp_batch['reward']
        batch_terminal = exp_batch['is_state_terminal']

        return batch_rewards + self.gamma * (1.0 - batch_terminal) * next_q_max
예제 #9
0
    def _compute_target_values(self, exp_batch, gamma):

        batch_next_state = exp_batch['next_state']

        with chainer.using_config('train', False), state_kept(self.q_function):
            next_qout = self.q_function(batch_next_state)

        target_next_qout = self.target_q_function(batch_next_state)

        next_q_max = target_next_qout.evaluate_actions(
            next_qout.greedy_actions)

        batch_rewards = exp_batch['reward']
        batch_terminal = exp_batch['is_state_terminal']

        return batch_rewards + self.gamma * (1.0 - batch_terminal) * next_q_max
예제 #10
0
    def update(self, statevar):
        assert self.t_start < self.t

        # Update
        if statevar is None:
            R = 0
        else:
            with state_kept(self.target_q_function):
                R = float(self.target_q_function(statevar).max.data)

        loss = 0
        for i in reversed(range(self.t_start, self.t)):
            R *= self.gamma
            R += self.past_rewards[i]
            q = F.reshape(self.past_action_values[i], (1, 1))
            # Accumulate gradients of Q-function
            loss += F.sum(
                F.huber_loss(q,
                             chainer.Variable(
                                 np.asarray([[R]], dtype=np.float32)),
                             delta=1.0))

        # Do we need to normalize losses by (self.t - self.t_start)?
        # Otherwise, loss scales can be different in case of self.t_max
        # and in case of termination.

        # I'm not sure but if we need to normalize losses...
        # loss /= self.t - self.t_start

        # Compute gradients using thread-specific model
        self.q_function.zerograds()
        loss.backward()
        # Copy the gradients to the globally shared model
        self.shared_q_function.zerograds()
        copy_param.copy_grad(self.shared_q_function, self.q_function)
        # Update the globally shared model
        self.optimizer.update()

        self.sync_parameters()
        if isinstance(self.q_function, Recurrent):
            self.q_function.unchain_backward()

        self.past_action_values = {}
        self.past_states = {}
        self.past_rewards = {}

        self.t_start = self.t
예제 #11
0
    def compute_actor_loss(self, batch):
        """Compute loss for actor.

        Preconditions:
          q_function must have seen up to s_{t-1} and s_{t-1}.
          policy must have seen up to s_{t-1}.
        Postconditions:
          q_function must have seen up to s_t and s_t.
          policy must have seen up to s_t.
        """

        batch_state = batch['state']
        batch_action = batch['action']
        batch_size = len(batch_action)

        # Estimated policy observes s_t
        onpolicy_actions = self.policy(batch_state).sample()

        # Q(s_t, mu(s_t)) is evaluated.
        # This should not affect the internal state of Q.
        with state_kept(self.q_function):
            q = self.q_function(batch_state, onpolicy_actions)

        # Estimated Q-function observes s_t and a_t
        if isinstance(self.q_function, Recurrent):
            self.q_function.update_state(batch_state, batch_action)

        # Avoid the numpy #9165 bug (see also: chainer #2744)
        q = q[:, :]

        # Since we want to maximize Q, loss is negation of Q
        loss = -F.sum(q) / batch_size
        if self.l2_action_penalty:
            loss += self.l2_action_penalty \
                        * F.square(onpolicy_actions) / batch_size

        # Update stats
        self.average_actor_loss *= self.average_loss_decay
        self.average_actor_loss += ((1 - self.average_loss_decay) *
                                    float(loss.array))
        return loss
예제 #12
0
파일: acer.py 프로젝트: uidilr/chainerrl
    def update_on_policy(self, statevar):
        assert self.t_start < self.t

        if not self.disable_online_update:
            if statevar is None:
                R = 0
            else:
                with chainer.no_backprop_mode():
                    with state_kept(self.model):
                        action_distrib, action_value, v = self.model(statevar)
                R = float(v.data)
            self.update(
                t_start=self.t_start, t_stop=self.t, R=R,
                states=self.past_states,
                actions=self.past_actions,
                rewards=self.past_rewards,
                values=self.past_values,
                action_values=self.past_action_values,
                action_distribs=self.past_action_distrib,
                action_distribs_mu=None,
                avg_action_distribs=self.past_avg_action_distrib)

        self.init_history_data_for_online_update()
예제 #13
0
    def update(self, statevar):
        assert self.t_start < self.t

        if statevar is None:
            R = 0
        else:
            with state_kept(self.model):
                _, vout, __ = self.model.pi_and_v(statevar)
#######################
            R = F.cast(vout.data, 'float32')
            #R = float(vout.data)
#######################

        pi_loss = 0
        v_loss = 0
        for i in reversed(range(self.t_start, self.t)):
            R *= self.gamma
            R += self.past_rewards[i]
            if self.use_average_reward:
                R -= self.average_reward
            v = self.past_values[i]
            advantage = R - v
            if self.use_average_reward:
                self.average_reward += self.average_reward_tau * \
                    float(advantage.data)
            # Accumulate gradients of policy
            log_prob = self.past_action_log_prob[i]
            entropy = self.past_action_entropy[i]

            # Log probability is increased proportionally to advantage
##############################
            pi_loss -= log_prob * F.cast(advantage.data, 'float32')
            #pi_loss -= log_prob * float(advantage.data)
##############################
            # Entropy is maximized
            pi_loss -= self.beta * entropy
            # Accumulate gradients of value function
            v_loss += (v - R) ** 2 / 2

        if self.pi_loss_coef != 1.0:
            pi_loss *= self.pi_loss_coef

        if self.v_loss_coef != 1.0:
            v_loss *= self.v_loss_coef

        # Normalize the loss of sequences truncated by terminal states
        if self.keep_loss_scale_same and \
                self.t - self.t_start < self.t_max:
            factor = self.t_max / (self.t - self.t_start)
            pi_loss *= factor
            v_loss *= factor

        if self.normalize_grad_by_t_max:
            pi_loss /= self.t - self.t_start
            v_loss /= self.t - self.t_start

        if self.process_idx == 0:
            logger.debug('pi_loss:%s v_loss:%s', pi_loss.data, v_loss.data)

##########################
        #total_loss = pi_loss + F.reshape(v_loss, pi_loss.data.shape)
        total_loss = F.mean(pi_loss + F.reshape(v_loss, pi_loss.data.shape))
##########################

        # Compute gradients using thread-specific model
        self.model.zerograds()
        total_loss.backward()
        # Copy the gradients to the globally shared model
        self.shared_model.zerograds()
        copy_param.copy_grad(
            target_link=self.shared_model, source_link=self.model)
        # Update the globally shared model
        if self.process_idx == 0:
            norm = sum(np.sum(np.square(param.grad))
                       for param in self.optimizer.target.params())
            logger.debug('grad norm:%s', norm)
        self.optimizer.update()
        if self.process_idx == 0:
            logger.debug('update')

        self.sync_parameters()
        if isinstance(self.model, Recurrent):
            self.model.unchain_backward()

        self.past_action_log_prob = {}
        self.past_action_entropy = {}
        self.past_states = {}
        self.past_rewards = {}
        self.past_values = {}

        self.t_start = self.t