Exemplo n.º 1
0
    def __init__(self,
                 network_signature,
                 replay_buffer_capacity=1000000,
                 replay_buffer_sampling_hierarchy=(),
                 sample_mode='priority',
                 **kwargs):
        super().__init__(
            network_signature,
            replay_buffer_capacity=replay_buffer_capacity,
            replay_buffer_sampling_hierarchy=replay_buffer_sampling_hierarchy,
            **kwargs)
        self._sample_mode = sample_mode
        if sample_mode == 'priority':
            raise NotImplementedError(
                "Prioritized replay buffers have been removed.")
        elif sample_mode == 'uniform':
            raw_buffer = replay_buffers.UniformReplayBuffer
        else:
            raise ValueError(
                f'Sample mode {sample_mode} is not supported. Please use '
                f'"priority" or "uniform" sampling.')
        self._replay_buffer = replay_buffers.HierarchicalReplayBuffer(
            network_signature,
            replay_buffer_capacity,
            hierarchy_depth=len(replay_buffer_sampling_hierarchy),
            raw_buffer=raw_buffer)
        if sample_mode == 'priority':

            def mse(y_pred, y_true):
                loss = (y_pred - y_true)**2
                return loss.mean(axis=1)

            self._priority_error_fn = mse
        else:
            self._priority_error_fn = None
Exemplo n.º 2
0
def test_hierarchical_samples_buckets_uniformly():
    buf = replay_buffers.HierarchicalReplayBuffer(_test_datapoint_sig,
                                                  capacity=10,
                                                  hierarchy_depth=1)
    # Add zeros and ones at a 10:1 ratio.
    buf.add(_TestTransition(np.zeros(10)), [0])
    buf.add(_TestTransition(np.ones(1)), [1])
    # Assert that sampled transitions have a mean value of 0.5.
    mean_value = np.mean(buf.sample(batch_size=1000).test_field)
    np.testing.assert_allclose(mean_value, 0.5, atol=0.1)
Exemplo n.º 3
0
    def __init__(self,
                 network_signature,
                 temporal_diff_n,
                 gamma=1.0,
                 batch_size=64,
                 n_steps_per_epoch=1000,
                 replay_buffer_capacity=1000000,
                 replay_buffer_sampling_hierarchy=(),
                 polyak_coeff=None):
        """Initializes TDTrainer.

        Args:
            network_signature (pytree): Input signature for the network.
            temporal_diff_n: temporal difference distance, np.inf is supported
            gamma: discount rate
            batch_size (int): Batch size.
            n_steps_per_epoch (int): Number of optimizer steps to do per
                epoch.
            replay_buffer_capacity (int): Maximum size of the replay buffer.
            replay_buffer_sampling_hierarchy (tuple): Sequence of Episode
                attribute names, defining the sampling hierarchy.
            polyak_coeff: polyak averaging coefficient
        """
        super().__init__(network_signature)
        target = lambda episode: target_n_return(episode, temporal_diff_n,
                                                 gamma)
        self._target_fn = lambda episode: data.nested_map(
            lambda f: f(episode), target)
        self._batch_size = batch_size
        self._n_steps_per_epoch = n_steps_per_epoch

        td_target_signature = TDTargetData(
            cum_reward=network_signature.output,
            bootstrap_obs=network_signature.input,
            bootstrap_gamma=network_signature.output)
        datapoint_sig = (network_signature.input, td_target_signature)
        self._replay_buffer = replay_buffers.HierarchicalReplayBuffer(
            datapoint_sig,
            capacity=replay_buffer_capacity,
            hierarchy_depth=len(replay_buffer_sampling_hierarchy),
        )
        self._sampling_hierarchy = replay_buffer_sampling_hierarchy
        self._polyak_coeff = polyak_coeff
        self._target_network_params = None
        self._target_network = None
Exemplo n.º 4
0
    def __init__(
            self,
            network_signature,
            input=input_observation,
            target=target_solved,
            mask=None,
            batch_size=64,
            n_steps_per_epoch=1000,
            replay_buffer_capacity=1000000,
            replay_buffer_sampling_hierarchy=(),
    ):
        super().__init__(network_signature)

        def build_episode_to_pytree_mapper(functions_pytree):
            return lambda episode: data.nested_map(lambda f: f(episode),
                                                   functions_pytree)

        self._input_fn = build_episode_to_pytree_mapper(input)
        self._target_fn = build_episode_to_pytree_mapper(target)

        if mask is None:
            mask = data.nested_map(lambda _: mask_one, target)

        self._mask_fn = lambda episode: data.nested_zip_with(
            lambda f, target: f(episode, target),
            (mask, self._target_fn(episode)))

        self._batch_size = batch_size
        self._n_steps_per_epoch = n_steps_per_epoch

        datapoint_sig = (
            network_signature.input,
            network_signature.output,
            network_signature.output,
        )
        self._replay_buffer = replay_buffers.HierarchicalReplayBuffer(
            datapoint_sig,
            capacity=replay_buffer_capacity,
            hierarchy_depth=len(replay_buffer_sampling_hierarchy),
        )
        self._sampling_hierarchy = replay_buffer_sampling_hierarchy
