Example #1
0
    def update_from_episodes(self, episodes, errors_out=None):
        has_weights = isinstance(episodes, tuple)
        if has_weights:
            episodes, weights = episodes
            if errors_out is None:
                errors_out = []
        if errors_out is None:
            errors_out_step = None
        else:
            del errors_out[:]
            for _ in episodes:
                errors_out.append(0.0)
            errors_out_step = []

        with state_reset(self.model), state_reset(self.target_model):
            loss = 0
            tmp = list(reversed(sorted(
                enumerate(episodes), key=lambda x: len(x[1]))))
            sorted_episodes = [elem[1] for elem in tmp]
            indices = [elem[0] for elem in tmp]  # argsort
            max_epi_len = len(sorted_episodes[0])
            for i in range(max_epi_len):
                transitions = []
                weights_step = []
                for ep, index in zip(sorted_episodes, indices):
                    if len(ep) <= i:
                        break
                    transitions.append([ep[i]])
                    if has_weights:
                        weights_step.append(weights[index])
                batch = batch_experiences(
                    transitions,
                    xp=self.xp,
                    phi=self.phi,
                    gamma=self.gamma,
                    batch_states=self.batch_states)
                assert len(batch['state']) == len(transitions)
                if i == 0:
                    self.input_initial_batch_to_target_model(batch)
                if has_weights:
                    batch['weights'] = self.xp.asarray(
                        weights_step, dtype=self.xp.float32)
                loss += self._compute_loss(batch,
                                           errors_out=errors_out_step)
                if errors_out is not None:
                    for err, index in zip(errors_out_step, indices):
                        errors_out[index] += err
            loss /= max_epi_len

            # Update stats
            self.average_loss *= self.average_loss_decay
            self.average_loss += \
                (1 - self.average_loss_decay) * float(loss.array)

            self.model.cleargrads()
            loss.backward()
            self.optimizer.update()
        if has_weights:
            self.replay_buffer.update_errors(errors_out)
Example #2
0
    def update_from_replay(self):

        if self.replay_buffer is None:
            return

        if len(self.replay_buffer) < self.replay_start_size:
            return

        episode = self.replay_buffer.sample_episodes(1, self.t_max)[0]

        with state_reset(self.model):
            with state_reset(self.shared_average_model):
                rewards = {}
                states = {}
                actions = {}
                action_distribs = {}
                action_distribs_mu = {}
                avg_action_distribs = {}
                action_values = {}
                values = {}
                for t, transition in enumerate(episode):
                    s = self.phi(transition['state'])
                    a = transition['action']
                    bs = np.expand_dims(s, 0)
                    action_distrib, action_value, v = self.model(bs)
                    with chainer.no_backprop_mode():
                        avg_action_distrib, _, _ = \
                            self.shared_average_model(bs)
                    states[t] = s
                    actions[t] = a
                    values[t] = v
                    action_distribs[t] = action_distrib
                    avg_action_distribs[t] = avg_action_distrib
                    rewards[t] = transition['reward']
                    action_distribs_mu[t] = transition['mu']
                    action_values[t] = action_value
                last_transition = episode[-1]
                if last_transition['is_state_terminal']:
                    R = 0
                else:
                    with chainer.no_backprop_mode():
                        last_s = last_transition['next_state']
                        action_distrib, action_value, last_v = self.model(
                            np.expand_dims(self.phi(last_s), 0))
                    R = float(last_v.array)
                return self.update(R=R,
                                   t_start=0,
                                   t_stop=len(episode),
                                   states=states,
                                   rewards=rewards,
                                   actions=actions,
                                   values=values,
                                   action_distribs=action_distribs,
                                   action_distribs_mu=action_distribs_mu,
                                   avg_action_distribs=avg_action_distribs,
                                   action_values=action_values)
Example #3
0
 def update_from_episodes(self, episodes, errors_out=None):
     with state_reset(self.model):
         with state_reset(self.target_model):
             loss = 0
             sorted_episodes = list(reversed(sorted(episodes, key=len)))
             max_epi_len = len(sorted_episodes[0])
             for i in range(max_epi_len):
                 transitions = []
                 for ep in sorted_episodes:
                     if len(ep) <= i:
                         break
                     transitions.append(ep[i])
                 batch = batch_experiences(transitions,
                                           xp=self.xp,
                                           phi=self.phi,
                                           batch_states=self.batch_states)
                 if i == 0:
                     self.input_initial_batch_to_target_model(batch)
                 loss += self._compute_loss(batch, self.gamma)
             loss /= max_epi_len
             self.optimizer.zero_grads()
             loss.backward()
             self.optimizer.update()
