Exemplo n.º 1
0
    def learn_from_batch(self, batch):
        # batch contains a list of episodes to learn from
        network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys()

        # get the values for the current states

        result = self.networks['main'].online_network.predict(batch.states(network_keys))
        current_state_values = result[0]

        self.state_values.add_sample(current_state_values)

        # the targets for the state value estimator
        num_transitions = batch.size
        state_value_head_targets = np.zeros((num_transitions, 1))

        # estimate the advantage function
        action_advantages = np.zeros((num_transitions, 1))

        if self.policy_gradient_rescaler == PolicyGradientRescaler.A_VALUE:
            if batch.game_overs()[-1]:
                R = 0
            else:
                R = self.networks['main'].online_network.predict(last_sample(batch.next_states(network_keys)))[0]

            for i in reversed(range(num_transitions)):
                R = batch.rewards()[i] + self.ap.algorithm.discount * R
                state_value_head_targets[i] = R
                action_advantages[i] = R - current_state_values[i]

        elif self.policy_gradient_rescaler == PolicyGradientRescaler.GAE:
            # get bootstraps
            bootstrapped_value = self.networks['main'].online_network.predict(last_sample(batch.next_states(network_keys)))[0]
            values = np.append(current_state_values, bootstrapped_value)
            if batch.game_overs()[-1]:
                values[-1] = 0

            # get general discounted returns table
            gae_values, state_value_head_targets = self.get_general_advantage_estimation_values(batch.rewards(), values)
            action_advantages = np.vstack(gae_values)
        else:
            screen.warning("WARNING: The requested policy gradient rescaler is not available")

        action_advantages = action_advantages.squeeze(axis=-1)
        actions = batch.actions()
        if not isinstance(self.spaces.action, DiscreteActionSpace) and len(actions.shape) < 2:
            actions = np.expand_dims(actions, -1)

        # train
        result = self.networks['main'].online_network.accumulate_gradients({**batch.states(network_keys),
                                                                            'output_1_0': actions},
                                                                       [state_value_head_targets, action_advantages])

        # logging
        total_loss, losses, unclipped_grads = result[:3]
        self.action_advantages.add_sample(action_advantages)
        self.unclipped_grads.add_sample(unclipped_grads)
        self.value_loss.add_sample(losses[0])
        self.policy_loss.add_sample(losses[1])

        return total_loss, losses, unclipped_grads
Exemplo n.º 2
0
    def learn_from_batch(self, batch):
        # batch contains a list of episodes to learn from
        network_keys = self.ap.network_wrappers[
            'main'].input_embedders_parameters.keys()

        # get the values for the current states
        state_value_head_targets = self.networks[
            'main'].online_network.predict(batch.states(network_keys))

        # the targets for the state value estimator
        if self.ap.algorithm.targets_horizon == '1-Step':
            # 1-Step Q learning
            q_st_plus_1 = self.networks['main'].target_network.predict(
                batch.next_states(network_keys))

            for i in reversed(range(batch.size)):
                state_value_head_targets[i][batch.actions()[i]] = \
                    batch.rewards()[i] \
                    + (1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * np.max(q_st_plus_1[i], 0)

        elif self.ap.algorithm.targets_horizon == 'N-Step':
            # N-Step Q learning
            if batch.game_overs()[-1]:
                R = 0
            else:
                R = np.max(self.networks['main'].target_network.predict(
                    last_sample(batch.next_states(network_keys))))

            for i in reversed(range(batch.size)):
                R = batch.rewards()[i] + self.ap.algorithm.discount * R
                state_value_head_targets[i][batch.actions()[i]] = R

        else:
            assert True, 'The available values for targets_horizon are: 1-Step, N-Step'

        # train
        result = self.networks['main'].online_network.accumulate_gradients(
            batch.states(network_keys), [state_value_head_targets])

        # logging
        total_loss, losses, unclipped_grads = result[:3]
        self.value_loss.add_sample(losses[0])

        return total_loss, losses, unclipped_grads
Exemplo n.º 3
0
    def _learn_from_batch(self, batch):

        fetches = [
            self.networks['main'].online_network.output_heads[1].
            probability_loss, self.networks['main'].online_network.
            output_heads[1].bias_correction_loss,
            self.networks['main'].online_network.output_heads[1].kl_divergence
        ]

        # batch contains a list of transitions to learn from
        network_keys = self.ap.network_wrappers[
            'main'].input_embedders_parameters.keys()

        # get the values for the current states
        Q_values, policy_prob = self.networks['main'].online_network.predict(
            batch.states(network_keys))
        avg_policy_prob = self.networks['main'].target_network.predict(
            batch.states(network_keys))[1]
        current_state_values = np.sum(policy_prob * Q_values, axis=1)

        actions = batch.actions()
        num_transitions = batch.size
        Q_head_targets = Q_values

        Q_i = Q_values[np.arange(num_transitions), actions]

        mu = batch.info('all_action_probabilities')
        rho = policy_prob / (mu + eps)
        rho_i = rho[np.arange(batch.size), actions]

        rho_bar = np.minimum(1.0, rho_i)

        if batch.game_overs()[-1]:
            Qret = 0
        else:
            result = self.networks['main'].online_network.predict(
                last_sample(batch.next_states(network_keys)))
            Qret = np.sum(result[0] * result[1], axis=1)[0]

        for i in reversed(range(num_transitions)):
            Qret = batch.rewards()[i] + self.ap.algorithm.discount * Qret
            Q_head_targets[i, actions[i]] = Qret
            Qret = rho_bar[i] * (Qret - Q_i[i]) + current_state_values[i]

        Q_retrace = Q_head_targets[np.arange(num_transitions), actions]

        # train
        result = self.networks['main'].train_and_sync_networks(
            {
                **batch.states(network_keys), 'output_1_0': actions,
                'output_1_1': rho,
                'output_1_2': rho_i,
                'output_1_3': Q_values,
                'output_1_4': Q_retrace,
                'output_1_5': avg_policy_prob
            }, [Q_head_targets, current_state_values],
            additional_fetches=fetches)

        for network in self.networks.values():
            network.update_target_network(
                self.ap.algorithm.rate_for_copying_weights_to_target)

        # logging
        total_loss, losses, unclipped_grads, fetch_result = result[:4]
        self.q_loss.add_sample(losses[0])
        self.policy_loss.add_sample(losses[1])
        self.probability_loss.add_sample(fetch_result[0])
        self.bias_correction_loss.add_sample(fetch_result[1])
        self.unclipped_grads.add_sample(unclipped_grads)
        self.V_Values.add_sample(current_state_values)
        self.kl_divergence.add_sample(fetch_result[2])

        return total_loss, losses, unclipped_grads