def add(self, stacked_datapoints): """Adds datapoints to the buffer. Args: stacked_datapoints (pytree): Transition object containing the datapoints, stacked along axis 0. """ n_elems = data.choose_leaf(data.nested_map( lambda x: x.shape[0], stacked_datapoints )) def insert_to_array(buf, elems): buf_size = buf.shape[0] assert elems.shape[0] == n_elems index = self._insert_index # Insert up to buf_size at the current index. buf[index:min(index + n_elems, buf_size)] = elems[:buf_size - index] # Insert whatever's left at the beginning of the buffer. buf[:max(index + n_elems - buf_size, 0)] = elems[buf_size - index:] # Insert to all arrays in the pytree. data.nested_zip_with( insert_to_array, (self._data_buffer, stacked_datapoints) ) if self._size < self._capacity: self._size = min(self._insert_index + n_elems, self._capacity) self._insert_index = (self._insert_index + n_elems) % self._capacity
def add(self, stacked_datapoints): datapoint_shape = data.nested_map(lambda x: x.shape[1:], stacked_datapoints) if datapoint_shape != self._datapoint_shape: raise ValueError( 'Datapoint shape mismatch: got {}, expected {}.'.format( datapoint_shape, self._datapoint_shape)) n_elems = data.choose_leaf( data.nested_map(lambda x: x.shape[0], stacked_datapoints)) def insert_to_array(buf, elems): buf_size = buf.shape[0] assert elems.shape[0] == n_elems index = self._insert_index buf[index:min(index + n_elems, buf_size)] = elems[:buf_size - index] buf[:max(index + n_elems - buf_size, 0)] = elems[buf_size - index:] data.nested_zip_with(insert_to_array, (self._data_buffer, stacked_datapoints)) if self._size < self._capacity: self._size = min(self._insert_index + n_elems, self._capacity) self._insert_index = (self._insert_index + n_elems) % self._capacity
def _make_output_heads(hidden, output_signature, output_activation): """Initializes Dense layers for heads. Args: hidden (tf.Tensor): Output of the last hidden layer. output_signature (pytree of TensorSignatures): Output signature. output_activation (pytree of activations): Activation of every head. See tf.keras.layers.Activation docstring for possible values. Returns: Pytree of head output tensors. """ def init_head(signature, activation, name): assert signature.dtype == np.float32 (depth, ) = signature.shape return keras.layers.Dense(depth, name=name, activation=activation)(hidden) names = None if isinstance(output_signature, dict): names = { output_name: output_name for output_name in output_signature.keys() } return data.nested_zip_with(init_head, (output_signature, output_activation, names))
def _make_output_heads(hidden, output_signature, output_activation, zero_init): """Initializes Dense layers for heads. Args: hidden (tf.Tensor): Output of the last hidden layer. output_signature (pytree of TensorSignatures): Output signature. output_activation (pytree of activations): Activation of every head. See tf.keras.layers.Activation docstring for possible values. zero_init (bool): Whether to initialize the heads with zeros. Useful for ensuring proper exploration in the initial stages of RL training. Returns: Pytree of head output tensors. """ def init_head(signature, activation): assert signature.dtype == np.float32 (depth, ) = signature.shape kwargs = {'activation': activation} if zero_init: kwargs['kernel_initializer'] = 'zeros' kwargs['bias_initializer'] = 'zeros' return keras.layers.Dense(depth, **kwargs)(hidden) return data.nested_zip_with(init_head, (output_signature, output_activation))
def _make_output_heads(hidden, output_signature, output_activation, zero_init): masks = _make_inputs(output_signature) def init_head(layer, signature, activation, mask): assert signature.dtype == np.float32 depth = signature.shape[-1] kwargs = {'activation': activation} if zero_init: kwargs['kernel_initializer'] = 'zeros' kwargs['bias_initializer'] = 'zeros' head = keras.layers.Dense(depth, **kwargs)(layer) return AddMask()((head, mask)) if tf.is_tensor(hidden): hidden = data.nested_map(lambda _: hidden, output_signature) heads = data.nested_zip_with( init_head, (hidden, output_signature, output_activation, masks)) return (heads, masks)
def __init__( self, network_signature, input=input_observation, target=target_solved, mask=None, batch_size=64, n_steps_per_epoch=1000, replay_buffer_capacity=1000000, replay_buffer_sampling_hierarchy=(), ): super().__init__(network_signature) def build_episode_to_pytree_mapper(functions_pytree): return lambda episode: data.nested_map(lambda f: f(episode), functions_pytree) self._input_fn = build_episode_to_pytree_mapper(input) self._target_fn = build_episode_to_pytree_mapper(target) if mask is None: mask = data.nested_map(lambda _: mask_one, target) self._mask_fn = lambda episode: data.nested_zip_with( lambda f, target: f(episode, target), (mask, self._target_fn(episode))) self._batch_size = batch_size self._n_steps_per_epoch = n_steps_per_epoch datapoint_sig = ( network_signature.input, network_signature.output, network_signature.output, ) self._replay_buffer = replay_buffers.HierarchicalReplayBuffer( datapoint_sig, capacity=replay_buffer_capacity, hierarchy_depth=len(replay_buffer_sampling_hierarchy), ) self._sampling_hierarchy = replay_buffer_sampling_hierarchy