Esempio n. 1
0
    def __init__(self, word_vocab_size: int, doc_vocab_size: int,
                 embedding_size: int, window_size: int, context_mode: str,
                 name: str):
        super().__init__(name=name)
        self.word_embedder = hk.Embed(vocab_size=word_vocab_size,
                                      embed_dim=embedding_size,
                                      name='word_embeddings')
        self.doc_embedder = hk.Embed(vocab_size=doc_vocab_size,
                                     embed_dim=embedding_size,
                                     name='doc_embeddings')
        self.fc = hk.Linear(output_size=word_vocab_size,
                            name='fully_connected')

        self.window_size = window_size
        self.embedding_size = embedding_size
        self.context_mode = context_mode
Esempio n. 2
0
    def forward_fn(data: Mapping[str, jnp.ndarray],
                   is_training: bool = True) -> jnp.ndarray:
        """Forward pass."""
        tokens = data['obs']
        input_mask = jnp.greater(tokens, 0)
        seq_length = tokens.shape[1]

        # Embed the input tokens and positions.
        embed_init = hk.initializers.TruncatedNormal(stddev=0.02)
        token_embedding_map = hk.Embed(vocab_size, d_model, w_init=embed_init)
        token_embs = token_embedding_map(tokens)
        positional_embeddings = hk.get_parameter('pos_embs',
                                                 [seq_length, d_model],
                                                 init=embed_init)
        input_embeddings = token_embs + positional_embeddings

        # Run the transformer over the inputs.
        transformer = model.Transformer(num_heads=num_heads,
                                        num_layers=num_layers,
                                        dropout_rate=dropout_rate)
        output_embeddings = transformer(input_embeddings, input_mask,
                                        is_training)

        # Reverse the embeddings (untied).
        return hk.Linear(vocab_size)(output_embeddings)
Esempio n. 3
0
    def forward_pass(batch):
        x = batch['x']
        # [time_steps, batch_size, ...].
        x = jnp.transpose(x)
        # [time_steps, batch_size, embed_dim].
        embedding_layer = hk.Embed(full_vocab_size, embed_size)
        embeddings = embedding_layer(x)

        lstm_layers = []
        for _ in range(lstm_num_layers):
            lstm_layers.extend([
                hk.LSTM(hidden_size=lstm_hidden_size),
                jnp.tanh,
                # Projection changes dimension from lstm_hidden_size to embed_size.
                hk.Linear(embed_size)
            ])
        rnn_core = hk.DeepRNN(lstm_layers)
        initial_state = rnn_core.initial_state(batch_size=embeddings.shape[1])
        # [time_steps, batch_size, hidden_size].
        output, _ = hk.static_unroll(rnn_core, embeddings, initial_state)

        if share_input_output_embeddings:
            output = jnp.dot(output, jnp.transpose(embedding_layer.embeddings))
            output = hk.Bias(bias_dims=[-1])(output)
        else:
            output = hk.Linear(full_vocab_size)(output)
        # [batch_size, time_steps, full_vocab_size].
        output = jnp.transpose(output, axes=(1, 0, 2))
        return output
Esempio n. 4
0
        def embed_forward(x):
            embed_init = hk.initializers.TruncatedNormal(stddev=0.02)

            seq_length = x.shape[1]
            positional_embeddings = hk.get_parameter('pos_embs', [seq_length, d_model], init=embed_init)

            o = hk.Embed(vocab, d_model, w_init=embed_init, name="embedding")(x) + positional_embeddings

            return o
