Esempio n. 1
0
def gather_experiences_directly(workers, config):
    rollouts = ParallelRollouts(
        workers,
        mode="async",
        num_async=config["max_requests_in_flight_per_sampler_worker"],
    )

    # Augment with replay and concat to desired train batch size.
    train_batches = (
        rollouts.for_each(lambda batch: batch.decompress_if_needed())
        .for_each(
            MixInReplay(
                num_slots=config["replay_buffer_num_slots"],
                replay_proportion=config["replay_proportion"],
            )
        )
        .flatten()
        .combine(
            ConcatBatches(
                min_batch_size=config["train_batch_size"],
                count_steps_by=config["multiagent"]["count_steps_by"],
            )
        )
    )

    return train_batches
Esempio n. 2
0
        def generator():
            it = rollout_group.gather_async(
                num_async=config["max_sample_requests_in_flight_per_worker"])

            # Update the rollout worker with our latest policy weights.
            def update_worker(item):
                worker, batch = item
                if self.weights:
                    worker.set_weights.remote(self.weights, self.global_vars)
                return batch

            # Augment with replay and concat to desired train batch size.
            it = it.zip_with_source_actor() \
                .for_each(update_worker) \
                .for_each(lambda batch: batch.decompress_if_needed()) \
                .for_each(MixInReplay(
                    num_slots=config["replay_buffer_num_slots"],
                    replay_proportion=config["replay_proportion"])) \
                .flatten() \
                .combine(
                    ConcatBatches(
                        min_batch_size=config["train_batch_size"]))

            for train_batch in it:
                yield train_batch
Esempio n. 3
0
def execution_plan(workers, config):
    rollouts = ParallelRollouts(workers, mode="bulk_sync")

    train_op = rollouts \
        .for_each(MixInReplay(config["buffer_size"])) \
        .combine(
            ConcatBatches(min_batch_size=config["train_batch_size"])) \
        .for_each(TrainOneStep(workers)) \
        .for_each(UpdateTargetNetwork(
            workers, config["target_network_update_freq"]))

    return StandardMetricsReporting(train_op, workers, config)