import os import time from typing import Callable, Text, Tuple from absl import logging import tensorflow.compat.v2 as tf from tf_agents.agents import tf_agent from tf_agents.policies import py_tf_eager_policy from tf_agents.typing import types from tf_agents.utils import lazy_loader # Lazy loading since not all users have the reverb package installed. reverb = lazy_loader.LazyLoader('reverb', globals(), 'reverb') # By default the implementation of wait functions blocks with relatively large # number of frequent retries assuming that the event usually happens soon, but # occasionally takes longer. _WAIT_DEFAULT_SLEEP_TIME_SECS = 1 _WAIT_DEFAULT_NUM_RETRIES = 60 * 60 * 24 # 1 day def create_train_step() -> tf.Variable: return tf.Variable( 0, trainable=False, dtype=tf.int64, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, shape=())
"""Utilities for working with a reverb replay buffer.""" from __future__ import absolute_import from __future__ import division # Using Type Annotations. from __future__ import print_function from typing import Text, Union from absl import logging from tf_agents.typing import types from tf_agents.utils import lazy_loader # Lazy loading since not all users have the reverb package installed. reverb = lazy_loader.LazyLoader("reverb", globals(), "reverb") class ReverbAddEpisodeObserver(object): """Observer for writing episodes to the Reverb replay buffer.""" def __init__(self, py_client: types.ReverbClient, table_name: Text, max_sequence_length: int, priority: Union[float, int] = 1, bypass_partial_episodes: bool = False): """Creates an instance of the ReverbAddEpisodeObserver. **Note**: This observer is designed to work with py_drivers only, and does not support batches.