Exemplo n.º 5
0
    def __init__(
            self,
            input_shape,
            target_fn=target_solved,
            batch_size=64,
            n_steps_per_epoch=1000,
            replay_buffer_capacity=1000000,
            replay_buffer_sampling_hierarchy=(),
    ):
        """Initializes SupervisedTrainer.

        Args:
            input_shape (tuple): Input shape for the network.
            target_fn (callable): Function episode -> target for
                determining the target for network training.
            batch_size (int): Batch size.
            n_steps_per_epoch (int): Number of optimizer steps to do per
                epoch.
            replay_buffer_capacity (int): Maximum size of the replay buffer.
            replay_buffer_sampling_hierarchy (tuple): Sequence of Episode
                attribute names, defining the sampling hierarchy.
        """
        super().__init__(input_shape)
        self._target_fn = target_fn
        self._batch_size = batch_size
        self._n_steps_per_epoch = n_steps_per_epoch

        # (input, target). For now we assume that all networks return a single
        # output.
        # TODO(koz4k): Lift this restriction.
        datapoint_spec = (input_shape, ())
        self._replay_buffer = replay_buffers.HierarchicalReplayBuffer(
            datapoint_spec,
            capacity=replay_buffer_capacity,
            hierarchy_depth=len(replay_buffer_sampling_hierarchy),
        )
        self._sampling_hierarchy = replay_buffer_sampling_hierarchy
Exemplo n.º 6
0
    def __init__(
            self,
            network_signature,
            target=target_solved,
            batch_size=64,
            n_steps_per_epoch=1000,
            replay_buffer_capacity=1000000,
            replay_buffer_sampling_hierarchy=(),
    ):
        """Initializes SupervisedTrainer.

        Args:
            network_signature (pytree): Input signature for the network.
            target (pytree): Pytree of functions episode -> target for
                determining the targets for network training. The structure of
                the tree should reflect the structure of a target.
            batch_size (int): Batch size.
            n_steps_per_epoch (int): Number of optimizer steps to do per
                epoch.
            replay_buffer_capacity (int): Maximum size of the replay buffer.
            replay_buffer_sampling_hierarchy (tuple): Sequence of Episode
                attribute names, defining the sampling hierarchy.
        """
        super().__init__(network_signature)
        self._target_fn = lambda episode: data.nested_map(
            lambda f: f(episode), target)
        self._batch_size = batch_size
        self._n_steps_per_epoch = n_steps_per_epoch

        # (input, target)
        datapoint_sig = (network_signature.input, network_signature.output)
        self._replay_buffer = replay_buffers.HierarchicalReplayBuffer(
            datapoint_sig,
            capacity=replay_buffer_capacity,
            hierarchy_depth=len(replay_buffer_sampling_hierarchy),
        )
        self._sampling_hierarchy = replay_buffer_sampling_hierarchy
Exemplo n.º 7
0
def test_hierarchical_samples_added_transitions(hierarchy_depth):
    buf = replay_buffers.HierarchicalReplayBuffer(
        _test_datapoint_sig, capacity=10, hierarchy_depth=hierarchy_depth)
    stacked_transitions = _TestTransition(np.array([123]))
    buf.add(stacked_transitions, [0] * hierarchy_depth)
    assert buf.sample(batch_size=1) == stacked_transitions
Exemplo n.º 8
0
    def __init__(
        self,
        network_signature,
        inputs=input_observation,
        target=target_solved,
        batch_size=64,
        n_steps_per_epoch=1000,
        replay_buffer_capacity=1000000,
        replay_buffer_sampling_hierarchy=(),
        validation_split=None,
        validate_every_n_epochs=None,
        validation_replay_buffer_capacity=None,
    ):
        """Initializes SupervisedTrainer.

        Args:
            network_signature (pytree): Input signature for the network.
            inputs (callable): Function Episode -> Datapoint_signature.input.
                Preprocesses episodes to the network inputs for further
                training.
            target (pytree): Pytree of functions episode -> target for
                determining the targets for network training. The structure of
                the tree should reflect the structure of a target.
            batch_size (int): Batch size.
            n_steps_per_epoch (int): Number of optimizer steps to do per
                epoch.
            replay_buffer_capacity (int): Maximum size of the replay buffer.
            replay_buffer_sampling_hierarchy (tuple): Sequence of Episode
                attribute names, defining the sampling hierarchy.
            validation_split (Optional[float]): Fraction of episodes which
                should be placed in the validation replay buffer.
            validate_every_n_epochs (Optional[int]): Validation frequency in
                epochs.
            validation_replay_buffer_capacity (Optional[int]): Maximum size
                of the validation replay buffer. Defaults to
                `replay_buffer_capacity` if not provided.

        Raises:
            ValueError: When validation parameters are provided only partially.
        """
        super().__init__(network_signature)
        self._input_fn = inputs
        self._target_fn = lambda episode: data.nested_map(
            lambda f: f(episode), target)
        self._batch_size = batch_size
        self._n_steps_per_epoch = n_steps_per_epoch

        # (input, target)
        datapoint_sig = (network_signature.input, network_signature.output)
        self._replay_buffer = replay_buffers.HierarchicalReplayBuffer(
            datapoint_sig,
            capacity=replay_buffer_capacity,
            hierarchy_depth=len(replay_buffer_sampling_hierarchy),
        )
        self._sampling_hierarchy = replay_buffer_sampling_hierarchy

        self._validation_split = validation_split
        self._validate_every_n_epochs = validate_every_n_epochs
        self._validation_replay_buffer = None
        if self._validation_split is not None:
            if self._validate_every_n_epochs is None:
                raise ValueError(
                    f'Argument validate_every_n_epochs should be specified '
                    f'when validation_split is provided: {validation_split}')
            if self._validate_every_n_epochs <= 0:
                raise ValueError(
                    f'Argument validate_every_n_epochs should be positive '
                    f'integer, got {validate_every_n_epochs}')
            self._validation_replay_buffer = \
                replay_buffers.HierarchicalReplayBuffer(
                    datapoint_sig,
                    validation_replay_buffer_capacity or replay_buffer_capacity,
                    len(self._sampling_hierarchy)
                )
        self._epoch = 0