Ejemplo n.º 1
0
            FcStack(fc_stack_ch, fc_stack_layers),
            Rnn(rnn_ch, rnn_type),
            FcStack(fc_stack_ch, fc_stack_layers),
        ]
        super().__init__(layers, **kwargs)


# ------------------ Utility Layers --------------------------------------------
@gin.register
class Identity(tfkl.Layer):
    """Utility identity layer."""
    def call(self, x):
        return x


gin.register(tfkl.Dense, module=__name__)


# ------------------ Embeddings ------------------------------------------------
def get_embedding(vocab_size=1024, n_dims=256):
    """Get a real-valued embedding from an integer."""
    return tfkl.Embedding(input_dim=vocab_size,
                          output_dim=n_dims,
                          input_length=1)


# ------------------ Normalization ---------------------------------------------
class ConditionalScaleAndShift(tfkl.Layer):
    """Conditional scaling and shifting after normalization."""
    def __init__(self, shift_only=False, **kwargs):
        super().__init__(**kwargs)
Ejemplo n.º 2
0
import gin
import tensorflow as tf

for opt in (
        tf.keras.optimizers.Adadelta,
        tf.keras.optimizers.Adagrad,
        tf.keras.optimizers.Adam,
        tf.keras.optimizers.Adamax,
        tf.keras.optimizers.Ftrl,
        tf.keras.optimizers.Nadam,
        tf.keras.optimizers.RMSprop,
        tf.keras.optimizers.SGD,
):
    gin.register(opt, module="tf.keras.optimizers")

for reg in (
        tf.keras.regularizers.L1,
        tf.keras.regularizers.L1L2,
        tf.keras.regularizers.L2,
):
    gin.register(reg, module="tf.keras.regularizers")

for cb in (
        tf.keras.callbacks.CSVLogger,
        tf.keras.callbacks.EarlyStopping,
        tf.keras.callbacks.History,
        tf.keras.callbacks.LambdaCallback,
        tf.keras.callbacks.LearningRateScheduler,
        tf.keras.callbacks.ModelCheckpoint,
        tf.keras.callbacks.ProgbarLogger,
        tf.keras.callbacks.ReduceLROnPlateau,
Ejemplo n.º 3
0
# utility function registration


@gin.register(module="kb.utils")
def identity(x: T) -> T:
    return x


@gin.register(module="kb.utils")
def concat(a: Iterable[T], b: Iterable[T]) -> List[T]:
    out = list(a)
    out.extend(b)
    return out


gin.register(dict, module="kb.utils")


@gin.register(name_or_fn="getattr", module="kb.utils")
def _getattr(object, name: str, default=None):  # pylint: disable=redefined-builtin
    return getattr(object, name, default)


@gin.register(module="kb.utils")
def call(func: Callable, **kwargs):
    """Configurable version of `func(**kwargs)`."""
    return func(**kwargs)


class memoized_property(property):  # pylint: disable=invalid-name
    """Descriptor that mimics @property but caches output in member variable."""
Ejemplo n.º 4
0
import os
import tempfile
import uuid
from typing import Optional

import gin
import tensorflow as tf


@gin.configurable(module="os.path")
def join(a: str, p: str) -> str:
    """Configurable equivalent to `os.path.join`. Only accepts 2 args."""
    return os.path.join(a, p)


gin.register(os.path.expanduser, module="os.path")
gin.register(os.path.expandvars, module="os.path")


@gin.register(module="kb.path")
def expand(path):
    return os.path.expanduser(os.path.expandvars(path))


@gin.configurable(module="kb.path")
def run_subdir(run: int = 0):
    return f"run-{run:02d}"


@gin.register(module="kb.path")
def temp_dir(subdir: Optional[str] = "kblocks"):