Example #1
0
def execution_plan(workers, config):
    # For A3C, compute policy gradients remotely on the rollout workers.
    grads = AsyncGradients(workers)

    # Apply the gradients as they arrive. We set update_all to False so that
    # only the worker sending the gradient is updated with new weights.
    train_op = grads.for_each(ApplyGradients(workers, update_all=False))

    return StandardMetricsReporting(train_op, workers, config)
Example #2
0
    def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
                       **kwargs) -> LocalIterator[dict]:
        assert (len(kwargs) == 0
                ), "A3C execution_plan does NOT take any additional parameters"

        # For A3C, compute policy gradients remotely on the rollout workers.
        grads = AsyncGradients(workers)

        # Apply the gradients as they arrive. We set update_all to False so
        # that only the worker sending the gradient is updated with new
        # weights.
        train_op = grads.for_each(ApplyGradients(workers, update_all=False))

        return StandardMetricsReporting(train_op, workers, config)
Example #3
0
def test_async_grads(ray_start_regular_shared):
    workers = make_workers(2)
    a = AsyncGradients(workers)
    res1 = next(a)
    assert isinstance(res1, tuple) and len(res1) == 2, res1
    counters = a.shared_metrics.get().counters
    assert counters["num_steps_sampled"] == 100, counters
    workers.stop()
Example #4
0
 def test_async_grads(self):
     workers = make_workers(2)
     a = AsyncGradients(workers)
     res1 = next(a)
     assert isinstance(res1, tuple) and len(res1) == 2, res1
     counters = a.shared_metrics.get().counters
     assert counters[STEPS_SAMPLED_COUNTER] == 100, counters
     workers.stop()
Example #5
0
def execution_plan(workers, config):
    # For A3C, compute policy gradients remotely on the rollout workers.
    # rollouts = ParallelRollouts(workers, mode="bulk_sync")

    grads = AsyncGradients(workers)
    
    # Apply the gradients as they arrive. We set update_all to False so that
    # only the worker sending the gradient is updated with new weights.
    #train_op = grads.for_each(ApplyGradients(workers, update_all=False))
    print("_____")
    print(workers)
    temp1 = workers
    temp2 = workers
    rem1 = workers.remote_workers()[0:6]
    rem2 = workers.remote_workers()[6:11]
    temp1.reset(rem1)
    temp2.reset(rem2)
  
    rollouts1 = ParallelRollouts(temp1, mode="bulk_sync")
    rollouts2 = ParallelRollouts(temp2, mode="bulk_sync")


    train_step_op1 = TrainTFMultiGPU(
                workers=temp1,
                sgd_minibatch_size=config["train_batch_size"],
                num_sgd_iter=1,
                num_gpus=config["num_gpus"],
                shuffle_sequences=True,
                _fake_gpus=config["_fake_gpus"],
                framework=config.get("framework"))

    train_step_op2 = TrainTFMultiGPU(
                    workers=temp2,
                    sgd_minibatch_size=config["train_batch_size"],
                    num_sgd_iter=1,
                    num_gpus=config["num_gpus"],
                    shuffle_sequences=True,
                    _fake_gpus=config["_fake_gpus"],
                    framework=config.get("framework"))

    train_op1 = rollouts1.combine(
            ConcatBatches(
                min_batch_size=config["train_batch_size"],
                count_steps_by=config["multiagent"][
                    "count_steps_by"])).for_each(train_step_op1)
    train_op2 = rollouts2.combine(
            ConcatBatches(
                min_batch_size=config["train_batch_size"],
                count_steps_by=config["multiagent"][
                    "count_steps_by"])).for_each(train_step_op2)
    
    #train_op = grads.for_each(ApplyGradients(workers, update_all=False))
    
    
    return StandardMetricsReporting(train_op1, temp1, config).union(StandardMetricsReporting(train_op2, temp2, config))
Example #6
0
File: a3c.py Project: yncxcw/ray
def execution_plan(workers: WorkerSet,
                   config: TrainerConfigDict) -> LocalIterator[dict]:
    """Execution plan of the MARWIL/BC algorithm. Defines the distributed
    dataflow.

    Args:
        workers (WorkerSet): The WorkerSet for training the Polic(y/ies)
            of the Trainer.
        config (TrainerConfigDict): The trainer's configuration dict.

    Returns:
        LocalIterator[dict]: A local iterator over training metrics.
    """
    # For A3C, compute policy gradients remotely on the rollout workers.
    grads = AsyncGradients(workers)

    # Apply the gradients as they arrive. We set update_all to False so that
    # only the worker sending the gradient is updated with new weights.
    train_op = grads.for_each(ApplyGradients(workers, update_all=False))

    return StandardMetricsReporting(train_op, workers, config)