コード例 #1
0
    def validate_config(self, config: AlgorithmConfigDict) -> None:
        # First check, whether old `timesteps_per_iteration` is used. If so
        # convert right away as for CQL, we must measure in training timesteps,
        # never sampling timesteps (CQL does not sample).
        if config.get("timesteps_per_iteration",
                      DEPRECATED_VALUE) != DEPRECATED_VALUE:
            deprecation_warning(
                old="timesteps_per_iteration",
                new="min_train_timesteps_per_iteration",
                error=False,
            )
            config["min_train_timesteps_per_iteration"] = config[
                "timesteps_per_iteration"]
            config["timesteps_per_iteration"] = DEPRECATED_VALUE

        # Call super's validation method.
        super().validate_config(config)

        if config["num_gpus"] > 1:
            raise ValueError("`num_gpus` > 1 not yet supported for CQL!")

        # CQL-torch performs the optimizer steps inside the loss function.
        # Using the multi-GPU optimizer will therefore not work (see multi-GPU
        # check above) and we must use the simple optimizer for now.
        if config["simple_optimizer"] is not True and config[
                "framework"] == "torch":
            config["simple_optimizer"] = True

        if config["framework"] in ["tf", "tf2", "tfe"] and tfp is None:
            logger.warning(
                "You need `tensorflow_probability` in order to run CQL! "
                "Install it via `pip install tensorflow_probability`. Your "
                f"tf.__version__={tf.__version__ if tf else None}."
                "Trying to import tfp results in the following error:")
            try_import_tfp(error=True)
コード例 #2
0
    def validate_config(self, config: AlgorithmConfigDict) -> None:
        # Call super's validation method.
        super().validate_config(config)

        if config["use_state_preprocessor"] != DEPRECATED_VALUE:
            deprecation_warning(old="config['use_state_preprocessor']",
                                error=False)
            config["use_state_preprocessor"] = DEPRECATED_VALUE

        if config.get("policy_model", DEPRECATED_VALUE) != DEPRECATED_VALUE:
            deprecation_warning(
                old="config['policy_model']",
                new="config['policy_model_config']",
                error=False,
            )
            config["policy_model_config"] = config["policy_model"]

        if config.get("Q_model", DEPRECATED_VALUE) != DEPRECATED_VALUE:
            deprecation_warning(
                old="config['Q_model']",
                new="config['q_model_config']",
                error=False,
            )
            config["q_model_config"] = config["Q_model"]

        if config["grad_clip"] is not None and config["grad_clip"] <= 0.0:
            raise ValueError("`grad_clip` value must be > 0.0!")

        if config["framework"] in ["tf", "tf2", "tfe"] and tfp is None:
            logger.warning(
                "You need `tensorflow_probability` in order to run SAC! "
                "Install it via `pip install tensorflow_probability`. Your "
                f"tf.__version__={tf.__version__ if tf else None}."
                "Trying to import tfp results in the following error:")
            try_import_tfp(error=True)
コード例 #3
0
ファイル: cql.py プロジェクト: stefanbschneider/ray
def validate_config(config: TrainerConfigDict):
    if config["num_gpus"] > 1:
        raise ValueError("`num_gpus` > 1 not yet supported for CQL!")

    # CQL-torch performs the optimizer steps inside the loss function.
    # Using the multi-GPU optimizer will therefore not work (see multi-GPU
    # check above) and we must use the simple optimizer for now.
    if config["simple_optimizer"] is not True and \
            config["framework"] == "torch":
        config["simple_optimizer"] = True

    if config["framework"] in ["tf", "tf2", "tfe"] and tfp is None:
        logger.warning(
            "You need `tensorflow_probability` in order to run CQL! "
            "Install it via `pip install tensorflow_probability`. Your "
            f"tf.__version__={tf.__version__ if tf else None}."
            "Trying to import tfp results in the following error:")
        try_import_tfp(error=True)
コード例 #4
0
ファイル: sac.py プロジェクト: kaushikb11/ray
def validate_config(config: TrainerConfigDict) -> None:
    """Validates the Trainer's config dict.

    Args:
        config (TrainerConfigDict): The Trainer's config to check.

    Raises:
        ValueError: In case something is wrong with the config.
    """
    if config["use_state_preprocessor"] != DEPRECATED_VALUE:
        deprecation_warning(old="config['use_state_preprocessor']",
                            error=False)
        config["use_state_preprocessor"] = DEPRECATED_VALUE

    if config["grad_clip"] is not None and config["grad_clip"] <= 0.0:
        raise ValueError("`grad_clip` value must be > 0.0!")

    if config["framework"] in ["tf", "tf2", "tfe"] and tfp is None:
        logger.warning(
            "You need `tensorflow_probability` in order to run SAC! "
            "Install it via `pip install tensorflow_probability`. Your "
            f"tf.__version__={tf.__version__ if tf else None}."
            "Trying to import tfp results in the following error:")
        try_import_tfp(error=True)
コード例 #5
0
from math import log
import numpy as np
import functools
import tree

from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils import MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT, \
    SMALL_NUMBER
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.framework import try_import_tf, try_import_tfp
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
from ray.rllib.utils.typing import TensorType, List

tf1, tf, tfv = try_import_tf()
tfp = try_import_tfp()


@DeveloperAPI
class TFActionDistribution(ActionDistribution):
    """TF-specific extensions for building action distributions."""
    @override(ActionDistribution)
    def __init__(self, inputs: List[TensorType], model: ModelV2):
        super().__init__(inputs, model)
        self.sample_op = self._build_sample_op()
        self.sampled_action_logp_op = self.logp(self.sample_op)

    @DeveloperAPI
    def _build_sample_op(self) -> TensorType:
        """Implement this instead of sample(), to enable op reuse.