示例#1
0
        info['success'] = 1.0 if reward > 0 else 0.
        return obs, reward, done, info

    def reset(self, **kwargs):
        obs = self.env.reset(**kwargs)
        self._last_mission = ''
        self._tokens = []
        return self._transform_observation(obs)

    def _transform_observation(self, observation):
        # Note: The original BabyAI environment give the same instruction at every
        # steps of an episode.
        observation['direction'] = np.int64(observation['direction'])
        mission = observation['mission']
        if not self._one_token_per_step:
            observation['mission'] = self._vectorize(mission)
        else:
            if mission != self._last_mission:
                if mission != '':
                    self._tokens.extend(self._tokenize(mission))
                self._last_mission = mission
            if len(self._tokens) > 0:
                observation['mission'] = self._tokens.pop(0)
            else:
                observation['mission'] = np.int64(0)

        return observation


gin.constant('BabyAIWrapper.VOCAB_SIZE', BabyAIWrapper.VOCAB_SIZE)
示例#2
0
gin.external_configurable(nn.SyncBatchNorm, module='th.nn')
gin.external_configurable(nn.InstanceNorm1d, module='th.nn')
gin.external_configurable(nn.InstanceNorm2d, module='th.nn')
gin.external_configurable(nn.InstanceNorm3d, module='th.nn')
gin.external_configurable(nn.LayerNorm, module='th.nn')
gin.external_configurable(nn.LocalResponseNorm, module='th.nn')

# Optimizers
gin.external_configurable(optim.Adadelta, module='th.optim')
gin.external_configurable(optim.Adagrad, module='th.optim')
gin.external_configurable(optim.Adam, module='th.optim')
gin.external_configurable(optim.SparseAdam, module='th.optim')
gin.external_configurable(optim.Adamax, module='th.optim')
gin.external_configurable(optim.ASGD, module='th.optim')
gin.external_configurable(optim.LBFGS, module='th.optim')
gin.external_configurable(optim.RMSprop, module='th.optim')
gin.external_configurable(optim.Rprop, module='th.optim')
gin.external_configurable(optim.SGD, module='th.optim')

# Constants
gin.constant('th.float', th.float)
gin.constant('th.float16', th.float16)
gin.constant('th.float32', th.float32)
gin.constant('th.float64', th.float64)
gin.constant('th.int', th.int)
gin.constant('th.int8', th.int8)
gin.constant('th.int16', th.int16)
gin.constant('th.int32', th.int32)
gin.constant('th.int64', th.int64)
gin.constant('th.uint8', th.uint8)
# DataLoader
gin.config.external_configurable(data.DataLoader, module='torch.utils.data')

# Transforms
gin.config.external_configurable(transforms.CenterCrop,
                                 module='torchvision.transforms')
gin.config.external_configurable(transforms.ColorJitter,
                                 module='torchvision.transforms')
gin.config.external_configurable(transforms.FiveCrop,
                                 module='torchvision.transforms')
gin.config.external_configurable(transforms.Grayscale,
                                 module='torchvision.transforms')
gin.config.external_configurable(transforms.Normalize,
                                 module='torchvision.transforms')
gin.config.external_configurable(transforms.RandomGrayscale,
                                 module='torchvision.transforms')
gin.config.external_configurable(transforms.RandomHorizontalFlip,
                                 module='torchvision.transforms')
gin.config.external_configurable(transforms.RandomRotation,
                                 module='torchvision.transforms')
gin.config.external_configurable(transforms.RandomVerticalFlip,
                                 module='torchvision.transforms')
gin.config.external_configurable(transforms.Resize,
                                 module='torchvision.transforms')
gin.config.external_configurable(transforms.Scale,
                                 module='torchvision.transforms')

