Exemplo n.º 1
0
    def forward_pass(batch):
        x = batch['x']
        # [time_steps, batch_size, ...].
        x = jnp.transpose(x)
        # [time_steps, batch_size, embed_dim].
        embedding_layer = hk.Embed(full_vocab_size, embed_size)
        embeddings = embedding_layer(x)

        lstm_layers = []
        for _ in range(lstm_num_layers):
            lstm_layers.extend([
                hk.LSTM(hidden_size=lstm_hidden_size),
                jnp.tanh,
                # Projection changes dimension from lstm_hidden_size to embed_size.
                hk.Linear(embed_size)
            ])
        rnn_core = hk.DeepRNN(lstm_layers)
        initial_state = rnn_core.initial_state(batch_size=embeddings.shape[1])
        # [time_steps, batch_size, hidden_size].
        output, _ = hk.static_unroll(rnn_core, embeddings, initial_state)

        if share_input_output_embeddings:
            output = jnp.dot(output, jnp.transpose(embedding_layer.embeddings))
            output = hk.Bias(bias_dims=[-1])(output)
        else:
            output = hk.Linear(full_vocab_size)(output)
        # [batch_size, time_steps, full_vocab_size].
        output = jnp.transpose(output, axes=(1, 0, 2))
        return output
Exemplo n.º 2
0
    def __init__(self,
                 model='lstm',
                 ntoken=10000,
                 nhid=650,
                 nlayers=1,
                 dropoute=0.0,
                 dropouti=0.0,
                 dropouth=0.0,
                 dropouto=0.0,
                 tie_weights=False,
                 use_embeddings=True,
                 with_bias=True):
        super().__init__()
        self.nhid = nhid
        self.ntoken = ntoken
        self.nlayers = nlayers
        self.dropoute = dropoute
        self.dropouti = dropouti
        self.dropouth = dropouth
        self.dropouto = dropouto
        self.tie_weights = tie_weights
        self.use_embeddings = use_embeddings

        if model == 'lstm':
            self.layers = [LSTMCell(nhid) for _ in range(nlayers)]

        initrange = 0.1
        if use_embeddings:
            self.embedding = hk.Embed(ntoken,
                                      nhid,
                                      w_init=hk.initializers.RandomUniform(
                                          -initrange, initrange))

        if self.tie_weights:
            self.decoder_bias = hk.Bias(b_init=hk.initializers.Constant(0.0))
        else:
            self.decoder = hk.Linear(
                ntoken,
                with_bias=with_bias,
                # w_init=hk.initializers.RandomUniform(-initrange, initrange),
                # w_init=hk.initializers.RandomNormal(0.01),
                b_init=hk.initializers.Constant(0.0),
            )
Exemplo n.º 3
0
        shape=(BATCH_SIZE, 2, 2)),
    ModuleDescriptor(
        name="nets.MLP",
        create=lambda: hk.nets.MLP([3, 4, 5]),
        shape=(BATCH_SIZE, 3)),
)

# Modules that require input to have a batch dimension.
BATCH_MODULES = (
    ModuleDescriptor(
        name="BatchNorm",
        create=lambda: Training(hk.BatchNorm(True, True, 0.9)),
        shape=(BATCH_SIZE, 2, 2, 3)),
    ModuleDescriptor(
        name="Bias",
        create=lambda: hk.Bias(),
        shape=(BATCH_SIZE, 3, 3, 3)),
    ModuleDescriptor(
        name="Flatten",
        create=lambda: hk.Flatten(),
        shape=(BATCH_SIZE, 3, 3, 3)),
    ModuleDescriptor(
        name="InstanceNorm",
        create=lambda: hk.InstanceNorm(True, True),
        shape=(BATCH_SIZE, 3, 2)),
    ModuleDescriptor(
        name="LayerNorm",
        create=lambda: hk.LayerNorm(1, True, True),
        shape=(BATCH_SIZE, 3, 2)),
    ModuleDescriptor(
        name="SpectralNorm",