def get_activation_fn(name: Optional[str] = None, framework: str = "tf"): """Returns a framework specific activation function, given a name string. Args: name (Optional[str]): One of "relu" (default), "tanh", "elu", "swish", or "linear" (same as None). framework (str): One of "jax", "tf|tfe|tf2" or "torch". Returns: A framework-specific activtion function. e.g. tf.nn.tanh or torch.nn.ReLU. None if name in ["linear", None]. Raises: ValueError: If name is an unknown activation function. """ # Already a callable, return as-is. if callable(name): return name # Infer the correct activation function from the string specifier. if framework == "torch": if name in ["linear", None]: return None if name == "swish": from ray.rllib.utils.torch_utils import Swish return Swish _, nn = try_import_torch() if name == "relu": return nn.ReLU elif name == "tanh": return nn.Tanh elif name == "elu": return nn.ELU elif framework == "jax": if name in ["linear", None]: return None jax, _ = try_import_jax() if name == "swish": return jax.nn.swish if name == "relu": return jax.nn.relu elif name == "tanh": return jax.nn.hard_tanh elif name == "elu": return jax.nn.elu else: assert framework in ["tf", "tfe", "tf2"], "Unsupported framework `{}`!".format( framework ) if name in ["linear", None]: return None tf1, tf, tfv = try_import_tf() fn = getattr(tf.nn, name, None) if fn is not None: return fn raise ValueError( "Unknown activation ({}) for framework={}!".format(name, framework) )
def get_initializer(name, framework="tf"): """Returns a framework specific initializer, given a name string. Args: name (str): One of "xavier_uniform" (default), "xavier_normal". framework (str): One of "jax", "tf|tfe|tf2" or "torch". Returns: A framework-specific initializer function, e.g. tf.keras.initializers.GlorotUniform or torch.nn.init.xavier_uniform_. Raises: ValueError: If name is an unknown initializer. """ # Already a callable, return as-is. if callable(name): return name if framework == "jax": _, flax = try_import_jax() assert flax is not None, "`flax` not installed. Try `pip install jax flax`." import flax.linen as nn if name in [None, "default", "xavier_uniform"]: return nn.initializers.xavier_uniform() elif name == "xavier_normal": return nn.initializers.xavier_normal() if framework == "torch": _, nn = try_import_torch() assert nn is not None, "`torch` not installed. Try `pip install torch`." if name in [None, "default", "xavier_uniform"]: return nn.init.xavier_uniform_ elif name == "xavier_normal": return nn.init.xavier_normal_ else: assert framework in ["tf", "tfe", "tf2"], "Unsupported framework `{}`!".format( framework ) tf1, tf, tfv = try_import_tf() assert ( tf is not None ), "`tensorflow` not installed. Try `pip install tensorflow`." if name in [None, "default", "xavier_uniform"]: return tf.keras.initializers.GlorotUniform elif name == "xavier_normal": return tf.keras.initializers.GlorotNormal raise ValueError( "Unknown activation ({}) for framework={}!".format(name, framework) )
import logging import numpy as np import random import re import time import tree # pip install dm_tree from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import yaml import ray from ray.rllib.utils.framework import try_import_jax, try_import_tf, \ try_import_torch from ray.rllib.utils.typing import PartialTrainerConfigDict from ray.tune import CLIReporter, run_experiments jax, _ = try_import_jax() tf1, tf, tfv = try_import_tf() if tf1: eager_mode = None try: from tensorflow.python.eager.context import eager_mode except (ImportError, ModuleNotFoundError): pass torch, _ = try_import_torch() logger = logging.getLogger(__name__) def framework_iterator( config: Optional[PartialTrainerConfigDict] = None,
import time from typing import Callable, Optional from ray.rllib.utils.framework import get_activation_fn, try_import_jax jax, flax = try_import_jax() nn = np = None if flax: import flax.linen as nn import jax.numpy as np class SlimFC: """Simple JAX version of a fully connected layer.""" def __init__( self, in_size, out_size, initializer: Optional[Callable] = None, activation_fn: Optional[str] = None, use_bias: bool = True, prng_key: Optional[jax.random.PRNGKey] = None, name: Optional[str] = None, ): """Initializes a SlimFC instance. Args: in_size (int): The input size of the input data that will be passed into this layer. out_size (int): The number of nodes in this FC layer. initializer (flax.:
import gym import logging import numpy as np from ray.rllib.utils.framework import try_import_jax, try_import_tf, \ try_import_torch jax = try_import_jax() tf1, tf, tfv = try_import_tf() if tf1: eager_mode = None try: from tensorflow.python.eager.context import eager_mode except (ImportError, ModuleNotFoundError): pass torch, _ = try_import_torch() logger = logging.getLogger(__name__) def framework_iterator(config=None, frameworks=("tf2", "tf", "tfe", "torch"), session=False): """An generator that allows for looping through n frameworks for testing. Provides the correct config entries ("framework") as well as the correct eager/non-eager contexts for tfe/tf. Args: config (Optional[dict]): An optional config dict to alter in place