def execution_plan(workers, config, **kwargs): assert len(kwargs) == 0, ( "Alpha zero execution_plan does NOT take any additional parameters") rollouts = ParallelRollouts(workers, mode="bulk_sync") if config["simple_optimizer"]: train_op = rollouts.combine( ConcatBatches( min_batch_size=config["train_batch_size"], count_steps_by=config["multiagent"]["count_steps_by"], )).for_each( TrainOneStep(workers, num_sgd_iter=config["num_sgd_iter"])) else: replay_buffer = SimpleReplayBuffer(config["buffer_size"]) store_op = rollouts \ .for_each(StoreToReplayBuffer(local_buffer=replay_buffer)) replay_op = Replay(local_buffer=replay_buffer) \ .filter(WaitUntilTimestepsElapsed(config["learning_starts"])) \ .combine( ConcatBatches( min_batch_size=config["train_batch_size"], count_steps_by=config["multiagent"]["count_steps_by"], )) \ .for_each(TrainOneStep( workers, num_sgd_iter=config["num_sgd_iter"])) train_op = Concurrently( [store_op, replay_op], mode="round_robin", output_indexes=[1]) return StandardMetricsReporting(train_op, workers, config)
def execution_plan(workers, config): rollouts = ParallelRollouts(workers, mode="bulk_sync") if config["simple_optimizer"]: train_op = rollouts \ .combine(ConcatBatches( min_batch_size=config["train_batch_size"])) \ .for_each(TrainOneStep( workers, num_sgd_iter=config["num_sgd_iter"])) else: replay_buffer = SimpleReplayBuffer(config["buffer_size"]) store_op = rollouts \ .for_each(StoreToReplayBuffer(local_buffer=replay_buffer)) replay_op = Replay(local_buffer=replay_buffer) \ .filter(WaitUntilTimestepsElapsed(config["learning_starts"])) \ .combine( ConcatBatches(min_batch_size=config["train_batch_size"])) \ .for_each(TrainOneStep( workers, num_sgd_iter=config["num_sgd_iter"])) train_op = Concurrently( [store_op, replay_op], mode="round_robin", output_indexes=[1]) return StandardMetricsReporting(train_op, workers, config)
def execution_plan( workers: WorkerSet, config: TrainerConfigDict, **kwargs ) -> LocalIterator[dict]: assert ( len(kwargs) == 0 ), "QMIX execution_plan does NOT take any additional parameters" rollouts = ParallelRollouts(workers, mode="bulk_sync") replay_buffer = SimpleReplayBuffer(config["buffer_size"]) store_op = rollouts.for_each(StoreToReplayBuffer(local_buffer=replay_buffer)) train_op = ( Replay(local_buffer=replay_buffer) .combine( ConcatBatches( min_batch_size=config["train_batch_size"], count_steps_by=config["multiagent"]["count_steps_by"], ) ) .for_each(TrainOneStep(workers)) .for_each( UpdateTargetNetwork(workers, config["target_network_update_freq"]) ) ) merged_op = Concurrently( [store_op, train_op], mode="round_robin", output_indexes=[1] ) return StandardMetricsReporting(merged_op, workers, config)
def execution_plan(workers: WorkerSet, config: TrainerConfigDict) -> LocalIterator[dict]: """Execution plan of the MARWIL/BC algorithm. Defines the distributed dataflow. Args: workers (WorkerSet): The WorkerSet for training the Polic(y/ies) of the Trainer. config (TrainerConfigDict): The trainer's configuration dict. Returns: LocalIterator[dict]: A local iterator over training metrics. """ rollouts = ParallelRollouts(workers, mode="bulk_sync") replay_buffer = SimpleReplayBuffer(config["replay_buffer_size"]) store_op = rollouts \ .for_each(StoreToReplayBuffer(local_buffer=replay_buffer)) replay_op = Replay(local_buffer=replay_buffer) \ .combine( ConcatBatches( min_batch_size=config["train_batch_size"], count_steps_by=config["multiagent"]["count_steps_by"], )) \ .for_each(TrainOneStep(workers)) train_op = Concurrently([store_op, replay_op], mode="round_robin", output_indexes=[1]) return StandardMetricsReporting(train_op, workers, config)
def execution_plan(workers, config): rollouts = ParallelRollouts(workers, mode="bulk_sync") replay_buffer = SimpleReplayBuffer(config["replay_buffer_size"]) store_op = rollouts \ .for_each(StoreToReplayBuffer(local_buffer=replay_buffer)) replay_op = Replay(local_buffer=replay_buffer) \ .combine( ConcatBatches(min_batch_size=config["train_batch_size"])) \ .for_each(TrainOneStep(workers)) train_op = Concurrently([store_op, replay_op], mode="round_robin", output_indexes=[1]) return StandardMetricsReporting(train_op, workers, config)
def execution_plan(workers, config): rollouts = ParallelRollouts(workers, mode="bulk_sync") replay_buffer = SimpleReplayBuffer(config["buffer_size"]) store_op = rollouts \ .for_each(StoreToReplayBuffer(local_buffer=replay_buffer)) train_op = Replay(local_buffer=replay_buffer) \ .combine( ConcatBatches( min_batch_size=config["train_batch_size"], count_steps_by=config["multiagent"]["count_steps_by"] )) \ .for_each(TrainOneStep(workers)) \ .for_each(UpdateTargetNetwork( workers, config["target_network_update_freq"])) merged_op = Concurrently([store_op, train_op], mode="round_robin", output_indexes=[1]) return StandardMetricsReporting(merged_op, workers, config)
def new_buffer(): return SimpleReplayBuffer(num_slots=capacity)