def get_exploration_optimizer(self, optimizers): # Create, but don't add Adam for curiosity NN updating to the policy. # If we added and returned it here, it would be used in the policy's # update loop, which we don't want (curiosity updating happens inside # `postprocess_trajectory`). if self.framework == "torch": feature_params = list(self._curiosity_feature_net.parameters()) inverse_params = list(self._curiosity_inverse_fcnet.parameters()) forward_params = list(self._curiosity_forward_fcnet.parameters()) # Now that the Policy's own optimizer(s) have been created (from # the Model parameters (IMPORTANT: w/o(!) the curiosity params), # we can add our curiosity sub-modules to the Policy's Model. self.model._curiosity_feature_net = \ self._curiosity_feature_net.to(self.device) self.model._curiosity_inverse_fcnet = \ self._curiosity_inverse_fcnet.to(self.device) self.model._curiosity_forward_fcnet = \ self._curiosity_forward_fcnet.to(self.device) self._optimizer = torch.optim.Adam(forward_params + inverse_params + feature_params, lr=self.lr) else: self.model._curiosity_feature_net = self._curiosity_feature_net self.model._curiosity_inverse_fcnet = self._curiosity_inverse_fcnet self.model._curiosity_forward_fcnet = self._curiosity_forward_fcnet # Feature net is a RLlib ModelV2, the other 2 are keras Models. self._optimizer_var_list = \ self._curiosity_feature_net.base_model.variables + \ self._curiosity_inverse_fcnet.variables + \ self._curiosity_forward_fcnet.variables self.model.register_variables(self._optimizer_var_list) self._optimizer = tf1.train.AdamOptimizer(learning_rate=self.lr) # Create placeholders and initialize the loss. if self.framework == "tf": self._obs_ph = get_placeholder(space=self.model.obs_space, name="_curiosity_obs") self._next_obs_ph = get_placeholder(space=self.model.obs_space, name="_curiosity_next_obs") self._action_ph = get_placeholder( space=self.model.action_space, name="_curiosity_action") self._forward_l2_norm_sqared, self._update_op = \ self._postprocess_helper_tf( self._obs_ph, self._next_obs_ph, self._action_ph) return optimizers
def _get_input_dict_and_dummy_batch(self, view_requirements, existing_inputs): """Creates input_dict and dummy_batch for loss initialization. Used for managing the Policy's input placeholders and for loss initialization. Input_dict: Str -> tf.placeholders, dummy_batch: str -> np.arrays. Args: view_requirements (ViewReqs): The view requirements dict. existing_inputs (Dict[str, tf.placeholder]): A dict of already existing placeholders. Returns: Tuple[Dict[str, tf.placeholder], Dict[str, np.ndarray]]: The input_dict/dummy_batch tuple. """ input_dict = {} dummy_batch = {} for view_col, view_req in view_requirements.items(): # Point state_in to the already existing self._state_inputs. mo = re.match("state_in_(\d+)", view_col) if mo is not None: input_dict[view_col] = self._state_inputs[int(mo.group(1))] dummy_batch[view_col] = np.zeros_like( [view_req.space.sample()]) # State-outs (no placeholders needed). elif view_col.startswith("state_out_"): dummy_batch[view_col] = np.zeros_like( [view_req.space.sample()]) # Skip action dist inputs placeholder (do later). elif view_col == SampleBatch.ACTION_DIST_INPUTS: continue elif view_col in existing_inputs: input_dict[view_col] = existing_inputs[view_col] dummy_batch[view_col] = np.zeros( shape=[ 1 if s is None else s for s in existing_inputs[view_col].shape.as_list() ], dtype=existing_inputs[view_col].dtype.as_numpy_dtype) # All others. else: if view_req.used_for_training: input_dict[view_col] = get_placeholder( space=view_req.space, name=view_col) dummy_batch[view_col] = np.zeros_like( [view_req.space.sample()]) return input_dict, dummy_batch
def _get_input_dict_and_dummy_batch(self, view_requirements, existing_inputs): """Creates input_dict and dummy_batch for loss initialization. Used for managing the Policy's input placeholders and for loss initialization. Input_dict: Str -> tf.placeholders, dummy_batch: str -> np.arrays. Args: view_requirements (ViewReqs): The view requirements dict. existing_inputs (Dict[str, tf.placeholder]): A dict of already existing placeholders. Returns: Tuple[Dict[str, tf.placeholder], Dict[str, np.ndarray]]: The input_dict/dummy_batch tuple. """ input_dict = {} for view_col, view_req in view_requirements.items(): # Point state_in to the already existing self._state_inputs. mo = re.match("state_in_(\d+)", view_col) if mo is not None: input_dict[view_col] = self._state_inputs[int(mo.group(1))] # State-outs (no placeholders needed). elif view_col.startswith("state_out_"): continue # Skip action dist inputs placeholder (do later). elif view_col == SampleBatch.ACTION_DIST_INPUTS: continue elif view_col in existing_inputs: input_dict[view_col] = existing_inputs[view_col] # All others. else: time_axis = not isinstance(view_req.shift, int) if view_req.used_for_training: # Create a +time-axis placeholder if the shift is not an # int (range or list of ints). input_dict[view_col] = get_placeholder( space=view_req.space, name=view_col, time_axis=time_axis) dummy_batch = self._get_dummy_batch_from_view_requirements( batch_size=32) return input_dict, dummy_batch
def __init__( self, obs_space: gym.spaces.Space, action_space: gym.spaces.Space, config: TrainerConfigDict, loss_fn: Callable[ [Policy, ModelV2, Type[TFActionDistribution], SampleBatch], TensorType], *, stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[str, TensorType]]] = None, grad_stats_fn: Optional[ Callable[[Policy, SampleBatch, ModelGradients], Dict[str, TensorType]]] = None, before_loss_init: Optional[Callable[[ Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict ], None]] = None, make_model: Optional[Callable[[ Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict ], ModelV2]] = None, action_sampler_fn: Optional[Callable[ [TensorType, List[TensorType]], Tuple[TensorType, TensorType]]] = None, action_distribution_fn: Optional[ Callable[[Policy, ModelV2, TensorType, TensorType, TensorType], Tuple[TensorType, type, List[TensorType]]]] = None, existing_inputs: Optional[Dict[str, "tf1.placeholder"]] = None, existing_model: Optional[ModelV2] = None, get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None, obs_include_prev_action_reward: bool = True): """Initialize a dynamic TF policy. Args: observation_space (gym.spaces.Space): Observation space of the policy. action_space (gym.spaces.Space): Action space of the policy. config (TrainerConfigDict): Policy-specific configuration data. loss_fn (Callable[[Policy, ModelV2, Type[TFActionDistribution], SampleBatch], TensorType]): Function that returns a loss tensor for the policy graph. stats_fn (Optional[Callable[[Policy, SampleBatch], Dict[str, TensorType]]]): Optional function that returns a dict of TF fetches given the policy and batch input tensors. grad_stats_fn (Optional[Callable[[Policy, SampleBatch, ModelGradients], Dict[str, TensorType]]]): Optional function that returns a dict of TF fetches given the policy, sample batch, and loss gradient tensors. before_loss_init (Optional[Callable[ [Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], None]]): Optional function to run prior to loss init that takes the same arguments as __init__. make_model (Optional[Callable[[Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], ModelV2]]): Optional function that returns a ModelV2 object given policy, obs_space, action_space, and policy config. All policy variables should be created in this function. If not specified, a default model will be created. action_sampler_fn (Optional[Callable[[Policy, ModelV2, Dict[ str, TensorType], TensorType, TensorType], Tuple[TensorType, TensorType]]]): A callable returning a sampled action and its log-likelihood given Policy, ModelV2, input_dict, explore, timestep, and is_training. action_distribution_fn (Optional[Callable[[Policy, ModelV2, Dict[str, TensorType], TensorType, TensorType], Tuple[TensorType, type, List[TensorType]]]]): A callable returning distribution inputs (parameters), a dist-class to generate an action distribution object from, and internal-state outputs (or an empty list if not applicable). Note: No Exploration hooks have to be called from within `action_distribution_fn`. It's should only perform a simple forward pass through some model. If None, pass inputs through `self.model()` to get distribution inputs. The callable takes as inputs: Policy, ModelV2, input_dict, explore, timestep, is_training. existing_inputs (Optional[Dict[str, tf1.placeholder]]): When copying a policy, this specifies an existing dict of placeholders to use instead of defining new ones. existing_model (Optional[ModelV2]): When copying a policy, this specifies an existing model to clone and share weights with. get_batch_divisibility_req (Optional[Callable[[Policy], int]]): Optional callable that returns the divisibility requirement for sample batches. If None, will assume a value of 1. obs_include_prev_action_reward (bool): Whether to include the previous action and reward in the model input (default: True). """ self.observation_space = obs_space self.action_space = action_space self.config = config self.framework = "tf" self._loss_fn = loss_fn self._stats_fn = stats_fn self._grad_stats_fn = grad_stats_fn self._obs_include_prev_action_reward = obs_include_prev_action_reward dist_class = dist_inputs = None if action_sampler_fn or action_distribution_fn: if not make_model: raise ValueError( "`make_model` is required if `action_sampler_fn` OR " "`action_distribution_fn` is given") else: dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"]) # Setup self.model. if existing_model: if isinstance(existing_model, list): self.model = existing_model[0] # TODO: (sven) hack, but works for `target_[q_]?model`. for i in range(1, len(existing_model)): setattr(self, existing_model[i][0], existing_model[i][1]) elif make_model: self.model = make_model(self, obs_space, action_space, config) else: self.model = ModelCatalog.get_model_v2( obs_space=obs_space, action_space=action_space, num_outputs=logit_dim, model_config=self.config["model"], framework="tf") # Auto-update model's inference view requirements, if recurrent. self._update_model_view_requirements_from_init_state() if existing_inputs: self._state_inputs = [ v for k, v in existing_inputs.items() if k.startswith("state_in_") ] if self._state_inputs: self._seq_lens = existing_inputs["seq_lens"] else: if self.config["_use_trajectory_view_api"]: self._state_inputs = [ get_placeholder( space=vr.space, time_axis=not isinstance(vr.shift, int), ) for k, vr in self.model.view_requirements.items() if k.startswith("state_in_") ] else: self._state_inputs = [ tf1.placeholder(shape=(None, ) + s.shape, dtype=s.dtype) for s in self.model.get_initial_state() ] # Use default settings. # Add NEXT_OBS, STATE_IN_0.., and others. self.view_requirements = self._get_default_view_requirements() # Combine view_requirements for Model and Policy. self.view_requirements.update(self.model.view_requirements) # Setup standard placeholders. if existing_inputs is not None: timestep = existing_inputs["timestep"] explore = existing_inputs["is_exploring"] self._input_dict, self._dummy_batch = \ self._get_input_dict_and_dummy_batch( self.view_requirements, existing_inputs) else: action_ph = ModelCatalog.get_action_placeholder(action_space) if self.config["_use_trajectory_view_api"]: prev_action_ph = {} if SampleBatch.PREV_ACTIONS not in self.view_requirements: prev_action_ph = { SampleBatch.PREV_ACTIONS: ModelCatalog.get_action_placeholder( action_space, "prev_action") } self._input_dict, self._dummy_batch = \ self._get_input_dict_and_dummy_batch( self.view_requirements, dict({SampleBatch.ACTIONS: action_ph}, **prev_action_ph)) else: prev_action_ph = ModelCatalog.get_action_placeholder( action_space, "prev_action") self._input_dict = { SampleBatch.CUR_OBS: tf1.placeholder(tf.float32, shape=[None] + list(obs_space.shape), name="observation") } self._input_dict[SampleBatch.ACTIONS] = action_ph if self._obs_include_prev_action_reward: self._input_dict.update({ SampleBatch.PREV_ACTIONS: prev_action_ph, SampleBatch.PREV_REWARDS: tf1.placeholder(tf.float32, [None], name="prev_reward"), }) # Placeholder for (sampling steps) timestep (int). timestep = tf1.placeholder_with_default(tf.zeros((), dtype=tf.int64), (), name="timestep") # Placeholder for `is_exploring` flag. explore = tf1.placeholder_with_default(True, (), name="is_exploring") # Placeholder for RNN time-chunk valid lengths. self._seq_lens = tf1.placeholder(dtype=tf.int32, shape=[None], name="seq_lens") # Placeholder for `is_training` flag. self._input_dict["is_training"] = self._get_is_training_placeholder() # Create the Exploration object to use for this Policy. self.exploration = self._create_exploration() # Fully customized action generation (e.g., custom policy). if action_sampler_fn: sampled_action, sampled_action_logp = action_sampler_fn( self, self.model, obs_batch=self._input_dict[SampleBatch.CUR_OBS], state_batches=self._state_inputs, seq_lens=self._seq_lens, prev_action_batch=self._input_dict.get( SampleBatch.PREV_ACTIONS), prev_reward_batch=self._input_dict.get( SampleBatch.PREV_REWARDS), explore=explore, is_training=self._input_dict["is_training"]) # Distribution generation is customized, e.g., DQN, DDPG. else: if action_distribution_fn: # Try new action_distribution_fn signature, supporting # state_batches and seq_lens. in_dict = self._input_dict try: dist_inputs, dist_class, self._state_out = \ action_distribution_fn( self, self.model, input_dict=in_dict, state_batches=self._state_inputs, seq_lens=self._seq_lens, explore=explore, timestep=timestep, is_training=in_dict["is_training"]) # Trying the old way (to stay backward compatible). # TODO: Remove in future. except TypeError as e: if "positional argument" in e.args[0] or \ "unexpected keyword argument" in e.args[0]: dist_inputs, dist_class, self._state_out = \ action_distribution_fn( self, self.model, obs_batch=in_dict[SampleBatch.CUR_OBS], state_batches=self._state_inputs, seq_lens=self._seq_lens, prev_action_batch=in_dict.get( SampleBatch.PREV_ACTIONS), prev_reward_batch=in_dict.get( SampleBatch.PREV_REWARDS), explore=explore, is_training=in_dict["is_training"]) else: raise e # Default distribution generation behavior: # Pass through model. E.g., PG, PPO. else: dist_inputs, self._state_out = self.model( self._input_dict, self._state_inputs, self._seq_lens) action_dist = dist_class(dist_inputs, self.model) # Using exploration to get final action (e.g. via sampling). sampled_action, sampled_action_logp = \ self.exploration.get_exploration_action( action_distribution=action_dist, timestep=timestep, explore=explore) # Phase 1 init. sess = tf1.get_default_session() or tf1.Session() batch_divisibility_req = get_batch_divisibility_req(self) if \ callable(get_batch_divisibility_req) else \ (get_batch_divisibility_req or 1) super().__init__( observation_space=obs_space, action_space=action_space, config=config, sess=sess, obs_input=self._input_dict[SampleBatch.OBS], action_input=self._input_dict[SampleBatch.ACTIONS], sampled_action=sampled_action, sampled_action_logp=sampled_action_logp, dist_inputs=dist_inputs, dist_class=dist_class, loss=None, # dynamically initialized on run loss_inputs=[], model=self.model, state_inputs=self._state_inputs, state_outputs=self._state_out, prev_action_input=self._input_dict.get(SampleBatch.PREV_ACTIONS), prev_reward_input=self._input_dict.get(SampleBatch.PREV_REWARDS), seq_lens=self._seq_lens, max_seq_len=config["model"]["max_seq_len"], batch_divisibility_req=batch_divisibility_req, explore=explore, timestep=timestep) # Phase 2 init. if before_loss_init is not None: before_loss_init(self, obs_space, action_space, config) # Loss initialization and model/postprocessing test calls. if not existing_inputs: self._initialize_loss_from_dummy_batch( auto_remove_unneeded_view_reqs=True)
def _initialize_loss_from_dummy_batch( self, auto_remove_unneeded_view_reqs: bool = True, stats_fn=None) -> None: # Create the optimizer/exploration optimizer here. Some initialization # steps (e.g. exploration postprocessing) may need this. self._optimizer = self.optimizer() # Test calls depend on variable init, so initialize model first. self._sess.run(tf1.global_variables_initializer()) if self.config["_use_trajectory_view_api"]: logger.info("Testing `compute_actions` w/ dummy batch.") actions, state_outs, extra_fetches = \ self.compute_actions_from_input_dict( self._dummy_batch, explore=False, timestep=0) for key, value in extra_fetches.items(): self._dummy_batch[key] = np.zeros_like(value) self._input_dict[key] = get_placeholder(value=value, name=key) if key not in self.view_requirements: logger.info("Adding extra-action-fetch `{}` to " "view-reqs.".format(key)) self.view_requirements[key] = \ ViewRequirement(space=gym.spaces.Box( -1.0, 1.0, shape=value.shape[1:], dtype=value.dtype)) dummy_batch = self._dummy_batch else: def fake_array(tensor): shape = tensor.shape.as_list() shape = [s if s is not None else 1 for s in shape] return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype) dummy_batch = { SampleBatch.CUR_OBS: fake_array(self._obs_input), SampleBatch.NEXT_OBS: fake_array(self._obs_input), SampleBatch.DONES: np.array([False], dtype=np.bool), SampleBatch.ACTIONS: fake_array( ModelCatalog.get_action_placeholder(self.action_space)), SampleBatch.REWARDS: np.array([0], dtype=np.float32), } if self._obs_include_prev_action_reward: dummy_batch.update({ SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input), SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input), }) state_init = self.get_initial_state() state_batches = [] for i, h in enumerate(state_init): dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0) dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0) state_batches.append(np.expand_dims(h, 0)) if state_init: dummy_batch["seq_lens"] = np.array([1], dtype=np.int32) for k, v in self.extra_compute_action_fetches().items(): dummy_batch[k] = fake_array(v) dummy_batch = SampleBatch(dummy_batch) batch_for_postproc = UsageTrackingDict(dummy_batch) batch_for_postproc.count = dummy_batch.count logger.info("Testing `postprocess_trajectory` w/ dummy batch.") self.exploration.postprocess_trajectory(self, batch_for_postproc, self._sess) postprocessed_batch = self.postprocess_trajectory(batch_for_postproc) # Add new columns automatically to (loss) input_dict. if self.config["_use_trajectory_view_api"]: for key in batch_for_postproc.added_keys: if key not in self._input_dict: self._input_dict[key] = get_placeholder( value=batch_for_postproc[key], name=key) if key not in self.view_requirements: self.view_requirements[key] = \ ViewRequirement(space=gym.spaces.Box( -1.0, 1.0, shape=batch_for_postproc[key].shape[1:], dtype=batch_for_postproc[key].dtype)) if not self.config["_use_trajectory_view_api"]: train_batch = UsageTrackingDict( dict({ SampleBatch.CUR_OBS: self._obs_input, }, **self._loss_input_dict)) if self._obs_include_prev_action_reward: train_batch.update({ SampleBatch.PREV_ACTIONS: self._prev_action_input, SampleBatch.PREV_REWARDS: self._prev_reward_input, SampleBatch.CUR_OBS: self._obs_input, }) for k, v in postprocessed_batch.items(): if k in train_batch: continue elif v.dtype == np.object: continue # can't handle arbitrary objects in TF elif k == "seq_lens" or k.startswith("state_in_"): continue shape = (None, ) + v.shape[1:] dtype = np.float32 if v.dtype == np.float64 else v.dtype placeholder = tf1.placeholder(dtype, shape=shape, name=k) train_batch[k] = placeholder for i, si in enumerate(self._state_inputs): train_batch["state_in_{}".format(i)] = si else: train_batch = UsageTrackingDict( dict(self._input_dict, **self._loss_input_dict)) if self._state_inputs: train_batch["seq_lens"] = self._seq_lens if log_once("loss_init"): logger.debug( "Initializing loss function with dummy input:\n\n{}\n".format( summarize(train_batch))) self._loss_input_dict.update({k: v for k, v in train_batch.items()}) loss = self._do_loss_init(train_batch) all_accessed_keys = \ train_batch.accessed_keys | batch_for_postproc.accessed_keys | \ batch_for_postproc.added_keys | set( self.model.view_requirements.keys()) TFPolicy._initialize_loss( self, loss, [(k, v) for k, v in train_batch.items() if k in all_accessed_keys]) if "is_training" in self._loss_input_dict: del self._loss_input_dict["is_training"] # Call the grads stats fn. # TODO: (sven) rename to simply stats_fn to match eager and torch. if self._grad_stats_fn: self._stats_fetches.update( self._grad_stats_fn(self, train_batch, self._grads)) # Add new columns automatically to view-reqs. if self.config["_use_trajectory_view_api"] and \ auto_remove_unneeded_view_reqs: # Add those needed for postprocessing and training. all_accessed_keys = train_batch.accessed_keys | \ batch_for_postproc.accessed_keys # Tag those only needed for post-processing (with some exceptions). for key in batch_for_postproc.accessed_keys: if key not in train_batch.accessed_keys and \ key not in self.model.view_requirements and \ key not in [ SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, SampleBatch.UNROLL_ID, SampleBatch.DONES, SampleBatch.REWARDS, SampleBatch.INFOS]: if key in self.view_requirements: self.view_requirements[key].used_for_training = False if key in self._loss_input_dict: del self._loss_input_dict[key] # Remove those not needed at all (leave those that are needed # by Sampler to properly execute sample collection). # Also always leave DONES, REWARDS, and INFOS, no matter what. for key in list(self.view_requirements.keys()): if key not in all_accessed_keys and key not in [ SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, SampleBatch.UNROLL_ID, SampleBatch.DONES, SampleBatch.REWARDS, SampleBatch.INFOS] and \ key not in self.model.view_requirements: # If user deleted this key manually in postprocessing # fn, warn about it and do not remove from # view-requirements. if key in batch_for_postproc.deleted_keys: logger.warning( "SampleBatch key '{}' was deleted manually in " "postprocessing function! RLlib will " "automatically remove non-used items from the " "data stream. Remove the `del` from your " "postprocessing function.".format(key)) else: del self.view_requirements[key] if key in self._loss_input_dict: del self._loss_input_dict[key] # Add those data_cols (again) that are missing and have # dependencies by view_cols. for key in list(self.view_requirements.keys()): vr = self.view_requirements[key] if (vr.data_col is not None and vr.data_col not in self.view_requirements): used_for_training = \ vr.data_col in train_batch.accessed_keys self.view_requirements[vr.data_col] = ViewRequirement( space=vr.space, used_for_training=used_for_training) self._loss_input_dict_no_rnn = { k: v for k, v in self._loss_input_dict.items() if (v not in self._state_inputs and v != self._seq_lens) } # Initialize again after loss init. self._sess.run(tf1.global_variables_initializer())
def _initialize_loss_from_dummy_batch( self, auto_remove_unneeded_view_reqs: bool = True, stats_fn=None) -> None: # Create the optimizer/exploration optimizer here. Some initialization # steps (e.g. exploration postprocessing) may need this. self._optimizer = self.optimizer() # Test calls depend on variable init, so initialize model first. self.get_session().run(tf1.global_variables_initializer()) logger.info("Testing `compute_actions` w/ dummy batch.") actions, state_outs, extra_fetches = \ self.compute_actions_from_input_dict( self._dummy_batch, explore=False, timestep=0) for key, value in extra_fetches.items(): self._dummy_batch[key] = value self._input_dict[key] = get_placeholder(value=value, name=key) if key not in self.view_requirements: logger.info("Adding extra-action-fetch `{}` to " "view-reqs.".format(key)) self.view_requirements[key] = ViewRequirement( space=gym.spaces.Box(-1.0, 1.0, shape=value.shape[1:], dtype=value.dtype), used_for_compute_actions=False, ) dummy_batch = self._dummy_batch logger.info("Testing `postprocess_trajectory` w/ dummy batch.") self.exploration.postprocess_trajectory(self, dummy_batch, self.get_session()) _ = self.postprocess_trajectory(dummy_batch) # Add new columns automatically to (loss) input_dict. for key in dummy_batch.added_keys: if key not in self._input_dict: self._input_dict[key] = get_placeholder(value=dummy_batch[key], name=key) if key not in self.view_requirements: self.view_requirements[key] = ViewRequirement( space=gym.spaces.Box(-1.0, 1.0, shape=dummy_batch[key].shape[1:], dtype=dummy_batch[key].dtype), used_for_compute_actions=False, ) train_batch = SampleBatch( dict(self._input_dict, **self._loss_input_dict)) if self._state_inputs: train_batch["seq_lens"] = self._seq_lens self._loss_input_dict.update({"seq_lens": train_batch["seq_lens"]}) self._loss_input_dict.update({k: v for k, v in train_batch.items()}) if log_once("loss_init"): logger.debug( "Initializing loss function with dummy input:\n\n{}\n".format( summarize(train_batch))) loss = self._do_loss_init(train_batch) all_accessed_keys = \ train_batch.accessed_keys | dummy_batch.accessed_keys | \ dummy_batch.added_keys | set( self.model.view_requirements.keys()) TFPolicy._initialize_loss( self, loss, [(k, v) for k, v in train_batch.items() if k in all_accessed_keys] + ([("seq_lens", train_batch["seq_lens"])] if "seq_lens" in train_batch else [])) if "is_training" in self._loss_input_dict: del self._loss_input_dict["is_training"] # Call the grads stats fn. # TODO: (sven) rename to simply stats_fn to match eager and torch. if self._grad_stats_fn: self._stats_fetches.update( self._grad_stats_fn(self, train_batch, self._grads)) # Add new columns automatically to view-reqs. if auto_remove_unneeded_view_reqs: # Add those needed for postprocessing and training. all_accessed_keys = train_batch.accessed_keys | \ dummy_batch.accessed_keys # Tag those only needed for post-processing (with some exceptions). for key in dummy_batch.accessed_keys: if key not in train_batch.accessed_keys and \ key not in self.model.view_requirements and \ key not in [ SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, SampleBatch.UNROLL_ID, SampleBatch.DONES, SampleBatch.REWARDS, SampleBatch.INFOS]: if key in self.view_requirements: self.view_requirements[key].used_for_training = False if key in self._loss_input_dict: del self._loss_input_dict[key] # Remove those not needed at all (leave those that are needed # by Sampler to properly execute sample collection). # Also always leave DONES, REWARDS, and INFOS, no matter what. for key in list(self.view_requirements.keys()): if key not in all_accessed_keys and key not in [ SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, SampleBatch.UNROLL_ID, SampleBatch.DONES, SampleBatch.REWARDS, SampleBatch.INFOS] and \ key not in self.model.view_requirements: # If user deleted this key manually in postprocessing # fn, warn about it and do not remove from # view-requirements. if key in dummy_batch.deleted_keys: logger.warning( "SampleBatch key '{}' was deleted manually in " "postprocessing function! RLlib will " "automatically remove non-used items from the " "data stream. Remove the `del` from your " "postprocessing function.".format(key)) else: del self.view_requirements[key] if key in self._loss_input_dict: del self._loss_input_dict[key] # Add those data_cols (again) that are missing and have # dependencies by view_cols. for key in list(self.view_requirements.keys()): vr = self.view_requirements[key] if (vr.data_col is not None and vr.data_col not in self.view_requirements): used_for_training = \ vr.data_col in train_batch.accessed_keys self.view_requirements[vr.data_col] = ViewRequirement( space=vr.space, used_for_training=used_for_training) self._loss_input_dict_no_rnn = { k: v for k, v in self._loss_input_dict.items() if (v not in self._state_inputs and v != self._seq_lens) } # Initialize again after loss init. self.get_session().run(tf1.global_variables_initializer())
def __init__( self, obs_space: gym.spaces.Space, action_space: gym.spaces.Space, config: TrainerConfigDict, loss_fn: Callable[[ Policy, ModelV2, Type[TFActionDistribution], SampleBatch ], TensorType], *, stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[ str, TensorType]]] = None, grad_stats_fn: Optional[Callable[[ Policy, SampleBatch, ModelGradients ], Dict[str, TensorType]]] = None, before_loss_init: Optional[Callable[[ Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict ], None]] = None, make_model: Optional[Callable[[ Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict ], ModelV2]] = None, action_sampler_fn: Optional[Callable[[ TensorType, List[TensorType] ], Tuple[TensorType, TensorType]]] = None, action_distribution_fn: Optional[Callable[[ Policy, ModelV2, TensorType, TensorType, TensorType ], Tuple[TensorType, type, List[TensorType]]]] = None, existing_inputs: Optional[Dict[str, "tf1.placeholder"]] = None, existing_model: Optional[ModelV2] = None, get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None, obs_include_prev_action_reward=DEPRECATED_VALUE): """Initializes a DynamicTFPolicy instance. Args: observation_space (gym.spaces.Space): Observation space of the policy. action_space (gym.spaces.Space): Action space of the policy. config (TrainerConfigDict): Policy-specific configuration data. loss_fn (Callable[[Policy, ModelV2, Type[TFActionDistribution], SampleBatch], TensorType]): Function that returns a loss tensor for the policy graph. stats_fn (Optional[Callable[[Policy, SampleBatch], Dict[str, TensorType]]]): Optional function that returns a dict of TF fetches given the policy and batch input tensors. grad_stats_fn (Optional[Callable[[Policy, SampleBatch, ModelGradients], Dict[str, TensorType]]]): Optional function that returns a dict of TF fetches given the policy, sample batch, and loss gradient tensors. before_loss_init (Optional[Callable[ [Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], None]]): Optional function to run prior to loss init that takes the same arguments as __init__. make_model (Optional[Callable[[Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], ModelV2]]): Optional function that returns a ModelV2 object given policy, obs_space, action_space, and policy config. All policy variables should be created in this function. If not specified, a default model will be created. action_sampler_fn (Optional[Callable[[Policy, ModelV2, Dict[ str, TensorType], TensorType, TensorType], Tuple[TensorType, TensorType]]]): A callable returning a sampled action and its log-likelihood given Policy, ModelV2, input_dict, explore, timestep, and is_training. action_distribution_fn (Optional[Callable[[Policy, ModelV2, Dict[str, TensorType], TensorType, TensorType], Tuple[TensorType, type, List[TensorType]]]]): A callable returning distribution inputs (parameters), a dist-class to generate an action distribution object from, and internal-state outputs (or an empty list if not applicable). Note: No Exploration hooks have to be called from within `action_distribution_fn`. It's should only perform a simple forward pass through some model. If None, pass inputs through `self.model()` to get distribution inputs. The callable takes as inputs: Policy, ModelV2, input_dict, explore, timestep, is_training. existing_inputs (Optional[Dict[str, tf1.placeholder]]): When copying a policy, this specifies an existing dict of placeholders to use instead of defining new ones. existing_model (Optional[ModelV2]): When copying a policy, this specifies an existing model to clone and share weights with. get_batch_divisibility_req (Optional[Callable[[Policy], int]]): Optional callable that returns the divisibility requirement for sample batches. If None, will assume a value of 1. """ if obs_include_prev_action_reward != DEPRECATED_VALUE: deprecation_warning( old="obs_include_prev_action_reward", error=False) self.observation_space = obs_space self.action_space = action_space self.config = config self.framework = "tf" self._loss_fn = loss_fn self._stats_fn = stats_fn self._grad_stats_fn = grad_stats_fn self._seq_lens = None self._is_tower = existing_inputs is not None dist_class = None if action_sampler_fn or action_distribution_fn: if not make_model: raise ValueError( "`make_model` is required if `action_sampler_fn` OR " "`action_distribution_fn` is given") else: dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"]) # Setup self.model. if existing_model: if isinstance(existing_model, list): self.model = existing_model[0] # TODO: (sven) hack, but works for `target_[q_]?model`. for i in range(1, len(existing_model)): setattr(self, existing_model[i][0], existing_model[i][1]) elif make_model: self.model = make_model(self, obs_space, action_space, config) else: self.model = ModelCatalog.get_model_v2( obs_space=obs_space, action_space=action_space, num_outputs=logit_dim, model_config=self.config["model"], framework="tf") # Auto-update model's inference view requirements, if recurrent. self._update_model_view_requirements_from_init_state() # Input placeholders already given -> Use these. if existing_inputs: self._state_inputs = [ v for k, v in existing_inputs.items() if k.startswith("state_in_") ] # Placeholder for RNN time-chunk valid lengths. if self._state_inputs: self._seq_lens = existing_inputs[SampleBatch.SEQ_LENS] # Create new input placeholders. else: self._state_inputs = [ get_placeholder( space=vr.space, time_axis=not isinstance(vr.shift, int), ) for k, vr in self.model.view_requirements.items() if k.startswith("state_in_") ] # Placeholder for RNN time-chunk valid lengths. if self._state_inputs: self._seq_lens = tf1.placeholder( dtype=tf.int32, shape=[None], name="seq_lens") # Use default settings. # Add NEXT_OBS, STATE_IN_0.., and others. self.view_requirements = self._get_default_view_requirements() # Combine view_requirements for Model and Policy. self.view_requirements.update(self.model.view_requirements) # Disable env-info placeholder. if SampleBatch.INFOS in self.view_requirements: self.view_requirements[SampleBatch.INFOS].used_for_training = False # Setup standard placeholders. if self._is_tower: timestep = existing_inputs["timestep"] explore = False self._input_dict, self._dummy_batch = \ self._get_input_dict_and_dummy_batch( self.view_requirements, existing_inputs) else: action_ph = ModelCatalog.get_action_placeholder(action_space) prev_action_ph = {} if SampleBatch.PREV_ACTIONS not in self.view_requirements: prev_action_ph = { SampleBatch.PREV_ACTIONS: ModelCatalog. get_action_placeholder(action_space, "prev_action") } self._input_dict, self._dummy_batch = \ self._get_input_dict_and_dummy_batch( self.view_requirements, dict({SampleBatch.ACTIONS: action_ph}, **prev_action_ph)) # Placeholder for (sampling steps) timestep (int). timestep = tf1.placeholder_with_default( tf.zeros((), dtype=tf.int64), (), name="timestep") # Placeholder for `is_exploring` flag. explore = tf1.placeholder_with_default( True, (), name="is_exploring") # Placeholder for `is_training` flag. self._input_dict.is_training = self._get_is_training_placeholder() # Multi-GPU towers do not need any action computing/exploration # graphs. sampled_action = None sampled_action_logp = None dist_inputs = None self._state_out = None if not self._is_tower: # Create the Exploration object to use for this Policy. self.exploration = self._create_exploration() # Fully customized action generation (e.g., custom policy). if action_sampler_fn: sampled_action, sampled_action_logp = action_sampler_fn( self, self.model, obs_batch=self._input_dict[SampleBatch.CUR_OBS], state_batches=self._state_inputs, seq_lens=self._seq_lens, prev_action_batch=self._input_dict.get( SampleBatch.PREV_ACTIONS), prev_reward_batch=self._input_dict.get( SampleBatch.PREV_REWARDS), explore=explore, is_training=self._input_dict.is_training) # Distribution generation is customized, e.g., DQN, DDPG. else: if action_distribution_fn: # Try new action_distribution_fn signature, supporting # state_batches and seq_lens. in_dict = self._input_dict try: dist_inputs, dist_class, self._state_out = \ action_distribution_fn( self, self.model, input_dict=in_dict, state_batches=self._state_inputs, seq_lens=self._seq_lens, explore=explore, timestep=timestep, is_training=in_dict.is_training) # Trying the old way (to stay backward compatible). # TODO: Remove in future. except TypeError as e: if "positional argument" in e.args[0] or \ "unexpected keyword argument" in e.args[0]: dist_inputs, dist_class, self._state_out = \ action_distribution_fn( self, self.model, obs_batch=in_dict[SampleBatch.CUR_OBS], state_batches=self._state_inputs, seq_lens=self._seq_lens, prev_action_batch=in_dict.get( SampleBatch.PREV_ACTIONS), prev_reward_batch=in_dict.get( SampleBatch.PREV_REWARDS), explore=explore, is_training=in_dict.is_training) else: raise e # Default distribution generation behavior: # Pass through model. E.g., PG, PPO. else: if isinstance(self.model, tf.keras.Model): dist_inputs, self._state_out, \ self._extra_action_fetches = \ self.model(self._input_dict) else: dist_inputs, self._state_out = self.model( self._input_dict, self._state_inputs, self._seq_lens) action_dist = dist_class(dist_inputs, self.model) # Using exploration to get final action (e.g. via sampling). sampled_action, sampled_action_logp = \ self.exploration.get_exploration_action( action_distribution=action_dist, timestep=timestep, explore=explore) # Phase 1 init. sess = tf1.get_default_session() or tf1.Session( config=tf1.ConfigProto(**self.config["tf_session_args"])) batch_divisibility_req = get_batch_divisibility_req(self) if \ callable(get_batch_divisibility_req) else \ (get_batch_divisibility_req or 1) super().__init__( observation_space=obs_space, action_space=action_space, config=config, sess=sess, obs_input=self._input_dict[SampleBatch.OBS], action_input=self._input_dict[SampleBatch.ACTIONS], sampled_action=sampled_action, sampled_action_logp=sampled_action_logp, dist_inputs=dist_inputs, dist_class=dist_class, loss=None, # dynamically initialized on run loss_inputs=[], model=self.model, state_inputs=self._state_inputs, state_outputs=self._state_out, prev_action_input=self._input_dict.get(SampleBatch.PREV_ACTIONS), prev_reward_input=self._input_dict.get(SampleBatch.PREV_REWARDS), seq_lens=self._seq_lens, max_seq_len=config["model"]["max_seq_len"], batch_divisibility_req=batch_divisibility_req, explore=explore, timestep=timestep) # Phase 2 init. if before_loss_init is not None: before_loss_init(self, obs_space, action_space, config) # Loss initialization and model/postprocessing test calls. if not self._is_tower: self._initialize_loss_from_dummy_batch( auto_remove_unneeded_view_reqs=True) # Create MultiGPUTowerStacks, if we have at least one actual # GPU or >1 CPUs (fake GPUs). if len(self.devices) > 1 or any("gpu" in d for d in self.devices): # Per-GPU graph copies created here must share vars with the # policy. Therefore, `reuse` is set to tf1.AUTO_REUSE because # Adam nodes are created after all of the device copies are # created. with tf1.variable_scope("", reuse=tf1.AUTO_REUSE): self.multi_gpu_tower_stacks = [ TFMultiGPUTowerStack(policy=self) for i in range( self.config.get("num_multi_gpu_tower_stacks", 1)) ] # Initialize again after loss and tower init. self.get_session().run(tf1.global_variables_initializer())