示例#1
0
def initialize_dataset_with_logits(server_port,
                                   table_name,
                                   observations_shape,
                                   batch_size,
                                   n_points,
                                   is_episode=False):
    maps_shape = tf.TensorShape(observations_shape[0])
    scalars_shape = tf.TensorShape(observations_shape[1])

    actions_tf_shape = tf.TensorShape([])
    logits_tf_shape = tf.TensorShape([
        4,
    ])
    rewards_tf_shape = tf.TensorShape([])
    dones_tf_shape = tf.TensorShape([])
    total_rewards_tf_shape = tf.TensorShape([])
    progress_tf_shape = tf.TensorShape([])
    # episode_dones_tf_shape = tf.TensorShape([])
    # episode_steps_tf_shape = tf.TensorShape([])

    if is_episode:
        observations_tf_shape = ([n_points] + maps_shape,
                                 [n_points] + scalars_shape)
        obs_dtypes = tf.nest.map_structure(lambda x: tf.uint8,
                                           observations_tf_shape)

        dataset = reverb.ReplayDataset(
            server_address=f'localhost:{server_port}',
            table=table_name,
            max_in_flight_samples_per_worker=2 * batch_size,
            dtypes=(tf.int32, tf.float32, obs_dtypes, tf.float32, tf.float32),
            shapes=([n_points] + actions_tf_shape,
                    [n_points] + logits_tf_shape, observations_tf_shape,
                    [n_points] + rewards_tf_shape,
                    [n_points] + dones_tf_shape))
    else:
        observations_tf_shape = (maps_shape, scalars_shape)
        obs_dtypes = tf.nest.map_structure(lambda x: tf.uint8,
                                           observations_tf_shape)

        dataset = reverb.ReplayDataset(
            server_address=f'localhost:{server_port}',
            table=table_name,
            max_in_flight_samples_per_worker=2 * batch_size,
            dtypes=(tf.int32, tf.float32, obs_dtypes, tf.float32, tf.float32,
                    tf.float32, tf.float32),
            shapes=(actions_tf_shape, logits_tf_shape, observations_tf_shape,
                    rewards_tf_shape, dones_tf_shape, total_rewards_tf_shape,
                    progress_tf_shape))
        dataset = dataset.batch(n_points)

    dataset = dataset.batch(batch_size)

    return dataset
示例#2
0
def initialize_dataset(server_port, table_name, observations_shape, batch_size,
                       n_steps):
    """
    batch_size in fact equals min size of a buffer
    """
    # if there are many dimensions assume halite
    if len(observations_shape) > 1:
        maps_shape = tf.TensorShape(observations_shape[0])
        scalars_shape = tf.TensorShape(observations_shape[1])
        observations_shape = (maps_shape, scalars_shape)
    else:
        observations_shape = tf.nest.map_structure(lambda x: tf.TensorShape(x),
                                                   observations_shape)

    actions_shape = tf.TensorShape([])
    rewards_shape = tf.TensorShape([])
    dones_shape = tf.TensorShape([])

    obs_dtypes = tf.nest.map_structure(lambda x: tf.float32,
                                       observations_shape)

    dataset = reverb.ReplayDataset(server_address=f'localhost:{server_port}',
                                   table=table_name,
                                   max_in_flight_samples_per_worker=10,
                                   dtypes=(tf.int32, obs_dtypes, tf.float32,
                                           tf.float32),
                                   shapes=(actions_shape, observations_shape,
                                           rewards_shape, dones_shape))

    dataset = dataset.batch(n_steps)
    dataset = dataset.batch(batch_size)

    return dataset
示例#3
0
    def _make_dataset(unused_idx: tf.Tensor) -> tf.data.Dataset:
        if environment_spec is not None:
            shapes, dtypes = _spec_to_shapes_and_dtypes(
                transition_adder,
                environment_spec,
                extra_spec=extra_spec,
                sequence_length=sequence_length,
                convert_zero_size_to_none=convert_zero_size_to_none,
                using_deprecated_adder=using_deprecated_adder)
            dataset = reverb.ReplayDataset(
                server_address=server_address,
                table=table,
                dtypes=dtypes,
                shapes=shapes,
                max_in_flight_samples_per_worker=
                max_in_flight_samples_per_worker,
                sequence_length=sequence_length,
                emit_timesteps=sequence_length is None)
        else:
            dataset = reverb.ReplayDataset.from_table_signature(
                server_address=server_address,
                table=table,
                max_in_flight_samples_per_worker=
                max_in_flight_samples_per_worker,
                sequence_length=sequence_length,
                emit_timesteps=sequence_length is None)

        # Finish the pipeline: batch and prefetch.
        if batch_size:
            dataset = dataset.batch(batch_size, drop_remainder=True)

        return dataset
