예제 #1
0
                               buffer_size=config["buffer_size"],
                               train_batch_size=config["train_batch_size"],
                               sample_batch_size=config["sample_batch_size"],
                               **extra_config)
    workers.add_workers(config["num_workers"])
    opt._set_workers(workers.remote_workers())
    return opt


def update_target_based_on_num_steps_trained(trainer, fetches):
    # Ape-X updates based on num steps trained, not sampled
    if (trainer.optimizer.num_steps_trained -
            trainer.state["last_target_update_ts"] >
            trainer.config["target_network_update_freq"]):
        trainer.workers.local_worker().foreach_trainable_policy(
            lambda p, _: p.update_target())
        trainer.state["last_target_update_ts"] = (
            trainer.optimizer.num_steps_trained)
        trainer.state["num_target_updates"] += 1


APEX_TRAINER_PROPERTIES = {
    "make_workers": defer_make_workers,
    "make_policy_optimizer": make_async_optimizer,
    "after_optimizer_step": update_target_based_on_num_steps_trained,
}

ApexTrainer = DQNTrainer.with_updates(name="APEX",
                                      default_config=APEX_DEFAULT_CONFIG,
                                      **APEX_TRAINER_PROPERTIES)
예제 #2
0
    if config["simple_optimizer"]:
        train_step_op = TrainOneStep(workers)
    else:
        train_step_op = TrainTFMultiGPU(
            workers=workers,
            sgd_minibatch_size=config["train_batch_size"],
            num_sgd_iter=1,
            num_gpus=config["num_gpus"],
            shuffle_sequences=True,
            _fake_gpus=config["_fake_gpus"],
            framework=config.get("framework"))

    # (2) Read and train on experiences from the replay buffer.
    replay_op = Replay(local_buffer=local_replay_buffer) \
        .for_each(train_step_op) \
        .for_each(UpdateTargetNetwork(
            workers, config["target_network_update_freq"]))

    # Alternate deterministically between (1) and (2).
    train_op = Concurrently([store_op, replay_op],
                            mode="round_robin",
                            output_indexes=[1])

    return StandardMetricsReporting(train_op, workers, config)


SimpleQTrainer = DQNTrainer.with_updates(default_policy=SimpleQTFPolicy,
                                         get_policy_class=get_policy_class,
                                         execution_plan=execution_plan,
                                         default_config=DEFAULT_CONFIG)
예제 #3
0
파일: apex.py 프로젝트: AmeerHajAli/ray2
    # Add in extra replay and learner metrics to the training result.
    def add_apex_metrics(result):
        replay_stats = ray.get(replay_actors[0].stats.remote(
            config["optimizer"].get("debug")))
        exploration_infos = workers.foreach_trainable_policy(
            lambda p, _: p.get_exploration_info())
        result["info"].update({
            "exploration_infos":
            exploration_infos,
            "learner_queue":
            learner_thread.learner_queue_size.stats(),
            "learner":
            copy.deepcopy(learner_thread.stats),
            "replay_shard_0":
            replay_stats,
        })
        return result

    # Only report metrics from the workers with the lowest 1/3 of epsilons.
    selected_workers = workers.remote_workers(
    )[-len(workers.remote_workers()) // 3:]

    return StandardMetricsReporting(
        merged_op, workers, config,
        selected_workers=selected_workers).for_each(add_apex_metrics)


ApexTrainer = DQNTrainer.with_updates(name="APEX",
                                      default_config=APEX_DEFAULT_CONFIG,
                                      execution_plan=apex_execution_plan)
예제 #4
0
            learner_thread.learner_queue_size.stats(),
            "learner":
            copy.deepcopy(learner_thread.stats),
            "replay_shard_0":
            replay_stats,
        })
        return result

    # Only report metrics from the workers with the lowest 1/3 of epsilons.
    selected_workers = workers.remote_workers(
    )[-len(workers.remote_workers()) // 3:]

    return StandardMetricsReporting(
        merged_op, workers, config,
        selected_workers=selected_workers).for_each(add_apex_metrics)


def apex_validate_config(config):
    if config["num_gpus"] > 1:
        raise ValueError("`num_gpus` > 1 not yet supported for APEX-DQN!")
    validate_config(config)


ApexTrainer = DQNTrainer.with_updates(
    name="APEX",
    default_config=APEX_DEFAULT_CONFIG,
    validate_config=apex_validate_config,
    execution_plan=apex_execution_plan,
    mixins=[OverrideDefaultResourceRequest],
)
예제 #5
0
    #     Path(__file__).parent / "../dataset_public/mixed_loop/its_merge_a"
    # ).resolve(), (
    #     Path(__file__).parent / "../dataset/intersection_4lane_sv_right"
    #     Path(__file__).parent / "../dataset_public/mixed_loop/roundabout_its_a"
    # ).resolve(), (
    #     Path(__file__).parent / "../dataset_public/mixed_loop/roundabout_merge_a"
    Path(__file__).parent / "../dataset/simple").resolve()]
print(f"training on {scenario_paths}")

from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.agents.dqn.dqn import DEFAULT_CONFIG, DQNTrainer, validate_config, execution_plan, get_policy_class

config = DEFAULT_CONFIG.copy()
config["decompose_num"] = 3
DQN = DQNTrainer.with_updates(name="DQN_TORCH",
                              default_policy=DQNTorchPolicy,
                              default_config=config,
                              get_policy_class=None)


def parse_args():
    parser = argparse.ArgumentParser("train on multi scenarios")

    # env setting
    parser.add_argument("--scenario",
                        type=str,
                        default=None,
                        help="Scenario name")
    parser.add_argument("--exper", type=str, default="multi_scenarios")
    parser.add_argument("--headless",
                        default=False,
                        action="store_true",