Exemple #1
0
def gin_register_keras_layers():
    """Registers all keras layers and Sequential to be referenceable in gin."""
    # Register sequential model.
    gin.external_configurable(tf.keras.Sequential, 'tf.keras.Sequential')

    # Register all the layers.
    for k, v in inspect.getmembers(tf.keras.layers):
        # Duck typing for tf.keras.layers.Layer since keras uses metaclasses.
        if hasattr(v, 'variables'):
            gin.external_configurable(v, f'tf.keras.layers.{k}')
Exemple #2
0
def configure_and_register_env(env_class):
    register(
        id="{}-v0".format(env_class.__name__),
        entry_point="tensor2tensor.trax.rlax.envs:{}".format(
            env_class.__name__),
    )
    return gin.external_configurable(env_class, module="trax.rlax.envs")
Exemple #3
0
def main():
    args = parse()
    config_path = args.config_path

    gin.external_configurable(keras.optimizers.Adam,
                              module='tensorflow.python.keras.optimizers')
    gin.external_configurable(keras.losses.categorical_crossentropy,
                              module='tensorflow.python.keras.losses')
    gin.parse_config_file(config_path)

    # args = RunConfig()
    # args.config_path = config_path

    data = create_load_data(args)

    model = train_model(data, args)

    evaluate_model(data, model, args)
Exemple #4
0
def wrapped_items(
    src_module: ModuleType,
    gin_module: str,
    blacklist: Iterable[str] = BLACKLIST,
) -> Iterable[Tuple[str, Any]]:
    for k in dir(src_module):
        v = getattr(src_module, k)
        if k not in blacklist and callable(v):
            yield k, gin.external_configurable(v, name=k, module=gin_module)
Exemple #5
0
def add_external_configurables():
    #pass
    from .models.utils import get_class_weight
    gin.external_configurable(get_class_weight)

    gin.external_configurable(ReduceLROnPlateau)

    gin.external_configurable(Adam, blacklist=["params", "lr", "weight_decay"])
import asyncio
import math

import gin
import gym
import numpy as np

from alpacka import batch_steppers
from alpacka import data
from alpacka import metric_logging
from alpacka.agents import base
from alpacka.agents import core
from alpacka.utils.transformations import discount_cumsum

# Basic returns aggregators.
gin.external_configurable(np.max, module='np')
gin.external_configurable(np.mean, module='np')


@gin.configurable
@asyncio.coroutine
def truncated_return(episodes, discount=1.):
    """Returns sum of rewards up to the truncation of the episode."""
    return np.array([
        discount_cumsum(episode.transition_batch.reward, discount)[0]
        for episode in episodes
    ])


@gin.configurable
def bootstrap_return_with_value(episodes, discount=1.):
Exemple #7
0
def configure_agent(agent_class):
    return gin.external_configurable(
        agent_class, module='alpacka.agents'
    )
Exemple #8
0
def configure_env(env_class):
    return gin.external_configurable(env_class, module='alpacka.envs')
Exemple #9
0
def model_configure(*args, **kwargs):
    kwargs["module"] = "trax.models"
    return gin.external_configurable(*args, **kwargs)
Exemple #10
0
def data_configure(*args, **kwargs):
    kwargs['module'] = 'trax.data'
    return gin.external_configurable(*args, **kwargs)
Exemple #11
0
def layer_configure(*args, **kwargs):
    kwargs["module"] = "trax.layers"
    return gin.external_configurable(*args, **kwargs)
Exemple #12
0
import torch.nn as nn
import torch.nn.functional as F
from chk import checkpoint_sequential_step, checkpoint

import math
import numpy as np
from torchvision.utils import save_image

import gin


def ginM(n):
    return gin.query_parameter(f'%{n}')


gin.external_configurable(nn.MaxPool2d, module='nn')
gin.external_configurable(nn.Upsample, module='nn')


class LN(nn.Module):
    def forward(self, x):
        return F.layer_norm(x, x.size()[1:], weight=None, bias=None, eps=1e-05)