Esempio n. 5
0
    def __init__(
            self,
            embed_dim: int,
            # vision
            image_resolution: int,
            vision_layers: int,
            vision_width: int,
            vision_patch_size: int,
            # text
            context_length: int,
            vocab_size: int,
            transformer_width: int,
            transformer_heads: int,
            transformer_layers: int):
        super().__init__()

        self.context_length = context_length

        vision_heads = vision_width // 64

        self.visual = VisualTransformer(input_resolution=image_resolution,
                                        patch_size=vision_patch_size,
                                        width=vision_width,
                                        layers=vision_layers,
                                        heads=vision_heads,
                                        output_dim=embed_dim,
                                        name="visual")

        self.transformer = Transformer(width=transformer_width,
                                       layers=transformer_layers,
                                       heads=transformer_heads,
                                       attn_mask=self.build_attention_mask(),
                                       name="transformer")

        self.vocab_size = vocab_size
        self.token_embedding = hk.Embed(vocab_size,
                                        transformer_width,
                                        name="token_embedding")

        scale = transformer_width**-0.5
        w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(scale))
        self.positional_embedding = hk.get_parameter(
            "positional_embedding",
            shape=[self.context_length, transformer_width],
            init=w_init)
        self.ln_final = LayerNorm(-1,
                                  create_scale=True,
                                  create_offset=True,
                                  name="ln_final")

        self.text_projection = hk.get_parameter(
            "text_projection",
            shape=[transformer_width, embed_dim],
            init=w_init)
        self.logit_scale = hk.get_parameter("logit_scale",
                                            shape=[],
                                            init=hk.initializers.Constant(1))
Esempio n. 6
0
    def __init__(self, doc_vocab_size: int, word_vocab_size: int,
                 embedding_size: int, name: str):
        super().__init__(name=name)
        self.doc_embedder = hk.Embed(vocab_size=doc_vocab_size,
                                     embed_dim=embedding_size,
                                     name='doc_embeddings')
        self.fc = hk.Linear(output_size=word_vocab_size,
                            name='fully_connected')

        self.embedding_size = embedding_size
Esempio n. 7
0
 def forward(inputs, labels, extra):
   input_mask = (labels != 0).astype(jnp.float32)
   extra_mask = (extra != 0).astype(jnp.float32)
   extra = hk.Embed(vocab_size=extra_vocab_size, embed_dim=16)(extra)
   model = models.TransformerXL(
       vocab_size=vocab_size,
       emb_dim=16,
       num_layers=2,
       num_heads=4,
       cutoffs=[],
   )
   return model.loss(inputs, labels, mask=input_mask,
                     extra=extra, extra_mask=extra_mask)
Esempio n. 8
0
 def __init__(self, vocab_size, lstm_dim, dropout_rate, is_training=True):
   super().__init__()
   self.is_training = is_training
   self.embed = hk.Embed(vocab_size, lstm_dim)
   self.conv1 = hk.Conv1D(lstm_dim, 3, padding='SAME')
   self.conv2 = hk.Conv1D(lstm_dim, 3, padding='SAME')
   self.conv3 = hk.Conv1D(lstm_dim, 3, padding='SAME')
   self.bn1 = hk.BatchNorm(True, True, 0.9)
   self.bn2 = hk.BatchNorm(True, True, 0.9)
   self.bn3 = hk.BatchNorm(True, True, 0.9)
   self.lstm_fwd = hk.LSTM(lstm_dim)
   self.lstm_bwd = hk.ResetCore(hk.LSTM(lstm_dim))
   self.dropout_rate = dropout_rate
Esempio n. 9
0
def embeddings(data: Mapping[str, jnp.ndarray], vocab_size: int):
    tokens = data['obs']
    input_mask = jnp.greater(tokens, 0)
    seq_length = tokens.shape[1]

    # Embed the input tokens and positions.
    embed_init = hk.initializers.TruncatedNormal(stddev=0.02)
    token_embedding_map = hk.Embed(vocab_size, d_model, w_init=embed_init)
    token_embs = token_embedding_map(tokens)
    positional_embeddings = hk.get_parameter('pos_embs', [seq_length, d_model],
                                             init=embed_init)
    input_embeddings = token_embs + positional_embeddings
    return input_embeddings, input_mask
Esempio n. 10
0
  def __init__(self,
               word_embedding_matrix,
               sentence_dim=1024,
               name="text_module"):
    """Initialize text module.

    Args:
      word_embedding_matrix: 2d matrix [vocab_size, embed_size] to embed words.
      sentence_dim: dimension of sentence representation.
      name: module name.
    """
    super(TextModule, self).__init__(name=name)
    self._word_embedding_module = hk.Embed(
        embedding_matrix=word_embedding_matrix)
    self._conv1d_module = hk.Conv1D(sentence_dim, 1, name="text_conv1")
