def gather_experiences_directly(workers, config): rollouts = ParallelRollouts( workers, mode="async", num_async=config["max_requests_in_flight_per_sampler_worker"], ) # Augment with replay and concat to desired train batch size. train_batches = ( rollouts.for_each(lambda batch: batch.decompress_if_needed()) .for_each( MixInReplay( num_slots=config["replay_buffer_num_slots"], replay_proportion=config["replay_proportion"], ) ) .flatten() .combine( ConcatBatches( min_batch_size=config["train_batch_size"], count_steps_by=config["multiagent"]["count_steps_by"], ) ) ) return train_batches
def generator(): it = rollout_group.gather_async( num_async=config["max_sample_requests_in_flight_per_worker"]) # Update the rollout worker with our latest policy weights. def update_worker(item): worker, batch = item if self.weights: worker.set_weights.remote(self.weights, self.global_vars) return batch # Augment with replay and concat to desired train batch size. it = it.zip_with_source_actor() \ .for_each(update_worker) \ .for_each(lambda batch: batch.decompress_if_needed()) \ .for_each(MixInReplay( num_slots=config["replay_buffer_num_slots"], replay_proportion=config["replay_proportion"])) \ .flatten() \ .combine( ConcatBatches( min_batch_size=config["train_batch_size"])) for train_batch in it: yield train_batch
def execution_plan(workers, config): rollouts = ParallelRollouts(workers, mode="bulk_sync") train_op = rollouts \ .for_each(MixInReplay(config["buffer_size"])) \ .combine( ConcatBatches(min_batch_size=config["train_batch_size"])) \ .for_each(TrainOneStep(workers)) \ .for_each(UpdateTargetNetwork( workers, config["target_network_update_freq"])) return StandardMetricsReporting(train_op, workers, config)