示例#4
0
 def _make_dataset(_):
     dataset = reverb.ReplayDataset(
         f"localhost:{PORT}",
         TABLE_NAME,
         max_in_flight_samples_per_worker=config["common"]["batch_size"],
         dtypes=(tf.float32, tf.int64, tf.float32, tf.float32, tf.float32),
         shapes=(
             tf.TensorShape((4, 84, 84)),
             tf.TensorShape([]),
             tf.TensorShape([]),
             tf.TensorShape((4, 84, 84)),
             tf.TensorShape([]),
         ),
     )
     dataset = dataset.batch(config["common"]["batch_size"], drop_remainder=True)
     return dataset
示例#5
0
    def __init__(self, name="ReverbUniformReplayBuffer", reverb_server=None):
        super().__init__(name=name)
        self.device = Params.DEVICE

        with tf.device(self.device), self.name_scope:

            self.buffer_size = tf.cast(Params.BUFFER_SIZE, tf.int64)
            self.batch_size = tf.cast(Params.MINIBATCH_SIZE, tf.int64)
            self.batch_size_float = tf.cast(Params.MINIBATCH_SIZE, tf.float64)
            self.sequence_length = tf.cast(Params.N_STEP_RETURNS, tf.int64)

            # Initialize the reverb server
            if not reverb_server:
                self.reverb_server = reverb.Server(tables=[
                    reverb.Table(
                        name=Params.BUFFER_TYPE,
                        sampler=reverb.selectors.Uniform(),
                        remover=reverb.selectors.Fifo(),
                        max_size=self.buffer_size,
                        rate_limiter=reverb.rate_limiters.MinSize(
                            self.batch_size),
                    )
                ], )
            else:
                self.reverb_server = reverb_server

            dataset = reverb.ReplayDataset(
                server_address=f'localhost:{self.reverb_server.port}',
                table=Params.BUFFER_TYPE,
                max_in_flight_samples_per_worker=2 * self.batch_size,
                dtypes=Params.BUFFER_DATA_SPEC_DTYPES,
                shapes=Params.BUFFER_DATA_SPEC_SHAPES,
            )

            dataset = dataset.map(
                map_func=reduce_trajectory,
                num_parallel_calls=tf.data.experimental.AUTOTUNE,
                deterministic=True,
            )
            dataset = dataset.batch(self.batch_size)
            dataset = dataset.prefetch(5)
            self.iterator = dataset.__iter__()
示例#6
0
def test_reverb(data):
    import reverb
    import tensorflow as tf

    print("TEST REVERB")
    print("initializing...")
    reverb_server = reverb.Server(
        tables=[
            reverb.Table(
                name="req",
                sampler=reverb.selectors.Prioritized(0.6),
                remover=reverb.selectors.Fifo(),
                max_size=CAPACITY,
                rate_limiter=reverb.rate_limiters.MinSize(100),
            )
        ],
        port=15867,
    )
    client = reverb_server.in_process_client()
    for i in range(CAPACITY):
        client.insert([col[i] for col in data], {"req": np.random.rand()})
    dataset = reverb.ReplayDataset(
        server_address="localhost:15867",
        table="req",
        dtypes=(tf.float64, tf.float64, tf.float64, tf.float64),
        shapes=(
            tf.TensorShape([1, 84, 84]),
            tf.TensorShape([1, 84, 84]),
            tf.TensorShape([]),
            tf.TensorShape([]),
        ),
        max_in_flight_samples_per_worker=10,
    )
    dataset = dataset.batch(BATCH_SIZE)
    print("ready")
    t0 = time.perf_counter()
    for sample in dataset.take(TEST_CNT):
        pass
    t1 = time.perf_counter()
    print(TEST_CNT, t1 - t0)