Esempio n. 11
0
    def __init__(self,
                 model='lstm',
                 ntoken=10000,
                 nhid=650,
                 nlayers=1,
                 dropoute=0.0,
                 dropouti=0.0,
                 dropouth=0.0,
                 dropouto=0.0,
                 tie_weights=False,
                 use_embeddings=True,
                 with_bias=True):
        super().__init__()
        self.nhid = nhid
        self.ntoken = ntoken
        self.nlayers = nlayers
        self.dropoute = dropoute
        self.dropouti = dropouti
        self.dropouth = dropouth
        self.dropouto = dropouto
        self.tie_weights = tie_weights
        self.use_embeddings = use_embeddings

        if model == 'lstm':
            self.layers = [LSTMCell(nhid) for _ in range(nlayers)]

        initrange = 0.1
        if use_embeddings:
            self.embedding = hk.Embed(ntoken,
                                      nhid,
                                      w_init=hk.initializers.RandomUniform(
                                          -initrange, initrange))

        if self.tie_weights:
            self.decoder_bias = hk.Bias(b_init=hk.initializers.Constant(0.0))
        else:
            self.decoder = hk.Linear(
                ntoken,
                with_bias=with_bias,
                # w_init=hk.initializers.RandomUniform(-initrange, initrange),
                # w_init=hk.initializers.RandomNormal(0.01),
                b_init=hk.initializers.Constant(0.0),
            )
Esempio n. 12
0
  def forward_pass(batch):
    x = batch['x']
    # [time_steps, batch_size, ...].
    x = jnp.transpose(x)
    # [time_steps, batch_size, embed_dim].
    embedding_layer = hk.Embed(full_vocab_size, embed_size)
    embeddings = embedding_layer(x)

    lstm_layers = []
    for _ in range(lstm_num_layers):
      lstm_layers.extend([hk.LSTM(hidden_size=lstm_hidden_size), jnp.tanh])
    rnn_core = hk.DeepRNN(lstm_layers)
    initial_state = rnn_core.initial_state(batch_size=embeddings.shape[1])
    # [time_steps, batch_size, hidden_size].
    output, _ = hk.static_unroll(rnn_core, embeddings, initial_state)

    output = hk.Linear(full_vocab_size)(output)
    # [batch_size, time_steps, full_vocab_size].
    output = jnp.transpose(output, axes=(1, 0, 2))
    return output
Esempio n. 13
0
    def forward_fn(data: Mapping[str, jnp.ndarray],
                   is_training: bool = True) -> jnp.ndarray:
        """Forward pass."""
        tokens = data['obs']
        input_mask = jnp.greater(tokens, 0)
        batch_size, seq_length = tokens.shape

        # Embed the input tokens and positions.
        embed_init = hk.initializers.TruncatedNormal(stddev=0.02)
        token_embedding_map = hk.Embed(vocab_size, d_model, w_init=embed_init)
        input_embeddings = token_embedding_map(tokens)
        positional_embeddings = hk.get_parameter('pos_embs',
                                                 [seq_length, d_model],
                                                 init=embed_init)

        x = input_embeddings + positional_embeddings
        h = jnp.zeros_like(x)

        # Create transformer block
        transformer_block = model.UTBlock(num_heads=num_heads,
                                          num_layers=num_layers,
                                          dropout_rate=dropout_rate)

        transformed_net = hk.transform(transformer_block)

        # lift params
        inner_params = hk.experimental.lift(transformed_net.init)(
            hk.next_rng_key(), h, x, input_mask, is_training)

        def f(_params, _rng, _z, *args):
            return transformed_net.apply(_params,
                                         _rng,
                                         _z,
                                         *args,
                                         is_training=is_training)

        z_star = deq(inner_params, hk.next_rng_key(), h, f, max_iter, x,
                     input_mask)

        # Reverse the embeddings (untied).
        return hk.Linear(vocab_size)(z_star)