# Class weights
gin.constant('RARE_CLASS_WEIGHTS', torch.from_numpy(np.load('./rare_class_weights.npy')).to(torch.device('cuda')))
gin.constant('CLASS_WEIGHTS', torch.from_numpy(np.load('./class_weights.npy')).to(torch.device('cuda')))
示例#4
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""MinAtar environment made compatible for Dopamine."""

from dopamine.discrete_domains import atari_lib
from flax import linen as nn
import gin
import jax
import jax.numpy as jnp
import minatar


gin.constant('minatar_env.ASTERIX_SHAPE', (10, 10, 4))
gin.constant('minatar_env.BREAKOUT_SHAPE', (10, 10, 4))
gin.constant('minatar_env.FREEWAY_SHAPE', (10, 10, 7))
gin.constant('minatar_env.SEAQUEST_SHAPE', (10, 10, 10))
gin.constant('minatar_env.SPACE_INVADERS_SHAPE', (10, 10, 6))
gin.constant('minatar_env.DTYPE', jnp.float64)


class MinAtarEnv(object):
  """Wrapper class for MinAtar environments."""

  def __init__(self, game_name):
    self.env = minatar.Environment(env_name=game_name)
    self.env.n = self.env.num_actions()
    self.game_over = False
示例#5
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Various networks for Jax Dopamine agents."""

from typing import Tuple, Union

from dopamine.discrete_domains import atari_lib
from flax import linen as nn
import gin
import jax
import jax.numpy as jnp
import numpy as onp

gin.constant('jax_networks.CARTPOLE_OBSERVATION_DTYPE', jnp.float64)
gin.constant('jax_networks.CARTPOLE_MIN_VALS',
             (-2.4, -5., -onp.pi / 12., -onp.pi * 2.))
gin.constant('jax_networks.CARTPOLE_MAX_VALS',
             (2.4, 5., onp.pi / 12., onp.pi * 2.))
gin.constant('jax_networks.ACROBOT_OBSERVATION_DTYPE', jnp.float64)
gin.constant('jax_networks.ACROBOT_MIN_VALS', (-1., -1., -1., -1., -5., -5.))
gin.constant('jax_networks.ACROBOT_MAX_VALS', (1., 1., 1., 1., 5., 5.))
gin.constant('jax_networks.LUNAR_OBSERVATION_DTYPE', jnp.float64)
gin.constant('jax_networks.MOUNTAINCAR_OBSERVATION_DTYPE', jnp.float64)
gin.constant('jax_networks.MOUNTAINCAR_MIN_VALS', (-1.2, -0.07))
gin.constant('jax_networks.MOUNTAINCAR_MAX_VALS', (0.6, 0.07))


### DQN Networks ###
@gin.configurable
示例#6
0
    Args:
      target: The target Vocabulary instance to be "translated" to.

    Returns:
      A list, mapping the ints of the self vocabulary to the ones of the target.
    """
    return [target.get(token, target.padding_code) for token in self._voc]


proteins = Vocabulary(
    tokens='LAVGESIRDTKPFNQYHMWCUOBZX',
    specials=('<', '>'),
    padding='_',
    order=(2, 0, 1)
)
gin.constant('vocabulary.proteins', proteins)


alternative = Vocabulary(
    tokens='ACDEFGHIKLMNPQRSTVWYBOUXZ',
    specials=('.', '-', '<', '>'),
    padding='_',
    order=(0, 1, 2)
)
gin.constant('vocabulary.alternative', alternative)


@gin.configurable
def get_default(vocab = alternative):
  """A convenient function to gin configure the default vocabulary."""
  return vocab
示例#7
0
from tf_agents.environments import atari_preprocessing
from tf_agents.environments import atari_wrappers
from tf_agents.environments import py_environment
from tf_agents.environments import suite_gym

from tf_agents.typing import types

# Typical Atari 2600 Gym environment with some basic preprocessing.
DEFAULT_ATARI_GYM_WRAPPERS = (atari_preprocessing.AtariPreprocessing, )
# The following is just AtariPreprocessing with frame stacking. Performance wise
# it's much better to have stacking implemented as part of replay-buffer/agent.
# As soon as this functionality in TF-Agents is ready and verified, this set of
# wrappers will be removed.
DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING = DEFAULT_ATARI_GYM_WRAPPERS + (
    atari_wrappers.FrameStack4, )
gin.constant('DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING',
             DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING)


@gin.configurable
def game(name: Text = 'Pong',
         obs_type: Text = 'image',
         mode: Text = 'NoFrameskip',
         version: Text = 'v0') -> Text:
    """Generates the full name for the game.

  Args:
    name: String. Ex. Pong, SpaceInvaders, ...
    obs_type: String, type of observation. Ex. 'image' or 'ram'.
    mode: String. Ex. '', 'NoFrameskip' or 'Deterministic'.
    version: String. Ex. 'v0' or 'v4'.
