def copy(self, existing_inputs: List[Tuple[str, "tf1.placeholder"]]) -> TFPolicy: """Creates a copy of self using existing input placeholders.""" # Note that there might be RNN state inputs at the end of the list if len(self._loss_input_dict) != len(existing_inputs): raise ValueError("Tensor list mismatch", self._loss_input_dict, self._state_inputs, existing_inputs) for i, (k, v) in enumerate(self._loss_input_dict_no_rnn.items()): if v.shape.as_list() != existing_inputs[i].shape.as_list(): raise ValueError("Tensor shape mismatch", i, k, v.shape, existing_inputs[i].shape) # By convention, the loss inputs are followed by state inputs and then # the seq len tensor. rnn_inputs = [] for i in range(len(self._state_inputs)): rnn_inputs.append( ("state_in_{}".format(i), existing_inputs[len(self._loss_input_dict_no_rnn) + i])) if rnn_inputs: rnn_inputs.append((SampleBatch.SEQ_LENS, existing_inputs[-1])) input_dict = OrderedDict( [("is_exploring", self._is_exploring), ("timestep", self._timestep)] + [(k, existing_inputs[i]) for i, k in enumerate(self._loss_input_dict_no_rnn.keys())] + rnn_inputs) instance = self.__class__( self.observation_space, self.action_space, self.config, existing_inputs=input_dict, existing_model=[ self.model, # Deprecated: Target models should all reside under # `policy.target_model` now. ("target_q_model", getattr(self, "target_q_model", None)), ("target_model", getattr(self, "target_model", None)), ]) instance._loss_input_dict = input_dict loss = instance._do_loss_init(SampleBatch(input_dict)) loss_inputs = [ (k, existing_inputs[i]) for i, k in enumerate(self._loss_input_dict_no_rnn.keys()) ] TFPolicy._initialize_loss(instance, loss, loss_inputs) if instance._grad_stats_fn: instance._stats_fetches.update( instance._grad_stats_fn(instance, input_dict, instance._grads)) return instance
def copy(self, existing_inputs): """Creates a copy of self using existing input placeholders.""" # self._sess.run(tf.global_variables_initializer()) # Note that there might be RNN state inputs at the end of the list if self._state_inputs: num_state_inputs = len(self._state_inputs) + 1 else: num_state_inputs = 0 if len(self._loss_inputs) + num_state_inputs != len(existing_inputs): raise ValueError("Tensor list mismatch", self._loss_inputs, self._state_inputs, existing_inputs) for i, (k, v) in enumerate(self._loss_inputs): if v.shape.as_list() != existing_inputs[i].shape.as_list(): raise ValueError("Tensor shape mismatch", i, k, v.shape, existing_inputs[i].shape) # By convention, the loss inputs are followed by state inputs and then # the seq len tensor rnn_inputs = [] for i in range(len(self._state_inputs)): rnn_inputs.append(("state_in_{}".format(i), existing_inputs[len(self._loss_inputs) + i])) if rnn_inputs: rnn_inputs.append(("seq_lens", existing_inputs[-1])) input_dict = OrderedDict([(k, existing_inputs[i]) for i, (k, _) in enumerate(self._loss_inputs)] + rnn_inputs) instance = self.__class__(self.observation_space, self.action_space, self.config, existing_inputs=input_dict, existing_model=self.model) self._sess.run(tf.global_variables_initializer()) instance._loss_input_dict = input_dict # self._sess.run(tf.global_variables_initializer()) loss = instance._do_loss_init(input_dict) loss_inputs = [(k, existing_inputs[i]) for i, (k, _) in enumerate(self._loss_inputs)] TFPolicy._initialize_loss(instance, loss, loss_inputs) if instance._grad_stats_fn: instance._stats_fetches.update( instance._grad_stats_fn(instance, input_dict, instance._grads)) return instance
def copy(self, existing_inputs): """Creates a copy of self using existing input placeholders.""" if self.config["use_eager"]: raise ValueError( "eager not implemented for multi-GPU, try setting " "`simple_optimizer: true`") # Note that there might be RNN state inputs at the end of the list if self._state_inputs: num_state_inputs = len(self._state_inputs) + 1 else: num_state_inputs = 0 if len(self._loss_inputs) + num_state_inputs != len(existing_inputs): raise ValueError("Tensor list mismatch", self._loss_inputs, self._state_inputs, existing_inputs) for i, (k, v) in enumerate(self._loss_inputs): if v.shape.as_list() != existing_inputs[i].shape.as_list(): raise ValueError("Tensor shape mismatch", i, k, v.shape, existing_inputs[i].shape) # By convention, the loss inputs are followed by state inputs and then # the seq len tensor rnn_inputs = [] for i in range(len(self._state_inputs)): rnn_inputs.append(("state_in_{}".format(i), existing_inputs[len(self._loss_inputs) + i])) if rnn_inputs: rnn_inputs.append(("seq_lens", existing_inputs[-1])) input_dict = OrderedDict([(k, existing_inputs[i]) for i, (k, _) in enumerate(self._loss_inputs)] + rnn_inputs) instance = self.__class__(self.observation_space, self.action_space, self.config, existing_inputs=input_dict) loss = instance._do_loss_init(input_dict) TFPolicy._initialize_loss( instance, loss, [(k, existing_inputs[i]) for i, (k, _) in enumerate(self._loss_inputs)]) if instance._grad_stats_fn: instance._stats_fetches.update( instance._grad_stats_fn(instance, instance._grads)) return instance
def _initialize_loss_from_dummy_batch( self, auto_remove_unneeded_view_reqs: bool = True, stats_fn=None) -> None: # Create the optimizer/exploration optimizer here. Some initialization # steps (e.g. exploration postprocessing) may need this. self._optimizer = self.optimizer() # Test calls depend on variable init, so initialize model first. self._sess.run(tf1.global_variables_initializer()) if self.config["_use_trajectory_view_api"]: logger.info("Testing `compute_actions` w/ dummy batch.") actions, state_outs, extra_fetches = \ self.compute_actions_from_input_dict( self._dummy_batch, explore=False, timestep=0) for key, value in extra_fetches.items(): self._dummy_batch[key] = np.zeros_like(value) self._input_dict[key] = get_placeholder(value=value, name=key) if key not in self.view_requirements: logger.info("Adding extra-action-fetch `{}` to " "view-reqs.".format(key)) self.view_requirements[key] = \ ViewRequirement(space=gym.spaces.Box( -1.0, 1.0, shape=value.shape[1:], dtype=value.dtype)) dummy_batch = self._dummy_batch else: def fake_array(tensor): shape = tensor.shape.as_list() shape = [s if s is not None else 1 for s in shape] return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype) dummy_batch = { SampleBatch.CUR_OBS: fake_array(self._obs_input), SampleBatch.NEXT_OBS: fake_array(self._obs_input), SampleBatch.DONES: np.array([False], dtype=np.bool), SampleBatch.ACTIONS: fake_array( ModelCatalog.get_action_placeholder(self.action_space)), SampleBatch.REWARDS: np.array([0], dtype=np.float32), } if self._obs_include_prev_action_reward: dummy_batch.update({ SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input), SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input), }) state_init = self.get_initial_state() state_batches = [] for i, h in enumerate(state_init): dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0) dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0) state_batches.append(np.expand_dims(h, 0)) if state_init: dummy_batch["seq_lens"] = np.array([1], dtype=np.int32) for k, v in self.extra_compute_action_fetches().items(): dummy_batch[k] = fake_array(v) dummy_batch = SampleBatch(dummy_batch) batch_for_postproc = UsageTrackingDict(dummy_batch) batch_for_postproc.count = dummy_batch.count logger.info("Testing `postprocess_trajectory` w/ dummy batch.") self.exploration.postprocess_trajectory(self, batch_for_postproc, self._sess) postprocessed_batch = self.postprocess_trajectory(batch_for_postproc) # Add new columns automatically to (loss) input_dict. if self.config["_use_trajectory_view_api"]: for key in batch_for_postproc.added_keys: if key not in self._input_dict: self._input_dict[key] = get_placeholder( value=batch_for_postproc[key], name=key) if key not in self.view_requirements: self.view_requirements[key] = \ ViewRequirement(space=gym.spaces.Box( -1.0, 1.0, shape=batch_for_postproc[key].shape[1:], dtype=batch_for_postproc[key].dtype)) if not self.config["_use_trajectory_view_api"]: train_batch = UsageTrackingDict( dict({ SampleBatch.CUR_OBS: self._obs_input, }, **self._loss_input_dict)) if self._obs_include_prev_action_reward: train_batch.update({ SampleBatch.PREV_ACTIONS: self._prev_action_input, SampleBatch.PREV_REWARDS: self._prev_reward_input, SampleBatch.CUR_OBS: self._obs_input, }) for k, v in postprocessed_batch.items(): if k in train_batch: continue elif v.dtype == np.object: continue # can't handle arbitrary objects in TF elif k == "seq_lens" or k.startswith("state_in_"): continue shape = (None, ) + v.shape[1:] dtype = np.float32 if v.dtype == np.float64 else v.dtype placeholder = tf1.placeholder(dtype, shape=shape, name=k) train_batch[k] = placeholder for i, si in enumerate(self._state_inputs): train_batch["state_in_{}".format(i)] = si else: train_batch = UsageTrackingDict( dict(self._input_dict, **self._loss_input_dict)) if self._state_inputs: train_batch["seq_lens"] = self._seq_lens if log_once("loss_init"): logger.debug( "Initializing loss function with dummy input:\n\n{}\n".format( summarize(train_batch))) self._loss_input_dict.update({k: v for k, v in train_batch.items()}) loss = self._do_loss_init(train_batch) all_accessed_keys = \ train_batch.accessed_keys | batch_for_postproc.accessed_keys | \ batch_for_postproc.added_keys | set( self.model.view_requirements.keys()) TFPolicy._initialize_loss( self, loss, [(k, v) for k, v in train_batch.items() if k in all_accessed_keys]) if "is_training" in self._loss_input_dict: del self._loss_input_dict["is_training"] # Call the grads stats fn. # TODO: (sven) rename to simply stats_fn to match eager and torch. if self._grad_stats_fn: self._stats_fetches.update( self._grad_stats_fn(self, train_batch, self._grads)) # Add new columns automatically to view-reqs. if self.config["_use_trajectory_view_api"] and \ auto_remove_unneeded_view_reqs: # Add those needed for postprocessing and training. all_accessed_keys = train_batch.accessed_keys | \ batch_for_postproc.accessed_keys # Tag those only needed for post-processing (with some exceptions). for key in batch_for_postproc.accessed_keys: if key not in train_batch.accessed_keys and \ key not in self.model.view_requirements and \ key not in [ SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, SampleBatch.UNROLL_ID, SampleBatch.DONES, SampleBatch.REWARDS, SampleBatch.INFOS]: if key in self.view_requirements: self.view_requirements[key].used_for_training = False if key in self._loss_input_dict: del self._loss_input_dict[key] # Remove those not needed at all (leave those that are needed # by Sampler to properly execute sample collection). # Also always leave DONES, REWARDS, and INFOS, no matter what. for key in list(self.view_requirements.keys()): if key not in all_accessed_keys and key not in [ SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, SampleBatch.UNROLL_ID, SampleBatch.DONES, SampleBatch.REWARDS, SampleBatch.INFOS] and \ key not in self.model.view_requirements: # If user deleted this key manually in postprocessing # fn, warn about it and do not remove from # view-requirements. if key in batch_for_postproc.deleted_keys: logger.warning( "SampleBatch key '{}' was deleted manually in " "postprocessing function! RLlib will " "automatically remove non-used items from the " "data stream. Remove the `del` from your " "postprocessing function.".format(key)) else: del self.view_requirements[key] if key in self._loss_input_dict: del self._loss_input_dict[key] # Add those data_cols (again) that are missing and have # dependencies by view_cols. for key in list(self.view_requirements.keys()): vr = self.view_requirements[key] if (vr.data_col is not None and vr.data_col not in self.view_requirements): used_for_training = \ vr.data_col in train_batch.accessed_keys self.view_requirements[vr.data_col] = ViewRequirement( space=vr.space, used_for_training=used_for_training) self._loss_input_dict_no_rnn = { k: v for k, v in self._loss_input_dict.items() if (v not in self._state_inputs and v != self._seq_lens) } # Initialize again after loss init. self._sess.run(tf1.global_variables_initializer())
def _initialize_loss(self): def fake_array(tensor): shape = tensor.shape.as_list() shape[0] = 1 return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype) dummy_batch = { SampleBatch.CUR_OBS: fake_array(self._obs_input), SampleBatch.NEXT_OBS: fake_array(self._obs_input), SampleBatch.DONES: np.array([False], dtype=np.bool), SampleBatch.ACTIONS: fake_array(ModelCatalog.get_action_placeholder(self.action_space)), SampleBatch.REWARDS: np.array([0], dtype=np.float32), } if self._obs_include_prev_action_reward: dummy_batch.update({ SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input), SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input), }) state_init = self.get_initial_state() for i, h in enumerate(state_init): dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0) dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0) if state_init: dummy_batch["seq_lens"] = np.array([1], dtype=np.int32) for k, v in self.extra_compute_action_fetches().items(): dummy_batch[k] = fake_array(v) # postprocessing might depend on variable init, so run it first here self._sess.run(tf.global_variables_initializer()) postprocessed_batch = self.postprocess_trajectory( SampleBatch(dummy_batch)) if self._obs_include_prev_action_reward: batch_tensors = UsageTrackingDict({ SampleBatch.PREV_ACTIONS: self._prev_action_input, SampleBatch.PREV_REWARDS: self._prev_reward_input, SampleBatch.CUR_OBS: self._obs_input, }) loss_inputs = [ (SampleBatch.PREV_ACTIONS, self._prev_action_input), (SampleBatch.PREV_REWARDS, self._prev_reward_input), (SampleBatch.CUR_OBS, self._obs_input), ] else: batch_tensors = UsageTrackingDict({ SampleBatch.CUR_OBS: self._obs_input, }) loss_inputs = [ (SampleBatch.CUR_OBS, self._obs_input), ] for k, v in postprocessed_batch.items(): if k in batch_tensors: continue elif v.dtype == np.object: continue # can't handle arbitrary objects in TF shape = (None, ) + v.shape[1:] dtype = np.float32 if v.dtype == np.float64 else v.dtype placeholder = tf.placeholder(dtype, shape=shape, name=k) batch_tensors[k] = placeholder if log_once("loss_init"): logger.info( "Initializing loss function with dummy input:\n\n{}\n".format( summarize(batch_tensors))) loss = self._do_loss_init(batch_tensors) for k in sorted(batch_tensors.accessed_keys): loss_inputs.append((k, batch_tensors[k])) TFPolicy._initialize_loss(self, loss, loss_inputs) if self._grad_stats_fn: self._stats_fetches.update(self._grad_stats_fn(self, self._grads)) self._sess.run(tf.global_variables_initializer())
def _initialize_loss(self): def fake_array(tensor): shape = tensor.shape.as_list() shape = [s if s is not None else 1 for s in shape] return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype) dummy_batch = { SampleBatch.CUR_OBS: fake_array(self._obs_input), SampleBatch.NEXT_OBS: fake_array(self._obs_input), SampleBatch.DONES: np.array([False], dtype=np.bool), SampleBatch.ACTIONS: fake_array( ModelCatalog.get_action_placeholder(self.action_space)), SampleBatch.REWARDS: np.array([0], dtype=np.float32), } if self._obs_include_prev_action_reward: dummy_batch.update({ SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input), SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input), }) state_init = self.get_initial_state() state_batches = [] for i, h in enumerate(state_init): dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0) dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0) state_batches.append(np.expand_dims(h, 0)) if state_init: dummy_batch["seq_lens"] = np.array([1], dtype=np.int32) for k, v in self.extra_compute_action_fetches().items(): dummy_batch[k] = fake_array(v) # postprocessing might depend on variable init, so run it first here self._sess.run(tf.global_variables_initializer()) postprocessed_batch = self.postprocess_trajectory( SampleBatch(dummy_batch)) # model forward pass for the loss (needed after postprocess to # overwrite any tensor state from that call) self.model(self._input_dict, self._state_in, self._seq_lens) if self._obs_include_prev_action_reward: train_batch = UsageTrackingDict({ SampleBatch.PREV_ACTIONS: self._prev_action_input, SampleBatch.PREV_REWARDS: self._prev_reward_input, SampleBatch.CUR_OBS: self._obs_input, }) loss_inputs = [ (SampleBatch.PREV_ACTIONS, self._prev_action_input), (SampleBatch.PREV_REWARDS, self._prev_reward_input), (SampleBatch.CUR_OBS, self._obs_input), ] else: train_batch = UsageTrackingDict({ SampleBatch.CUR_OBS: self._obs_input, }) loss_inputs = [ (SampleBatch.CUR_OBS, self._obs_input), ] for k, v in postprocessed_batch.items(): if k in train_batch: continue elif v.dtype == np.object: continue # can't handle arbitrary objects in TF elif k == "seq_lens" or k.startswith("state_in_"): continue shape = (None, ) + v.shape[1:] dtype = np.float32 if v.dtype == np.float64 else v.dtype placeholder = tf.placeholder(dtype, shape=shape, name=k) train_batch[k] = placeholder for i, si in enumerate(self._state_in): train_batch["state_in_{}".format(i)] = si train_batch["seq_lens"] = self._seq_lens if log_once("loss_init"): logger.debug( "Initializing loss function with dummy input:\n\n{}\n".format( summarize(train_batch))) self._loss_input_dict = train_batch loss = self._do_loss_init(train_batch) for k in sorted(train_batch.accessed_keys): if k != "seq_lens" and not k.startswith("state_in_"): loss_inputs.append((k, train_batch[k])) TFPolicy._initialize_loss(self, loss, loss_inputs) if self._grad_stats_fn: self._stats_fetches.update( self._grad_stats_fn(self, train_batch, self._grads)) self._sess.run(tf.global_variables_initializer())
def _initialize_loss_from_dummy_batch( self, auto_remove_unneeded_view_reqs: bool = True) -> None: # Test calls depend on variable init, so initialize model first. self.get_session().run(tf1.global_variables_initializer()) # Fields that have not been accessed are not needed for action # computations -> Tag them as `used_for_compute_actions=False`. for key, view_req in self.view_requirements.items(): if (not key.startswith("state_in_") and key not in self._input_dict.accessed_keys): view_req.used_for_compute_actions = False for key, value in self.extra_action_out_fn().items(): self._dummy_batch[key] = get_dummy_batch_for_space( gym.spaces.Box(-1.0, 1.0, shape=value.shape.as_list()[1:], dtype=value.dtype.name), batch_size=len(self._dummy_batch), ) self._input_dict[key] = get_placeholder(value=value, name=key) if key not in self.view_requirements: logger.info( "Adding extra-action-fetch `{}` to view-reqs.".format(key)) self.view_requirements[key] = ViewRequirement( space=gym.spaces.Box(-1.0, 1.0, shape=value.shape[1:], dtype=value.dtype.name), used_for_compute_actions=False, ) dummy_batch = self._dummy_batch logger.info("Testing `postprocess_trajectory` w/ dummy batch.") self.exploration.postprocess_trajectory(self, dummy_batch, self.get_session()) _ = self.postprocess_trajectory(dummy_batch) # Add new columns automatically to (loss) input_dict. for key in dummy_batch.added_keys: if key not in self._input_dict: self._input_dict[key] = get_placeholder(value=dummy_batch[key], name=key) if key not in self.view_requirements: self.view_requirements[key] = ViewRequirement( space=gym.spaces.Box( -1.0, 1.0, shape=dummy_batch[key].shape[1:], dtype=dummy_batch[key].dtype, ), used_for_compute_actions=False, ) train_batch = SampleBatch( dict(self._input_dict, **self._loss_input_dict), _is_training=True, ) if self._state_inputs: train_batch[SampleBatch.SEQ_LENS] = self._seq_lens self._loss_input_dict.update( {SampleBatch.SEQ_LENS: train_batch[SampleBatch.SEQ_LENS]}) self._loss_input_dict.update({k: v for k, v in train_batch.items()}) if log_once("loss_init"): logger.debug( "Initializing loss function with dummy input:\n\n{}\n".format( summarize(train_batch))) losses = self._do_loss_init(train_batch) all_accessed_keys = (train_batch.accessed_keys | dummy_batch.accessed_keys | dummy_batch.added_keys | set(self.model.view_requirements.keys())) TFPolicy._initialize_loss( self, losses, [(k, v) for k, v in train_batch.items() if k in all_accessed_keys] + ([ (SampleBatch.SEQ_LENS, train_batch[SampleBatch.SEQ_LENS]) ] if SampleBatch.SEQ_LENS in train_batch else []), ) if "is_training" in self._loss_input_dict: del self._loss_input_dict["is_training"] # Call the grads stats fn. # TODO: (sven) rename to simply stats_fn to match eager and torch. self._stats_fetches.update(self.grad_stats_fn(train_batch, self._grads)) # Add new columns automatically to view-reqs. if auto_remove_unneeded_view_reqs: # Add those needed for postprocessing and training. all_accessed_keys = train_batch.accessed_keys | dummy_batch.accessed_keys # Tag those only needed for post-processing (with some exceptions). for key in dummy_batch.accessed_keys: if (key not in train_batch.accessed_keys and key not in self.model.view_requirements and key not in [ SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, SampleBatch.UNROLL_ID, SampleBatch.DONES, SampleBatch.REWARDS, SampleBatch.INFOS, SampleBatch.OBS_EMBEDS, ]): if key in self.view_requirements: self.view_requirements[key].used_for_training = False if key in self._loss_input_dict: del self._loss_input_dict[key] # Remove those not needed at all (leave those that are needed # by Sampler to properly execute sample collection). # Also always leave DONES, REWARDS, and INFOS, no matter what. for key in list(self.view_requirements.keys()): if (key not in all_accessed_keys and key not in [ SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, SampleBatch.UNROLL_ID, SampleBatch.DONES, SampleBatch.REWARDS, SampleBatch.INFOS, ] and key not in self.model.view_requirements): # If user deleted this key manually in postprocessing # fn, warn about it and do not remove from # view-requirements. if key in dummy_batch.deleted_keys: logger.warning( "SampleBatch key '{}' was deleted manually in " "postprocessing function! RLlib will " "automatically remove non-used items from the " "data stream. Remove the `del` from your " "postprocessing function.".format(key)) # If we are not writing output to disk, safe to erase # this key to save space in the sample batch. elif self.config["output"] is None: del self.view_requirements[key] if key in self._loss_input_dict: del self._loss_input_dict[key] # Add those data_cols (again) that are missing and have # dependencies by view_cols. for key in list(self.view_requirements.keys()): vr = self.view_requirements[key] if (vr.data_col is not None and vr.data_col not in self.view_requirements): used_for_training = vr.data_col in train_batch.accessed_keys self.view_requirements[vr.data_col] = ViewRequirement( space=vr.space, used_for_training=used_for_training) self._loss_input_dict_no_rnn = { k: v for k, v in self._loss_input_dict.items() if (v not in self._state_inputs and v != self._seq_lens) }
def _initialize_loss(self): def fake_array(tensor): shape = tensor.shape.as_list() shape[0] = 1 return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype) dummy_batch = { SampleBatch.CUR_OBS: fake_array(self._obs_input), SampleBatch.NEXT_OBS: fake_array(self._obs_input), SampleBatch.DONES: np.array([False], dtype=np.bool), SampleBatch.ACTIONS: fake_array(ModelCatalog.get_action_placeholder(self.action_space)), SampleBatch.REWARDS: np.array([0], dtype=np.float32), } if self._obs_include_prev_action_reward: dummy_batch.update({ SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input), SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input), }) state_init = self.get_initial_state() for i, h in enumerate(state_init): dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0) dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0) if state_init: dummy_batch["seq_lens"] = np.array([1], dtype=np.int32) for k, v in self.extra_compute_action_fetches().items(): dummy_batch[k] = fake_array(v) # postprocessing might depend on variable init, so run it first here self._sess.run(tf.global_variables_initializer()) postprocessed_batch = self.postprocess_trajectory( SampleBatch(dummy_batch)) if self._obs_include_prev_action_reward: batch_tensors = UsageTrackingDict({ SampleBatch.PREV_ACTIONS: self._prev_action_input, SampleBatch.PREV_REWARDS: self._prev_reward_input, SampleBatch.CUR_OBS: self._obs_input, }) loss_inputs = [ (SampleBatch.PREV_ACTIONS, self._prev_action_input), (SampleBatch.PREV_REWARDS, self._prev_reward_input), (SampleBatch.CUR_OBS, self._obs_input), ] else: batch_tensors = UsageTrackingDict({ SampleBatch.CUR_OBS: self._obs_input, }) loss_inputs = [ (SampleBatch.CUR_OBS, self._obs_input), ] for k, v in postprocessed_batch.items(): if k in batch_tensors: continue elif v.dtype == np.object: continue # can't handle arbitrary objects in TF shape = (None, ) + v.shape[1:] dtype = np.float32 if v.dtype == np.float64 else v.dtype placeholder = tf.placeholder(dtype, shape=shape, name=k) batch_tensors[k] = placeholder if log_once("loss_init"): logger.info( "Initializing loss function with dummy input:\n\n{}\n".format( summarize(batch_tensors))) loss = self._do_loss_init(batch_tensors) for k in sorted(batch_tensors.accessed_keys): loss_inputs.append((k, batch_tensors[k])) # XXX experimental support for automatically eagerifying the loss. # The main limitation right now is that TF doesn't support mixing eager # and non-eager tensors, so losses that read non-eager tensors through # `policy` need to use `policy.convert_to_eager(tensor)`. if self.config["use_eager"]: if not self.model: raise ValueError("eager not implemented in this case") graph_tensors = list(self._needs_eager_conversion) def gen_loss(model_outputs, *args): # fill in the batch tensor dict with eager ensors eager_inputs = dict( zip([k for (k, v) in loss_inputs], args[:len(loss_inputs)])) # fill in the eager versions of all accessed graph tensors self._eager_tensors = dict( zip(graph_tensors, args[len(loss_inputs):])) # patch the action dist to use eager mode tensors self.action_dist.inputs = model_outputs return self._loss_fn(self, eager_inputs) # TODO(ekl) also handle the stats funcs loss = tf.py_function( gen_loss, # cast works around TypeError: Cannot convert provided value # to EagerTensor. Provided value: 0.0 Requested dtype: int64 [self.model.outputs] + [tf.cast(v, tf.float32) for (k, v) in loss_inputs] + [tf.cast(t, tf.float32) for t in graph_tensors], tf.float32) TFPolicy._initialize_loss(self, loss, loss_inputs) if self._grad_stats_fn: self._stats_fetches.update(self._grad_stats_fn(self, self._grads)) self._sess.run(tf.global_variables_initializer())
def _initialize_loss_from_dummy_batch( self, auto_remove_unneeded_view_reqs: bool = True, stats_fn=None) -> None: # Create the optimizer/exploration optimizer here. Some initialization # steps (e.g. exploration postprocessing) may need this. self._optimizer = self.optimizer() # Test calls depend on variable init, so initialize model first. self.get_session().run(tf1.global_variables_initializer()) logger.info("Testing `compute_actions` w/ dummy batch.") actions, state_outs, extra_fetches = \ self.compute_actions_from_input_dict( self._dummy_batch, explore=False, timestep=0) for key, value in extra_fetches.items(): self._dummy_batch[key] = value self._input_dict[key] = get_placeholder(value=value, name=key) if key not in self.view_requirements: logger.info("Adding extra-action-fetch `{}` to " "view-reqs.".format(key)) self.view_requirements[key] = ViewRequirement( space=gym.spaces.Box(-1.0, 1.0, shape=value.shape[1:], dtype=value.dtype), used_for_compute_actions=False, ) dummy_batch = self._dummy_batch logger.info("Testing `postprocess_trajectory` w/ dummy batch.") self.exploration.postprocess_trajectory(self, dummy_batch, self.get_session()) _ = self.postprocess_trajectory(dummy_batch) # Add new columns automatically to (loss) input_dict. for key in dummy_batch.added_keys: if key not in self._input_dict: self._input_dict[key] = get_placeholder(value=dummy_batch[key], name=key) if key not in self.view_requirements: self.view_requirements[key] = ViewRequirement( space=gym.spaces.Box(-1.0, 1.0, shape=dummy_batch[key].shape[1:], dtype=dummy_batch[key].dtype), used_for_compute_actions=False, ) train_batch = SampleBatch( dict(self._input_dict, **self._loss_input_dict)) if self._state_inputs: train_batch["seq_lens"] = self._seq_lens self._loss_input_dict.update({"seq_lens": train_batch["seq_lens"]}) self._loss_input_dict.update({k: v for k, v in train_batch.items()}) if log_once("loss_init"): logger.debug( "Initializing loss function with dummy input:\n\n{}\n".format( summarize(train_batch))) loss = self._do_loss_init(train_batch) all_accessed_keys = \ train_batch.accessed_keys | dummy_batch.accessed_keys | \ dummy_batch.added_keys | set( self.model.view_requirements.keys()) TFPolicy._initialize_loss( self, loss, [(k, v) for k, v in train_batch.items() if k in all_accessed_keys] + ([("seq_lens", train_batch["seq_lens"])] if "seq_lens" in train_batch else [])) if "is_training" in self._loss_input_dict: del self._loss_input_dict["is_training"] # Call the grads stats fn. # TODO: (sven) rename to simply stats_fn to match eager and torch. if self._grad_stats_fn: self._stats_fetches.update( self._grad_stats_fn(self, train_batch, self._grads)) # Add new columns automatically to view-reqs. if auto_remove_unneeded_view_reqs: # Add those needed for postprocessing and training. all_accessed_keys = train_batch.accessed_keys | \ dummy_batch.accessed_keys # Tag those only needed for post-processing (with some exceptions). for key in dummy_batch.accessed_keys: if key not in train_batch.accessed_keys and \ key not in self.model.view_requirements and \ key not in [ SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, SampleBatch.UNROLL_ID, SampleBatch.DONES, SampleBatch.REWARDS, SampleBatch.INFOS]: if key in self.view_requirements: self.view_requirements[key].used_for_training = False if key in self._loss_input_dict: del self._loss_input_dict[key] # Remove those not needed at all (leave those that are needed # by Sampler to properly execute sample collection). # Also always leave DONES, REWARDS, and INFOS, no matter what. for key in list(self.view_requirements.keys()): if key not in all_accessed_keys and key not in [ SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, SampleBatch.UNROLL_ID, SampleBatch.DONES, SampleBatch.REWARDS, SampleBatch.INFOS] and \ key not in self.model.view_requirements: # If user deleted this key manually in postprocessing # fn, warn about it and do not remove from # view-requirements. if key in dummy_batch.deleted_keys: logger.warning( "SampleBatch key '{}' was deleted manually in " "postprocessing function! RLlib will " "automatically remove non-used items from the " "data stream. Remove the `del` from your " "postprocessing function.".format(key)) else: del self.view_requirements[key] if key in self._loss_input_dict: del self._loss_input_dict[key] # Add those data_cols (again) that are missing and have # dependencies by view_cols. for key in list(self.view_requirements.keys()): vr = self.view_requirements[key] if (vr.data_col is not None and vr.data_col not in self.view_requirements): used_for_training = \ vr.data_col in train_batch.accessed_keys self.view_requirements[vr.data_col] = ViewRequirement( space=vr.space, used_for_training=used_for_training) self._loss_input_dict_no_rnn = { k: v for k, v in self._loss_input_dict.items() if (v not in self._state_inputs and v != self._seq_lens) } # Initialize again after loss init. self.get_session().run(tf1.global_variables_initializer())
def _initialize_loss(self): def fake_array(tensor): shape = tensor.shape.as_list() shape = [s if s is not None else 1 for s in shape] return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype) dummy_batch = { SampleBatch.CUR_OBS: fake_array(self._obs_input), SampleBatch.NEXT_OBS: fake_array(self._obs_input), SampleBatch.DONES: np.array([False], dtype=np.bool), SampleBatch.ACTIONS: fake_array(ModelCatalog.get_action_placeholder(self.action_space)), SampleBatch.REWARDS: np.array([0], dtype=np.float32), } # Add dummy things PENGZHENGHAO # for name, val in self.model.mask_placeholder_dict.items(): # shape = val.shape.as_list() # shape = [1] + [s if s is not None else 1 for s in shape] # dummy_batch[name] = \ # np.zeros(shape, dtype=val.dtype.as_numpy_dtype) if self._obs_include_prev_action_reward: dummy_batch.update({ SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input), SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input), }) state_init = self.get_initial_state() state_batches = [] for i, h in enumerate(state_init): dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0) dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0) state_batches.append(np.expand_dims(h, 0)) if state_init: dummy_batch["seq_lens"] = np.array([1], dtype=np.int32) for k, v in self.extra_compute_action_fetches().items(): dummy_batch[k] = fake_array(v) # postprocessing might depend on variable init, so run it first here self._sess.run(tf.global_variables_initializer()) postprocessed_batch = self.postprocess_trajectory( SampleBatch(dummy_batch)) # model forward pass for the loss (needed after postprocess to # overwrite any tensor state from that call) self.model(self._input_dict, self._state_in, self._seq_lens) if self._obs_include_prev_action_reward: train_batch = UsageTrackingDict({ SampleBatch.PREV_ACTIONS: self._prev_action_input, SampleBatch.PREV_REWARDS: self._prev_reward_input, SampleBatch.CUR_OBS: self._obs_input, }) loss_inputs = [ (SampleBatch.PREV_ACTIONS, self._prev_action_input), (SampleBatch.PREV_REWARDS, self._prev_reward_input), (SampleBatch.CUR_OBS, self._obs_input), ] else: train_batch = UsageTrackingDict({ SampleBatch.CUR_OBS: self._obs_input, }) loss_inputs = [ (SampleBatch.CUR_OBS, self._obs_input), ] # When using the mask, the key of postprocessed_batch is : # dict_keys(['obs', 'new_obs', 'dones', 'actions', 'rewards', # 'fc_1_mask', 'fc_2_mask', 'prev_actions', 'prev_rewards', # 'action_prob', 'action_logp', 'vf_preds', 'behaviour_logits', # 'layer0', 'layer1', 'advantages', 'value_targets']) # When not using the mask, the keys is: # dict_keys(['obs', 'new_obs', 'dones', 'actions', 'rewards', # 'fc_1_mask', 'fc_2_mask', 'prev_actions', 'prev_rewards', # 'action_prob', 'action_logp', 'vf_preds', 'behaviour_logits', # 'layer0', 'layer1', 'advantages', 'value_targets']) for k, v in postprocessed_batch.items(): if k in train_batch: continue elif v.dtype == np.object: continue # can't handle arbitrary objects in TF elif k == "seq_lens" or k.startswith("state_in_"): continue shape = (None, ) + v.shape[1:] dtype = np.float32 if v.dtype == np.float64 else v.dtype placeholder = tf.placeholder(dtype, shape=shape, name=k) train_batch[k] = placeholder # When using the mask. At this time, the train_batch contain 17 # element. # <class 'list'>: ['prev_actions', 'prev_rewards', 'obs', 'new_obs', # 'dones', 'actions', 'rewards', 'fc_1_mask', 'fc_2_mask', # 'action_prob', 'action_logp', 'vf_preds', 'behaviour_logits', # 'layer0', 'layer1', 'advantages', 'value_targets'] for i, si in enumerate(self._state_in): train_batch["state_in_{}".format(i)] = si train_batch["seq_lens"] = self._seq_lens if log_once("loss_init"): logger.debug( "Initializing loss function with dummy input:\n\n{}\n".format( summarize(train_batch))) self._loss_input_dict = train_batch # At this time, the accessed_keys: <class 'set'>: # {'obs', 'prev_rewards', 'value_targets', 'behaviour_logits', # 'prev_actions', 'advantages', 'action_logp', 'actions', # 'vf_preds', 'accessed_keys', 'intercepted_values'} # However, in the no-mask exp, current accessed_keys: # <class 'set'>: {'intercepted_values', 'accessed_keys'} loss = self._do_loss_init(train_batch) # after the above line, the accessed_keys: <class 'set'>: # {'advantages', 'action_logp', 'behaviour_logits', 'prev_rewards', # 'prev_actions', 'vf_preds', 'actions', 'value_targets', 'obs'} # However, in the no-mask exp, above line lead to: They are same. # but different order. # {'action_logp', 'prev_actions', 'behaviour_logits', # 'value_targets', 'obs', 'prev_rewards', 'advantages', 'vf_preds', # 'actions'} # at this time, the loss input already has: prev_actions, # prev_rewards, obs for k in sorted(train_batch.accessed_keys): # sorted train_batch.accessed_keys: <class 'list'>: [ # 'action_logp', 'actions', 'advantages', 'behaviour_logits', # 'obs', 'prev_actions', 'prev_rewards', 'value_targets', # 'vf_preds'] if k != "seq_lens" and not k.startswith("state_in_"): loss_inputs.append((k, train_batch[k])) # PENGZHENGHAO # for name, ph in self.model.mask_placeholder_dict.items(): # loss_inputs.append((name, ph)) TFPolicy._initialize_loss(self, loss, loss_inputs) if self._grad_stats_fn: self._stats_fetches.update( self._grad_stats_fn(self, train_batch, self._grads)) self._sess.run(tf.global_variables_initializer())