def learn_from_batch(self, batch): # batch contains a list of episodes to learn from network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys() # get the values for the current states result = self.networks['main'].online_network.predict(batch.states(network_keys)) current_state_values = result[0] self.state_values.add_sample(current_state_values) # the targets for the state value estimator num_transitions = batch.size state_value_head_targets = np.zeros((num_transitions, 1)) # estimate the advantage function action_advantages = np.zeros((num_transitions, 1)) if self.policy_gradient_rescaler == PolicyGradientRescaler.A_VALUE: if batch.game_overs()[-1]: R = 0 else: R = self.networks['main'].online_network.predict(last_sample(batch.next_states(network_keys)))[0] for i in reversed(range(num_transitions)): R = batch.rewards()[i] + self.ap.algorithm.discount * R state_value_head_targets[i] = R action_advantages[i] = R - current_state_values[i] elif self.policy_gradient_rescaler == PolicyGradientRescaler.GAE: # get bootstraps bootstrapped_value = self.networks['main'].online_network.predict(last_sample(batch.next_states(network_keys)))[0] values = np.append(current_state_values, bootstrapped_value) if batch.game_overs()[-1]: values[-1] = 0 # get general discounted returns table gae_values, state_value_head_targets = self.get_general_advantage_estimation_values(batch.rewards(), values) action_advantages = np.vstack(gae_values) else: screen.warning("WARNING: The requested policy gradient rescaler is not available") action_advantages = action_advantages.squeeze(axis=-1) actions = batch.actions() if not isinstance(self.spaces.action, DiscreteActionSpace) and len(actions.shape) < 2: actions = np.expand_dims(actions, -1) # train result = self.networks['main'].online_network.accumulate_gradients({**batch.states(network_keys), 'output_1_0': actions}, [state_value_head_targets, action_advantages]) # logging total_loss, losses, unclipped_grads = result[:3] self.action_advantages.add_sample(action_advantages) self.unclipped_grads.add_sample(unclipped_grads) self.value_loss.add_sample(losses[0]) self.policy_loss.add_sample(losses[1]) return total_loss, losses, unclipped_grads
def learn_from_batch(self, batch): # batch contains a list of episodes to learn from network_keys = self.ap.network_wrappers[ 'main'].input_embedders_parameters.keys() # get the values for the current states state_value_head_targets = self.networks[ 'main'].online_network.predict(batch.states(network_keys)) # the targets for the state value estimator if self.ap.algorithm.targets_horizon == '1-Step': # 1-Step Q learning q_st_plus_1 = self.networks['main'].target_network.predict( batch.next_states(network_keys)) for i in reversed(range(batch.size)): state_value_head_targets[i][batch.actions()[i]] = \ batch.rewards()[i] \ + (1.0 - batch.game_overs()[i]) * self.ap.algorithm.discount * np.max(q_st_plus_1[i], 0) elif self.ap.algorithm.targets_horizon == 'N-Step': # N-Step Q learning if batch.game_overs()[-1]: R = 0 else: R = np.max(self.networks['main'].target_network.predict( last_sample(batch.next_states(network_keys)))) for i in reversed(range(batch.size)): R = batch.rewards()[i] + self.ap.algorithm.discount * R state_value_head_targets[i][batch.actions()[i]] = R else: assert True, 'The available values for targets_horizon are: 1-Step, N-Step' # train result = self.networks['main'].online_network.accumulate_gradients( batch.states(network_keys), [state_value_head_targets]) # logging total_loss, losses, unclipped_grads = result[:3] self.value_loss.add_sample(losses[0]) return total_loss, losses, unclipped_grads
def _learn_from_batch(self, batch): fetches = [ self.networks['main'].online_network.output_heads[1]. probability_loss, self.networks['main'].online_network. output_heads[1].bias_correction_loss, self.networks['main'].online_network.output_heads[1].kl_divergence ] # batch contains a list of transitions to learn from network_keys = self.ap.network_wrappers[ 'main'].input_embedders_parameters.keys() # get the values for the current states Q_values, policy_prob = self.networks['main'].online_network.predict( batch.states(network_keys)) avg_policy_prob = self.networks['main'].target_network.predict( batch.states(network_keys))[1] current_state_values = np.sum(policy_prob * Q_values, axis=1) actions = batch.actions() num_transitions = batch.size Q_head_targets = Q_values Q_i = Q_values[np.arange(num_transitions), actions] mu = batch.info('all_action_probabilities') rho = policy_prob / (mu + eps) rho_i = rho[np.arange(batch.size), actions] rho_bar = np.minimum(1.0, rho_i) if batch.game_overs()[-1]: Qret = 0 else: result = self.networks['main'].online_network.predict( last_sample(batch.next_states(network_keys))) Qret = np.sum(result[0] * result[1], axis=1)[0] for i in reversed(range(num_transitions)): Qret = batch.rewards()[i] + self.ap.algorithm.discount * Qret Q_head_targets[i, actions[i]] = Qret Qret = rho_bar[i] * (Qret - Q_i[i]) + current_state_values[i] Q_retrace = Q_head_targets[np.arange(num_transitions), actions] # train result = self.networks['main'].train_and_sync_networks( { **batch.states(network_keys), 'output_1_0': actions, 'output_1_1': rho, 'output_1_2': rho_i, 'output_1_3': Q_values, 'output_1_4': Q_retrace, 'output_1_5': avg_policy_prob }, [Q_head_targets, current_state_values], additional_fetches=fetches) for network in self.networks.values(): network.update_target_network( self.ap.algorithm.rate_for_copying_weights_to_target) # logging total_loss, losses, unclipped_grads, fetch_result = result[:4] self.q_loss.add_sample(losses[0]) self.policy_loss.add_sample(losses[1]) self.probability_loss.add_sample(fetch_result[0]) self.bias_correction_loss.add_sample(fetch_result[1]) self.unclipped_grads.add_sample(unclipped_grads) self.V_Values.add_sample(current_state_values) self.kl_divergence.add_sample(fetch_result[2]) return total_loss, losses, unclipped_grads