Example #4
0
    def update_from_replay(self):

        if self.replay_buffer is None:
            return

        if len(self.replay_buffer) < self.replay_start_size:
            return

        if self.replay_buffer.n_episodes < self.batchsize:
            return

        if self.process_idx == 0:
            self.logger.debug('update_from_replay')

        episodes = self.replay_buffer.sample_episodes(self.batchsize,
                                                      max_len=self.t_max)
        if isinstance(episodes, tuple):
            # Prioritized replay
            episodes, weights = episodes
        else:
            weights = [1] * len(episodes)
        sorted_episodes = list(reversed(sorted(episodes, key=len)))
        max_epi_len = len(sorted_episodes[0])

        with state_reset(self.model):
            # Batch computation of multiple episodes
            rewards = {}
            values = {}
            next_values = {}
            log_probs = {}
            next_action_distrib = None
            next_v = None
            for t in range(max_epi_len):
                transitions = []
                for ep in sorted_episodes:
                    if len(ep) <= t:
                        break
                    transitions.append([ep[t]])
                batch = batch_experiences(transitions,
                                          xp=self.xp,
                                          phi=self.phi,
                                          gamma=self.gamma,
                                          batch_states=self.batch_states)
                batchsize = batch['action'].shape[0]
                if next_action_distrib is not None:
                    action_distrib = next_action_distrib[0:batchsize]
                    v = next_v[0:batchsize]
                else:
                    action_distrib, v = self.model(batch['state'])
                next_action_distrib, next_v = self.model(batch['next_state'])
                values[t] = v
                next_values[t] = next_v * \
                    (1 - batch['is_state_terminal'].reshape(next_v.shape))
                rewards[t] = chainer.cuda.to_cpu(batch['reward'])
                log_probs[t] = action_distrib.log_prob(batch['action'])
            # Loss is computed one by one episode
            losses = []
            for i, ep in enumerate(sorted_episodes):
                e_values = {}
                e_next_values = {}
                e_rewards = {}
                e_log_probs = {}
                for t in range(len(ep)):
                    assert values[t].shape[0] > i
                    assert next_values[t].shape[0] > i
                    assert rewards[t].shape[0] > i
                    assert log_probs[t].shape[0] > i
                    e_values[t] = values[t][i:i + 1]
                    e_next_values[t] = next_values[t][i:i + 1]
                    e_rewards[t] = float(rewards[t][i:i + 1])
                    e_log_probs[t] = log_probs[t][i:i + 1]
                losses.append(
                    self.compute_loss(t_start=0,
                                      t_stop=len(ep),
                                      rewards=e_rewards,
                                      values=e_values,
                                      next_values=e_next_values,
                                      log_probs=e_log_probs))
            loss = chainerrl.functions.weighted_sum_arrays(
                losses, weights) / self.batchsize
            self.update(loss)
Example #5
0
    def update_from_replay(self):

        if len(self.replay_buffer) < self.replay_start_size:
            return

        episode = self.replay_buffer.sample_episodes(1, self.t_max)[0]

        with state_reset(self.model):
            with state_reset(self.shared_average_model):
                rewards = {}
                states = {}
                actions = {}
                action_log_probs = {}
                action_distribs = {}
                avg_action_distribs = {}
                rho = {}
                rho_all = {}
                action_values = {}
                values = {}
                for t, transition in enumerate(episode):
                    s = self.phi(transition['state'])
                    a = transition['action']
                    ba = np.expand_dims(a, 0)
                    bs = np.expand_dims(s, 0)
                    action_distrib, action_value = self.model(bs)
                    v = compute_state_value_as_expected_action_value(
                        action_value, action_distrib)
                    with chainer.no_backprop_mode():
                        avg_action_distrib, _ = self.shared_average_model(bs)
                    states[t] = s
                    actions[t] = a
                    action_log_probs[t] = action_distrib.log_prob(ba)
                    values[t] = v
                    action_distribs[t] = action_distrib
                    avg_action_distribs[t] = avg_action_distrib
                    rewards[t] = transition['reward']
                    mu = transition['mu']
                    action_values[t] = action_value
                    rho[t] = (action_distrib.prob(ba).data /
                              (mu.prob(ba).data + self.eps_division))
                    rho_all[t] = (action_distrib.all_prob.data /
                                  (mu.all_prob.data + self.eps_division))
                last_transition = episode[-1]
                if last_transition['is_state_terminal']:
                    R = 0
                else:
                    with chainer.no_backprop_mode():
                        last_s = last_transition['next_state']
                        action_distrib, action_value = self.model(
                            np.expand_dims(self.phi(last_s), 0))
                        last_v = compute_state_value_as_expected_action_value(
                            action_value, action_distrib)
                    R = last_v
                return self.update(R=R,
                                   t_start=0,
                                   t_stop=len(episode),
                                   states=states,
                                   rewards=rewards,
                                   actions=actions,
                                   values=values,
                                   action_log_probs=action_log_probs,
                                   action_distribs=action_distribs,
                                   avg_action_distribs=avg_action_distribs,
                                   rho=rho,
                                   rho_all=rho_all,
                                   action_values=action_values)