示例#8
0
LEG_ORDER = ["front_left", "back_left", "front_right", "back_right"]

END_EFFECTOR_NAMES = (
    "knee_front_rightR_joint",
    "knee_front_leftL_joint",
    "knee_back_rightR_joint",
    "knee_back_leftL_joint",
)

MOTOR_NAMES = JOINT_NAMES
MOTOR_GROUP = collections.OrderedDict((("body_motors", JOINT_NAMES),))

KNEE_CONSTRAINT_POINT_LONG = [0, 0.0045, 0.088]
KNEE_CONSTRAINT_POINT_SHORT = [0, 0.0045, 0.100]

# Add the gin constants to be used for gin binding in config.
gin.constant("minitaur_constants.MINITAUR_URDF_PATH", MINITAUR_URDF_PATH)
gin.constant("minitaur_constants.MINITAUR_INIT_POSITION", INIT_POSITION)
gin.constant("minitaur_constants.MINITAUR_INIT_ORIENTATION_QUAT",
             INIT_ORIENTATION_QUAT)
gin.constant("minitaur_constants.MINITAUR_INIT_ORIENTATION_RPY",
             INIT_ORIENTATION_RPY)
gin.constant("minitaur_constants.MINITAUR_INIT_JOINT_ANGLES", INIT_JOINT_ANGLES)
gin.constant("minitaur_constants.MINITAUR_JOINT_DIRECTIONS", JOINT_DIRECTIONS)
gin.constant("minitaur_constants.MINITAUR_JOINT_OFFSETS", JOINT_OFFSETS)
gin.constant("minitaur_constants.MINITAUR_MOTOR_NAMES", MOTOR_NAMES)
gin.constant("minitaur_constants.MINITAUR_END_EFFECTOR_NAMES",
             END_EFFECTOR_NAMES)
gin.constant("minitaur_constants.MINITAUR_MOTOR_GROUP", MOTOR_GROUP)
示例#9
0
                                            shape=[self.REWARD_DIMENSION])

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        # Get the second and third reward from ``info``
        cost_reward = -info["cost"]
        success_reward = float(info["goal_met"])
        return obs, np.array([reward, cost_reward, success_reward],
                             dtype=np.float32), done, info

    @property
    def reward_space(self):
        return self._reward_space


gin.constant('SafetyGym.REWARD_DIMENSION', VectorReward.REWARD_DIMENSION)


