Ejemplo n.º 1
0
    def add(self, stacked_datapoints):
        """Adds datapoints to the buffer.

        Args:
            stacked_datapoints (pytree): Transition object containing the
                datapoints, stacked along axis 0.
        """
        n_elems = data.choose_leaf(data.nested_map(
            lambda x: x.shape[0], stacked_datapoints
        ))

        def insert_to_array(buf, elems):
            buf_size = buf.shape[0]
            assert elems.shape[0] == n_elems
            index = self._insert_index
            # Insert up to buf_size at the current index.
            buf[index:min(index + n_elems, buf_size)] = elems[:buf_size - index]
            # Insert whatever's left at the beginning of the buffer.
            buf[:max(index + n_elems - buf_size, 0)] = elems[buf_size - index:]

        # Insert to all arrays in the pytree.
        data.nested_zip_with(
            insert_to_array, (self._data_buffer, stacked_datapoints)
        )
        if self._size < self._capacity:
            self._size = min(self._insert_index + n_elems, self._capacity)
        self._insert_index = (self._insert_index + n_elems) % self._capacity
Ejemplo n.º 2
0
    def add(self, stacked_datapoints):
        datapoint_shape = data.nested_map(lambda x: x.shape[1:],
                                          stacked_datapoints)
        if datapoint_shape != self._datapoint_shape:
            raise ValueError(
                'Datapoint shape mismatch: got {}, expected {}.'.format(
                    datapoint_shape, self._datapoint_shape))

        n_elems = data.choose_leaf(
            data.nested_map(lambda x: x.shape[0], stacked_datapoints))

        def insert_to_array(buf, elems):
            buf_size = buf.shape[0]
            assert elems.shape[0] == n_elems
            index = self._insert_index

            buf[index:min(index + n_elems, buf_size)] = elems[:buf_size -
                                                              index]

            buf[:max(index + n_elems - buf_size, 0)] = elems[buf_size - index:]

        data.nested_zip_with(insert_to_array,
                             (self._data_buffer, stacked_datapoints))
        if self._size < self._capacity:
            self._size = min(self._insert_index + n_elems, self._capacity)
        self._insert_index = (self._insert_index + n_elems) % self._capacity
Ejemplo n.º 3
0
def _make_output_heads(hidden, output_signature, output_activation):
    """Initializes Dense layers for heads.

    Args:
        hidden (tf.Tensor): Output of the last hidden layer.
        output_signature (pytree of TensorSignatures): Output signature.
        output_activation (pytree of activations): Activation of every head. See
            tf.keras.layers.Activation docstring for possible values.

    Returns:
        Pytree of head output tensors.
    """
    def init_head(signature, activation, name):
        assert signature.dtype == np.float32
        (depth, ) = signature.shape
        return keras.layers.Dense(depth, name=name,
                                  activation=activation)(hidden)

    names = None
    if isinstance(output_signature, dict):
        names = {
            output_name: output_name
            for output_name in output_signature.keys()
        }
    return data.nested_zip_with(init_head,
                                (output_signature, output_activation, names))
Ejemplo n.º 4
0
def _make_output_heads(hidden, output_signature, output_activation, zero_init):
    """Initializes Dense layers for heads.

    Args:
        hidden (tf.Tensor): Output of the last hidden layer.
        output_signature (pytree of TensorSignatures): Output signature.
        output_activation (pytree of activations): Activation of every head. See
            tf.keras.layers.Activation docstring for possible values.
        zero_init (bool): Whether to initialize the heads with zeros. Useful for
            ensuring proper exploration in the initial stages of RL training.

    Returns:
        Pytree of head output tensors.
    """
    def init_head(signature, activation):
        assert signature.dtype == np.float32
        (depth, ) = signature.shape
        kwargs = {'activation': activation}
        if zero_init:
            kwargs['kernel_initializer'] = 'zeros'
            kwargs['bias_initializer'] = 'zeros'
        return keras.layers.Dense(depth, **kwargs)(hidden)

    return data.nested_zip_with(init_head,
                                (output_signature, output_activation))
Ejemplo n.º 5
0
def _make_output_heads(hidden, output_signature, output_activation, zero_init):
    masks = _make_inputs(output_signature)

    def init_head(layer, signature, activation, mask):
        assert signature.dtype == np.float32
        depth = signature.shape[-1]
        kwargs = {'activation': activation}
        if zero_init:
            kwargs['kernel_initializer'] = 'zeros'
            kwargs['bias_initializer'] = 'zeros'
        head = keras.layers.Dense(depth, **kwargs)(layer)
        return AddMask()((head, mask))

    if tf.is_tensor(hidden):
        hidden = data.nested_map(lambda _: hidden, output_signature)

    heads = data.nested_zip_with(
        init_head, (hidden, output_signature, output_activation, masks))
    return (heads, masks)
Ejemplo n.º 6
0
    def __init__(
            self,
            network_signature,
            input=input_observation,
            target=target_solved,
            mask=None,
            batch_size=64,
            n_steps_per_epoch=1000,
            replay_buffer_capacity=1000000,
            replay_buffer_sampling_hierarchy=(),
    ):
        super().__init__(network_signature)

        def build_episode_to_pytree_mapper(functions_pytree):
            return lambda episode: data.nested_map(lambda f: f(episode),
                                                   functions_pytree)

        self._input_fn = build_episode_to_pytree_mapper(input)
        self._target_fn = build_episode_to_pytree_mapper(target)

        if mask is None:
            mask = data.nested_map(lambda _: mask_one, target)

        self._mask_fn = lambda episode: data.nested_zip_with(
            lambda f, target: f(episode, target),
            (mask, self._target_fn(episode)))

        self._batch_size = batch_size
        self._n_steps_per_epoch = n_steps_per_epoch

        datapoint_sig = (
            network_signature.input,
            network_signature.output,
            network_signature.output,
        )
        self._replay_buffer = replay_buffers.HierarchicalReplayBuffer(
            datapoint_sig,
            capacity=replay_buffer_capacity,
            hierarchy_depth=len(replay_buffer_sampling_hierarchy),
        )
        self._sampling_hierarchy = replay_buffer_sampling_hierarchy