def initialize_dataset_with_logits(server_port, table_name, observations_shape, batch_size, n_points, is_episode=False): maps_shape = tf.TensorShape(observations_shape[0]) scalars_shape = tf.TensorShape(observations_shape[1]) actions_tf_shape = tf.TensorShape([]) logits_tf_shape = tf.TensorShape([ 4, ]) rewards_tf_shape = tf.TensorShape([]) dones_tf_shape = tf.TensorShape([]) total_rewards_tf_shape = tf.TensorShape([]) progress_tf_shape = tf.TensorShape([]) # episode_dones_tf_shape = tf.TensorShape([]) # episode_steps_tf_shape = tf.TensorShape([]) if is_episode: observations_tf_shape = ([n_points] + maps_shape, [n_points] + scalars_shape) obs_dtypes = tf.nest.map_structure(lambda x: tf.uint8, observations_tf_shape) dataset = reverb.ReplayDataset( server_address=f'localhost:{server_port}', table=table_name, max_in_flight_samples_per_worker=2 * batch_size, dtypes=(tf.int32, tf.float32, obs_dtypes, tf.float32, tf.float32), shapes=([n_points] + actions_tf_shape, [n_points] + logits_tf_shape, observations_tf_shape, [n_points] + rewards_tf_shape, [n_points] + dones_tf_shape)) else: observations_tf_shape = (maps_shape, scalars_shape) obs_dtypes = tf.nest.map_structure(lambda x: tf.uint8, observations_tf_shape) dataset = reverb.ReplayDataset( server_address=f'localhost:{server_port}', table=table_name, max_in_flight_samples_per_worker=2 * batch_size, dtypes=(tf.int32, tf.float32, obs_dtypes, tf.float32, tf.float32, tf.float32, tf.float32), shapes=(actions_tf_shape, logits_tf_shape, observations_tf_shape, rewards_tf_shape, dones_tf_shape, total_rewards_tf_shape, progress_tf_shape)) dataset = dataset.batch(n_points) dataset = dataset.batch(batch_size) return dataset
def initialize_dataset(server_port, table_name, observations_shape, batch_size, n_steps): """ batch_size in fact equals min size of a buffer """ # if there are many dimensions assume halite if len(observations_shape) > 1: maps_shape = tf.TensorShape(observations_shape[0]) scalars_shape = tf.TensorShape(observations_shape[1]) observations_shape = (maps_shape, scalars_shape) else: observations_shape = tf.nest.map_structure(lambda x: tf.TensorShape(x), observations_shape) actions_shape = tf.TensorShape([]) rewards_shape = tf.TensorShape([]) dones_shape = tf.TensorShape([]) obs_dtypes = tf.nest.map_structure(lambda x: tf.float32, observations_shape) dataset = reverb.ReplayDataset(server_address=f'localhost:{server_port}', table=table_name, max_in_flight_samples_per_worker=10, dtypes=(tf.int32, obs_dtypes, tf.float32, tf.float32), shapes=(actions_shape, observations_shape, rewards_shape, dones_shape)) dataset = dataset.batch(n_steps) dataset = dataset.batch(batch_size) return dataset
def _make_dataset(unused_idx: tf.Tensor) -> tf.data.Dataset: if environment_spec is not None: shapes, dtypes = _spec_to_shapes_and_dtypes( transition_adder, environment_spec, extra_spec=extra_spec, sequence_length=sequence_length, convert_zero_size_to_none=convert_zero_size_to_none, using_deprecated_adder=using_deprecated_adder) dataset = reverb.ReplayDataset( server_address=server_address, table=table, dtypes=dtypes, shapes=shapes, max_in_flight_samples_per_worker= max_in_flight_samples_per_worker, sequence_length=sequence_length, emit_timesteps=sequence_length is None) else: dataset = reverb.ReplayDataset.from_table_signature( server_address=server_address, table=table, max_in_flight_samples_per_worker= max_in_flight_samples_per_worker, sequence_length=sequence_length, emit_timesteps=sequence_length is None) # Finish the pipeline: batch and prefetch. if batch_size: dataset = dataset.batch(batch_size, drop_remainder=True) return dataset
def _make_dataset(_): dataset = reverb.ReplayDataset( f"localhost:{PORT}", TABLE_NAME, max_in_flight_samples_per_worker=config["common"]["batch_size"], dtypes=(tf.float32, tf.int64, tf.float32, tf.float32, tf.float32), shapes=( tf.TensorShape((4, 84, 84)), tf.TensorShape([]), tf.TensorShape([]), tf.TensorShape((4, 84, 84)), tf.TensorShape([]), ), ) dataset = dataset.batch(config["common"]["batch_size"], drop_remainder=True) return dataset
def __init__(self, name="ReverbUniformReplayBuffer", reverb_server=None): super().__init__(name=name) self.device = Params.DEVICE with tf.device(self.device), self.name_scope: self.buffer_size = tf.cast(Params.BUFFER_SIZE, tf.int64) self.batch_size = tf.cast(Params.MINIBATCH_SIZE, tf.int64) self.batch_size_float = tf.cast(Params.MINIBATCH_SIZE, tf.float64) self.sequence_length = tf.cast(Params.N_STEP_RETURNS, tf.int64) # Initialize the reverb server if not reverb_server: self.reverb_server = reverb.Server(tables=[ reverb.Table( name=Params.BUFFER_TYPE, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=self.buffer_size, rate_limiter=reverb.rate_limiters.MinSize( self.batch_size), ) ], ) else: self.reverb_server = reverb_server dataset = reverb.ReplayDataset( server_address=f'localhost:{self.reverb_server.port}', table=Params.BUFFER_TYPE, max_in_flight_samples_per_worker=2 * self.batch_size, dtypes=Params.BUFFER_DATA_SPEC_DTYPES, shapes=Params.BUFFER_DATA_SPEC_SHAPES, ) dataset = dataset.map( map_func=reduce_trajectory, num_parallel_calls=tf.data.experimental.AUTOTUNE, deterministic=True, ) dataset = dataset.batch(self.batch_size) dataset = dataset.prefetch(5) self.iterator = dataset.__iter__()
def test_reverb(data): import reverb import tensorflow as tf print("TEST REVERB") print("initializing...") reverb_server = reverb.Server( tables=[ reverb.Table( name="req", sampler=reverb.selectors.Prioritized(0.6), remover=reverb.selectors.Fifo(), max_size=CAPACITY, rate_limiter=reverb.rate_limiters.MinSize(100), ) ], port=15867, ) client = reverb_server.in_process_client() for i in range(CAPACITY): client.insert([col[i] for col in data], {"req": np.random.rand()}) dataset = reverb.ReplayDataset( server_address="localhost:15867", table="req", dtypes=(tf.float64, tf.float64, tf.float64, tf.float64), shapes=( tf.TensorShape([1, 84, 84]), tf.TensorShape([1, 84, 84]), tf.TensorShape([]), tf.TensorShape([]), ), max_in_flight_samples_per_worker=10, ) dataset = dataset.batch(BATCH_SIZE) print("ready") t0 = time.perf_counter() for sample in dataset.take(TEST_CNT): pass t1 = time.perf_counter() print(TEST_CNT, t1 - t0)