def compute_gradients( self, postprocessed_batch: SampleBatch ) -> Tuple[ModelGradients, Dict[str, TensorType]]: assert self.loss_initialized() # Switch on is_training flag in our batch. postprocessed_batch.set_training(True) builder = _TFRunBuilder(self.get_session(), "compute_gradients") fetches = self._build_compute_gradients(builder, postprocessed_batch) return builder.get(fetches)
def compute_log_likelihoods( self, actions: Union[List[TensorType], TensorType], obs_batch: Union[List[TensorType], TensorType], state_batches: Optional[List[TensorType]] = None, prev_action_batch: Optional[Union[List[TensorType], TensorType]] = None, prev_reward_batch: Optional[Union[List[TensorType], TensorType]] = None, actions_normalized: bool = True, ) -> TensorType: if self._log_likelihood is None: raise ValueError( "Cannot compute log-prob/likelihood w/o a self._log_likelihood op!" ) # Exploration hook before each forward pass. self.exploration.before_compute_actions(explore=False, tf_sess=self.get_session()) builder = _TFRunBuilder(self.get_session(), "compute_log_likelihoods") # Normalize actions if necessary. if actions_normalized is False and self.config["normalize_actions"]: actions = normalize_action(actions, self.action_space_struct) # Feed actions (for which we want logp values) into graph. builder.add_feed_dict({self._action_input: actions}) # Feed observations. builder.add_feed_dict({self._obs_input: obs_batch}) # Internal states. state_batches = state_batches or [] if len(self._state_inputs) != len(state_batches): raise ValueError( "Must pass in RNN state batches for placeholders {}, got {}". format(self._state_inputs, state_batches)) builder.add_feed_dict( {k: v for k, v in zip(self._state_inputs, state_batches)}) if state_batches: builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))}) # Prev-a and r. if self._prev_action_input is not None and prev_action_batch is not None: builder.add_feed_dict({self._prev_action_input: prev_action_batch}) if self._prev_reward_input is not None and prev_reward_batch is not None: builder.add_feed_dict({self._prev_reward_input: prev_reward_batch}) # Fetch the log_likelihoods output and return. fetches = builder.add_fetches([self._log_likelihood]) return builder.get(fetches)[0]
def compute_actions( self, obs_batch: Union[List[TensorType], TensorType], state_batches: Optional[List[TensorType]] = None, prev_action_batch: Union[List[TensorType], TensorType] = None, prev_reward_batch: Union[List[TensorType], TensorType] = None, info_batch: Optional[Dict[str, list]] = None, episodes: Optional[List["Episode"]] = None, explore: Optional[bool] = None, timestep: Optional[int] = None, **kwargs, ): explore = explore if explore is not None else self.config["explore"] timestep = timestep if timestep is not None else self.global_timestep builder = _TFRunBuilder(self.get_session(), "compute_actions") input_dict = {SampleBatch.OBS: obs_batch, "is_training": False} if state_batches: for i, s in enumerate(state_batches): input_dict[f"state_in_{i}"] = s if prev_action_batch is not None: input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch if prev_reward_batch is not None: input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch to_fetch = self._build_compute_actions(builder, input_dict=input_dict, explore=explore, timestep=timestep) # Execute session run to get action (and other fetches). fetched = builder.get(to_fetch) # Update our global timestep by the batch size. self.global_timestep += (len(obs_batch) if isinstance(obs_batch, list) else tree.flatten(obs_batch)[0].shape[0]) return fetched
def learn_on_batch( self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]: assert self.loss_initialized() # Switch on is_training flag in our batch. postprocessed_batch.set_training(True) builder = _TFRunBuilder(self.get_session(), "learn_on_batch") # Callback handling. learn_stats = {} self.callbacks.on_learn_on_batch(policy=self, train_batch=postprocessed_batch, result=learn_stats) fetches = self._build_learn_on_batch(builder, postprocessed_batch) stats = builder.get(fetches) stats.update({ "custom_metrics": learn_stats, NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count, }) return stats
def compute_actions_from_input_dict( self, input_dict: Union[SampleBatch, Dict[str, TensorType]], explore: bool = None, timestep: Optional[int] = None, episodes: Optional[List["Episode"]] = None, **kwargs, ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: explore = explore if explore is not None else self.config["explore"] timestep = timestep if timestep is not None else self.global_timestep # Switch off is_training flag in our batch. if isinstance(input_dict, SampleBatch): input_dict.set_training(False) else: # Deprecated dict input. input_dict["is_training"] = False builder = _TFRunBuilder(self.get_session(), "compute_actions_from_input_dict") obs_batch = input_dict[SampleBatch.OBS] to_fetch = self._build_compute_actions(builder, input_dict=input_dict, explore=explore, timestep=timestep) # Execute session run to get action (and other fetches). fetched = builder.get(to_fetch) # Update our global timestep by the batch size. self.global_timestep += (len(obs_batch) if isinstance( obs_batch, list) else len(input_dict) if isinstance( input_dict, SampleBatch) else obs_batch.shape[0]) return fetched
def apply_gradients(self, gradients: ModelGradients) -> None: assert self.loss_initialized() builder = _TFRunBuilder(self.get_session(), "apply_gradients") fetches = self._build_apply_gradients(builder, gradients) builder.get(fetches)