Esempio n. 14
0
 def forward(batch, is_training):
   x, _ = batch
   batch_size = x.shape[0]
   x = hk.Embed(vocab_size=max_features, embed_dim=embedding_size)(x)
   x = hk.Conv1D(output_channels=num_filters, kernel_shape=kernel_size,
                 padding="VALID")(x)
   if use_swish:
       x = jax.nn.swish(x)
   else:
       x = jax.nn.relu(x)
   if use_maxpool:
       x = hk.MaxPool(
           window_shape=pool_size, strides=pool_size, padding='VALID',
           channel_axis=2)(x)
   x = jnp.moveaxis(x, 1, 0)[:, :] #[T, B, F]
   lstm_layer = hk.LSTM(hidden_size=cell_size)
   init_state = lstm_layer.initial_state(batch_size)
   x, state = hk.static_unroll(lstm_layer, x, init_state)
   x = x[-1]
   logits = hk.Linear(num_classes)(x)
   return logits
Esempio n. 15
0
    def __init__(self,
                 input_dim,
                 hidden_dim,
                 latent_dim,
                 max_num_segments,
                 temp_b=1.,
                 temp_z=1.,
                 latent_dist='gaussian',
                 name='compile'):
        super().__init__(name=name)

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.max_num_segments = max_num_segments
        self.temp_b = temp_b
        self.temp_z = temp_z
        self.latent_dist = latent_dist

        self.embed = hk.Embed(input_dim, hidden_dim)
        self.lstm_cell = hk.LSTM(hidden_dim)

        # LSTM output heads.
        self.head_z_1 = hk.Linear(hidden_dim)  # Latents (z).

        if latent_dist == 'gaussian':
            self.head_z_2 = hk.Linear(latent_dim * 2)
        elif latent_dist == 'concrete':
            self.head_z_2 = hk.Linear(latent_dim)
        else:
            raise ValueError('Invalid argument for `latent_dist`.')

        self.head_b_1 = hk.Linear(hidden_dim)  # Boundaries (b).
        self.head_b_2 = hk.Linear(1)

        # Decoder MLP.
        self.decode_1 = hk.Linear(hidden_dim)
        self.decode_2 = hk.Linear(input_dim)
Esempio n. 16
0
class ModuleDescriptor(NamedTuple):
  name: Any
  create: ModuleFn
  shape: Shape
  dtype: DType = jnp.float32


BATCH_SIZE = 8

# pylint: disable=unnecessary-lambda
# Modules that have equivalent behaviour with or without a batch dimension.
OPTIONAL_BATCH_MODULES = (
    ModuleDescriptor(
        name="Embed",
        create=lambda: hk.Embed(vocab_size=6, embed_dim=12),
        shape=(BATCH_SIZE,),
        dtype=jnp.int32),
    ModuleDescriptor(
        name="Linear",
        create=lambda: hk.Linear(10),
        shape=(BATCH_SIZE, 1)),
    ModuleDescriptor(
        name="Sequential",
        create=lambda: hk.Sequential([lambda x: x]),
        shape=(BATCH_SIZE, 2, 2)),
    ModuleDescriptor(
        name="nets.MLP",
        create=lambda: hk.nets.MLP([3, 4, 5]),
        shape=(BATCH_SIZE, 3)),
)
Esempio n. 17
0
        hk.Conv2D(16, (8, 8), padding='SAME', stride=(2, 2)),
        jax.nn.relu,
        hk.MaxPool(2, 1, padding='VALID'),  # matches stax
        hk.Conv2D(32, (4, 4), padding='VALID', stride=(2, 2)),
        jax.nn.relu,
        hk.MaxPool(2, 1, padding='VALID'),  # matches stax
        hk.Flatten(),
        hk.Linear(32),
        jax.nn.relu,
        hk.Linear(10),
    ])(features)


def lstm_model(x, vocab_size=10_000, seq_len=256, args=None, **_):
    embed_init = hk.initializers.TruncatedNormal(stddev=0.02)
    token_embedding_map = hk.Embed(vocab_size + 4, 100, w_init=embed_init)
    o2 = token_embedding_map(x)

    o2 = jnp.reshape(o2, (o2.shape[1], o2.shape[0], o2.shape[2]))

    # LSTM Part of Network
    core = hk.LSTM(100)
    if args and args.dynamic_unroll:
        outs, state = hk.dynamic_unroll(core, o2,
                                        core.initial_state(x.shape[0]))
    else:
        outs, state = hk.static_unroll(core, o2,
                                       core.initial_state(x.shape[0]))
    outs = outs.reshape(outs.shape[1], outs.shape[0], outs.shape[2])

    # Avg Pool -> Linear