def __init__(self, network_signature, replay_buffer_capacity=1000000, replay_buffer_sampling_hierarchy=(), sample_mode='priority', **kwargs): super().__init__( network_signature, replay_buffer_capacity=replay_buffer_capacity, replay_buffer_sampling_hierarchy=replay_buffer_sampling_hierarchy, **kwargs) self._sample_mode = sample_mode if sample_mode == 'priority': raise NotImplementedError( "Prioritized replay buffers have been removed.") elif sample_mode == 'uniform': raw_buffer = replay_buffers.UniformReplayBuffer else: raise ValueError( f'Sample mode {sample_mode} is not supported. Please use ' f'"priority" or "uniform" sampling.') self._replay_buffer = replay_buffers.HierarchicalReplayBuffer( network_signature, replay_buffer_capacity, hierarchy_depth=len(replay_buffer_sampling_hierarchy), raw_buffer=raw_buffer) if sample_mode == 'priority': def mse(y_pred, y_true): loss = (y_pred - y_true)**2 return loss.mean(axis=1) self._priority_error_fn = mse else: self._priority_error_fn = None
def test_hierarchical_samples_buckets_uniformly(): buf = replay_buffers.HierarchicalReplayBuffer(_test_datapoint_sig, capacity=10, hierarchy_depth=1) # Add zeros and ones at a 10:1 ratio. buf.add(_TestTransition(np.zeros(10)), [0]) buf.add(_TestTransition(np.ones(1)), [1]) # Assert that sampled transitions have a mean value of 0.5. mean_value = np.mean(buf.sample(batch_size=1000).test_field) np.testing.assert_allclose(mean_value, 0.5, atol=0.1)
def __init__(self, network_signature, temporal_diff_n, gamma=1.0, batch_size=64, n_steps_per_epoch=1000, replay_buffer_capacity=1000000, replay_buffer_sampling_hierarchy=(), polyak_coeff=None): """Initializes TDTrainer. Args: network_signature (pytree): Input signature for the network. temporal_diff_n: temporal difference distance, np.inf is supported gamma: discount rate batch_size (int): Batch size. n_steps_per_epoch (int): Number of optimizer steps to do per epoch. replay_buffer_capacity (int): Maximum size of the replay buffer. replay_buffer_sampling_hierarchy (tuple): Sequence of Episode attribute names, defining the sampling hierarchy. polyak_coeff: polyak averaging coefficient """ super().__init__(network_signature) target = lambda episode: target_n_return(episode, temporal_diff_n, gamma) self._target_fn = lambda episode: data.nested_map( lambda f: f(episode), target) self._batch_size = batch_size self._n_steps_per_epoch = n_steps_per_epoch td_target_signature = TDTargetData( cum_reward=network_signature.output, bootstrap_obs=network_signature.input, bootstrap_gamma=network_signature.output) datapoint_sig = (network_signature.input, td_target_signature) self._replay_buffer = replay_buffers.HierarchicalReplayBuffer( datapoint_sig, capacity=replay_buffer_capacity, hierarchy_depth=len(replay_buffer_sampling_hierarchy), ) self._sampling_hierarchy = replay_buffer_sampling_hierarchy self._polyak_coeff = polyak_coeff self._target_network_params = None self._target_network = None
def __init__( self, network_signature, input=input_observation, target=target_solved, mask=None, batch_size=64, n_steps_per_epoch=1000, replay_buffer_capacity=1000000, replay_buffer_sampling_hierarchy=(), ): super().__init__(network_signature) def build_episode_to_pytree_mapper(functions_pytree): return lambda episode: data.nested_map(lambda f: f(episode), functions_pytree) self._input_fn = build_episode_to_pytree_mapper(input) self._target_fn = build_episode_to_pytree_mapper(target) if mask is None: mask = data.nested_map(lambda _: mask_one, target) self._mask_fn = lambda episode: data.nested_zip_with( lambda f, target: f(episode, target), (mask, self._target_fn(episode))) self._batch_size = batch_size self._n_steps_per_epoch = n_steps_per_epoch datapoint_sig = ( network_signature.input, network_signature.output, network_signature.output, ) self._replay_buffer = replay_buffers.HierarchicalReplayBuffer( datapoint_sig, capacity=replay_buffer_capacity, hierarchy_depth=len(replay_buffer_sampling_hierarchy), ) self._sampling_hierarchy = replay_buffer_sampling_hierarchy
def __init__( self, input_shape, target_fn=target_solved, batch_size=64, n_steps_per_epoch=1000, replay_buffer_capacity=1000000, replay_buffer_sampling_hierarchy=(), ): """Initializes SupervisedTrainer. Args: input_shape (tuple): Input shape for the network. target_fn (callable): Function episode -> target for determining the target for network training. batch_size (int): Batch size. n_steps_per_epoch (int): Number of optimizer steps to do per epoch. replay_buffer_capacity (int): Maximum size of the replay buffer. replay_buffer_sampling_hierarchy (tuple): Sequence of Episode attribute names, defining the sampling hierarchy. """ super().__init__(input_shape) self._target_fn = target_fn self._batch_size = batch_size self._n_steps_per_epoch = n_steps_per_epoch # (input, target). For now we assume that all networks return a single # output. # TODO(koz4k): Lift this restriction. datapoint_spec = (input_shape, ()) self._replay_buffer = replay_buffers.HierarchicalReplayBuffer( datapoint_spec, capacity=replay_buffer_capacity, hierarchy_depth=len(replay_buffer_sampling_hierarchy), ) self._sampling_hierarchy = replay_buffer_sampling_hierarchy
def __init__( self, network_signature, target=target_solved, batch_size=64, n_steps_per_epoch=1000, replay_buffer_capacity=1000000, replay_buffer_sampling_hierarchy=(), ): """Initializes SupervisedTrainer. Args: network_signature (pytree): Input signature for the network. target (pytree): Pytree of functions episode -> target for determining the targets for network training. The structure of the tree should reflect the structure of a target. batch_size (int): Batch size. n_steps_per_epoch (int): Number of optimizer steps to do per epoch. replay_buffer_capacity (int): Maximum size of the replay buffer. replay_buffer_sampling_hierarchy (tuple): Sequence of Episode attribute names, defining the sampling hierarchy. """ super().__init__(network_signature) self._target_fn = lambda episode: data.nested_map( lambda f: f(episode), target) self._batch_size = batch_size self._n_steps_per_epoch = n_steps_per_epoch # (input, target) datapoint_sig = (network_signature.input, network_signature.output) self._replay_buffer = replay_buffers.HierarchicalReplayBuffer( datapoint_sig, capacity=replay_buffer_capacity, hierarchy_depth=len(replay_buffer_sampling_hierarchy), ) self._sampling_hierarchy = replay_buffer_sampling_hierarchy
def test_hierarchical_samples_added_transitions(hierarchy_depth): buf = replay_buffers.HierarchicalReplayBuffer( _test_datapoint_sig, capacity=10, hierarchy_depth=hierarchy_depth) stacked_transitions = _TestTransition(np.array([123])) buf.add(stacked_transitions, [0] * hierarchy_depth) assert buf.sample(batch_size=1) == stacked_transitions
def __init__( self, network_signature, inputs=input_observation, target=target_solved, batch_size=64, n_steps_per_epoch=1000, replay_buffer_capacity=1000000, replay_buffer_sampling_hierarchy=(), validation_split=None, validate_every_n_epochs=None, validation_replay_buffer_capacity=None, ): """Initializes SupervisedTrainer. Args: network_signature (pytree): Input signature for the network. inputs (callable): Function Episode -> Datapoint_signature.input. Preprocesses episodes to the network inputs for further training. target (pytree): Pytree of functions episode -> target for determining the targets for network training. The structure of the tree should reflect the structure of a target. batch_size (int): Batch size. n_steps_per_epoch (int): Number of optimizer steps to do per epoch. replay_buffer_capacity (int): Maximum size of the replay buffer. replay_buffer_sampling_hierarchy (tuple): Sequence of Episode attribute names, defining the sampling hierarchy. validation_split (Optional[float]): Fraction of episodes which should be placed in the validation replay buffer. validate_every_n_epochs (Optional[int]): Validation frequency in epochs. validation_replay_buffer_capacity (Optional[int]): Maximum size of the validation replay buffer. Defaults to `replay_buffer_capacity` if not provided. Raises: ValueError: When validation parameters are provided only partially. """ super().__init__(network_signature) self._input_fn = inputs self._target_fn = lambda episode: data.nested_map( lambda f: f(episode), target) self._batch_size = batch_size self._n_steps_per_epoch = n_steps_per_epoch # (input, target) datapoint_sig = (network_signature.input, network_signature.output) self._replay_buffer = replay_buffers.HierarchicalReplayBuffer( datapoint_sig, capacity=replay_buffer_capacity, hierarchy_depth=len(replay_buffer_sampling_hierarchy), ) self._sampling_hierarchy = replay_buffer_sampling_hierarchy self._validation_split = validation_split self._validate_every_n_epochs = validate_every_n_epochs self._validation_replay_buffer = None if self._validation_split is not None: if self._validate_every_n_epochs is None: raise ValueError( f'Argument validate_every_n_epochs should be specified ' f'when validation_split is provided: {validation_split}') if self._validate_every_n_epochs <= 0: raise ValueError( f'Argument validate_every_n_epochs should be positive ' f'integer, got {validate_every_n_epochs}') self._validation_replay_buffer = \ replay_buffers.HierarchicalReplayBuffer( datapoint_sig, validation_replay_buffer_capacity or replay_buffer_capacity, len(self._sampling_hierarchy) ) self._epoch = 0