@gin.configurable
def load(environment_name,
         env_id=None,
         discount=1.0,
         max_episode_steps=None,
         unconstrained=False,
         gym_env_wrappers=(),
         alf_env_wrappers=()):
    """Loads the selected environment and wraps it with the specified wrappers.

    Note that by default a ``TimeLimit`` wrapper is used to limit episode lengths
    to the default benchmarks defined by the registered environments.
示例#10
0
import gin
import numpy as np
from events_tfds.events import asl_dvs, cifar10_dvs, mnist_dvs, ncaltech101, nmnist

gin.constant("ASL_DVS_GRID_SHAPE", asl_dvs.GRID_SHAPE)
gin.constant("CIFAR10_DVS_GRID_SHAPE", cifar10_dvs.GRID_SHAPE)
gin.constant("MNIST_DVS_GRID_SHAPE", mnist_dvs.GRID_SHAPE)
gin.constant("NCALTECH101_GRID_SHAPE", ncaltech101.GRID_SHAPE)
gin.constant("NMNIST_GRID_SHAPE", nmnist.GRID_SHAPE)

gin.constant("ASL_DVS_NUM_CLASSES", asl_dvs.NUM_CLASSES)
gin.constant("CIFAR10_DVS_NUM_CLASSES", cifar10_dvs.NUM_CLASSES)
gin.constant("MNIST_DVS_NUM_CLASSES", mnist_dvs.NUM_CLASSES)
gin.constant("NCALTECH101_NUM_CLASSES", ncaltech101.NUM_CLASSES)
gin.constant("NMNIST_NUM_CLASSES", nmnist.NUM_CLASSES)

gin.constant("PI_ON_8", np.pi / 8)
gin.constant("NEG_PI_ON_8", -np.pi / 8)
示例#11
0
    (LEG_NAMES[0], JOINT_NAMES[0:3]),
    (LEG_NAMES[1], JOINT_NAMES[3:6]),
    (LEG_NAMES[2], JOINT_NAMES[6:9]),
    (LEG_NAMES[3], JOINT_NAMES[9:12]),
))

# Regulates the joint angle change when in position control mode.
MAX_MOTOR_ANGLE_CHANGE_PER_STEP = 0.12

# The hip joint location in the CoM frame.
HIP_POSITIONS = collections.OrderedDict((
    (LEG_NAMES[0], (0.21, -0.1157, 0)),
    (LEG_NAMES[1], (0.21, 0.1157, 0)),
    (LEG_NAMES[2], (-0.21, -0.1157, 0)),
    (LEG_NAMES[3], (-0.21, 0.1157, 0)),
))

# Add the gin constants to be used for gin binding in config. Append "LAIKAGO_"
# for unique binding names.
gin.constant("laikago_constants.LAIKAGO_NUM_MOTORS", NUM_MOTORS)
gin.constant("laikago_constants.LAIKAGO_URDF_PATH", URDF_PATH)
gin.constant("laikago_constants.LAIKAGO_INIT_POSITION", INIT_POSITION)
gin.constant("laikago_constants.LAIKAGO_INIT_ORIENTATION", INIT_ORIENTATION)
gin.constant("laikago_constants.LAIKAGO_INIT_JOINT_ANGLES", INIT_JOINT_ANGLES)
gin.constant("laikago_constants.LAIKAGO_JOINT_DIRECTIONS", JOINT_DIRECTIONS)
gin.constant("laikago_constants.LAIKAGO_JOINT_OFFSETS", JOINT_OFFSETS)
gin.constant("laikago_constants.LAIKAGO_MOTOR_NAMES", MOTOR_NAMES)
gin.constant("laikago_constants.LAIKAGO_END_EFFECTOR_NAMES",
             END_EFFECTOR_NAMES)
gin.constant("laikago_constants.LAIKAGO_MOTOR_GROUP", MOTOR_GROUP)
示例#12
0
from pybullet_envs.minitaur.envs_v2.evaluation import metric_logger
from pybullet_envs.minitaur.envs_v2.scenes import scene_base
from pybullet_envs.minitaur.envs_v2.sensors import sensor
from pybullet_envs.minitaur.envs_v2.sensors import space_utils
from pybullet_envs.minitaur.envs_v2.utilities import rendering_utils
from pybullet_envs.minitaur.robots import autonomous_object
from pybullet_envs.minitaur.robots import robot_base

_ACTION_EPS = 0.01
_NUM_SIMULATION_ITERATION_STEPS = 300
_LOG_BUFFER_LENGTH = 5000

SIM_CLOCK = 'SIM_CLOCK'

# Exports this symbol so we can use it in the config file.
gin.constant('locomotion_gym_env.SIM_CLOCK', SIM_CLOCK)

# This allows us to bind @time.time in the gin configuration.
gin.external_configurable(time.time, module='time')


# TODO(b/122048194): Enable position/torque/hybrid control mode.
@gin.configurable
class LocomotionGymEnv(gym.Env):
    """The gym environment for the locomotion tasks."""
    metadata = {
        'render.modes': ['human', 'rgb_array', 'topdown'],
        'video.frames_per_second': 100
    }

    def __init__(self,
        return tf.data.Dataset.from_generator(
                generator,
                tf.uint8,
                tf.TensorShape([64,64,1]),
                ).map(utils.normalize_uint8, num_parallel_calls=num_parallel_calls).prefetch(prefetch_batches)

    @staticmethod
    @gin.configurable(module="DSprites")
    def supervised(num_parallel_calls=tf.data.experimental.AUTOTUNE, shuffle=None):
        dataset = (
            DSprites.load()
            .map(utils.image_float32, num_parallel_calls=num_parallel_calls)
            .map(DSprites.label_map, num_parallel_calls=num_parallel_calls)
        )

        if shuffle is None:
            return dataset

        return shuffle(dataset)


    @staticmethod
    def label_map(element):
        labels = tf.convert_to_tensor(
            [element[f] for f in DSprites.factors], dtype=tf.uint8
        )

        return {"image": element["image"], "label": labels}

gin.constant('DSprites.num_values_per_factor', DSprites.num_values_per_factor)
        "Leaky": jax.nn.leaky_relu(x, negative_slope=0.01),
        "ReLU": jax.nn.relu(x),
        "Sigmoid": jax.nn.sigmoid(x),
        "Swish": jax.nn.swish(x)
    },
    "initializers_layers": {
        "lecun_normal": nn.initializers.lecun_normal(),
        "lecun_uniform": nn.initializers.lecun_uniform(),
        "xavier_normal": nn.initializers.xavier_normal(),
        "xavier_uniform": nn.initializers.xavier_uniform()
    }
}

#---------------------------------------------------------------------------------------------------------------------

gin.constant('jax_networks.LUNALANDER_OBSERVATION_DTYPE', jnp.float64)
gin.constant('jax_networks.MOUNTAINCAR_OBSERVATION_DTYPE', jnp.float64)

#---------------------------------------------------------------------------------------------------------------------


@gin.configurable
class NoisyNetwork(nn.Module):
    def apply(self, x, features, bias=True, kernel_init=None):
        def sample_noise(shape):
            noise = jax.random.normal(random.PRNGKey(0), shape)
            return noise

        def f(x):
            return jnp.multiply(jnp.sign(x), jnp.power(jnp.abs(x), 0.5))
示例#15
0
L -1 -2 -3 -4 -1 -2 -3 -4 -3  2  4 -2  2  0 -3 -2 -1 -2 -1  1 -4 -3 -1 -4
K -1  2  0 -1 -3  1  1 -2 -1 -3 -2  5 -1 -3 -1  0 -1 -3 -2 -2  0  1 -1 -4
M -1 -1 -2 -3 -1  0 -2 -3 -2  1  2 -1  5  0 -2 -1 -1 -1 -1  1 -3 -1 -1 -4
F -2 -3 -3 -3 -2 -3 -3 -3 -1  0  0 -3  0  6 -4 -2 -2  1  3 -1 -3 -3 -1 -4
P -1 -2 -2 -1 -3 -1 -1 -2 -2 -3 -3 -1 -2 -4  7 -1 -1 -4 -3 -2 -2 -1 -2 -4
S  1 -1  1  0 -1  0  0  0 -1 -2 -2  0 -1 -2 -1  4  1 -3 -2 -2  0  0  0 -4
T  0 -1  0 -1 -1 -1 -1 -2 -2 -1 -1 -1 -1 -2 -1  1  5 -2 -2  0 -1 -1  0 -4
W -3 -3 -4 -4 -2 -2 -3 -2 -2 -3 -2 -3 -1  1 -4 -3 -2 11  2 -3 -4 -3 -2 -4
Y -2 -2 -2 -3 -2 -1 -2 -3  2 -1 -1 -2 -1  3 -3 -2 -2  2  7 -1 -3 -2 -1 -4
V  0 -3 -3 -3 -1 -2 -2 -3 -3  3  1 -2  1 -1 -2 -2  0 -3 -1  4 -3 -2 -1 -4
B -2 -1  3  4 -3  0  1 -1  0 -3 -4  0 -3 -3 -2  0 -1 -4 -3 -3  4  1 -1 -4
Z -1  0  0  1 -3  3  4 -2  0 -3 -3  1 -1 -3 -1  0 -1 -3 -2 -2  1  4 -1 -4
X  0 -1 -1 -1 -2 -1 -1 -1 -1 -1 -1 -1 -1 -1 -2  0  0 -2 -1 -1 -1 -1 -1 -4
* -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4  1
"""
gin.constant('BLOSUM_62', BLOSUM_62)


