def _get_dummy_batch_from_view_requirements(self, batch_size: int = 1 ) -> SampleBatch: """Creates a numpy dummy batch based on the Policy's view requirements. Args: batch_size (int): The size of the batch to create. Returns: Dict[str, TensorType]: The dummy batch containing all zero values. """ ret = {} for view_col, view_req in self.view_requirements.items(): data_col = view_req.data_col or view_col # Flattened dummy batch. if (isinstance( view_req.space, (gym.spaces.Tuple, gym.spaces.Dict))) and ( (data_col == SampleBatch.OBS and not self.config["_disable_preprocessor_api"]) or (data_col == SampleBatch.ACTIONS and not self.config.get("_disable_action_flattening"))): _, shape = ModelCatalog.get_action_shape( view_req.space, framework=self.config["framework"]) ret[view_col] = np.zeros((batch_size, ) + shape[1:], np.float32) # Non-flattened dummy batch. else: # Range of indices on time-axis, e.g. "-50:-1". if view_req.shift_from is not None: ret[view_col] = get_dummy_batch_for_space( view_req.space, batch_size=batch_size, time_size=view_req.shift_to - view_req.shift_from + 1, ) # Sequence of (probably non-consecutive) indices. elif isinstance(view_req.shift, (list, tuple)): ret[view_col] = get_dummy_batch_for_space( view_req.space, batch_size=batch_size, time_size=len(view_req.shift), ) # Single shift int value. else: if isinstance(view_req.space, gym.spaces.Space): ret[view_col] = get_dummy_batch_for_space( view_req.space, batch_size=batch_size, fill_value=0.0) else: ret[view_col] = [ view_req.space for _ in range(batch_size) ] # Due to different view requirements for the different columns, # columns in the resulting batch may not all have the same batch size. return SampleBatch(ret)
def _initialize_loss_from_dummy_batch( self, auto_remove_unneeded_view_reqs: bool = True) -> None: # Test calls depend on variable init, so initialize model first. self.get_session().run(tf1.global_variables_initializer()) # Fields that have not been accessed are not needed for action # computations -> Tag them as `used_for_compute_actions=False`. for key, view_req in self.view_requirements.items(): if (not key.startswith("state_in_") and key not in self._input_dict.accessed_keys): view_req.used_for_compute_actions = False for key, value in self.extra_action_out_fn().items(): self._dummy_batch[key] = get_dummy_batch_for_space( gym.spaces.Box(-1.0, 1.0, shape=value.shape.as_list()[1:], dtype=value.dtype.name), batch_size=len(self._dummy_batch), ) self._input_dict[key] = get_placeholder(value=value, name=key) if key not in self.view_requirements: logger.info( "Adding extra-action-fetch `{}` to view-reqs.".format(key)) self.view_requirements[key] = ViewRequirement( space=gym.spaces.Box(-1.0, 1.0, shape=value.shape[1:], dtype=value.dtype.name), used_for_compute_actions=False, ) dummy_batch = self._dummy_batch logger.info("Testing `postprocess_trajectory` w/ dummy batch.") self.exploration.postprocess_trajectory(self, dummy_batch, self.get_session()) _ = self.postprocess_trajectory(dummy_batch) # Add new columns automatically to (loss) input_dict. for key in dummy_batch.added_keys: if key not in self._input_dict: self._input_dict[key] = get_placeholder(value=dummy_batch[key], name=key) if key not in self.view_requirements: self.view_requirements[key] = ViewRequirement( space=gym.spaces.Box( -1.0, 1.0, shape=dummy_batch[key].shape[1:], dtype=dummy_batch[key].dtype, ), used_for_compute_actions=False, ) train_batch = SampleBatch( dict(self._input_dict, **self._loss_input_dict), _is_training=True, ) if self._state_inputs: train_batch[SampleBatch.SEQ_LENS] = self._seq_lens self._loss_input_dict.update( {SampleBatch.SEQ_LENS: train_batch[SampleBatch.SEQ_LENS]}) self._loss_input_dict.update({k: v for k, v in train_batch.items()}) if log_once("loss_init"): logger.debug( "Initializing loss function with dummy input:\n\n{}\n".format( summarize(train_batch))) losses = self._do_loss_init(train_batch) all_accessed_keys = (train_batch.accessed_keys | dummy_batch.accessed_keys | dummy_batch.added_keys | set(self.model.view_requirements.keys())) TFPolicy._initialize_loss( self, losses, [(k, v) for k, v in train_batch.items() if k in all_accessed_keys] + ([ (SampleBatch.SEQ_LENS, train_batch[SampleBatch.SEQ_LENS]) ] if SampleBatch.SEQ_LENS in train_batch else []), ) if "is_training" in self._loss_input_dict: del self._loss_input_dict["is_training"] # Call the grads stats fn. # TODO: (sven) rename to simply stats_fn to match eager and torch. self._stats_fetches.update(self.grad_stats_fn(train_batch, self._grads)) # Add new columns automatically to view-reqs. if auto_remove_unneeded_view_reqs: # Add those needed for postprocessing and training. all_accessed_keys = train_batch.accessed_keys | dummy_batch.accessed_keys # Tag those only needed for post-processing (with some exceptions). for key in dummy_batch.accessed_keys: if (key not in train_batch.accessed_keys and key not in self.model.view_requirements and key not in [ SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, SampleBatch.UNROLL_ID, SampleBatch.DONES, SampleBatch.REWARDS, SampleBatch.INFOS, SampleBatch.OBS_EMBEDS, ]): if key in self.view_requirements: self.view_requirements[key].used_for_training = False if key in self._loss_input_dict: del self._loss_input_dict[key] # Remove those not needed at all (leave those that are needed # by Sampler to properly execute sample collection). # Also always leave DONES, REWARDS, and INFOS, no matter what. for key in list(self.view_requirements.keys()): if (key not in all_accessed_keys and key not in [ SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, SampleBatch.UNROLL_ID, SampleBatch.DONES, SampleBatch.REWARDS, SampleBatch.INFOS, ] and key not in self.model.view_requirements): # If user deleted this key manually in postprocessing # fn, warn about it and do not remove from # view-requirements. if key in dummy_batch.deleted_keys: logger.warning( "SampleBatch key '{}' was deleted manually in " "postprocessing function! RLlib will " "automatically remove non-used items from the " "data stream. Remove the `del` from your " "postprocessing function.".format(key)) # If we are not writing output to disk, safe to erase # this key to save space in the sample batch. elif self.config["output"] is None: del self.view_requirements[key] if key in self._loss_input_dict: del self._loss_input_dict[key] # Add those data_cols (again) that are missing and have # dependencies by view_cols. for key in list(self.view_requirements.keys()): vr = self.view_requirements[key] if (vr.data_col is not None and vr.data_col not in self.view_requirements): used_for_training = vr.data_col in train_batch.accessed_keys self.view_requirements[vr.data_col] = ViewRequirement( space=vr.space, used_for_training=used_for_training) self._loss_input_dict_no_rnn = { k: v for k, v in self._loss_input_dict.items() if (v not in self._state_inputs and v != self._seq_lens) }
def get_inference_input_dict(self, policy_id: PolicyID) -> Dict[str, TensorType]: policy = self.policy_map[policy_id] keys = self.forward_pass_agent_keys[policy_id] batch_size = len(keys) # Return empty batch, if no forward pass to do. if batch_size == 0: return SampleBatch() buffers = {} for k in keys: collector = self.agent_collectors[k] buffers[k] = collector.buffers # Use one agent's buffer_structs (they should all be the same). buffer_structs = self.agent_collectors[keys[0]].buffer_structs input_dict = {} for view_col, view_req in policy.view_requirements.items(): # Not used for action computations. if not view_req.used_for_compute_actions: continue # Create the batch of data from the different buffers. data_col = view_req.data_col or view_col delta = (-1 if data_col in [ SampleBatch.OBS, SampleBatch.ENV_ID, SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, SampleBatch.T, ] else 0) # Range of shifts, e.g. "-100:0". Note: This includes index 0! if view_req.shift_from is not None: time_indices = (view_req.shift_from + delta, view_req.shift_to + delta) # Single shift (e.g. -1) or list of shifts, e.g. [-4, -1, 0]. else: time_indices = view_req.shift + delta # Loop through agents and add up their data (batch). data = None for k in keys: # Buffer for the data does not exist yet: Create dummy # (zero) data. if data_col not in buffers[k]: if view_req.data_col is not None: space = policy.view_requirements[ view_req.data_col].space else: space = view_req.space if isinstance(space, Space): fill_value = get_dummy_batch_for_space( space, batch_size=0, ) else: fill_value = space self.agent_collectors[k]._build_buffers( {data_col: fill_value}) if data is None: data = [[] for _ in range(len(buffers[keys[0]][data_col]))] # `shift_from` and `shift_to` are defined: User wants a # view with some time-range. if isinstance(time_indices, tuple): # `shift_to` == -1: Until the end (including(!) the # last item). if time_indices[1] == -1: for d, b in zip(data, buffers[k][data_col]): d.append(b[time_indices[0]:]) # `shift_to` != -1: "Normal" range. else: for d, b in zip(data, buffers[k][data_col]): d.append(b[time_indices[0]:time_indices[1] + 1]) # Single index. else: for d, b in zip(data, buffers[k][data_col]): d.append(b[time_indices]) np_data = [np.array(d) for d in data] if data_col in buffer_structs: input_dict[view_col] = tree.unflatten_as( buffer_structs[data_col], np_data) else: input_dict[view_col] = np_data[0] self._reset_inference_calls(policy_id) return SampleBatch( input_dict, seq_lens=np.ones(batch_size, dtype=np.int32) if "state_in_0" in input_dict else None, )