def test_pugail(self): def dummy_discriminator( state: losses.State, transition: types.Transition) -> losses.DiscriminatorOutput: return transition.observation, state zero_transition = types.Transition(.1, 0., 0., 0., 0.) zero_transition = tree.map_structure( lambda x: jnp.expand_dims(x, axis=0), zero_transition) one_transition = types.Transition(1., 0., 0., 0., 0.) one_transition = tree.map_structure( lambda x: jnp.expand_dims(x, axis=0), one_transition) prior = .7 loss_fn = losses.pugail_loss(positive_class_prior=prior, entropy_coefficient=0.) loss, _ = loss_fn(dummy_discriminator, {}, one_transition, zero_transition, ()) d_one = jax.nn.sigmoid(dummy_discriminator({}, one_transition)[0]) d_zero = jax.nn.sigmoid(dummy_discriminator({}, zero_transition)[0]) expected_loss = -prior * jnp.log( d_one) + -jnp.log(1. - d_zero) - prior * -jnp.log(1 - d_one) self.assertAlmostEqual(loss, expected_loss, places=6)
def test_weighted_generator(self): data0 = types.Transition(np.array([[1], [2], [3]]), (), _REWARD, (), ()) it0 = iter([data0]) data1 = types.Transition(np.array([[4], [5], [6]]), (), _REWARD, (), ()) data2 = types.Transition(np.array([[7], [8], [9]]), (), _REWARD, (), ()) it1 = iter([ reverb.ReplaySample(info=reverb.SampleInfo( *[() for _ in reverb.SampleInfo.tf_dtypes()]), data=data1), reverb.ReplaySample(info=reverb.SampleInfo( *[() for _ in reverb.SampleInfo.tf_dtypes()]), data=data2) ]) weighted_it = builder._generate_samples_with_demonstrations( it0, it1, policy_to_expert_data_ratio=2, batch_size=3) np.testing.assert_array_equal( next(weighted_it).data.observation, np.array([[1], [4], [5]])) np.testing.assert_array_equal( next(weighted_it).data.observation, np.array([[7], [8], [2]])) self.assertRaises(StopIteration, lambda: next(weighted_it))
def replay_sample_to_sars_transition( sample: reverb.ReplaySample, is_sequence: bool) -> types.Transition: """Converts the replay sample to a types.Transition. NB: If is_sequence is True then the last next_observation of each sequence is rubbish. Don't train on it. Args: sample: The replay sample is_sequence: If False we expect the sample data to match the types.Transition already. Otherwise we expect a batch of sequences of steps. Returns: A types.Transition built from the sample data. The number of leading dimensions will be unchanged, so expect 2 for sequence based ([Batch, Time]) and 1 ([Batch]) otherwise. NB: If is_sequence is True then the last next_observation of each sequence is rubbish. Don't train on it. """ if not is_sequence: return types.Transition(*sample.data) # Note that the last next_observation is invalid. steps = sample.data return types.Transition( observation=steps.observation, action=steps.action, reward=steps.reward, discount=steps.discount, next_observation=jnp.roll(steps.observation, shift=-1, axis=1))
def replay_sample_to_sars_transition( sample: reverb.ReplaySample, is_sequence: bool, strip_last_transition: bool = False, flatten_batch: bool = False) -> types.Transition: """Converts the replay sample to a types.Transition. NB: If is_sequence is True then the last next_observation of each sequence is rubbish. Don't train on it. Args: sample: The replay sample is_sequence: If False we expect the sample data to match the types.Transition already. Otherwise we expect a batch of sequences of steps. strip_last_transition: If True and is_sequence, the last transition will be stripped as its next_observation field is incorrect. flatten_batch: If True and is_sequence, the two batch dimensions will be flatten to one. Returns: A types.Transition built from the sample data. If is_sequence and strip_last_transition are both True, the output will be smaller than the output as the last transition of every sequence will have been removed. """ if not is_sequence: return types.Transition(*sample.data) # Note that the last next_observation is invalid. steps = sample.data def roll(observation): return np.roll(observation, shift=-1, axis=1) transitions = types.Transition(observation=steps.observation, action=steps.action, reward=steps.reward, discount=steps.discount, next_observation=tree.map_structure( roll, steps.observation), extras=steps.extras) if strip_last_transition: # We remove the last transition as its next_observation field is incorrect. # It has been obtained by rolling the observation field, such that # transitions.next_observations[:, -1] is transitions.observations[:, 0] transitions = jax.tree_map(lambda x: x[:, :-1, ...], transitions) if flatten_batch: # Merge the 2 leading batch dimensions into 1. transitions = jax.tree_map( lambda x: np.reshape(x, (-1, ) + x.shape[2:]), transitions) return transitions
def test_step(self): simple_spec = specs.Array(shape=(), dtype=float) spec = specs.EnvironmentSpec(simple_spec, simple_spec, simple_spec, simple_spec) discriminator = _make_discriminator(spec) ail_network = ail_networks.AILNetworks(discriminator, imitation_reward_fn=lambda x: x, direct_rl_networks=None) loss = losses.gail_loss() optimizer = optax.adam(.01) step = jax.jit( functools.partial(ail_learning.ail_update_step, optimizer=optimizer, ail_network=ail_network, loss_fn=loss)) zero_transition = types.Transition(np.array([0.]), np.array([0.]), 0., 0., np.array([0.])) zero_transition = utils.add_batch_dim(zero_transition) one_transition = types.Transition(np.array([1.]), np.array([0.]), 0., 0., np.array([0.])) one_transition = utils.add_batch_dim(one_transition) key = jax.random.PRNGKey(0) discriminator_params, discriminator_state = discriminator.init(key) state = ail_learning.DiscriminatorTrainingState( optimizer_state=optimizer.init(discriminator_params), discriminator_params=discriminator_params, discriminator_state=discriminator_state, policy_params=None, key=key, steps=0, ) expected_loss = [1.062, 1.057, 1.052] for i in range(3): state, loss = step(state, (one_transition, zero_transition)) self.assertAlmostEqual(loss['total_loss'], expected_loss[i], places=3)
def test_sqil_iterator(self): demonstrations = [ types.Transition(np.array([[1], [2], [3]]), (), (), (), ()) ] replay = [ reverb.ReplaySample(info=(), data=types.Transition( np.array([[4], [5], [6]]), (), (), (), ())) ] sqil_it = builder._generate_sqil_samples(iter(demonstrations), iter(replay)) np.testing.assert_array_equal( next(sqil_it).data.observation, np.array([[1], [3], [5]])) np.testing.assert_array_equal( next(sqil_it).data.observation, np.array([[2], [4], [6]])) self.assertRaises(StopIteration, lambda: next(sqil_it))
def transition_dataset(environment: dm_env.Environment) -> tf.data.Dataset: """Fake dataset of Reverb N-step transition samples. Args: environment: Used to create a fake transition by looking at the observation, action, discount and reward specs. Returns: tf.data.Dataset that produces the same fake N-step transition ReverSample object indefinitely. """ observation = environment.observation_spec().generate_value() action = environment.action_spec().generate_value() reward = environment.reward_spec().generate_value() discount = environment.discount_spec().generate_value() data = types.Transition(observation, action, reward, discount, observation) key = np.array(0, np.uint64) probability = np.array(1.0, np.float64) table_size = np.array(1, np.int64) priority = np.array(1.0, np.float64) info = reverb.SampleInfo(key=key, probability=probability, table_size=table_size, priority=priority) sample = reverb.ReplaySample(info=info, data=data) return tf.data.Dataset.from_tensors(sample).repeat()
def signature(cls, environment_spec: specs.EnvironmentSpec, extras_spec: types.NestedSpec = ()): # This function currently assumes that self._discount is a scalar. # If it ever becomes a nested structure and/or a np.ndarray, this method # will need to know its structure / shape. This is because the signature # discount shape is the environment's discount shape and this adder's # discount shape broadcasted together. Also, the reward shape is this # signature discount shape broadcasted together with the environment # reward shape. As long as self._discount is a scalar, it will not affect # either the signature discount shape nor the signature reward shape, so we # can ignore it. rewards_spec, step_discounts_spec = tree_utils.broadcast_structures( environment_spec.rewards, environment_spec.discounts) rewards_spec = tree.map_structure(_broadcast_specs, rewards_spec, step_discounts_spec) step_discounts_spec = tree.map_structure(copy.deepcopy, step_discounts_spec) transition_spec = types.Transition( environment_spec.observations, environment_spec.actions, rewards_spec, step_discounts_spec, environment_spec.observations, # next_observation extras_spec) return tree.map_structure_with_path(base.spec_like_to_tensor_spec, transition_spec)
def _create_dummy_transitions(self): return types.Transition(observation=self._DUMMY_OBS, action=self._DUMMY_ACTION, reward=self._DUMMY_REWARD, discount=self._DUMMY_DISCOUNT, next_observation=self._DUMMY_NEXT_OBS, extras={'return': self._DUMMY_RETURN})
def step(self): sample = next(self._iterator) transitions = types.Transition(*sample.data) counts = self._counter.get_counts() if 'learner_steps' not in counts: cur_step = 0 else: cur_step = counts['learner_steps'] in_initial_bc_iters = cur_step < self._num_bc_iters if in_initial_bc_iters: self._state, metrics = self._update_step_in_initial_bc_iters( self._state, transitions) else: self._state, metrics = self._update_step_rest( self._state, transitions) # self._state, metrics = self._update_step(self._state, transitions) # Compute elapsed time. timestamp = time.time() elapsed_time = timestamp - self._timestamp if self._timestamp else 0 self._timestamp = timestamp # Increment counts and record the current time counts = self._counter.increment(steps=self._num_sgd_steps_per_step, walltime=elapsed_time) # Attempts to write the logs. self._logger.write({**metrics, **counts})
def _n_step_transition_from_episode(observations: acme_types.NestedTensor, actions: tf.Tensor, rewards: tf.Tensor, discounts: tf.Tensor, n_step: int, discount: float): """Produce Reverb-like N-step transition from a full episode. Observations, actions, rewards and discounts have the same length. This function will ignore the first reward and discount and the last action. Args: observations: [L, ...] Tensor. actions: [L, ...] Tensor. rewards: [L] Tensor. discounts: [L] Tensor. n_step: number of steps to squash into a single transition. discount: discount to use for TD updates. Returns: (o_t, a_t, r_t, d_t, o_tp1) tuple. """ max_index = tf.shape(rewards)[0] - 1 first = tf.random.uniform(shape=(), minval=0, maxval=max_index - 1, dtype=tf.int32) last = tf.minimum(first + n_step, max_index) o_t = tree.map_structure(operator.itemgetter(first), observations) a_t = tree.map_structure(operator.itemgetter(first), actions) o_tp1 = tree.map_structure(operator.itemgetter(last), observations) # 0, 1, ..., n-1. discount_range = tf.cast(tf.range(last - first), tf.float32) # 1, g, ..., g^{n-1}. additional_discounts = tf.pow(discount, discount_range) # 1, d_t, d_t * d_{t+1}, ..., d_t * ... * d_{t+n-2}. discounts = tf.concat([[1.], tf.math.cumprod(discounts[first:last-1])], 0) # 1, g * d_t, ..., g^{n-1} * d_t * ... * d_{t+n-2}. discounts *= additional_discounts # r_t + g * d_t * r_{t+1} + ... + g^{n-1} * d_t * ... * d_{t+n-2} * r_{t+n-1} # We have to shift rewards by one so last=max_index corresponds to transitions # that include the last reward. r_t = tf.reduce_sum(rewards[first+1:last+1] * discounts) # g^{n-1} * d_{t} * ... * d_{t+n-1}. d_t = discounts[-1] key = tf.constant(0, tf.uint64) probability = tf.constant(1.0, tf.float64) table_size = tf.constant(1, tf.int64) priority = tf.constant(1.0, tf.float64) info = reverb.SampleInfo( key=key, probability=probability, table_size=table_size, priority=priority) return reverb.ReplaySample( info=info, data=acme_types.Transition(o_t, a_t, r_t, d_t, o_tp1))
def _episode_to_transition(step: Dict[str, Any]) -> types.Transition: return types.Transition( observation=step['observation'][:-1], action=step['action'][:-1], reward=step['reward'][:-1], discount=1.0 - tf.cast(step['is_terminal'][1:], dtype=tf.float32), # If next step is terminal, then the observation may be arbitrary. next_observation=step['observation'][1:], )
def _batched_step_to_transition(step: rlds.BatchedStep) -> types.Transition: return types.Transition( observation=tf.nest.map_structure(lambda x: x[0], step[rlds.OBSERVATION]), action=tf.nest.map_structure(lambda x: x[0], step[rlds.ACTION]), reward=tf.nest.map_structure(lambda x: x[0], step[rlds.REWARD]), discount=1.0 - tf.cast(step[rlds.IS_TERMINAL][1], dtype=tf.float32), # If next step is terminal, then the observation may be arbitrary. next_observation=tf.nest.map_structure(lambda x: x[1], step[rlds.OBSERVATION]))
def test_gradient_penalty(self): def dummy_discriminator( transition: types.Transition) -> networks_lib.Logits: return transition.observation + jnp.square(transition.action) zero_transition = types.Transition(0., 0., 0., 0., 0.) zero_transition = tree.map_structure( lambda x: jnp.expand_dims(x, axis=0), zero_transition) self.assertEqual( losses._compute_gradient_penalty(zero_transition, dummy_discriminator, 0.), 1**2 + 0**2) one_transition = types.Transition(1., 1., 0., 0., 0.) one_transition = tree.map_structure( lambda x: jnp.expand_dims(x, axis=0), one_transition) self.assertEqual( losses._compute_gradient_penalty(one_transition, dummy_discriminator, 0.), 1**2 + 2**2)
def test_discrete_actions(self, loss_name): with chex.fake_pmap_and_jit(): num_sgd_steps_per_step = 1 num_steps = 5 # Create a fake environment to test with. environment = fakes.DiscreteEnvironment(num_actions=10, num_observations=100, obs_shape=(10, ), obs_dtype=np.float32) spec = specs.make_environment_spec(environment) dataset_demonstration = fakes.transition_dataset(environment) dataset_demonstration = dataset_demonstration.map( lambda sample: types.Transition(*sample.data)) dataset_demonstration = dataset_demonstration.batch( 8).as_numpy_iterator() # Construct the agent. network = make_networks(spec, discrete_actions=True) def logp_fn(logits, actions): max_logits = jnp.max(logits, axis=-1, keepdims=True) logits = logits - max_logits logits_actions = jnp.sum( jax.nn.one_hot(actions, spec.actions.num_values) * logits, axis=-1) log_prob = logits_actions - special.logsumexp(logits, axis=-1) return log_prob if loss_name == 'logp': loss_fn = bc.logp(logp_fn=logp_fn) elif loss_name == 'rcal': base_loss_fn = bc.logp(logp_fn=logp_fn) loss_fn = bc.rcal(base_loss_fn, discount=0.99, alpha=0.1) else: raise ValueError learner = bc.BCLearner( network=network, random_key=jax.random.PRNGKey(0), loss_fn=loss_fn, optimizer=optax.adam(0.01), demonstrations=dataset_demonstration, num_sgd_steps_per_step=num_sgd_steps_per_step) # Train the agent for _ in range(num_steps): learner.step()
def _step_to_transition(rlds_step: rlds.BatchedStep) -> types.Transition: """Converts batched RLDS steps to batched transitions.""" return types.Transition( observation=rlds_step[rlds.OBSERVATION], action=rlds_step[rlds.ACTION], reward=rlds_step[rlds.REWARD], discount=rlds_step[rlds.DISCOUNT], # We provide next_observation if an algorithm needs it, however note that # it will only contain s_t and s_t+1, so will be one element short of all # other attributes (which contain s_t-1, s_t, s_t+1). next_observation=tree.map_structure(lambda x: x[1:], rlds_step[rlds.OBSERVATION]), extras={ N_STEP_RETURN: rlds_step[N_STEP_RETURN], })
def step(self): sample = next(self._iterator) transitions = types.Transition(*sample.data) self._state, metrics = self._update_step(self._state, transitions) # Compute elapsed time. timestamp = time.time() elapsed_time = timestamp - self._timestamp if self._timestamp else 0 self._timestamp = timestamp # Increment counts and record the current time counts = self._counter.increment(steps=1, walltime=elapsed_time) # Attempts to write the logs. self._logger.write({**metrics, **counts})
def test_make_dataset_transition_adder(self): environment = fakes.ContinuousEnvironment() environment_spec = specs.make_environment_spec(environment) dataset = reverb_dataset.make_dataset( server_address=self.server_address, environment_spec=environment_spec, transition_adder=True) environment_spec = types.Transition( observation=environment_spec.observations, action=environment_spec.actions, reward=environment_spec.rewards, discount=environment_spec.discounts, next_observation=environment_spec.observations, extras=()) self.assertTrue( _check_specs(environment_spec, dataset.element_spec.data))
def step(self): # Get data from replay (dropping extras if any). Note there is no # extra data here because we do not insert any into Reverb. sample = next(self._iterator) transitions = types.Transition(*sample.data) self._state, metrics = self._sgd_step(self._state, transitions) # Compute elapsed time. timestamp = time.time() elapsed_time = timestamp - self._timestamp if self._timestamp else 0 self._timestamp = timestamp # Increment counts and record the current time counts = self._counter.increment(steps=1, walltime=elapsed_time) # Attempts to write the logs. self._logger.write({**metrics, **counts})
def test_continuous_actions(self, loss_name): with chex.fake_pmap_and_jit(): num_sgd_steps_per_step = 1 num_steps = 5 # Create a fake environment to test with. environment = fakes.ContinuousEnvironment(episode_length=10, bounded=True, action_dim=6) spec = specs.make_environment_spec(environment) dataset_demonstration = fakes.transition_dataset(environment) dataset_demonstration = dataset_demonstration.map( lambda sample: types.Transition(*sample.data)) dataset_demonstration = dataset_demonstration.batch( 8).as_numpy_iterator() # Construct the agent. network = make_networks(spec) if loss_name == 'logp': loss_fn = bc.logp(logp_fn=lambda dist_params, actions: dist_params.log_prob(actions)) elif loss_name == 'mse': loss_fn = bc.mse(sample_fn=lambda dist_params, key: dist_params .sample(seed=key)) elif loss_name == 'peerbc': base_loss_fn = bc.logp(logp_fn=lambda dist_params, actions: dist_params.log_prob(actions)) loss_fn = bc.peerbc(base_loss_fn, zeta=0.1) else: raise ValueError learner = bc.BCLearner( network=network, random_key=jax.random.PRNGKey(0), loss_fn=loss_fn, optimizer=optax.adam(0.01), demonstrations=dataset_demonstration, num_sgd_steps_per_step=num_sgd_steps_per_step) # Train the agent for _ in range(num_steps): learner.step()
def step(self): with jax.profiler.StepTraceAnnotation('sampling batch'): sample = next(self._iterator) transitions = types.Transition(*sample.data) with jax.profiler.StepTraceAnnotation('train step'): self._state, metrics = self._update_step(self._state, transitions) # Compute elapsed time. timestamp = time.time() elapsed_time = timestamp - self._timestamp if self._timestamp else 0 self._timestamp = timestamp # Increment counts and record the current time counts = self._counter.increment( steps=self._num_sgd_steps_per_step, walltime=elapsed_time) # Attempts to write the logs. self._logger.write({**metrics, **counts})
def transition_iterator_from_spec( spec: specs.EnvironmentSpec ) -> Callable[[int], Iterator[types.Transition]]: """Constructs fake iterator of transitions. Args: spec: Constructed fake transitions match the provided specification.. Returns: A callable that given a batch_size returns an iterator of transitions. """ observation = _generate_from_spec(spec.observations) action = _generate_from_spec(spec.actions) reward = _generate_from_spec(spec.rewards) discount = _generate_from_spec(spec.discounts) data = types.Transition(observation, action, reward, discount, observation) dataset = tf.data.Dataset.from_tensors(data).repeat() return lambda batch_size: dataset.batch(batch_size).as_numpy_iterator()
def flatten_fn(sample): seq_len = tf.shape(sample.data.observation)[0] arange = tf.range(seq_len) is_future_mask = tf.cast(arange[:, None] < arange[None], tf.float32) discount = self._config.discount**tf.cast(arange[None] - arange[:, None], tf.float32) # pylint: disable=line-too-long probs = is_future_mask * discount # The indexing changes the shape from [seq_len, 1] to [seq_len] goal_index = tf.random.categorical(logits=tf.math.log(probs), num_samples=1)[:, 0] state = sample.data.observation[:-1, :self._config.obs_dim] next_state = sample.data.observation[1:, :self._config.obs_dim] # Create the goal observations in three steps. # 1. Take all future states (not future goals). # 2. Apply obs_to_goal. # 3. Sample one of the future states. Note that we don't look for a goal # for the final state, because there are no future states. goal = sample.data.observation[:, :self._config.obs_dim] goal = contrastive_utils.obs_to_goal_2d( goal, start_index=self._config.start_index, end_index=self._config.end_index) goal = tf.gather(goal, goal_index[:-1]) new_obs = tf.concat([state, goal], axis=1) new_next_obs = tf.concat([next_state, goal], axis=1) transition = types.Transition(observation=new_obs, action=sample.data.action[:-1], reward=sample.data.reward[:-1], discount=sample.data.discount[:-1], next_observation=new_next_obs, extras={ 'next_action': sample.data.action[1:], }) # Shift for the transpose_shuffle. shift = tf.random.uniform((), 0, seq_len, tf.int32) transition = tree.map_structure( lambda t: tf.roll(t, shift, axis=0), transition) return transition
def transition_iterator( environment: dm_env.Environment ) -> Callable[[int], Iterator[types.Transition]]: """Fake dataset of Reverb N-step transition samples. Args: environment: Used to create a fake transition by looking at the observation, action, discount and reward specs. Returns: A callable that given a batch_size returns an iterator with demonstrations. """ observation = environment.observation_spec().generate_value() action = environment.action_spec().generate_value() reward = environment.reward_spec().generate_value() discount = environment.discount_spec().generate_value() data = types.Transition(observation, action, reward, discount, observation) dataset = tf.data.Dataset.from_tensors(data).repeat() return lambda batch_size: dataset.batch(batch_size).as_numpy_iterator()
def step(self): # Get data from replay (dropping extras if any). Note there is no # extra data here because we do not insert any into Reverb. # TODO(raveman): Add a support for offline training, where we do not consume # data from the replay buffer. sample = next(self._iterator_replay) replay_transitions = types.Transition(*sample.data) # Get a batch of Transitions from the demonstration. demonstration_transitions = next(self._iterator_demonstrations) self._state, metrics = self._sgd_step( self._state, (replay_transitions, demonstration_transitions)) # Compute elapsed time. timestamp = time.time() elapsed_time = timestamp - self._timestamp if self._timestamp else 0 self._timestamp = timestamp # Increment counts and record the current time counts = self._counter.increment(steps=1, walltime=elapsed_time) # Attempts to write the logs. self._logger.write({**metrics, **counts})
def transition_dataset_from_spec( spec: specs.EnvironmentSpec) -> tf.data.Dataset: """Constructs fake dataset of Reverb N-step transition samples. Args: spec: Constructed fake transitions match the provided specification. Returns: tf.data.Dataset that produces the same fake N-step transition ReverbSample object indefinitely. """ observation = _generate_from_spec(spec.observations) action = _generate_from_spec(spec.actions) reward = _generate_from_spec(spec.rewards) discount = _generate_from_spec(spec.discounts) data = types.Transition(observation, action, reward, discount, observation) info = tree.map_structure( lambda tf_dtype: tf.ones([], tf_dtype.as_numpy_dtype), reverb.SampleInfo.tf_dtypes()) sample = reverb.ReplaySample(info=info, data=data) return tf.data.Dataset.from_tensors(sample).repeat()
def transition_dataset(environment: dm_env.Environment) -> tf.data.Dataset: """Fake dataset of Reverb N-step transition samples. Args: environment: Used to create a fake transition by looking at the observation, action, discount and reward specs. Returns: tf.data.Dataset that produces the same fake N-step transition ReverSample object indefinitely. """ observation = environment.observation_spec().generate_value() action = environment.action_spec().generate_value() reward = environment.reward_spec().generate_value() discount = environment.discount_spec().generate_value() data = types.Transition(observation, action, reward, discount, observation) info = tree.map_structure( lambda tf_dtype: tf.ones([], tf_dtype.as_numpy_dtype), reverb.SampleInfo.tf_dtypes()) sample = reverb.ReplaySample(info=info, data=data) return tf.data.Dataset.from_tensors(sample).repeat()
# expected transitions that should result from this trajectory. The expected # transitions are of the form: (observation, action, reward, discount, # next_observation, extras). TEST_CASES = [ dict( testcase_name='OneStepFinalReward', n_step=1, additional_discount=1.0, first=dm_env.restart(1), steps=( (0, dm_env.transition(reward=0.0, observation=2)), (0, dm_env.transition(reward=0.0, observation=3)), (0, dm_env.termination(reward=1.0, observation=4)), ), expected_transitions=( types.Transition(1, 0, 0.0, 1.0, 2), types.Transition(2, 0, 0.0, 1.0, 3), types.Transition(3, 0, 1.0, 0.0, 4), )), dict( testcase_name='OneStepDict', n_step=1, additional_discount=1.0, first=dm_env.restart({'foo': 1}), steps=( (0, dm_env.transition(reward=0.0, observation={'foo': 2})), (0, dm_env.transition(reward=0.0, observation={'foo': 3})), (0, dm_env.termination(reward=1.0, observation={'foo': 4})), ), expected_transitions=( types.Transition({'foo': 1}, 0, 0.0, 1.0, {'foo': 2}),
def _spec_to_shapes_and_dtypes(transition_adder: bool, environment_spec: specs.EnvironmentSpec, extra_spec: Optional[types.NestedSpec], sequence_length: Optional[int], convert_zero_size_to_none: bool, using_deprecated_adder: bool): """Creates the shapes and dtypes needed to describe the Reverb dataset. This takes a `environment_spec`, `extra_spec`, and additional information and returns a tuple (shapes, dtypes) that describe the data contained in Reverb. Args: transition_adder: A boolean, describing if a `TransitionAdder` was used to add data. environment_spec: A `specs.EnvironmentSpec`, describing the shapes and dtypes of the data produced by the environment (and the action). extra_spec: A nested structure of objects with a `.shape` and `.dtype` property. This describes any additional data the Actor adds into Reverb. sequence_length: An optional integer for how long the added sequences are, only used with `SequenceAdder`. convert_zero_size_to_none: If True, then all shape dimensions that are 0 are converted to None. A None dimension is only set at runtime. using_deprecated_adder: True if the adder used to generate the data is from acme/adders/reverb/deprecated. Returns: A tuple (dtypes, shapes) that describes the data that has been added into Reverb. """ # The *transition* adder is special in that it also adds an arrival state. if transition_adder: adder_spec = types.Transition( observation=environment_spec.observations, action=environment_spec.actions, reward=environment_spec.rewards, discount=environment_spec.discounts, next_observation=environment_spec.observations, extras=() if not extra_spec else extra_spec) elif using_deprecated_adder and deprecated_base is not None: adder_spec = deprecated_base.Step( observation=environment_spec.observations, action=environment_spec.actions, reward=environment_spec.rewards, discount=environment_spec.discounts, extras=() if not extra_spec else extra_spec) else: adder_spec = adders.Step(observation=environment_spec.observations, action=environment_spec.actions, reward=environment_spec.rewards, discount=environment_spec.discounts, start_of_episode=specs.Array(shape=(), dtype=bool), extras=() if not extra_spec else extra_spec) # Extract the shapes and dtypes from these specs. get_dtype = lambda x: tf.as_dtype(x.dtype) get_shape = lambda x: tf.TensorShape(x.shape) if sequence_length: get_shape = lambda x: tf.TensorShape([sequence_length, *x.shape]) if convert_zero_size_to_none: # TODO(b/143692455): Consider making this default behaviour. get_shape = lambda x: tf.TensorShape( [s if s else None for s in x.shape]) shapes = tree.map_structure(get_shape, adder_spec) dtypes = tree.map_structure(get_dtype, adder_spec) return shapes, dtypes
def _write(self): # NOTE: we do not check that the buffer is of length N here. This means # that at the beginning of an episode we will add the initial N-1 # transitions (of size 1, 2, ...) and at the end of an episode (when # called from write_last) we will write the final transitions of size (N, # N-1, ...). See the Note in the docstring. # Form the n-step transition given the steps. observation = self._buffer[0].observation action = self._buffer[0].action extras = self._buffer[0].extras next_observation = self._next_observation # Give the same tree structure to the n-step return accumulator, # n-step discount accumulator, and self.discount, so that they can be # iterated in parallel using tree.map_structure. (n_step_return, total_discount, self_discount) = tree_utils.broadcast_structures( self._buffer[0].reward, self._buffer[0].discount, self._discount) # Copy total_discount, so that accumulating into it doesn't affect # _buffer[0].discount. total_discount = tree.map_structure(np.copy, total_discount) # Broadcast n_step_return to have the broadcasted shape of # reward * discount. Also copy, to avoid accumulating into # _buffer[0].reward. n_step_return = tree.map_structure( lambda r, d: np.copy(np.broadcast_to(r, np.broadcast(r, d).shape)), n_step_return, total_discount) # NOTE: total discount will have one less discount than it does # step.discounts. This is so that when the learner/update uses an additional # discount we don't apply it twice. Inside the following loop we will # apply this right before summing up the n_step_return. for step in itertools.islice(self._buffer, 1, None): (step_discount, step_reward, total_discount) = tree_utils.broadcast_structures( step.discount, step.reward, total_discount) # Equivalent to: `total_discount *= self._discount`. tree.map_structure(operator.imul, total_discount, self_discount) # Equivalent to: `n_step_return += step.reward * total_discount`. tree.map_structure(lambda nsr, sr, td: operator.iadd(nsr, sr * td), n_step_return, step_reward, total_discount) # Equivalent to: `total_discount *= step.discount`. tree.map_structure(operator.imul, total_discount, step_discount) transition = types.Transition(observation=observation, action=action, reward=n_step_return, discount=total_discount, next_observation=next_observation, extras=extras) # Create a list of steps. if self._final_step_placeholder is None: # utils.final_step_like is expensive (around 0.085ms) to run every time # so we cache its output. self._final_step_placeholder = utils.final_step_like( self._buffer[0], next_observation) final_step: base.Step = self._final_step_placeholder._replace( observation=next_observation) steps = list(self._buffer) + [final_step] # Calculate the priority for this transition. table_priorities = utils.calculate_priorities(self._priority_fns, steps) # Insert the transition into replay along with its priority. self._writer.append(transition) for table, priority in table_priorities.items(): self._writer.create_item(table=table, num_timesteps=1, priority=priority)