예제 #1
0
    def train_value_network(self, dataset, epochs):
        loss = []
        batch = Batch(dataset)
        network_keys = self.ap.network_wrappers[
            'critic'].input_embedders_parameters.keys()

        # * Found not to have any impact *
        # add a timestep to the observation
        # current_states_with_timestep = self.concat_state_and_timestep(dataset)

        mix_fraction = self.ap.algorithm.value_targets_mix_fraction
        for j in range(epochs):
            curr_batch_size = batch.size
            if self.networks['critic'].online_network.optimizer_type != 'LBFGS':
                curr_batch_size = self.ap.network_wrappers['critic'].batch_size
            for i in range(batch.size // curr_batch_size):
                # split to batches for first order optimization techniques
                current_states_batch = {
                    k: v[i * curr_batch_size:(i + 1) * curr_batch_size]
                    for k, v in batch.states(network_keys).items()
                }
                total_return_batch = batch.total_returns(
                    True)[i * curr_batch_size:(i + 1) * curr_batch_size]
                old_policy_values = force_list(
                    self.networks['critic'].target_network.predict(
                        current_states_batch).squeeze())
                if self.networks[
                        'critic'].online_network.optimizer_type != 'LBFGS':
                    targets = total_return_batch
                else:
                    current_values = self.networks[
                        'critic'].online_network.predict(current_states_batch)
                    targets = current_values * (
                        1 - mix_fraction) + total_return_batch * mix_fraction

                inputs = copy.copy(current_states_batch)
                for input_index, input in enumerate(old_policy_values):
                    name = 'output_0_{}'.format(input_index)
                    if name in self.networks['critic'].online_network.inputs:
                        inputs[name] = input

                value_loss = self.networks[
                    'critic'].online_network.accumulate_gradients(
                        inputs, targets)

                self.networks['critic'].apply_gradients_to_online_network()
                if isinstance(self.ap.task_parameters,
                              DistributedTaskParameters):
                    self.networks['critic'].apply_gradients_to_global_network()
                self.networks[
                    'critic'].online_network.reset_accumulated_gradients()

                loss.append([value_loss[0]])
        loss = np.mean(loss, 0)
        return loss
예제 #2
0
    def fill_advantages(self, batch):
        batch = Batch(batch)
        network_keys = self.ap.network_wrappers[
            'critic'].input_embedders_parameters.keys()

        # * Found not to have any impact *
        # current_states_with_timestep = self.concat_state_and_timestep(batch)

        current_state_values = self.networks['critic'].online_network.predict(
            batch.states(network_keys)).squeeze()

        # calculate advantages
        advantages = []
        if self.policy_gradient_rescaler == PolicyGradientRescaler.A_VALUE:
            advantages = batch.total_returns() - current_state_values
        elif self.policy_gradient_rescaler == PolicyGradientRescaler.GAE:
            # get bootstraps
            episode_start_idx = 0
            advantages = np.array([])
            # current_state_values[batch.game_overs()] = 0
            for idx, game_over in enumerate(batch.game_overs()):
                if game_over:
                    # get advantages for the rollout
                    value_bootstrapping = np.zeros((1, ))
                    rollout_state_values = np.append(
                        current_state_values[episode_start_idx:idx + 1],
                        value_bootstrapping)

                    rollout_advantages, _ = \
                        self.get_general_advantage_estimation_values(batch.rewards()[episode_start_idx:idx+1],
                                                                     rollout_state_values)
                    episode_start_idx = idx + 1
                    advantages = np.append(advantages, rollout_advantages)
        else:
            screen.warning(
                "WARNING: The requested policy gradient rescaler is not available"
            )

        # standardize
        advantages = (advantages - np.mean(advantages)) / np.std(advantages)

        # TODO: this will be problematic with a shared memory
        for transition, advantage in zip(self.memory.transitions, advantages):
            transition.info['advantage'] = advantage

        self.action_advantages.add_sample(advantages)