Exemple #1
0
def validate_config(config):
    """Validates the Trainer's config dict.

    Args:
        config (TrainerConfigDict): The Trainer's config to check.

    Raises:
        ValueError: In case something is wrong with the config.
    """

    # Auto-train_batch_size: Calculate from rollout len and envs-per-worker.
    if config["train_batch_size"] == -1:
        config["train_batch_size"] = (
            config["rollout_fragment_length"] * config["num_envs_per_worker"])
    # Users should not define `train_batch_size` directly (always -1).
    else:
        raise ValueError(
            "Set rollout_fragment_length instead of train_batch_size "
            "for DDPPO.")

    # Only supported for PyTorch so far.
    if config["framework"] != "torch":
        raise ValueError(
            "Distributed data parallel is only supported for PyTorch")
    # `num_gpus` must be 0/None, since all optimization happens on Workers.
    if config["num_gpus"]:
        raise ValueError(
            "When using distributed data parallel, you should set "
            "num_gpus=0 since all optimization "
            "is happening on workers. Enable GPUs for workers by setting "
            "num_gpus_per_worker=1.")
    # `batch_mode` must be "truncate_episodes".
    if config["batch_mode"] != "truncate_episodes":
        raise ValueError(
            "Distributed data parallel requires truncate_episodes "
            "batch mode.")
    # Call (base) PPO's config validation function.
    ppo.validate_config(config)
Exemple #2
0
def validate_config(config):
    if config["train_batch_size"] == -1:
        # Auto set.
        config["train_batch_size"] = (
            config["rollout_fragment_length"] * config["num_envs_per_worker"])
    else:
        raise ValueError(
            "Set rollout_fragment_length instead of train_batch_size "
            "for DDPPO.")
    if not config["use_pytorch"]:
        raise ValueError(
            "Distributed data parallel is only supported for PyTorch")
    if config["num_gpus"]:
        raise ValueError(
            "When using distributed data parallel, you should set "
            "num_gpus=0 since all optimization "
            "is happening on workers. Enable GPUs for workers by setting "
            "num_gpus_per_worker=1.")
    if config["batch_mode"] != "truncate_episodes":
        raise ValueError(
            "Distributed data parallel requires truncate_episodes "
            "batch mode.")
    ppo.validate_config(config)
Exemple #3
0
def validate_config_basic(config):
    assert "joint_dataset_sample_batch_size" in config
    assert "use_joint_dataset" in config
    assert "novelty_mode" in config
    assert config["novelty_mode"] in ["mean", "min", "max"]
    validate_config(config)