@gin.configurable
class HarmonicEmbeddings(tf.initializers.Initializer):
    """Initializes weights for sinusoidal positional embeddings.

  Attributes:
    scale_factor: angular frequencies for sinusoidal embeddings will be
      logarithmically spaced between max_freq x scale_factor and max_freq,
      with base equal to scale_factor.
    max_freq: the largest angular frequency to be used.
  """
    def __init__(self, scale_factor=1e-4, max_freq=1.0, **kwargs):
        super().__init__(**kwargs)
        self._scale_factor = scale_factor
示例#16
0
import os

import gin
import tensorflow as tf

from kblocks.gin_utils.config import try_register_config_dir

gin.constant("AUTOTUNE", tf.data.experimental.AUTOTUNE)

KB_CONFIG_DIR = os.path.realpath(os.path.dirname(__file__))
try_register_config_dir("KB_CONFIG", KB_CONFIG_DIR)
示例#17
0
import tensorflow as tf


try:
  logging.warning(
      ('Setting tf to CPU only, to avoid OOM. '
       'See https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html '
       'for more information.'))
  tf.config.set_visible_devices([], 'GPU')
except tf.errors.NotFoundError:
  logging.info(
      ('Unable to modify visible devices. '
       'If you don\'t have a GPU, this is expected.'))


