Ejemplo n.º 1
0
    def execution_plan(workers, config, **kwargs):
        assert (
            len(kwargs) == 0
        ), "IMPALA execution_plan does NOT take any additional parameters"

        if config["num_aggregation_workers"] > 0:
            train_batches = gather_experiences_tree_aggregation(workers, config)
        else:
            train_batches = gather_experiences_directly(workers, config)

        # Start the learner thread.
        learner_thread = make_learner_thread(workers.local_worker(), config)
        learner_thread.start()

        # This sub-flow sends experiences to the learner.
        enqueue_op = train_batches.for_each(Enqueue(learner_thread.inqueue))
        # Only need to update workers if there are remote workers.
        if workers.remote_workers():
            enqueue_op = enqueue_op.zip_with_source_actor().for_each(
                BroadcastUpdateLearnerWeights(
                    learner_thread,
                    workers,
                    broadcast_interval=config["broadcast_interval"],
                )
            )

        def record_steps_trained(item):
            count, fetches, _ = item
            metrics = _get_shared_metrics()
            # Manually update the steps trained counter since the learner
            # thread is executing outside the pipeline.
            metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = count
            metrics.counters[STEPS_TRAINED_COUNTER] += count
            return item

        # This sub-flow updates the steps trained counter based on learner
        # output.
        dequeue_op = Dequeue(
            learner_thread.outqueue, check=learner_thread.is_alive
        ).for_each(record_steps_trained)

        merged_op = Concurrently(
            [enqueue_op, dequeue_op], mode="async", output_indexes=[1]
        )

        # Callback for APPO to use to update KL, target network periodically.
        # The input to the callback is the learner fetches dict.
        if config["after_train_step"]:
            merged_op = merged_op.for_each(lambda t: t[1]).for_each(
                config["after_train_step"](workers, config)
            )

        return StandardMetricsReporting(merged_op, workers, config).for_each(
            learner_thread.add_learner_metrics
        )
Ejemplo n.º 2
0
def execution_plan(workers, config):
    if config["num_aggregation_workers"] > 0:
        train_batches = gather_experiences_tree_aggregation(workers, config)
    else:
        train_batches = gather_experiences_directly(workers, config)

    # Start the learner thread.
    learner_thread = make_learner_thread(workers.local_worker(), config)
    learner_thread.start()

    # This sub-flow sends experiences to the learner.
    enqueue_op = train_batches \
        .for_each(Enqueue(learner_thread.inqueue))
    # Only need to update workers if there are remote workers.
    if workers.remote_workers():
        enqueue_op = enqueue_op.zip_with_source_actor() \
            .for_each(BroadcastUpdateLearnerWeights(
                learner_thread, workers,
                broadcast_interval=config["broadcast_interval"]))

    # This sub-flow updates the steps trained counter based on learner output.
    dequeue_op = Dequeue(
            learner_thread.outqueue, check=learner_thread.is_alive) \
        .for_each(record_steps_trained)

    merged_op = Concurrently([enqueue_op, dequeue_op],
                             mode="async",
                             output_indexes=[1])

    # Callback for APPO to use to update KL, target network periodically.
    # The input to the callback is the learner fetches dict.
    if config["after_train_step"]:
        merged_op = merged_op.for_each(lambda t: t[1]).for_each(
            config["after_train_step"](workers, config))

    return StandardMetricsReporting(merged_op, workers, config) \
        .for_each(learner_thread.add_learner_metrics)