def test_dict_properties_of_sample_batches(self): base_dict = { "a": np.array([1, 2, 3]), "b": np.array([[0.1, 0.2], [0.3, 0.4]]), "c": True, } batch = SampleBatch(base_dict) try: SampleBatch(base_dict) except AssertionError: pass # expected keys_ = list(base_dict.keys()) values_ = list(base_dict.values()) items_ = list(base_dict.items()) assert list(batch.keys()) == keys_ assert list(batch.values()) == values_ assert list(batch.items()) == items_ # Add an item and check, whether it's in the "added" list. batch["d"] = np.array(1) assert batch.added_keys == {"d"}, batch.added_keys # Access two keys and check, whether they are in the # "accessed" list. print(batch["a"], batch["b"]) assert batch.accessed_keys == {"a", "b"}, batch.accessed_keys # Delete a key and check, whether it's in the "deleted" list. del batch["c"] assert batch.deleted_keys == {"c"}, batch.deleted_keys
def add_postprocessed_batch_for_training( self, batch: SampleBatch, view_requirements: ViewRequirementsDict) -> None: """Adds a postprocessed SampleBatch (single agent) to our buffers. Args: batch (SampleBatch): An individual agent's (one trajectory) SampleBatch to be added to the Policy's buffers. view_requirements (ViewRequirementsDict): The view requirements for the policy. This is so we know, whether a view-column needs to be copied at all (not needed for training). """ for view_col, data in batch.items(): # 1) If col is not in view_requirements, we must have a direct # child of the base Policy that doesn't do auto-view req creation. # 2) Col is in view-reqs and needed for training. view_req = view_requirements.get(view_col) if view_req is None or view_req.used_for_training: self.buffers[view_col].extend(data) # Add the agent's trajectory length to our count. self.agent_steps += batch.count # Adjust the seq-lens array depending on the incoming agent sequences. if self.seq_lens is not None: max_seq_len = self.policy.config["model"]["max_seq_len"] count = batch.count while count > 0: self.seq_lens.append(min(count, max_seq_len)) count -= max_seq_len
def call( self, input_dict: SampleBatch ) -> (TensorType, List[TensorType], Dict[str, TensorType]): assert input_dict[SampleBatch.SEQ_LENS] is not None # Push obs through "unwrapped" net's `forward()` first. wrapped_out, _, _ = self.wrapped_keras_model(input_dict) # Concat. prev-action/reward if required. prev_a_r = [] if self.use_n_prev_actions: if isinstance(self.action_space, Discrete): for i in range(self.use_n_prev_actions): prev_a_r.append( one_hot( input_dict[SampleBatch.PREV_ACTIONS][:, i], self.action_space, )) elif isinstance(self.action_space, MultiDiscrete): for i in range(0, self.use_n_prev_actions, self.action_space.shape[0]): prev_a_r.append( one_hot( tf.cast( input_dict[SampleBatch.PREV_ACTIONS] [:, i:i + self.action_space.shape[0]], tf.float32, ), self.action_space, )) else: prev_a_r.append( tf.reshape( tf.cast(input_dict[SampleBatch.PREV_ACTIONS], tf.float32), [-1, self.use_n_prev_actions * self.action_dim], )) if self.use_n_prev_rewards: prev_a_r.append( tf.reshape( tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32), [-1, self.use_n_prev_rewards], )) if prev_a_r: wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1) memory_ins = [ s for k, s in input_dict.items() if k.startswith("state_in_") ] model_out, memory_outs, value_outs = self.base_model([wrapped_out] + memory_ins) return ( model_out, memory_outs, { SampleBatch.VF_PREDS: tf.reshape(value_outs, [-1]) }, )
def compute_actions_from_input_dict( self, input_dict: SampleBatch, explore: bool = None, timestep: Optional[int] = None, episodes: Optional[List["MultiAgentEpisode"]] = None, **kwargs) -> \ Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: """Computes actions from collected samples (across multiple-agents). Uses the currently "forward-pass-registered" samples from the collector to construct the input_dict for the Model. Args: input_dict (SampleBatch): A SampleBatch containing the Tensors to compute actions. `input_dict` already abides to the Policy's as well as the Model's view requirements and can thus be passed to the Model as-is. explore (bool): Whether to pick an exploitation or exploration action (default: None -> use self.config["explore"]). timestep (Optional[int]): The current (sampling) time step. kwargs: forward compatibility placeholder Returns: Tuple: actions (TensorType): Batch of output actions, with shape like [BATCH_SIZE, ACTION_SHAPE]. state_outs (List[TensorType]): List of RNN state output batches, if any, each with shape [BATCH_SIZE, STATE_SIZE]. info (dict): Dictionary of extra feature batches, if any, with shape like {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}. """ # Default implementation just passes obs, prev-a/r, and states on to # `self.compute_actions()`. state_batches = [ s for k, s in input_dict.items() if k[:9] == "state_in_" ] return self.compute_actions( input_dict[SampleBatch.OBS], state_batches, prev_action_batch=input_dict.get(SampleBatch.PREV_ACTIONS), prev_reward_batch=input_dict.get(SampleBatch.PREV_REWARDS), info_batch=input_dict.get(SampleBatch.INFOS), explore=explore, timestep=timestep, episodes=episodes, **kwargs, )
def call(self, input_dict: SampleBatch) -> \ (TensorType, List[TensorType], Dict[str, TensorType]): obs = input_dict["obs"] if self.data_format == "channels_first": obs = tf.transpose(obs, [0, 2, 3, 1]) # Explicit cast to float32 needed in eager. model_out, self._value_out = self.base_model(tf.cast(obs, tf.float32)) state = [v for k, v in input_dict.items() if k.startswith("state_in_")] extra_outs = {SampleBatch.VF_PREDS: tf.reshape(self._value_out, [-1])} # Our last layer is already flat. if self.last_layer_is_flattened: return model_out, state, extra_outs # Last layer is a n x [1,1] Conv2D -> Flatten. else: return tf.squeeze(model_out, axis=[1, 2]), state, extra_outs
def add_postprocessed_batch_for_training( self, batch: SampleBatch, view_requirements: Dict[str, ViewRequirement]) -> None: """Adds a postprocessed SampleBatch (single agent) to our buffers. Args: batch (SampleBatch): A single agent (one trajectory) SampleBatch to be added to the Policy's buffers. view_requirements (Dict[str, ViewRequirement]: The view requirements for the policy. This is so we know, whether a view-column needs to be copied at all (not needed for training). """ for view_col, data in batch.items(): # Skip columns that are not used for training. if view_col not in view_requirements or \ not view_requirements[view_col].used_for_training: continue self.buffers[view_col].extend(data) # Add the agent's trajectory length to our count. self.count += batch.count
def add_postprocessed_batch_for_training( self, batch: SampleBatch, view_requirements: ViewRequirementsDict) -> None: """Adds a postprocessed SampleBatch (single agent) to our buffers. Args: batch (SampleBatch): An individual agent's (one trajectory) SampleBatch to be added to the Policy's buffers. view_requirements (ViewRequirementsDict): The view requirements for the policy. This is so we know, whether a view-column needs to be copied at all (not needed for training). """ for view_col, data in batch.items(): # 1) If col is not in view_requirements, we must have a direct # child of the base Policy that doesn't do auto-view req creation. # 2) Col is in view-reqs and needed for training. if view_col not in view_requirements or \ view_requirements[view_col].used_for_training: self.buffers[view_col].extend(data) # Add the agent's trajectory length to our count. self.agent_steps += batch.count
def pad_batch_to_sequences_of_same_size( batch: SampleBatch, max_seq_len: int, shuffle: bool = False, batch_divisibility_req: int = 1, feature_keys: Optional[List[str]] = None, view_requirements: Optional[ViewRequirementsDict] = None, ): """Applies padding to `batch` so it's choppable into same-size sequences. Shuffles `batch` (if desired), makes sure divisibility requirement is met, then pads the batch ([B, ...]) into same-size chunks ([B, ...]) w/o adding a time dimension (yet). Padding depends on episodes found in batch and `max_seq_len`. Args: batch (SampleBatch): The SampleBatch object. All values in here have the shape [B, ...]. max_seq_len (int): The max. sequence length to use for chopping. shuffle (bool): Whether to shuffle batch sequences. Shuffle may be done in-place. This only makes sense if you're further applying minibatch SGD after getting the outputs. batch_divisibility_req (int): The int by which the batch dimension must be dividable. feature_keys (Optional[List[str]]): An optional list of keys to apply sequence-chopping to. If None, use all keys in batch that are not "state_in/out_"-type keys. view_requirements (Optional[ViewRequirementsDict]): An optional Policy ViewRequirements dict to be able to infer whether e.g. dynamic max'ing should be applied over the seq_lens. """ if batch_divisibility_req > 1: meets_divisibility_reqs = ( len(batch[SampleBatch.CUR_OBS]) % batch_divisibility_req == 0 # not multiagent and max(batch[SampleBatch.AGENT_INDEX]) == 0) else: meets_divisibility_reqs = True states_already_reduced_to_init = False # RNN/attention net case. Figure out whether we should apply dynamic # max'ing over the list of sequence lengths. if "state_in_0" in batch or "state_out_0" in batch: # Check, whether the state inputs have already been reduced to their # init values at the beginning of each max_seq_len chunk. if batch.seq_lens is not None and \ len(batch["state_in_0"]) == len(batch.seq_lens): states_already_reduced_to_init = True # RNN (or single timestep state-in): Set the max dynamically. if view_requirements["state_in_0"].shift_from is None: dynamic_max = True # Attention Nets (state inputs are over some range): No dynamic maxing # possible. else: dynamic_max = False # Multi-agent case. elif not meets_divisibility_reqs: max_seq_len = batch_divisibility_req dynamic_max = False # Simple case: No RNN/attention net, nor do we need to pad. else: if shuffle: batch.shuffle() return # RNN, attention net, or multi-agent case. state_keys = [] feature_keys_ = feature_keys or [] for k, v in batch.items(): if k.startswith("state_in_"): state_keys.append(k) elif not feature_keys and not k.startswith("state_out_") and \ k not in ["infos", "seq_lens"] and isinstance(v, np.ndarray): feature_keys_.append(k) feature_sequences, initial_states, seq_lens = \ chop_into_sequences( feature_columns=[batch[k] for k in feature_keys_], state_columns=[batch[k] for k in state_keys], episode_ids=batch.get(SampleBatch.EPS_ID), unroll_ids=batch.get(SampleBatch.UNROLL_ID), agent_indices=batch.get(SampleBatch.AGENT_INDEX), seq_lens=getattr(batch, "seq_lens", batch.get("seq_lens")), max_seq_len=max_seq_len, dynamic_max=dynamic_max, states_already_reduced_to_init=states_already_reduced_to_init, shuffle=shuffle) for i, k in enumerate(feature_keys_): batch[k] = feature_sequences[i] for i, k in enumerate(state_keys): batch[k] = initial_states[i] batch["seq_lens"] = np.array(seq_lens) if log_once("rnn_ma_feed_dict"): logger.info("Padded input for RNN/Attn.Nets/MA:\n\n{}\n".format( summarize({ "features": feature_sequences, "initial_states": initial_states, "seq_lens": seq_lens, "max_seq_len": max_seq_len, })))
def _initialize_loss_from_dummy_batch( self, auto_remove_unneeded_view_reqs: bool = True) -> None: # Test calls depend on variable init, so initialize model first. self.get_session().run(tf1.global_variables_initializer()) # Fields that have not been accessed are not needed for action # computations -> Tag them as `used_for_compute_actions=False`. for key, view_req in self.view_requirements.items(): if (not key.startswith("state_in_") and key not in self._input_dict.accessed_keys): view_req.used_for_compute_actions = False for key, value in self.extra_action_out_fn().items(): self._dummy_batch[key] = get_dummy_batch_for_space( gym.spaces.Box(-1.0, 1.0, shape=value.shape.as_list()[1:], dtype=value.dtype.name), batch_size=len(self._dummy_batch), ) self._input_dict[key] = get_placeholder(value=value, name=key) if key not in self.view_requirements: logger.info( "Adding extra-action-fetch `{}` to view-reqs.".format(key)) self.view_requirements[key] = ViewRequirement( space=gym.spaces.Box(-1.0, 1.0, shape=value.shape[1:], dtype=value.dtype.name), used_for_compute_actions=False, ) dummy_batch = self._dummy_batch logger.info("Testing `postprocess_trajectory` w/ dummy batch.") self.exploration.postprocess_trajectory(self, dummy_batch, self.get_session()) _ = self.postprocess_trajectory(dummy_batch) # Add new columns automatically to (loss) input_dict. for key in dummy_batch.added_keys: if key not in self._input_dict: self._input_dict[key] = get_placeholder(value=dummy_batch[key], name=key) if key not in self.view_requirements: self.view_requirements[key] = ViewRequirement( space=gym.spaces.Box( -1.0, 1.0, shape=dummy_batch[key].shape[1:], dtype=dummy_batch[key].dtype, ), used_for_compute_actions=False, ) train_batch = SampleBatch( dict(self._input_dict, **self._loss_input_dict), _is_training=True, ) if self._state_inputs: train_batch[SampleBatch.SEQ_LENS] = self._seq_lens self._loss_input_dict.update( {SampleBatch.SEQ_LENS: train_batch[SampleBatch.SEQ_LENS]}) self._loss_input_dict.update({k: v for k, v in train_batch.items()}) if log_once("loss_init"): logger.debug( "Initializing loss function with dummy input:\n\n{}\n".format( summarize(train_batch))) losses = self._do_loss_init(train_batch) all_accessed_keys = (train_batch.accessed_keys | dummy_batch.accessed_keys | dummy_batch.added_keys | set(self.model.view_requirements.keys())) TFPolicy._initialize_loss( self, losses, [(k, v) for k, v in train_batch.items() if k in all_accessed_keys] + ([ (SampleBatch.SEQ_LENS, train_batch[SampleBatch.SEQ_LENS]) ] if SampleBatch.SEQ_LENS in train_batch else []), ) if "is_training" in self._loss_input_dict: del self._loss_input_dict["is_training"] # Call the grads stats fn. # TODO: (sven) rename to simply stats_fn to match eager and torch. self._stats_fetches.update(self.grad_stats_fn(train_batch, self._grads)) # Add new columns automatically to view-reqs. if auto_remove_unneeded_view_reqs: # Add those needed for postprocessing and training. all_accessed_keys = train_batch.accessed_keys | dummy_batch.accessed_keys # Tag those only needed for post-processing (with some exceptions). for key in dummy_batch.accessed_keys: if (key not in train_batch.accessed_keys and key not in self.model.view_requirements and key not in [ SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, SampleBatch.UNROLL_ID, SampleBatch.DONES, SampleBatch.REWARDS, SampleBatch.INFOS, SampleBatch.OBS_EMBEDS, ]): if key in self.view_requirements: self.view_requirements[key].used_for_training = False if key in self._loss_input_dict: del self._loss_input_dict[key] # Remove those not needed at all (leave those that are needed # by Sampler to properly execute sample collection). # Also always leave DONES, REWARDS, and INFOS, no matter what. for key in list(self.view_requirements.keys()): if (key not in all_accessed_keys and key not in [ SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, SampleBatch.UNROLL_ID, SampleBatch.DONES, SampleBatch.REWARDS, SampleBatch.INFOS, ] and key not in self.model.view_requirements): # If user deleted this key manually in postprocessing # fn, warn about it and do not remove from # view-requirements. if key in dummy_batch.deleted_keys: logger.warning( "SampleBatch key '{}' was deleted manually in " "postprocessing function! RLlib will " "automatically remove non-used items from the " "data stream. Remove the `del` from your " "postprocessing function.".format(key)) # If we are not writing output to disk, safe to erase # this key to save space in the sample batch. elif self.config["output"] is None: del self.view_requirements[key] if key in self._loss_input_dict: del self._loss_input_dict[key] # Add those data_cols (again) that are missing and have # dependencies by view_cols. for key in list(self.view_requirements.keys()): vr = self.view_requirements[key] if (vr.data_col is not None and vr.data_col not in self.view_requirements): used_for_training = vr.data_col in train_batch.accessed_keys self.view_requirements[vr.data_col] = ViewRequirement( space=vr.space, used_for_training=used_for_training) self._loss_input_dict_no_rnn = { k: v for k, v in self._loss_input_dict.items() if (v not in self._state_inputs and v != self._seq_lens) }
def add_batch(self, batch: SampleBatch) -> None: """Add the given batch of values to this batch.""" for k, column in batch.items(): self.buffers[k].extend(column) self.count += batch.count
def _initialize_loss_from_dummy_batch( self, auto_remove_unneeded_view_reqs: bool = True, stats_fn=None) -> None: # Create the optimizer/exploration optimizer here. Some initialization # steps (e.g. exploration postprocessing) may need this. self._optimizer = self.optimizer() # Test calls depend on variable init, so initialize model first. self.get_session().run(tf1.global_variables_initializer()) logger.info("Testing `compute_actions` w/ dummy batch.") actions, state_outs, extra_fetches = \ self.compute_actions_from_input_dict( self._dummy_batch, explore=False, timestep=0) for key, value in extra_fetches.items(): self._dummy_batch[key] = value self._input_dict[key] = get_placeholder(value=value, name=key) if key not in self.view_requirements: logger.info("Adding extra-action-fetch `{}` to " "view-reqs.".format(key)) self.view_requirements[key] = ViewRequirement( space=gym.spaces.Box(-1.0, 1.0, shape=value.shape[1:], dtype=value.dtype), used_for_compute_actions=False, ) dummy_batch = self._dummy_batch logger.info("Testing `postprocess_trajectory` w/ dummy batch.") self.exploration.postprocess_trajectory(self, dummy_batch, self.get_session()) _ = self.postprocess_trajectory(dummy_batch) # Add new columns automatically to (loss) input_dict. for key in dummy_batch.added_keys: if key not in self._input_dict: self._input_dict[key] = get_placeholder(value=dummy_batch[key], name=key) if key not in self.view_requirements: self.view_requirements[key] = ViewRequirement( space=gym.spaces.Box(-1.0, 1.0, shape=dummy_batch[key].shape[1:], dtype=dummy_batch[key].dtype), used_for_compute_actions=False, ) train_batch = SampleBatch( dict(self._input_dict, **self._loss_input_dict)) if self._state_inputs: train_batch["seq_lens"] = self._seq_lens self._loss_input_dict.update({"seq_lens": train_batch["seq_lens"]}) self._loss_input_dict.update({k: v for k, v in train_batch.items()}) if log_once("loss_init"): logger.debug( "Initializing loss function with dummy input:\n\n{}\n".format( summarize(train_batch))) loss = self._do_loss_init(train_batch) all_accessed_keys = \ train_batch.accessed_keys | dummy_batch.accessed_keys | \ dummy_batch.added_keys | set( self.model.view_requirements.keys()) TFPolicy._initialize_loss( self, loss, [(k, v) for k, v in train_batch.items() if k in all_accessed_keys] + ([("seq_lens", train_batch["seq_lens"])] if "seq_lens" in train_batch else [])) if "is_training" in self._loss_input_dict: del self._loss_input_dict["is_training"] # Call the grads stats fn. # TODO: (sven) rename to simply stats_fn to match eager and torch. if self._grad_stats_fn: self._stats_fetches.update( self._grad_stats_fn(self, train_batch, self._grads)) # Add new columns automatically to view-reqs. if auto_remove_unneeded_view_reqs: # Add those needed for postprocessing and training. all_accessed_keys = train_batch.accessed_keys | \ dummy_batch.accessed_keys # Tag those only needed for post-processing (with some exceptions). for key in dummy_batch.accessed_keys: if key not in train_batch.accessed_keys and \ key not in self.model.view_requirements and \ key not in [ SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, SampleBatch.UNROLL_ID, SampleBatch.DONES, SampleBatch.REWARDS, SampleBatch.INFOS]: if key in self.view_requirements: self.view_requirements[key].used_for_training = False if key in self._loss_input_dict: del self._loss_input_dict[key] # Remove those not needed at all (leave those that are needed # by Sampler to properly execute sample collection). # Also always leave DONES, REWARDS, and INFOS, no matter what. for key in list(self.view_requirements.keys()): if key not in all_accessed_keys and key not in [ SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, SampleBatch.UNROLL_ID, SampleBatch.DONES, SampleBatch.REWARDS, SampleBatch.INFOS] and \ key not in self.model.view_requirements: # If user deleted this key manually in postprocessing # fn, warn about it and do not remove from # view-requirements. if key in dummy_batch.deleted_keys: logger.warning( "SampleBatch key '{}' was deleted manually in " "postprocessing function! RLlib will " "automatically remove non-used items from the " "data stream. Remove the `del` from your " "postprocessing function.".format(key)) else: del self.view_requirements[key] if key in self._loss_input_dict: del self._loss_input_dict[key] # Add those data_cols (again) that are missing and have # dependencies by view_cols. for key in list(self.view_requirements.keys()): vr = self.view_requirements[key] if (vr.data_col is not None and vr.data_col not in self.view_requirements): used_for_training = \ vr.data_col in train_batch.accessed_keys self.view_requirements[vr.data_col] = ViewRequirement( space=vr.space, used_for_training=used_for_training) self._loss_input_dict_no_rnn = { k: v for k, v in self._loss_input_dict.items() if (v not in self._state_inputs and v != self._seq_lens) } # Initialize again after loss init. self.get_session().run(tf1.global_variables_initializer())
def _initialize_loss_from_dummy_batch( self, auto_remove_unneeded_view_reqs: bool = True, stats_fn=None) -> None: # Create the optimizer/exploration optimizer here. Some initialization # steps (e.g. exploration postprocessing) may need this. self._optimizer = self.optimizer() # Test calls depend on variable init, so initialize model first. self._sess.run(tf1.global_variables_initializer()) if self.config["_use_trajectory_view_api"]: logger.info("Testing `compute_actions` w/ dummy batch.") actions, state_outs, extra_fetches = \ self.compute_actions_from_input_dict( self._dummy_batch, explore=False, timestep=0) for key, value in extra_fetches.items(): self._dummy_batch[key] = np.zeros_like(value) self._input_dict[key] = get_placeholder(value=value, name=key) if key not in self.view_requirements: logger.info("Adding extra-action-fetch `{}` to " "view-reqs.".format(key)) self.view_requirements[key] = \ ViewRequirement(space=gym.spaces.Box( -1.0, 1.0, shape=value.shape[1:], dtype=value.dtype)) dummy_batch = self._dummy_batch else: def fake_array(tensor): shape = tensor.shape.as_list() shape = [s if s is not None else 1 for s in shape] return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype) dummy_batch = { SampleBatch.CUR_OBS: fake_array(self._obs_input), SampleBatch.NEXT_OBS: fake_array(self._obs_input), SampleBatch.DONES: np.array([False], dtype=np.bool), SampleBatch.ACTIONS: fake_array( ModelCatalog.get_action_placeholder(self.action_space)), SampleBatch.REWARDS: np.array([0], dtype=np.float32), } if self._obs_include_prev_action_reward: dummy_batch.update({ SampleBatch.PREV_ACTIONS: fake_array( self._prev_action_input), SampleBatch.PREV_REWARDS: fake_array( self._prev_reward_input), }) state_init = self.get_initial_state() state_batches = [] for i, h in enumerate(state_init): dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0) dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0) state_batches.append(np.expand_dims(h, 0)) if state_init: dummy_batch["seq_lens"] = np.array([1], dtype=np.int32) for k, v in self.extra_compute_action_fetches().items(): dummy_batch[k] = fake_array(v) dummy_batch = SampleBatch(dummy_batch) logger.info("Testing `postprocess_trajectory` w/ dummy batch.") self.exploration.postprocess_trajectory(self, dummy_batch, self._sess) postprocessed_batch = self.postprocess_trajectory(dummy_batch) # Add new columns automatically to (loss) input_dict. if self.config["_use_trajectory_view_api"]: for key in dummy_batch.added_keys: if key not in self._input_dict: self._input_dict[key] = get_placeholder( value=dummy_batch[key], name=key) if key not in self.view_requirements: self.view_requirements[key] = \ ViewRequirement(space=gym.spaces.Box( -1.0, 1.0, shape=dummy_batch[key].shape[1:], dtype=dummy_batch[key].dtype)) if not self.config["_use_trajectory_view_api"]: train_batch = SampleBatch( dict({ SampleBatch.CUR_OBS: self._obs_input, }, **self._loss_input_dict)) if self._obs_include_prev_action_reward: train_batch.update({ SampleBatch.PREV_ACTIONS: self._prev_action_input, SampleBatch.PREV_REWARDS: self._prev_reward_input, SampleBatch.CUR_OBS: self._obs_input, }) for k, v in postprocessed_batch.items(): if k in train_batch: continue elif v.dtype == np.object: continue # can't handle arbitrary objects in TF elif k == "seq_lens" or k.startswith("state_in_"): continue shape = (None, ) + v.shape[1:] dtype = np.float32 if v.dtype == np.float64 else v.dtype placeholder = tf1.placeholder(dtype, shape=shape, name=k) train_batch[k] = placeholder for i, si in enumerate(self._state_inputs): train_batch["state_in_{}".format(i)] = si else: train_batch = SampleBatch( dict(self._input_dict, **self._loss_input_dict)) if self._state_inputs: train_batch["seq_lens"] = self._seq_lens if log_once("loss_init"): logger.debug( "Initializing loss function with dummy input:\n\n{}\n".format( summarize(train_batch))) self._loss_input_dict.update({k: v for k, v in train_batch.items()}) loss = self._do_loss_init(train_batch) all_accessed_keys = \ train_batch.accessed_keys | dummy_batch.accessed_keys | \ dummy_batch.added_keys | set( self.model.view_requirements.keys()) TFPolicy._initialize_loss(self, loss, [(k, v) for k, v in train_batch.items() if k in all_accessed_keys]) if "is_training" in self._loss_input_dict: del self._loss_input_dict["is_training"] # Call the grads stats fn. # TODO: (sven) rename to simply stats_fn to match eager and torch. if self._grad_stats_fn: self._stats_fetches.update( self._grad_stats_fn(self, train_batch, self._grads)) # Add new columns automatically to view-reqs. if self.config["_use_trajectory_view_api"] and \ auto_remove_unneeded_view_reqs: # Add those needed for postprocessing and training. all_accessed_keys = train_batch.accessed_keys | \ dummy_batch.accessed_keys # Tag those only needed for post-processing (with some exceptions). for key in dummy_batch.accessed_keys: if key not in train_batch.accessed_keys and \ key not in self.model.view_requirements and \ key not in [ SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, SampleBatch.UNROLL_ID, SampleBatch.DONES, SampleBatch.REWARDS, SampleBatch.INFOS]: if key in self.view_requirements: self.view_requirements[key].used_for_training = False if key in self._loss_input_dict: del self._loss_input_dict[key] # Remove those not needed at all (leave those that are needed # by Sampler to properly execute sample collection). # Also always leave DONES, REWARDS, and INFOS, no matter what. for key in list(self.view_requirements.keys()): if key not in all_accessed_keys and key not in [ SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, SampleBatch.UNROLL_ID, SampleBatch.DONES, SampleBatch.REWARDS, SampleBatch.INFOS] and \ key not in self.model.view_requirements: # If user deleted this key manually in postprocessing # fn, warn about it and do not remove from # view-requirements. if key in dummy_batch.deleted_keys: logger.warning( "SampleBatch key '{}' was deleted manually in " "postprocessing function! RLlib will " "automatically remove non-used items from the " "data stream. Remove the `del` from your " "postprocessing function.".format(key)) else: del self.view_requirements[key] if key in self._loss_input_dict: del self._loss_input_dict[key] # Add those data_cols (again) that are missing and have # dependencies by view_cols. for key in list(self.view_requirements.keys()): vr = self.view_requirements[key] if (vr.data_col is not None and vr.data_col not in self.view_requirements): used_for_training = \ vr.data_col in train_batch.accessed_keys self.view_requirements[vr.data_col] = ViewRequirement( space=vr.space, used_for_training=used_for_training) self._loss_input_dict_no_rnn = { k: v for k, v in self._loss_input_dict.items() if (v not in self._state_inputs and v != self._seq_lens) } # Initialize again after loss init. self._sess.run(tf1.global_variables_initializer())
def compute_advantages(rollout: SampleBatch, last_r: float, gamma: float = 0.9, lambda_: float = 1.0, use_gae: bool = True, use_critic: bool = True): """ Given a rollout, compute its value targets and the advantage. Args: rollout (SampleBatch): SampleBatch of a single trajectory last_r (float): Value estimation for last observation gamma (float): Discount factor. lambda_ (float): Parameter for GAE use_gae (bool): Using Generalized Advantage Estimation use_critic (bool): Whether to use critic (value estimates). Setting this to False will use 0 as baseline. Returns: SampleBatch (SampleBatch): Object with experience from rollout and processed rewards. """ rollout_size = len(rollout[SampleBatch.ACTIONS]) assert SampleBatch.VF_PREDS in rollout or not use_critic, \ "use_critic=True but values not found" assert use_critic or not use_gae, \ "Can't use gae without using a value function" if use_gae: vpred_t = np.concatenate( [rollout[SampleBatch.VF_PREDS], np.array([last_r])]) delta_t = (rollout[SampleBatch.REWARDS] + gamma * vpred_t[1:] - vpred_t[:-1]) # This formula for the advantage comes from: # "Generalized Advantage Estimation": https://arxiv.org/abs/1506.02438 rollout[Postprocessing.ADVANTAGES] = discount(delta_t, gamma * lambda_) rollout[Postprocessing.VALUE_TARGETS] = ( rollout[Postprocessing.ADVANTAGES] + rollout[SampleBatch.VF_PREDS]).copy().astype(np.float32) else: rewards_plus_v = np.concatenate( [rollout[SampleBatch.REWARDS], np.array([last_r])]) discounted_returns = discount(rewards_plus_v, gamma)[:-1].copy().astype(np.float32) if use_critic: rollout[Postprocessing.ADVANTAGES] = discounted_returns - rollout[ SampleBatch.VF_PREDS] rollout[Postprocessing.VALUE_TARGETS] = discounted_returns else: rollout[Postprocessing.ADVANTAGES] = discounted_returns rollout[Postprocessing.VALUE_TARGETS] = np.zeros_like( rollout[Postprocessing.ADVANTAGES]) rollout[Postprocessing.ADVANTAGES] = rollout[ Postprocessing.ADVANTAGES].copy().astype(np.float32) assert all(val.shape[0] == rollout_size for key, val in rollout.items()), \ "Rollout stacked incorrectly!" return rollout