def get_exploration_optimizer(self, optimizers): # Create, but don't add Adam for curiosity NN updating to the policy. # If we added and returned it here, it would be used in the policy's # update loop, which we don't want (curiosity updating happens inside # `postprocess_trajectory`). if self.framework == "torch": feature_params = list(self._curiosity_feature_net.parameters()) inverse_params = list(self._curiosity_inverse_fcnet.parameters()) forward_params = list(self._curiosity_forward_fcnet.parameters()) # Now that the Policy's own optimizer(s) have been created (from # the Model parameters (IMPORTANT: w/o(!) the curiosity params), # we can add our curiosity sub-modules to the Policy's Model. self.model._curiosity_feature_net = self._curiosity_feature_net.to( self.device ) self.model._curiosity_inverse_fcnet = self._curiosity_inverse_fcnet.to( self.device ) self.model._curiosity_forward_fcnet = self._curiosity_forward_fcnet.to( self.device ) self._optimizer = torch.optim.Adam( forward_params + inverse_params + feature_params, lr=self.lr ) else: self.model._curiosity_feature_net = self._curiosity_feature_net self.model._curiosity_inverse_fcnet = self._curiosity_inverse_fcnet self.model._curiosity_forward_fcnet = self._curiosity_forward_fcnet # Feature net is a RLlib ModelV2, the other 2 are keras Models. self._optimizer_var_list = ( self._curiosity_feature_net.base_model.variables + self._curiosity_inverse_fcnet.variables + self._curiosity_forward_fcnet.variables ) self._optimizer = tf1.train.AdamOptimizer(learning_rate=self.lr) # Create placeholders and initialize the loss. if self.framework == "tf": self._obs_ph = get_placeholder( space=self.model.obs_space, name="_curiosity_obs" ) self._next_obs_ph = get_placeholder( space=self.model.obs_space, name="_curiosity_next_obs" ) self._action_ph = get_placeholder( space=self.model.action_space, name="_curiosity_action" ) ( self._forward_l2_norm_sqared, self._update_op, ) = self._postprocess_helper_tf( self._obs_ph, self._next_obs_ph, self._action_ph ) return optimizers
def _init_state_inputs(self, existing_inputs: Dict[str, "tf1.placeholder"]): """Initialize input placeholders. Args: existing_inputs: existing placeholders. """ if existing_inputs: self._state_inputs = [ v for k, v in existing_inputs.items() if k.startswith("state_in_") ] # Placeholder for RNN time-chunk valid lengths. if self._state_inputs: self._seq_lens = existing_inputs[SampleBatch.SEQ_LENS] # Create new input placeholders. else: self._state_inputs = [ get_placeholder( space=vr.space, time_axis=not isinstance(vr.shift, int), name=k, ) for k, vr in self.model.view_requirements.items() if k.startswith("state_in_") ] # Placeholder for RNN time-chunk valid lengths. if self._state_inputs: self._seq_lens = tf1.placeholder(dtype=tf.int32, shape=[None], name="seq_lens")
def _initialize_loss_from_dummy_batch( self, auto_remove_unneeded_view_reqs: bool = True) -> None: # Test calls depend on variable init, so initialize model first. self.get_session().run(tf1.global_variables_initializer()) # Fields that have not been accessed are not needed for action # computations -> Tag them as `used_for_compute_actions=False`. for key, view_req in self.view_requirements.items(): if (not key.startswith("state_in_") and key not in self._input_dict.accessed_keys): view_req.used_for_compute_actions = False for key, value in self.extra_action_out_fn().items(): self._dummy_batch[key] = get_dummy_batch_for_space( gym.spaces.Box(-1.0, 1.0, shape=value.shape.as_list()[1:], dtype=value.dtype.name), batch_size=len(self._dummy_batch), ) self._input_dict[key] = get_placeholder(value=value, name=key) if key not in self.view_requirements: logger.info( "Adding extra-action-fetch `{}` to view-reqs.".format(key)) self.view_requirements[key] = ViewRequirement( space=gym.spaces.Box(-1.0, 1.0, shape=value.shape[1:], dtype=value.dtype.name), used_for_compute_actions=False, ) dummy_batch = self._dummy_batch logger.info("Testing `postprocess_trajectory` w/ dummy batch.") self.exploration.postprocess_trajectory(self, dummy_batch, self.get_session()) _ = self.postprocess_trajectory(dummy_batch) # Add new columns automatically to (loss) input_dict. for key in dummy_batch.added_keys: if key not in self._input_dict: self._input_dict[key] = get_placeholder(value=dummy_batch[key], name=key) if key not in self.view_requirements: self.view_requirements[key] = ViewRequirement( space=gym.spaces.Box( -1.0, 1.0, shape=dummy_batch[key].shape[1:], dtype=dummy_batch[key].dtype, ), used_for_compute_actions=False, ) train_batch = SampleBatch( dict(self._input_dict, **self._loss_input_dict), _is_training=True, ) if self._state_inputs: train_batch[SampleBatch.SEQ_LENS] = self._seq_lens self._loss_input_dict.update( {SampleBatch.SEQ_LENS: train_batch[SampleBatch.SEQ_LENS]}) self._loss_input_dict.update({k: v for k, v in train_batch.items()}) if log_once("loss_init"): logger.debug( "Initializing loss function with dummy input:\n\n{}\n".format( summarize(train_batch))) losses = self._do_loss_init(train_batch) all_accessed_keys = (train_batch.accessed_keys | dummy_batch.accessed_keys | dummy_batch.added_keys | set(self.model.view_requirements.keys())) TFPolicy._initialize_loss( self, losses, [(k, v) for k, v in train_batch.items() if k in all_accessed_keys] + ([ (SampleBatch.SEQ_LENS, train_batch[SampleBatch.SEQ_LENS]) ] if SampleBatch.SEQ_LENS in train_batch else []), ) if "is_training" in self._loss_input_dict: del self._loss_input_dict["is_training"] # Call the grads stats fn. # TODO: (sven) rename to simply stats_fn to match eager and torch. self._stats_fetches.update(self.grad_stats_fn(train_batch, self._grads)) # Add new columns automatically to view-reqs. if auto_remove_unneeded_view_reqs: # Add those needed for postprocessing and training. all_accessed_keys = train_batch.accessed_keys | dummy_batch.accessed_keys # Tag those only needed for post-processing (with some exceptions). for key in dummy_batch.accessed_keys: if (key not in train_batch.accessed_keys and key not in self.model.view_requirements and key not in [ SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, SampleBatch.UNROLL_ID, SampleBatch.DONES, SampleBatch.REWARDS, SampleBatch.INFOS, SampleBatch.OBS_EMBEDS, ]): if key in self.view_requirements: self.view_requirements[key].used_for_training = False if key in self._loss_input_dict: del self._loss_input_dict[key] # Remove those not needed at all (leave those that are needed # by Sampler to properly execute sample collection). # Also always leave DONES, REWARDS, and INFOS, no matter what. for key in list(self.view_requirements.keys()): if (key not in all_accessed_keys and key not in [ SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, SampleBatch.UNROLL_ID, SampleBatch.DONES, SampleBatch.REWARDS, SampleBatch.INFOS, ] and key not in self.model.view_requirements): # If user deleted this key manually in postprocessing # fn, warn about it and do not remove from # view-requirements. if key in dummy_batch.deleted_keys: logger.warning( "SampleBatch key '{}' was deleted manually in " "postprocessing function! RLlib will " "automatically remove non-used items from the " "data stream. Remove the `del` from your " "postprocessing function.".format(key)) # If we are not writing output to disk, safe to erase # this key to save space in the sample batch. elif self.config["output"] is None: del self.view_requirements[key] if key in self._loss_input_dict: del self._loss_input_dict[key] # Add those data_cols (again) that are missing and have # dependencies by view_cols. for key in list(self.view_requirements.keys()): vr = self.view_requirements[key] if (vr.data_col is not None and vr.data_col not in self.view_requirements): used_for_training = vr.data_col in train_batch.accessed_keys self.view_requirements[vr.data_col] = ViewRequirement( space=vr.space, used_for_training=used_for_training) self._loss_input_dict_no_rnn = { k: v for k, v in self._loss_input_dict.items() if (v not in self._state_inputs and v != self._seq_lens) }
def _create_input_dict_and_dummy_batch(self, view_requirements, existing_inputs): """Creates input_dict and dummy_batch for loss initialization. Used for managing the Policy's input placeholders and for loss initialization. Input_dict: Str -> tf.placeholders, dummy_batch: str -> np.arrays. Args: view_requirements: The view requirements dict. existing_inputs (Dict[str, tf.placeholder]): A dict of already existing placeholders. Returns: Tuple[Dict[str, tf.placeholder], Dict[str, np.ndarray]]: The input_dict/dummy_batch tuple. """ input_dict = {} for view_col, view_req in view_requirements.items(): # Point state_in to the already existing self._state_inputs. mo = re.match("state_in_(\d+)", view_col) if mo is not None: input_dict[view_col] = self._state_inputs[int(mo.group(1))] # State-outs (no placeholders needed). elif view_col.startswith("state_out_"): continue # Skip action dist inputs placeholder (do later). elif view_col == SampleBatch.ACTION_DIST_INPUTS: continue # This is a tower: Input placeholders already exist. elif view_col in existing_inputs: input_dict[view_col] = existing_inputs[view_col] # All others. else: time_axis = not isinstance(view_req.shift, int) if view_req.used_for_training: # Create a +time-axis placeholder if the shift is not an # int (range or list of ints). # Do not flatten actions if action flattening disabled. if self.config.get( "_disable_action_flattening") and view_col in [ SampleBatch.ACTIONS, SampleBatch.PREV_ACTIONS, ]: flatten = False # Do not flatten observations if no preprocessor API used. elif (view_col in [SampleBatch.OBS, SampleBatch.NEXT_OBS] and self.config["_disable_preprocessor_api"]): flatten = False # Flatten everything else. else: flatten = True input_dict[view_col] = get_placeholder( space=view_req.space, name=view_col, time_axis=time_axis, flatten=flatten, ) dummy_batch = self._get_dummy_batch_from_view_requirements( batch_size=32) return SampleBatch(input_dict, seq_lens=self._seq_lens), dummy_batch
def __init__(self, action_space: Space, *, framework: str, model: ModelV2, embeds_dim: int = 128, encoder_net_config: Optional[ModelConfigDict] = None, beta: float = 0.2, beta_schedule: str = "constant", rho: float = 0.1, k_nn: int = 50, random_timesteps: int = 10000, sub_exploration: Optional[FromConfigSpec] = None, **kwargs): """Initialize RE3. Args: action_space: The action space in which to explore. framework: Supports "tf", this implementation does not support torch. model: The policy's model. embeds_dim: The dimensionality of the observation embedding vectors in latent space. encoder_net_config: Optional model configuration for the encoder network, producing embedding vectors from observations. This can be used to configure fcnet- or conv_net setups to properly process any observation space. beta: Hyperparameter to choose between exploration and exploitation. beta_schedule: Schedule to use for beta decay, one of "constant" or "linear_decay". rho: Beta decay factor, used for on-policy algorithm. k_nn: Number of neighbours to set for K-NN entropy estimation. random_timesteps: The number of timesteps to act completely randomly (see [1]). sub_exploration: The config dict for the underlying Exploration to use (e.g. epsilon-greedy for DQN). If None, uses the FromSpecDict provided in the Policy's default config. Raises: ValueError: If the input framework is Torch. """ # TODO(gjoliver): Add supports for Pytorch. if framework == "torch": raise ValueError("This RE3 implementation does not support Torch.") super().__init__(action_space, model=model, framework=framework, **kwargs) self.beta = beta self.rho = rho self.k_nn = k_nn self.embeds_dim = embeds_dim if encoder_net_config is None: encoder_net_config = self.policy_config["model"].copy() self.encoder_net_config = encoder_net_config # Auto-detection of underlying exploration functionality. if sub_exploration is None: # For discrete action spaces, use an underlying EpsilonGreedy with # a special schedule. if isinstance(self.action_space, Discrete): sub_exploration = { "type": "EpsilonGreedy", "epsilon_schedule": { "type": "PiecewiseSchedule", # Step function (see [2]). "endpoints": [ (0, 1.0), (random_timesteps + 1, 1.0), (random_timesteps + 2, 0.01), ], "outside_value": 0.01, }, } elif isinstance(self.action_space, Box): sub_exploration = { "type": "OrnsteinUhlenbeckNoise", "random_timesteps": random_timesteps, } else: raise NotImplementedError self.sub_exploration = sub_exploration # Creates ModelV2 embedding module / layers. self._encoder_net = ModelCatalog.get_model_v2( self.model.obs_space, self.action_space, self.embeds_dim, model_config=self.encoder_net_config, framework=self.framework, name="encoder_net", ) if self.framework == "tf": self._obs_ph = get_placeholder(space=self.model.obs_space, name="_encoder_obs") self._obs_embeds = tf.stop_gradient( self._encoder_net({SampleBatch.OBS: self._obs_ph})[0]) # This is only used to select the correct action self.exploration_submodule = from_config( cls=Exploration, config=self.sub_exploration, action_space=self.action_space, framework=self.framework, policy_config=self.policy_config, model=self.model, num_workers=self.num_workers, worker_index=self.worker_index, )