info['success'] = 1.0 if reward > 0 else 0. return obs, reward, done, info def reset(self, **kwargs): obs = self.env.reset(**kwargs) self._last_mission = '' self._tokens = [] return self._transform_observation(obs) def _transform_observation(self, observation): # Note: The original BabyAI environment give the same instruction at every # steps of an episode. observation['direction'] = np.int64(observation['direction']) mission = observation['mission'] if not self._one_token_per_step: observation['mission'] = self._vectorize(mission) else: if mission != self._last_mission: if mission != '': self._tokens.extend(self._tokenize(mission)) self._last_mission = mission if len(self._tokens) > 0: observation['mission'] = self._tokens.pop(0) else: observation['mission'] = np.int64(0) return observation gin.constant('BabyAIWrapper.VOCAB_SIZE', BabyAIWrapper.VOCAB_SIZE)
gin.external_configurable(nn.SyncBatchNorm, module='th.nn') gin.external_configurable(nn.InstanceNorm1d, module='th.nn') gin.external_configurable(nn.InstanceNorm2d, module='th.nn') gin.external_configurable(nn.InstanceNorm3d, module='th.nn') gin.external_configurable(nn.LayerNorm, module='th.nn') gin.external_configurable(nn.LocalResponseNorm, module='th.nn') # Optimizers gin.external_configurable(optim.Adadelta, module='th.optim') gin.external_configurable(optim.Adagrad, module='th.optim') gin.external_configurable(optim.Adam, module='th.optim') gin.external_configurable(optim.SparseAdam, module='th.optim') gin.external_configurable(optim.Adamax, module='th.optim') gin.external_configurable(optim.ASGD, module='th.optim') gin.external_configurable(optim.LBFGS, module='th.optim') gin.external_configurable(optim.RMSprop, module='th.optim') gin.external_configurable(optim.Rprop, module='th.optim') gin.external_configurable(optim.SGD, module='th.optim') # Constants gin.constant('th.float', th.float) gin.constant('th.float16', th.float16) gin.constant('th.float32', th.float32) gin.constant('th.float64', th.float64) gin.constant('th.int', th.int) gin.constant('th.int8', th.int8) gin.constant('th.int16', th.int16) gin.constant('th.int32', th.int32) gin.constant('th.int64', th.int64) gin.constant('th.uint8', th.uint8)
# DataLoader gin.config.external_configurable(data.DataLoader, module='torch.utils.data') # Transforms gin.config.external_configurable(transforms.CenterCrop, module='torchvision.transforms') gin.config.external_configurable(transforms.ColorJitter, module='torchvision.transforms') gin.config.external_configurable(transforms.FiveCrop, module='torchvision.transforms') gin.config.external_configurable(transforms.Grayscale, module='torchvision.transforms') gin.config.external_configurable(transforms.Normalize, module='torchvision.transforms') gin.config.external_configurable(transforms.RandomGrayscale, module='torchvision.transforms') gin.config.external_configurable(transforms.RandomHorizontalFlip, module='torchvision.transforms') gin.config.external_configurable(transforms.RandomRotation, module='torchvision.transforms') gin.config.external_configurable(transforms.RandomVerticalFlip, module='torchvision.transforms') gin.config.external_configurable(transforms.Resize, module='torchvision.transforms') gin.config.external_configurable(transforms.Scale, module='torchvision.transforms') # Class weights gin.constant('RARE_CLASS_WEIGHTS', torch.from_numpy(np.load('./rare_class_weights.npy')).to(torch.device('cuda'))) gin.constant('CLASS_WEIGHTS', torch.from_numpy(np.load('./class_weights.npy')).to(torch.device('cuda')))
# distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """MinAtar environment made compatible for Dopamine.""" from dopamine.discrete_domains import atari_lib from flax import linen as nn import gin import jax import jax.numpy as jnp import minatar gin.constant('minatar_env.ASTERIX_SHAPE', (10, 10, 4)) gin.constant('minatar_env.BREAKOUT_SHAPE', (10, 10, 4)) gin.constant('minatar_env.FREEWAY_SHAPE', (10, 10, 7)) gin.constant('minatar_env.SEAQUEST_SHAPE', (10, 10, 10)) gin.constant('minatar_env.SPACE_INVADERS_SHAPE', (10, 10, 6)) gin.constant('minatar_env.DTYPE', jnp.float64) class MinAtarEnv(object): """Wrapper class for MinAtar environments.""" def __init__(self, game_name): self.env = minatar.Environment(env_name=game_name) self.env.n = self.env.num_actions() self.game_over = False
# distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Various networks for Jax Dopamine agents.""" from typing import Tuple, Union from dopamine.discrete_domains import atari_lib from flax import linen as nn import gin import jax import jax.numpy as jnp import numpy as onp gin.constant('jax_networks.CARTPOLE_OBSERVATION_DTYPE', jnp.float64) gin.constant('jax_networks.CARTPOLE_MIN_VALS', (-2.4, -5., -onp.pi / 12., -onp.pi * 2.)) gin.constant('jax_networks.CARTPOLE_MAX_VALS', (2.4, 5., onp.pi / 12., onp.pi * 2.)) gin.constant('jax_networks.ACROBOT_OBSERVATION_DTYPE', jnp.float64) gin.constant('jax_networks.ACROBOT_MIN_VALS', (-1., -1., -1., -1., -5., -5.)) gin.constant('jax_networks.ACROBOT_MAX_VALS', (1., 1., 1., 1., 5., 5.)) gin.constant('jax_networks.LUNAR_OBSERVATION_DTYPE', jnp.float64) gin.constant('jax_networks.MOUNTAINCAR_OBSERVATION_DTYPE', jnp.float64) gin.constant('jax_networks.MOUNTAINCAR_MIN_VALS', (-1.2, -0.07)) gin.constant('jax_networks.MOUNTAINCAR_MAX_VALS', (0.6, 0.07)) ### DQN Networks ### @gin.configurable
Args: target: The target Vocabulary instance to be "translated" to. Returns: A list, mapping the ints of the self vocabulary to the ones of the target. """ return [target.get(token, target.padding_code) for token in self._voc] proteins = Vocabulary( tokens='LAVGESIRDTKPFNQYHMWCUOBZX', specials=('<', '>'), padding='_', order=(2, 0, 1) ) gin.constant('vocabulary.proteins', proteins) alternative = Vocabulary( tokens='ACDEFGHIKLMNPQRSTVWYBOUXZ', specials=('.', '-', '<', '>'), padding='_', order=(0, 1, 2) ) gin.constant('vocabulary.alternative', alternative) @gin.configurable def get_default(vocab = alternative): """A convenient function to gin configure the default vocabulary.""" return vocab
from tf_agents.environments import atari_preprocessing from tf_agents.environments import atari_wrappers from tf_agents.environments import py_environment from tf_agents.environments import suite_gym from tf_agents.typing import types # Typical Atari 2600 Gym environment with some basic preprocessing. DEFAULT_ATARI_GYM_WRAPPERS = (atari_preprocessing.AtariPreprocessing, ) # The following is just AtariPreprocessing with frame stacking. Performance wise # it's much better to have stacking implemented as part of replay-buffer/agent. # As soon as this functionality in TF-Agents is ready and verified, this set of # wrappers will be removed. DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING = DEFAULT_ATARI_GYM_WRAPPERS + ( atari_wrappers.FrameStack4, ) gin.constant('DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING', DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING) @gin.configurable def game(name: Text = 'Pong', obs_type: Text = 'image', mode: Text = 'NoFrameskip', version: Text = 'v0') -> Text: """Generates the full name for the game. Args: name: String. Ex. Pong, SpaceInvaders, ... obs_type: String, type of observation. Ex. 'image' or 'ram'. mode: String. Ex. '', 'NoFrameskip' or 'Deterministic'. version: String. Ex. 'v0' or 'v4'.
LEG_ORDER = ["front_left", "back_left", "front_right", "back_right"] END_EFFECTOR_NAMES = ( "knee_front_rightR_joint", "knee_front_leftL_joint", "knee_back_rightR_joint", "knee_back_leftL_joint", ) MOTOR_NAMES = JOINT_NAMES MOTOR_GROUP = collections.OrderedDict((("body_motors", JOINT_NAMES),)) KNEE_CONSTRAINT_POINT_LONG = [0, 0.0045, 0.088] KNEE_CONSTRAINT_POINT_SHORT = [0, 0.0045, 0.100] # Add the gin constants to be used for gin binding in config. gin.constant("minitaur_constants.MINITAUR_URDF_PATH", MINITAUR_URDF_PATH) gin.constant("minitaur_constants.MINITAUR_INIT_POSITION", INIT_POSITION) gin.constant("minitaur_constants.MINITAUR_INIT_ORIENTATION_QUAT", INIT_ORIENTATION_QUAT) gin.constant("minitaur_constants.MINITAUR_INIT_ORIENTATION_RPY", INIT_ORIENTATION_RPY) gin.constant("minitaur_constants.MINITAUR_INIT_JOINT_ANGLES", INIT_JOINT_ANGLES) gin.constant("minitaur_constants.MINITAUR_JOINT_DIRECTIONS", JOINT_DIRECTIONS) gin.constant("minitaur_constants.MINITAUR_JOINT_OFFSETS", JOINT_OFFSETS) gin.constant("minitaur_constants.MINITAUR_MOTOR_NAMES", MOTOR_NAMES) gin.constant("minitaur_constants.MINITAUR_END_EFFECTOR_NAMES", END_EFFECTOR_NAMES) gin.constant("minitaur_constants.MINITAUR_MOTOR_GROUP", MOTOR_GROUP)
shape=[self.REWARD_DIMENSION]) def step(self, action): obs, reward, done, info = self.env.step(action) # Get the second and third reward from ``info`` cost_reward = -info["cost"] success_reward = float(info["goal_met"]) return obs, np.array([reward, cost_reward, success_reward], dtype=np.float32), done, info @property def reward_space(self): return self._reward_space gin.constant('SafetyGym.REWARD_DIMENSION', VectorReward.REWARD_DIMENSION) @gin.configurable def load(environment_name, env_id=None, discount=1.0, max_episode_steps=None, unconstrained=False, gym_env_wrappers=(), alf_env_wrappers=()): """Loads the selected environment and wraps it with the specified wrappers. Note that by default a ``TimeLimit`` wrapper is used to limit episode lengths to the default benchmarks defined by the registered environments.
import gin import numpy as np from events_tfds.events import asl_dvs, cifar10_dvs, mnist_dvs, ncaltech101, nmnist gin.constant("ASL_DVS_GRID_SHAPE", asl_dvs.GRID_SHAPE) gin.constant("CIFAR10_DVS_GRID_SHAPE", cifar10_dvs.GRID_SHAPE) gin.constant("MNIST_DVS_GRID_SHAPE", mnist_dvs.GRID_SHAPE) gin.constant("NCALTECH101_GRID_SHAPE", ncaltech101.GRID_SHAPE) gin.constant("NMNIST_GRID_SHAPE", nmnist.GRID_SHAPE) gin.constant("ASL_DVS_NUM_CLASSES", asl_dvs.NUM_CLASSES) gin.constant("CIFAR10_DVS_NUM_CLASSES", cifar10_dvs.NUM_CLASSES) gin.constant("MNIST_DVS_NUM_CLASSES", mnist_dvs.NUM_CLASSES) gin.constant("NCALTECH101_NUM_CLASSES", ncaltech101.NUM_CLASSES) gin.constant("NMNIST_NUM_CLASSES", nmnist.NUM_CLASSES) gin.constant("PI_ON_8", np.pi / 8) gin.constant("NEG_PI_ON_8", -np.pi / 8)
(LEG_NAMES[0], JOINT_NAMES[0:3]), (LEG_NAMES[1], JOINT_NAMES[3:6]), (LEG_NAMES[2], JOINT_NAMES[6:9]), (LEG_NAMES[3], JOINT_NAMES[9:12]), )) # Regulates the joint angle change when in position control mode. MAX_MOTOR_ANGLE_CHANGE_PER_STEP = 0.12 # The hip joint location in the CoM frame. HIP_POSITIONS = collections.OrderedDict(( (LEG_NAMES[0], (0.21, -0.1157, 0)), (LEG_NAMES[1], (0.21, 0.1157, 0)), (LEG_NAMES[2], (-0.21, -0.1157, 0)), (LEG_NAMES[3], (-0.21, 0.1157, 0)), )) # Add the gin constants to be used for gin binding in config. Append "LAIKAGO_" # for unique binding names. gin.constant("laikago_constants.LAIKAGO_NUM_MOTORS", NUM_MOTORS) gin.constant("laikago_constants.LAIKAGO_URDF_PATH", URDF_PATH) gin.constant("laikago_constants.LAIKAGO_INIT_POSITION", INIT_POSITION) gin.constant("laikago_constants.LAIKAGO_INIT_ORIENTATION", INIT_ORIENTATION) gin.constant("laikago_constants.LAIKAGO_INIT_JOINT_ANGLES", INIT_JOINT_ANGLES) gin.constant("laikago_constants.LAIKAGO_JOINT_DIRECTIONS", JOINT_DIRECTIONS) gin.constant("laikago_constants.LAIKAGO_JOINT_OFFSETS", JOINT_OFFSETS) gin.constant("laikago_constants.LAIKAGO_MOTOR_NAMES", MOTOR_NAMES) gin.constant("laikago_constants.LAIKAGO_END_EFFECTOR_NAMES", END_EFFECTOR_NAMES) gin.constant("laikago_constants.LAIKAGO_MOTOR_GROUP", MOTOR_GROUP)
from pybullet_envs.minitaur.envs_v2.evaluation import metric_logger from pybullet_envs.minitaur.envs_v2.scenes import scene_base from pybullet_envs.minitaur.envs_v2.sensors import sensor from pybullet_envs.minitaur.envs_v2.sensors import space_utils from pybullet_envs.minitaur.envs_v2.utilities import rendering_utils from pybullet_envs.minitaur.robots import autonomous_object from pybullet_envs.minitaur.robots import robot_base _ACTION_EPS = 0.01 _NUM_SIMULATION_ITERATION_STEPS = 300 _LOG_BUFFER_LENGTH = 5000 SIM_CLOCK = 'SIM_CLOCK' # Exports this symbol so we can use it in the config file. gin.constant('locomotion_gym_env.SIM_CLOCK', SIM_CLOCK) # This allows us to bind @time.time in the gin configuration. gin.external_configurable(time.time, module='time') # TODO(b/122048194): Enable position/torque/hybrid control mode. @gin.configurable class LocomotionGymEnv(gym.Env): """The gym environment for the locomotion tasks.""" metadata = { 'render.modes': ['human', 'rgb_array', 'topdown'], 'video.frames_per_second': 100 } def __init__(self,
return tf.data.Dataset.from_generator( generator, tf.uint8, tf.TensorShape([64,64,1]), ).map(utils.normalize_uint8, num_parallel_calls=num_parallel_calls).prefetch(prefetch_batches) @staticmethod @gin.configurable(module="DSprites") def supervised(num_parallel_calls=tf.data.experimental.AUTOTUNE, shuffle=None): dataset = ( DSprites.load() .map(utils.image_float32, num_parallel_calls=num_parallel_calls) .map(DSprites.label_map, num_parallel_calls=num_parallel_calls) ) if shuffle is None: return dataset return shuffle(dataset) @staticmethod def label_map(element): labels = tf.convert_to_tensor( [element[f] for f in DSprites.factors], dtype=tf.uint8 ) return {"image": element["image"], "label": labels} gin.constant('DSprites.num_values_per_factor', DSprites.num_values_per_factor)
"Leaky": jax.nn.leaky_relu(x, negative_slope=0.01), "ReLU": jax.nn.relu(x), "Sigmoid": jax.nn.sigmoid(x), "Swish": jax.nn.swish(x) }, "initializers_layers": { "lecun_normal": nn.initializers.lecun_normal(), "lecun_uniform": nn.initializers.lecun_uniform(), "xavier_normal": nn.initializers.xavier_normal(), "xavier_uniform": nn.initializers.xavier_uniform() } } #--------------------------------------------------------------------------------------------------------------------- gin.constant('jax_networks.LUNALANDER_OBSERVATION_DTYPE', jnp.float64) gin.constant('jax_networks.MOUNTAINCAR_OBSERVATION_DTYPE', jnp.float64) #--------------------------------------------------------------------------------------------------------------------- @gin.configurable class NoisyNetwork(nn.Module): def apply(self, x, features, bias=True, kernel_init=None): def sample_noise(shape): noise = jax.random.normal(random.PRNGKey(0), shape) return noise def f(x): return jnp.multiply(jnp.sign(x), jnp.power(jnp.abs(x), 0.5))
L -1 -2 -3 -4 -1 -2 -3 -4 -3 2 4 -2 2 0 -3 -2 -1 -2 -1 1 -4 -3 -1 -4 K -1 2 0 -1 -3 1 1 -2 -1 -3 -2 5 -1 -3 -1 0 -1 -3 -2 -2 0 1 -1 -4 M -1 -1 -2 -3 -1 0 -2 -3 -2 1 2 -1 5 0 -2 -1 -1 -1 -1 1 -3 -1 -1 -4 F -2 -3 -3 -3 -2 -3 -3 -3 -1 0 0 -3 0 6 -4 -2 -2 1 3 -1 -3 -3 -1 -4 P -1 -2 -2 -1 -3 -1 -1 -2 -2 -3 -3 -1 -2 -4 7 -1 -1 -4 -3 -2 -2 -1 -2 -4 S 1 -1 1 0 -1 0 0 0 -1 -2 -2 0 -1 -2 -1 4 1 -3 -2 -2 0 0 0 -4 T 0 -1 0 -1 -1 -1 -1 -2 -2 -1 -1 -1 -1 -2 -1 1 5 -2 -2 0 -1 -1 0 -4 W -3 -3 -4 -4 -2 -2 -3 -2 -2 -3 -2 -3 -1 1 -4 -3 -2 11 2 -3 -4 -3 -2 -4 Y -2 -2 -2 -3 -2 -1 -2 -3 2 -1 -1 -2 -1 3 -3 -2 -2 2 7 -1 -3 -2 -1 -4 V 0 -3 -3 -3 -1 -2 -2 -3 -3 3 1 -2 1 -1 -2 -2 0 -3 -1 4 -3 -2 -1 -4 B -2 -1 3 4 -3 0 1 -1 0 -3 -4 0 -3 -3 -2 0 -1 -4 -3 -3 4 1 -1 -4 Z -1 0 0 1 -3 3 4 -2 0 -3 -3 1 -1 -3 -1 0 -1 -3 -2 -2 1 4 -1 -4 X 0 -1 -1 -1 -2 -1 -1 -1 -1 -1 -1 -1 -1 -1 -2 0 0 -2 -1 -1 -1 -1 -1 -4 * -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 1 """ gin.constant('BLOSUM_62', BLOSUM_62) @gin.configurable class HarmonicEmbeddings(tf.initializers.Initializer): """Initializes weights for sinusoidal positional embeddings. Attributes: scale_factor: angular frequencies for sinusoidal embeddings will be logarithmically spaced between max_freq x scale_factor and max_freq, with base equal to scale_factor. max_freq: the largest angular frequency to be used. """ def __init__(self, scale_factor=1e-4, max_freq=1.0, **kwargs): super().__init__(**kwargs) self._scale_factor = scale_factor
import os import gin import tensorflow as tf from kblocks.gin_utils.config import try_register_config_dir gin.constant("AUTOTUNE", tf.data.experimental.AUTOTUNE) KB_CONFIG_DIR = os.path.realpath(os.path.dirname(__file__)) try_register_config_dir("KB_CONFIG", KB_CONFIG_DIR)
import tensorflow as tf try: logging.warning( ('Setting tf to CPU only, to avoid OOM. ' 'See https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html ' 'for more information.')) tf.config.set_visible_devices([], 'GPU') except tf.errors.NotFoundError: logging.info( ('Unable to modify visible devices. ' 'If you don\'t have a GPU, this is expected.')) gin.constant('sac_agent.IMAGE_DTYPE', onp.uint8) gin.constant('sac_agent.STATE_DTYPE', onp.float32) @functools.partial(jax.jit, static_argnums=(0, 1, 2)) def train(network_def: nn.Module, optim: optax.GradientTransformation, alpha_optim: optax.GradientTransformation, optimizer_state: jnp.ndarray, alpha_optimizer_state: jnp.ndarray, network_params: flax.core.FrozenDict, target_params: flax.core.FrozenDict, log_alpha: jnp.ndarray, key: jnp.ndarray, states: jnp.ndarray, actions: jnp.ndarray,
# Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Various networks for Jax Dopamine agents.""" from dopamine.discrete_domains import atari_lib from dopamine.discrete_domains import gym_lib from flax import nn import gin import jax import jax.numpy as jnp import numpy as onp gin.constant('jax_networks.CARTPOLE_OBSERVATION_DTYPE', jnp.float64) gin.constant('jax_networks.ACROBOT_OBSERVATION_DTYPE', jnp.float64) @gin.configurable class CartpoleDQNNetwork(nn.Module): """Jax DQN network for Cartpole.""" def apply(self, x, num_actions): initializer = nn.initializers.xavier_uniform() # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will # have removed the true batch dimension. x = x[None, ...] x = x.astype(jnp.float32) x = x.reshape((x.shape[0], -1)) # flatten x -= gym_lib.CARTPOLE_MIN_VALS x /= gym_lib.CARTPOLE_MAX_VALS - gym_lib.CARTPOLE_MIN_VALS
self._surface.blit(info_surface, (0, 0)) v_offset = 4 for item in texts: surface = self._font.render(item, True, (255, 255, 255)) self._surface.blit(surface, (8, v_offset)) v_offset += 18 def _exec(command): stream = os.popen(command) ret = stream.read() stream.close() return ret gin.constant('CarlaEnvironment.REWARD_DIMENSION', Player.REWARD_DIMENSION) @gin.configurable class CarlaServer(object): """CarlaServer for doing the simulation.""" def __init__(self, rpc_port=2000, streaming_port=2001, docker_image="horizonrobotics/alf:0.0.3-carla", quality_level="Low", carla_root="/home/carla", use_opengl=True): """ Args:
"""Enum constants for DQN target network variants. Attributes: notarget: No target network used. Next-step action-value computed using using online Q network. normal: Target network used to select action and evaluate next-step action-value. doubleq: Double-Q Learning as proposed by https://arxiv.org/abs/1509.06461. Action is selected by online Q network but evaluated using target network. """ notarget = 'notarget' normal = 'normal' doubleq = 'doubleq' gin.constant('DQNTarget.notarget', DQNTarget.notarget) gin.constant('DQNTarget.normal', DQNTarget.normal) gin.constant('DQNTarget.doubleq', DQNTarget.doubleq) @gin.configurable def discrete_q_graph(q_func, transition, target_network_type=DQNTarget.normal, gamma=1.0, loss_fn=tf.losses.huber_loss, extra_callback=None): """Construct loss/summary graph for discrete Q-Learning (DQN). This Q-function loss implementation is derived from OpenAI baselines. This function supports dynamic batch sizes.
from __future__ import print_function import itertools import math from dopamine.discrete_domains import atari_lib import gin import gym import numpy as np import tensorflow as tf CARTPOLE_MIN_VALS = np.array([-2.4, -5., -math.pi / 12., -math.pi * 2.]) CARTPOLE_MAX_VALS = np.array([2.4, 5., math.pi / 12., math.pi * 2.]) ACROBOT_MIN_VALS = np.array([-1., -1., -1., -1., -5., -5.]) ACROBOT_MAX_VALS = np.array([1., 1., 1., 1., 5., 5.]) gin.constant('gym_lib.CARTPOLE_OBSERVATION_SHAPE', (4, 1)) gin.constant('gym_lib.CARTPOLE_OBSERVATION_DTYPE', tf.float64) gin.constant('gym_lib.CARTPOLE_STACK_SIZE', 1) gin.constant('gym_lib.ACROBOT_OBSERVATION_SHAPE', (6, 1)) gin.constant('gym_lib.ACROBOT_OBSERVATION_DTYPE', tf.float64) gin.constant('gym_lib.ACROBOT_STACK_SIZE', 1) @gin.configurable def create_gym_environment(environment_name=None, version='v0'): """Wraps a Gym environment with some basic preprocessing. Args: environment_name: str, the name of the environment to run. version: str, version of the environment to run.
@staticmethod @gin.configurable(module="Shapes3d") def ordered(chunk_size=1000, prefetch_batches=1, num_parallel_calls=tf.data.experimental.AUTOTUNE): fname = 'shapes3d.h5' file = disentangled.utils.get_data_path() / 'downloads' / fname # file = tf.keras.utils.get_file(str(path), 'https://storage.cloud.google.com/3d-shapes/3dshapes.h5') def generator(): with h5py.File(file, 'r') as data: chunk_idx = np.arange(0, len(data['images']) + 1, chunk_size) chunk_idx = np.concatenate([chunk_idx, [len(data['images'])]]) for i in range(len(chunk_idx) - 1): start = chunk_idx[i] end = chunk_idx[i + 1] for im in data["images"][start:end]: yield im return tf.data.Dataset.from_generator( generator, tf.uint8, tf.TensorShape([64, 64, 3]), ).map(utils.normalize_uint8, num_parallel_calls=num_parallel_calls).prefetch(prefetch_batches) gin.constant('Shapes3d.num_values_per_factor', Shapes3d.num_values_per_factor)
strides=(2, 2), activation="relu"), tf.keras.layers.Conv2DTranspose(64, kernel_size=(4, 4), strides=(2, 2), activation="relu"), ]) @gin.configurable def discriminator(latents, activation): return tf.keras.Sequential([ tf.keras.layers.Dense(1000, activation=activation, input_shape=(latents, )), tf.keras.layers.Dense(1000, activation=activation), tf.keras.layers.Dense(1000, activation=activation), tf.keras.layers.Dense(1000, activation=activation), tf.keras.layers.Dense(1000, activation=activation), tf.keras.layers.Dense(1000, activation=activation), tf.keras.layers.Dense(2, activation="softmax"), ]) gin.constant("disentangled.model.networks.conv_2", conv_2) gin.constant("disentangled.model.networks.conv_2_transpose", conv_2_transpose) gin.constant("disentangled.model.networks.conv_4", conv_4) gin.constant("disentangled.model.networks.conv_4_transpose", conv_4_transpose) gin.constant("disentangled.model.networks.conv_4_transpose_padded", conv_4_transpose_padded)