def train_value_network(self, dataset, epochs): loss = [] batch = Batch(dataset) network_keys = self.ap.network_wrappers[ 'critic'].input_embedders_parameters.keys() # * Found not to have any impact * # add a timestep to the observation # current_states_with_timestep = self.concat_state_and_timestep(dataset) mix_fraction = self.ap.algorithm.value_targets_mix_fraction for j in range(epochs): curr_batch_size = batch.size if self.networks['critic'].online_network.optimizer_type != 'LBFGS': curr_batch_size = self.ap.network_wrappers['critic'].batch_size for i in range(batch.size // curr_batch_size): # split to batches for first order optimization techniques current_states_batch = { k: v[i * curr_batch_size:(i + 1) * curr_batch_size] for k, v in batch.states(network_keys).items() } total_return_batch = batch.total_returns( True)[i * curr_batch_size:(i + 1) * curr_batch_size] old_policy_values = force_list( self.networks['critic'].target_network.predict( current_states_batch).squeeze()) if self.networks[ 'critic'].online_network.optimizer_type != 'LBFGS': targets = total_return_batch else: current_values = self.networks[ 'critic'].online_network.predict(current_states_batch) targets = current_values * ( 1 - mix_fraction) + total_return_batch * mix_fraction inputs = copy.copy(current_states_batch) for input_index, input in enumerate(old_policy_values): name = 'output_0_{}'.format(input_index) if name in self.networks['critic'].online_network.inputs: inputs[name] = input value_loss = self.networks[ 'critic'].online_network.accumulate_gradients( inputs, targets) self.networks['critic'].apply_gradients_to_online_network() if isinstance(self.ap.task_parameters, DistributedTaskParameters): self.networks['critic'].apply_gradients_to_global_network() self.networks[ 'critic'].online_network.reset_accumulated_gradients() loss.append([value_loss[0]]) loss = np.mean(loss, 0) return loss
def fill_advantages(self, batch): batch = Batch(batch) network_keys = self.ap.network_wrappers[ 'critic'].input_embedders_parameters.keys() # * Found not to have any impact * # current_states_with_timestep = self.concat_state_and_timestep(batch) current_state_values = self.networks['critic'].online_network.predict( batch.states(network_keys)).squeeze() # calculate advantages advantages = [] if self.policy_gradient_rescaler == PolicyGradientRescaler.A_VALUE: advantages = batch.total_returns() - current_state_values elif self.policy_gradient_rescaler == PolicyGradientRescaler.GAE: # get bootstraps episode_start_idx = 0 advantages = np.array([]) # current_state_values[batch.game_overs()] = 0 for idx, game_over in enumerate(batch.game_overs()): if game_over: # get advantages for the rollout value_bootstrapping = np.zeros((1, )) rollout_state_values = np.append( current_state_values[episode_start_idx:idx + 1], value_bootstrapping) rollout_advantages, _ = \ self.get_general_advantage_estimation_values(batch.rewards()[episode_start_idx:idx+1], rollout_state_values) episode_start_idx = idx + 1 advantages = np.append(advantages, rollout_advantages) else: screen.warning( "WARNING: The requested policy gradient rescaler is not available" ) # standardize advantages = (advantages - np.mean(advantages)) / np.std(advantages) # TODO: this will be problematic with a shared memory for transition, advantage in zip(self.memory.transitions, advantages): transition.info['advantage'] = advantage self.action_advantages.add_sample(advantages)