예제 #1
0
def get_activation_fn(name: Optional[str] = None, framework: str = "tf"):
    """Returns a framework specific activation function, given a name string.

    Args:
        name (Optional[str]): One of "relu" (default), "tanh", "elu",
            "swish", or "linear" (same as None).
        framework (str): One of "jax", "tf|tfe|tf2" or "torch".

    Returns:
        A framework-specific activtion function. e.g. tf.nn.tanh or
            torch.nn.ReLU. None if name in ["linear", None].

    Raises:
        ValueError: If name is an unknown activation function.
    """
    # Already a callable, return as-is.
    if callable(name):
        return name

    # Infer the correct activation function from the string specifier.
    if framework == "torch":
        if name in ["linear", None]:
            return None
        if name == "swish":
            from ray.rllib.utils.torch_utils import Swish

            return Swish
        _, nn = try_import_torch()
        if name == "relu":
            return nn.ReLU
        elif name == "tanh":
            return nn.Tanh
        elif name == "elu":
            return nn.ELU
    elif framework == "jax":
        if name in ["linear", None]:
            return None
        jax, _ = try_import_jax()
        if name == "swish":
            return jax.nn.swish
        if name == "relu":
            return jax.nn.relu
        elif name == "tanh":
            return jax.nn.hard_tanh
        elif name == "elu":
            return jax.nn.elu
    else:
        assert framework in ["tf", "tfe", "tf2"], "Unsupported framework `{}`!".format(
            framework
        )
        if name in ["linear", None]:
            return None
        tf1, tf, tfv = try_import_tf()
        fn = getattr(tf.nn, name, None)
        if fn is not None:
            return fn

    raise ValueError(
        "Unknown activation ({}) for framework={}!".format(name, framework)
    )
예제 #2
0
def get_initializer(name, framework="tf"):
    """Returns a framework specific initializer, given a name string.

    Args:
        name (str): One of "xavier_uniform" (default), "xavier_normal".
        framework (str): One of "jax", "tf|tfe|tf2" or "torch".

    Returns:
        A framework-specific initializer function, e.g.
            tf.keras.initializers.GlorotUniform or
            torch.nn.init.xavier_uniform_.

    Raises:
        ValueError: If name is an unknown initializer.
    """
    # Already a callable, return as-is.
    if callable(name):
        return name

    if framework == "jax":
        _, flax = try_import_jax()
        assert flax is not None, "`flax` not installed. Try `pip install jax flax`."
        import flax.linen as nn

        if name in [None, "default", "xavier_uniform"]:
            return nn.initializers.xavier_uniform()
        elif name == "xavier_normal":
            return nn.initializers.xavier_normal()
    if framework == "torch":
        _, nn = try_import_torch()
        assert nn is not None, "`torch` not installed. Try `pip install torch`."
        if name in [None, "default", "xavier_uniform"]:
            return nn.init.xavier_uniform_
        elif name == "xavier_normal":
            return nn.init.xavier_normal_
    else:
        assert framework in ["tf", "tfe", "tf2"], "Unsupported framework `{}`!".format(
            framework
        )
        tf1, tf, tfv = try_import_tf()
        assert (
            tf is not None
        ), "`tensorflow` not installed. Try `pip install tensorflow`."
        if name in [None, "default", "xavier_uniform"]:
            return tf.keras.initializers.GlorotUniform
        elif name == "xavier_normal":
            return tf.keras.initializers.GlorotNormal

    raise ValueError(
        "Unknown activation ({}) for framework={}!".format(name, framework)
    )
예제 #3
0
import logging
import numpy as np
import random
import re
import time
import tree  # pip install dm_tree
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import yaml

import ray
from ray.rllib.utils.framework import try_import_jax, try_import_tf, \
    try_import_torch
from ray.rllib.utils.typing import PartialTrainerConfigDict
from ray.tune import CLIReporter, run_experiments

jax, _ = try_import_jax()
tf1, tf, tfv = try_import_tf()
if tf1:
    eager_mode = None
    try:
        from tensorflow.python.eager.context import eager_mode
    except (ImportError, ModuleNotFoundError):
        pass

torch, _ = try_import_torch()

logger = logging.getLogger(__name__)


def framework_iterator(
    config: Optional[PartialTrainerConfigDict] = None,
예제 #4
0
파일: misc.py 프로젝트: vakker/ray
import time
from typing import Callable, Optional

from ray.rllib.utils.framework import get_activation_fn, try_import_jax

jax, flax = try_import_jax()
nn = np = None
if flax:
    import flax.linen as nn
    import jax.numpy as np


class SlimFC:
    """Simple JAX version of a fully connected layer."""
    def __init__(
        self,
        in_size,
        out_size,
        initializer: Optional[Callable] = None,
        activation_fn: Optional[str] = None,
        use_bias: bool = True,
        prng_key: Optional[jax.random.PRNGKey] = None,
        name: Optional[str] = None,
    ):
        """Initializes a SlimFC instance.

        Args:
            in_size (int): The input size of the input data that will be passed
                into this layer.
            out_size (int): The number of nodes in this FC layer.
            initializer (flax.:
예제 #5
0
import gym
import logging
import numpy as np

from ray.rllib.utils.framework import try_import_jax, try_import_tf, \
    try_import_torch

jax = try_import_jax()
tf1, tf, tfv = try_import_tf()
if tf1:
    eager_mode = None
    try:
        from tensorflow.python.eager.context import eager_mode
    except (ImportError, ModuleNotFoundError):
        pass

torch, _ = try_import_torch()

logger = logging.getLogger(__name__)


def framework_iterator(config=None,
                       frameworks=("tf2", "tf", "tfe", "torch"),
                       session=False):
    """An generator that allows for looping through n frameworks for testing.

    Provides the correct config entries ("framework") as well
    as the correct eager/non-eager contexts for tfe/tf.

    Args:
        config (Optional[dict]): An optional config dict to alter in place