@gin.configurable
class PadPool(nn.Module):
    def forward(self, x):
        x = F.pad(x, [0, 0, 0, 1])
        x = F.max_pool2d(x, (2, 2), stride=(1, 2))
        return x
Exemple #13
0
                                    units_scale=4,
                                    unit_expansion_factor=2,
                                    network_depth=2):
    return [
        functools.partial(mlp,
                          units=[4 * units_scale * unit_expansion_factor**i] *
                          network_depth) for i in range(num_factories)
    ]


@gin.configurable
def get_base_global_network(units=(512, 256), dropout_impl=None):
    return mlp(units=units, dropout_impl=dropout_impl)


EmbeddingSpec = gin.external_configurable(
    collections.namedtuple('EmbeddingSpec', ['input_dim', 'output_dim']))


def as_residual(inp, output_dim, dense_factory=Dense):
    if inp.shape[-1] != output_dim:
        inp = dense_factory(output_dim)(inp)
    return inp


def residual_output(original, previous, current, dense_factory):
    terms = []
    if original is not None:
        original = as_residual(original, current.shape[-1], dense_factory)
        terms.append(original)
    terms.extend(previous)
    terms.append(current)
Exemple #14
0
def configure_trainer(trainer_class):
    return gin.external_configurable(
        trainer_class, module='alpacka.trainers'
    )
Exemple #15
0
def configure_rl(*args, **kwargs):
    kwargs['module'] = 'trax.rl'
    return gin.external_configurable(*args, **kwargs)
Exemple #16
0
# pylint: disable=g-bad-import-order, unused-import, g-multiple-import
# pylint: disable=line-too-long, missing-docstring, g-importing-member
# pylint: disable=g-wrong-blank-lines, missing-super-argument
import gin
import tensorflow.compat.v1 as tf
import tensorflow_probability as tfp
from functools import partial
from collections import OrderedDict
import numpy as np

from weak_disentangle import tensorsketch as ts
from weak_disentangle import utils as ut

tfd = tfp.distributions
dense = gin.external_configurable(ts.Dense)
conv = gin.external_configurable(ts.Conv2d)
deconv = gin.external_configurable(ts.ConvTranspose2d)
add_wn = gin.external_configurable(ts.WeightNorm.add)
add_bn = gin.external_configurable(ts.BatchNorm.add)


