def execution_plan(workers, config, **kwargs): assert ( len(kwargs) == 0 ), "IMPALA execution_plan does NOT take any additional parameters" if config["num_aggregation_workers"] > 0: train_batches = gather_experiences_tree_aggregation(workers, config) else: train_batches = gather_experiences_directly(workers, config) # Start the learner thread. learner_thread = make_learner_thread(workers.local_worker(), config) learner_thread.start() # This sub-flow sends experiences to the learner. enqueue_op = train_batches.for_each(Enqueue(learner_thread.inqueue)) # Only need to update workers if there are remote workers. if workers.remote_workers(): enqueue_op = enqueue_op.zip_with_source_actor().for_each( BroadcastUpdateLearnerWeights( learner_thread, workers, broadcast_interval=config["broadcast_interval"], ) ) def record_steps_trained(item): count, fetches, _ = item metrics = _get_shared_metrics() # Manually update the steps trained counter since the learner # thread is executing outside the pipeline. metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = count metrics.counters[STEPS_TRAINED_COUNTER] += count return item # This sub-flow updates the steps trained counter based on learner # output. dequeue_op = Dequeue( learner_thread.outqueue, check=learner_thread.is_alive ).for_each(record_steps_trained) merged_op = Concurrently( [enqueue_op, dequeue_op], mode="async", output_indexes=[1] ) # Callback for APPO to use to update KL, target network periodically. # The input to the callback is the learner fetches dict. if config["after_train_step"]: merged_op = merged_op.for_each(lambda t: t[1]).for_each( config["after_train_step"](workers, config) ) return StandardMetricsReporting(merged_op, workers, config).for_each( learner_thread.add_learner_metrics )
def execution_plan(workers, config): if config["num_aggregation_workers"] > 0: train_batches = gather_experiences_tree_aggregation(workers, config) else: train_batches = gather_experiences_directly(workers, config) # Start the learner thread. learner_thread = make_learner_thread(workers.local_worker(), config) learner_thread.start() # This sub-flow sends experiences to the learner. enqueue_op = train_batches \ .for_each(Enqueue(learner_thread.inqueue)) # Only need to update workers if there are remote workers. if workers.remote_workers(): enqueue_op = enqueue_op.zip_with_source_actor() \ .for_each(BroadcastUpdateLearnerWeights( learner_thread, workers, broadcast_interval=config["broadcast_interval"])) # This sub-flow updates the steps trained counter based on learner output. dequeue_op = Dequeue( learner_thread.outqueue, check=learner_thread.is_alive) \ .for_each(record_steps_trained) merged_op = Concurrently([enqueue_op, dequeue_op], mode="async", output_indexes=[1]) # Callback for APPO to use to update KL, target network periodically. # The input to the callback is the learner fetches dict. if config["after_train_step"]: merged_op = merged_op.for_each(lambda t: t[1]).for_each( config["after_train_step"](workers, config)) return StandardMetricsReporting(merged_op, workers, config) \ .for_each(learner_thread.add_learner_metrics)