def _get_dummy_batch_from_view_requirements(self, batch_size: int = 1 ) -> SampleBatch: """Creates a numpy dummy batch based on the Policy's view requirements. Args: batch_size (int): The size of the batch to create. Returns: Dict[str, TensorType]: The dummy batch containing all zero values. """ ret = {} for view_col, view_req in self.view_requirements.items(): data_col = view_req.data_col or view_col # Flattened dummy batch. if (isinstance( view_req.space, (gym.spaces.Tuple, gym.spaces.Dict))) and ( (data_col == SampleBatch.OBS and not self.config["_disable_preprocessor_api"]) or (data_col == SampleBatch.ACTIONS and not self.config.get("_disable_action_flattening"))): _, shape = ModelCatalog.get_action_shape( view_req.space, framework=self.config["framework"]) ret[view_col] = np.zeros((batch_size, ) + shape[1:], np.float32) # Non-flattened dummy batch. else: # Range of indices on time-axis, e.g. "-50:-1". if view_req.shift_from is not None: ret[view_col] = get_dummy_batch_for_space( view_req.space, batch_size=batch_size, time_size=view_req.shift_to - view_req.shift_from + 1, ) # Sequence of (probably non-consecutive) indices. elif isinstance(view_req.shift, (list, tuple)): ret[view_col] = get_dummy_batch_for_space( view_req.space, batch_size=batch_size, time_size=len(view_req.shift), ) # Single shift int value. else: if isinstance(view_req.space, gym.spaces.Space): ret[view_col] = get_dummy_batch_for_space( view_req.space, batch_size=batch_size, fill_value=0.0) else: ret[view_col] = [ view_req.space for _ in range(batch_size) ] # Due to different view requirements for the different columns, # columns in the resulting batch may not all have the same batch size. return SampleBatch(ret)
def _get_dummy_batch_from_view_requirements(self, batch_size: int = 1 ) -> SampleBatch: """Creates a numpy dummy batch based on the Policy's view requirements. Args: batch_size (int): The size of the batch to create. Returns: Dict[str, TensorType]: The dummy batch containing all zero values. """ ret = {} for view_col, view_req in self.view_requirements.items(): if isinstance(view_req.space, (gym.spaces.Dict, gym.spaces.Tuple)): _, shape = ModelCatalog.get_action_shape( view_req.space, framework=self.config["framework"]) ret[view_col] = \ np.zeros((batch_size, ) + shape[1:], np.float32) else: # Range of indices on time-axis, e.g. "-50:-1". if view_req.shift_from is not None: ret[view_col] = np.zeros_like([[ view_req.space.sample() for _ in range(view_req.shift_to - view_req.shift_from + 1) ] for _ in range(batch_size)]) # Set of (probably non-consecutive) indices. elif isinstance(view_req.shift, (list, tuple)): ret[view_col] = np.zeros_like([[ view_req.space.sample() for t in range(len(view_req.shift)) ] for _ in range(batch_size)]) # Single shift int value. else: if isinstance(view_req.space, gym.spaces.Space): ret[view_col] = np.zeros_like([ view_req.space.sample() for _ in range(batch_size) ]) else: ret[view_col] = [ view_req.space for _ in range(batch_size) ] # Due to different view requirements for the different columns, # columns in the resulting batch may not all have the same batch size. return SampleBatch(ret)
def _get_dummy_batch_from_view_requirements( self, batch_size: int = 1) -> SampleBatch: """Creates a numpy dummy batch based on the Policy's view requirements. Args: batch_size (int): The size of the batch to create. Returns: Dict[str, TensorType]: The dummy batch containing all zero values. """ ret = {} for view_col, view_req in self.view_requirements.items(): if isinstance(view_req.space, (gym.spaces.Dict, gym.spaces.Tuple)): _, shape = ModelCatalog.get_action_shape(view_req.space) ret[view_col] = \ np.zeros((batch_size, ) + shape[1:], np.float32) else: ret[view_col] = np.zeros_like( [view_req.space.sample() for _ in range(batch_size)]) return SampleBatch(ret)
def _initialize_loss_with_dummy_batch(self): # Dummy forward pass to initialize any policy attributes, etc. action_dtype, action_shape = ModelCatalog.get_action_shape( self.action_space) dummy_batch = { SampleBatch.CUR_OBS: np.array([self.observation_space.sample()]), SampleBatch.NEXT_OBS: np.array([self.observation_space.sample()]), SampleBatch.DONES: np.array([False], dtype=np.bool), SampleBatch.ACTIONS: tf.nest.map_structure(lambda c: np.array([c]), self.action_space.sample()), SampleBatch.REWARDS: np.array([0], dtype=np.float32), } if obs_include_prev_action_reward: dummy_batch.update({ SampleBatch.PREV_ACTIONS: dummy_batch[SampleBatch.ACTIONS], SampleBatch.PREV_REWARDS: dummy_batch[SampleBatch.REWARDS], }) for i, h in enumerate(self._state_in): dummy_batch["state_in_{}".format(i)] = h dummy_batch["state_out_{}".format(i)] = h if self._state_in: dummy_batch["seq_lens"] = np.array([1], dtype=np.int32) # Convert everything to tensors. dummy_batch = tf.nest.map_structure(tf.convert_to_tensor, dummy_batch) # for IMPALA which expects a certain sample batch size. def tile_to(tensor, n): return tf.tile(tensor, [n] + [1 for _ in tensor.shape.as_list()[1:]]) if get_batch_divisibility_req: dummy_batch = tf.nest.map_structure( lambda c: tile_to(c, get_batch_divisibility_req(self)), dummy_batch) # Execute a forward pass to get self.action_dist etc initialized, # and also obtain the extra action fetches _, _, fetches = self.compute_actions( dummy_batch[SampleBatch.CUR_OBS], self._state_in, dummy_batch.get(SampleBatch.PREV_ACTIONS), dummy_batch.get(SampleBatch.PREV_REWARDS)) dummy_batch.update(fetches) postprocessed_batch = self.postprocess_trajectory( SampleBatch(dummy_batch)) # model forward pass for the loss (needed after postprocess to # overwrite any tensor state from that call) self.model.from_batch(dummy_batch) postprocessed_batch = tf.nest.map_structure( lambda c: tf.convert_to_tensor(c), postprocessed_batch.data) loss_fn(self, self.model, self.dist_class, postprocessed_batch) if stats_fn: stats_fn(self, postprocessed_batch)
def _initialize_loss_with_dummy_batch(self): # Dummy forward pass to initialize any policy attributes, etc. action_dtype, action_shape = ModelCatalog.get_action_shape( self.action_space) dummy_batch = { SampleBatch.CUR_OBS: tf.convert_to_tensor( np.array([self.observation_space.sample()])), SampleBatch.NEXT_OBS: tf.convert_to_tensor( np.array([self.observation_space.sample()])), SampleBatch.DONES: tf.convert_to_tensor(np.array([False], dtype=np.bool)), SampleBatch.ACTIONS: tf.convert_to_tensor( np.zeros((1, ) + action_shape[1:], dtype=action_dtype.as_numpy_dtype())), SampleBatch.REWARDS: tf.convert_to_tensor(np.array([0], dtype=np.float32)), } if obs_include_prev_action_reward: dummy_batch.update({ SampleBatch.PREV_ACTIONS: dummy_batch[SampleBatch.ACTIONS], SampleBatch.PREV_REWARDS: dummy_batch[SampleBatch.REWARDS], }) state_init = self.get_initial_state() state_batches = [] for i, h in enumerate(state_init): dummy_batch["state_in_{}".format(i)] = tf.convert_to_tensor( np.expand_dims(h, 0)) dummy_batch["state_out_{}".format(i)] = tf.convert_to_tensor( np.expand_dims(h, 0)) state_batches.append(tf.convert_to_tensor(np.expand_dims(h, 0))) if state_init: dummy_batch["seq_lens"] = tf.convert_to_tensor( np.array([1], dtype=np.int32)) # for IMPALA which expects a certain sample batch size def tile_to(tensor, n): return tf.tile(tensor, [n] + [1 for _ in tensor.shape.as_list()[1:]]) if get_batch_divisibility_req: dummy_batch = { k: tile_to(v, get_batch_divisibility_req(self)) for k, v in dummy_batch.items() } # Execute a forward pass to get self.action_dist etc initialized, # and also obtain the extra action fetches _, _, fetches = self.compute_actions( dummy_batch[SampleBatch.CUR_OBS], state_batches, dummy_batch.get(SampleBatch.PREV_ACTIONS), dummy_batch.get(SampleBatch.PREV_REWARDS)) dummy_batch.update(fetches) postprocessed_batch = self.postprocess_trajectory( SampleBatch(dummy_batch)) # model forward pass for the loss (needed after postprocess to # overwrite any tensor state from that call) self.model.from_batch(dummy_batch) postprocessed_batch = { k: tf.convert_to_tensor(v) for k, v in postprocessed_batch.items() } loss_fn(self, self.model, self.dist_class, postprocessed_batch) if stats_fn: stats_fn(self, postprocessed_batch)