@gin.configurable
class Encoder(ts.Module):
  def __init__(self, x_shape, z_dim, width=1, spectral_norm=True):
    super().__init__()
    self.net = ts.Sequential(
        conv(32 * width, 4, 2, "same"), ts.LeakyReLU(),
        conv(32 * width, 4, 2, "same"), ts.LeakyReLU(),
        conv(64 * width, 4, 2, "same"), ts.LeakyReLU(),
        conv(64 * width, 4, 2, "same"), ts.LeakyReLU(),
Exemple #17
0
def model_configure(*args, **kwargs):
  kwargs['module'] = 'trax.models'
  return gin.external_configurable(*args, **kwargs)
Exemple #18
0
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Wrap external code in gin."""

import gin
import gin.tf.external_configurables
import tensorflow as tf

# Tensorflow.
gin.external_configurable(tf.keras.layers.experimental.SyncBatchNormalization)
Exemple #19
0
def configure_rl(*args, **kwargs):
    kwargs['module'] = 'trax.rl'
    kwargs['blacklist'] = ['task', 'output_dir']
    return gin.external_configurable(*args, **kwargs)
Exemple #20
0
def configure_and_register_env(env_class):
    register(
        id='{}-v0'.format(env_class.__name__),
        entry_point='trax.rl.envs:{}'.format(env_class.__name__),
    )
    return gin.external_configurable(env_class, module='trax.rl.envs')
Exemple #21
0
    Dict[Text, tf.Tensor], Dict[Text, tf.estimator.export.PredictOutput]]]
InferenceNetworkOutputsType = Union[DictOrSpec,
                                    Tuple[DictOrSpec,
                                          Optional[Sequence[tf.Tensor]]]]

try:
  flags.DEFINE_string('master', '', 'Master for TPU RunConfig')
except flags.DuplicateFlagError:
  pass

DEVICE_TYPE_CPU = 'cpu'
DEVICE_TYPE_GPU = 'gpu'
DEVICE_TYPE_TPU = 'tpu'

gin_configurable_run_config_cls = gin.external_configurable(
    tf.estimator.RunConfig,
    name='tf.estimator.RunConfig',
    blacklist=['model_dir'])

gin_configurable_tpu_run_config_cls = gin.external_configurable(
    contrib_tpu.RunConfig,
    name='tf.contrib.tpu.RunConfig',
    blacklist=['model_dir', 'tpu_config'])

gin_configurable_tpu_config_cls = gin.external_configurable(
    contrib_tpu.TPUConfig, name='tf.contrib.tpu.TPUConfig')

# Expose the tf.train.Saver to gin.
gin_configurable_saver = gin.external_configurable(
    tf.train.Saver, name='tf.train.Saver', whitelist=['save_relative_paths'])

Exemple #22
0
class ConstantLearningRateSchedule(LearningRateSchedule):
    """Learning rate schedule that just uses a constant value.

  Attributes:
    learning_rate: The constant learning rate.
  """
    learning_rate: float

    def learning_rate_for_step(self, step):
        return self.learning_rate


# Gin tries to wrap classes in subclasses, which flax can't serialize. So we
# wrap the constructor instead.
gin.external_configurable(functools.partial(ConstantLearningRateSchedule),
                          "ConstantLearningRateSchedule",
                          module="learning_rate_schedules")


@jax_util.register_dataclass_pytree
@dataclasses.dataclass
class InverseTimeLearningRateSchedule(LearningRateSchedule):
    """Learning rate schedule that decays as 1/(1 + t/c).

  Attributes:
    initial_rate: Initial learning rate.
    time_scale: Scale of the learning rate.
  """
    initial_rate: float
    time_scale: float
Exemple #23
0
    AbstractExportGenerator
], List[tf.estimator.Exporter]]

FLAGS = flags.FLAGS

try:
    flags.DEFINE_list(
        'gin_configs', None, 'A comma-separated list of paths to Gin '
        'configuration files.')
    flags.DEFINE_multi_string(
        'gin_bindings', [],
        'A newline separated list of Gin parameter bindings.')
except flags.DuplicateFlagError:
    pass

gin_configurable_eval_spec = gin.external_configurable(
    tf.estimator.EvalSpec, name='tf.estimator.EvalSpec')


def print_spec(tensor_spec):
    """Iterate over a spec and print its values in sorted order.

  Args:
    tensor_spec: A dict, (named)tuple, list or a hierarchy thereof filled by
      TensorSpecs(subclasses) or Tensors.
  """
    for key, value in sorted(
            tensorspec_utils.flatten_spec_structure(tensor_spec).items()):
        logging.info('%s: %s', key, value)


def print_specification(t2r_model):
Exemple #24
0
def configure_batch_stapper(batch_stepper_class):
    return gin.external_configurable(batch_stepper_class,
                                     module='alpacka.batch_steppers')
Exemple #25
0
def layer_configure(*args, **kwargs):
  kwargs['module'] = 'trax.layers'
  return gin.external_configurable(*args, **kwargs)
Exemple #26
0
        'resources_per_trial': {
            'cpu': cpus_per_trial,
            'gpu': gpus_per_trial,
        },
        'stop': {
            'training_iteration': max_epochs,
        },
        'config': tune_config,
        'num_samples': num_samples,
        'checkpoint_freq': checkpoint_freq,
        'local_dir': local_dir,
        'max_failures': max_failures,
    }


PopulationBasedTraining = gin.external_configurable(
    schedulers.PopulationBasedTraining)
AsyncHyperBandScheduler = gin.external_configurable(
    schedulers.AsyncHyperBandScheduler)
MedianStoppingRule = gin.external_configurable(schedulers.MedianStoppingRule)
HyperOptSearch = gin.external_configurable(HyperOptSearch)


@gin.configurable
def scheduler(cls,
              metric='val_sparse_categorical_accuracy',
              mode='max',
              time_attr='training_iteration',
              **kwargs):

    return cls(
        metric=metric,
# limitations under the License.
"""Make various external gin-configurable objects."""

import gin
import gin.tf.external_configurables
import gym
import tensorflow as tf
from tf_agents.specs.tensor_spec import BoundedTensorSpec
from tf_agents.networks.utils import mlp_layers
from tf_agents.networks.sequential_layer import SequentialLayer

from tf_agents.environments import atari_wrappers

from alf.utils import math_ops

tf.keras.layers.Conv2D = gin.external_configurable(tf.keras.layers.Conv2D,
                                                   'tf.keras.layers.Conv2D')
tf.optimizers.Adam = gin.external_configurable(tf.optimizers.Adam,
                                               'tf.optimizers.Adam')
gin.external_configurable(tf.keras.layers.Concatenate,
                          'tf.keras.layers.Concatenate')

# This allows the environment creation arguments to be configurable by supplying
# gym.envs.registration.EnvSpec.make.ARG_NAME=VALUE
gym.envs.registration.EnvSpec.make = gin.external_configurable(
    gym.envs.registration.EnvSpec.make, 'gym.envs.registration.EnvSpec.make')

# Activation functions.
gin.external_configurable(tf.math.exp, 'tf.math.exp')

gin.external_configurable(tf.TensorSpec, 'tf.TensorSpec')
# limitations under the License.

"""Optimizer factory functions to be used with tensor2robot models."""

from __future__ import absolute_import
from __future__ import division

from __future__ import print_function

import gin
import tensorflow as tf  # tf
from typing import Callable


gin_configurable_adam_optimizer = gin.external_configurable(
    tf.train.AdamOptimizer,
    name='tf.train.AdamOptimizer',
    blacklist=['learning_rate'])

gin_configurable_gradient_descent_optimizer = gin.external_configurable(
    tf.train.GradientDescentOptimizer,
    name='tf.train.GradientDescentOptimizer',
    blacklist=['learning_rate'])

gin_configurable_momentum_optimizer = gin.external_configurable(
    tf.train.MomentumOptimizer,
    name='tf.train.MomentumOptimizer',
    blacklist=['learning_rate'])


@gin.configurable
def create_constant_learning_rate(initial_learning_rate = 0.0001):
def opt_configure(*args, **kwargs):
    kwargs["module"] = "trax.optimizers"
    return gin.external_configurable(*args, **kwargs)
# pylint: disable=invalid-name, missing-function-docstring, line-too-long

import abc
from collections.abc import Iterable  # pylint: disable=g-importing-member
import functools
from absl import logging
import gin
import jax
from jax import lax
from jax import random
import jax.numpy as jnp

import numpy as onp

# Nonlinear mappings encoding different attention kernels.
gin.external_configurable(jnp.cos, 'jcos')
gin.external_configurable(jnp.sin, 'jsin')
gin.external_configurable(jnp.tanh, 'jtanh')
gin.external_configurable(jax.nn.sigmoid, 'jsigmoid')
gin.external_configurable(
    lambda x: jax.nn.gelu(x, approximate=False), 'jgelu'
)  # Needs to be exact, although might be slower. See https://github.com/google/jax/issues/4428.
gin.external_configurable(lambda x: x * x * (x > 0.0), 'jrequ')
gin.external_configurable(jnp.exp, 'jexp')
gin.external_configurable(lambda x: x, 'jidentity')
gin.external_configurable(
    lambda x: (jnp.exp(x)) * (x <= 0.0) + (x + 1.0) * (x > 0.0), 'jshiftedelu'
)  # Nonlinearity used in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention" (https://arxiv.org/abs/2006.16236).


def nonnegative_softmax_kernel_feature_creator(data,