gin.constant('sac_agent.IMAGE_DTYPE', onp.uint8)
gin.constant('sac_agent.STATE_DTYPE', onp.float32)


@functools.partial(jax.jit, static_argnums=(0, 1, 2))
def train(network_def: nn.Module,
          optim: optax.GradientTransformation,
          alpha_optim: optax.GradientTransformation,
          optimizer_state: jnp.ndarray,
          alpha_optimizer_state: jnp.ndarray,
          network_params: flax.core.FrozenDict,
          target_params: flax.core.FrozenDict,
          log_alpha: jnp.ndarray,
          key: jnp.ndarray,
          states: jnp.ndarray,
          actions: jnp.ndarray,
示例#18
0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Various networks for Jax Dopamine agents."""

from dopamine.discrete_domains import atari_lib
from dopamine.discrete_domains import gym_lib
from flax import nn
import gin
import jax
import jax.numpy as jnp
import numpy as onp

gin.constant('jax_networks.CARTPOLE_OBSERVATION_DTYPE', jnp.float64)
gin.constant('jax_networks.ACROBOT_OBSERVATION_DTYPE', jnp.float64)


@gin.configurable
class CartpoleDQNNetwork(nn.Module):
    """Jax DQN network for Cartpole."""
    def apply(self, x, num_actions):
        initializer = nn.initializers.xavier_uniform()
        # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will
        # have removed the true batch dimension.
        x = x[None, ...]
        x = x.astype(jnp.float32)
        x = x.reshape((x.shape[0], -1))  # flatten
        x -= gym_lib.CARTPOLE_MIN_VALS
        x /= gym_lib.CARTPOLE_MAX_VALS - gym_lib.CARTPOLE_MIN_VALS
示例#19
0
        self._surface.blit(info_surface, (0, 0))
        v_offset = 4
        for item in texts:
            surface = self._font.render(item, True, (255, 255, 255))
            self._surface.blit(surface, (8, v_offset))
            v_offset += 18


def _exec(command):
    stream = os.popen(command)
    ret = stream.read()
    stream.close()
    return ret


gin.constant('CarlaEnvironment.REWARD_DIMENSION', Player.REWARD_DIMENSION)


@gin.configurable
class CarlaServer(object):
    """CarlaServer for doing the simulation."""
    def __init__(self,
                 rpc_port=2000,
                 streaming_port=2001,
                 docker_image="horizonrobotics/alf:0.0.3-carla",
                 quality_level="Low",
                 carla_root="/home/carla",
                 use_opengl=True):
        """

        Args:
示例#20
0
    """Enum constants for DQN target network variants.

  Attributes:
    notarget: No target network used. Next-step action-value computed using
      using online Q network.
    normal: Target network used to select action and evaluate next-step
      action-value.
    doubleq: Double-Q Learning as proposed by https://arxiv.org/abs/1509.06461.
      Action is selected by online Q network but evaluated using target network.
  """
    notarget = 'notarget'
    normal = 'normal'
    doubleq = 'doubleq'


gin.constant('DQNTarget.notarget', DQNTarget.notarget)
gin.constant('DQNTarget.normal', DQNTarget.normal)
gin.constant('DQNTarget.doubleq', DQNTarget.doubleq)


@gin.configurable
def discrete_q_graph(q_func,
                     transition,
                     target_network_type=DQNTarget.normal,
                     gamma=1.0,
                     loss_fn=tf.losses.huber_loss,
                     extra_callback=None):
    """Construct loss/summary graph for discrete Q-Learning (DQN).

  This Q-function loss implementation is derived from OpenAI baselines.
  This function supports dynamic batch sizes.
示例#21
0
from __future__ import print_function

import itertools
import math

from dopamine.discrete_domains import atari_lib
import gin
import gym
import numpy as np
import tensorflow as tf

CARTPOLE_MIN_VALS = np.array([-2.4, -5., -math.pi / 12., -math.pi * 2.])
CARTPOLE_MAX_VALS = np.array([2.4, 5., math.pi / 12., math.pi * 2.])
ACROBOT_MIN_VALS = np.array([-1., -1., -1., -1., -5., -5.])
ACROBOT_MAX_VALS = np.array([1., 1., 1., 1., 5., 5.])
gin.constant('gym_lib.CARTPOLE_OBSERVATION_SHAPE', (4, 1))
gin.constant('gym_lib.CARTPOLE_OBSERVATION_DTYPE', tf.float64)
gin.constant('gym_lib.CARTPOLE_STACK_SIZE', 1)
gin.constant('gym_lib.ACROBOT_OBSERVATION_SHAPE', (6, 1))
gin.constant('gym_lib.ACROBOT_OBSERVATION_DTYPE', tf.float64)
gin.constant('gym_lib.ACROBOT_STACK_SIZE', 1)


@gin.configurable
def create_gym_environment(environment_name=None, version='v0'):
    """Wraps a Gym environment with some basic preprocessing.

  Args:
    environment_name: str, the name of the environment to run.
    version: str, version of the environment to run.
    @staticmethod
    @gin.configurable(module="Shapes3d")
    def ordered(chunk_size=1000,
                prefetch_batches=1,
                num_parallel_calls=tf.data.experimental.AUTOTUNE):
        fname = 'shapes3d.h5'
        file = disentangled.utils.get_data_path() / 'downloads' / fname

        # file = tf.keras.utils.get_file(str(path), 'https://storage.cloud.google.com/3d-shapes/3dshapes.h5')

        def generator():
            with h5py.File(file, 'r') as data:
                chunk_idx = np.arange(0, len(data['images']) + 1, chunk_size)
                chunk_idx = np.concatenate([chunk_idx, [len(data['images'])]])
                for i in range(len(chunk_idx) - 1):
                    start = chunk_idx[i]
                    end = chunk_idx[i + 1]
                    for im in data["images"][start:end]:
                        yield im

        return tf.data.Dataset.from_generator(
            generator,
            tf.uint8,
            tf.TensorShape([64, 64, 3]),
        ).map(utils.normalize_uint8,
              num_parallel_calls=num_parallel_calls).prefetch(prefetch_batches)


gin.constant('Shapes3d.num_values_per_factor', Shapes3d.num_values_per_factor)
示例#23
0
                                    strides=(2, 2),
                                    activation="relu"),
    tf.keras.layers.Conv2DTranspose(64,
                                    kernel_size=(4, 4),
                                    strides=(2, 2),
                                    activation="relu"),
])


@gin.configurable
def discriminator(latents, activation):
    return tf.keras.Sequential([
        tf.keras.layers.Dense(1000,
                              activation=activation,
                              input_shape=(latents, )),
        tf.keras.layers.Dense(1000, activation=activation),
        tf.keras.layers.Dense(1000, activation=activation),
        tf.keras.layers.Dense(1000, activation=activation),
        tf.keras.layers.Dense(1000, activation=activation),
        tf.keras.layers.Dense(1000, activation=activation),
        tf.keras.layers.Dense(2, activation="softmax"),
    ])


gin.constant("disentangled.model.networks.conv_2", conv_2)
gin.constant("disentangled.model.networks.conv_2_transpose", conv_2_transpose)
gin.constant("disentangled.model.networks.conv_4", conv_4)
gin.constant("disentangled.model.networks.conv_4_transpose", conv_4_transpose)
gin.constant("disentangled.model.networks.conv_4_transpose_padded",
             conv_4_transpose_padded)