def forward_pass(batch): x = batch['x'] # [time_steps, batch_size, ...]. x = jnp.transpose(x) # [time_steps, batch_size, embed_dim]. embedding_layer = hk.Embed(full_vocab_size, embed_size) embeddings = embedding_layer(x) lstm_layers = [] for _ in range(lstm_num_layers): lstm_layers.extend([ hk.LSTM(hidden_size=lstm_hidden_size), jnp.tanh, # Projection changes dimension from lstm_hidden_size to embed_size. hk.Linear(embed_size) ]) rnn_core = hk.DeepRNN(lstm_layers) initial_state = rnn_core.initial_state(batch_size=embeddings.shape[1]) # [time_steps, batch_size, hidden_size]. output, _ = hk.static_unroll(rnn_core, embeddings, initial_state) if share_input_output_embeddings: output = jnp.dot(output, jnp.transpose(embedding_layer.embeddings)) output = hk.Bias(bias_dims=[-1])(output) else: output = hk.Linear(full_vocab_size)(output) # [batch_size, time_steps, full_vocab_size]. output = jnp.transpose(output, axes=(1, 0, 2)) return output
def __init__(self, model='lstm', ntoken=10000, nhid=650, nlayers=1, dropoute=0.0, dropouti=0.0, dropouth=0.0, dropouto=0.0, tie_weights=False, use_embeddings=True, with_bias=True): super().__init__() self.nhid = nhid self.ntoken = ntoken self.nlayers = nlayers self.dropoute = dropoute self.dropouti = dropouti self.dropouth = dropouth self.dropouto = dropouto self.tie_weights = tie_weights self.use_embeddings = use_embeddings if model == 'lstm': self.layers = [LSTMCell(nhid) for _ in range(nlayers)] initrange = 0.1 if use_embeddings: self.embedding = hk.Embed(ntoken, nhid, w_init=hk.initializers.RandomUniform( -initrange, initrange)) if self.tie_weights: self.decoder_bias = hk.Bias(b_init=hk.initializers.Constant(0.0)) else: self.decoder = hk.Linear( ntoken, with_bias=with_bias, # w_init=hk.initializers.RandomUniform(-initrange, initrange), # w_init=hk.initializers.RandomNormal(0.01), b_init=hk.initializers.Constant(0.0), )
shape=(BATCH_SIZE, 2, 2)), ModuleDescriptor( name="nets.MLP", create=lambda: hk.nets.MLP([3, 4, 5]), shape=(BATCH_SIZE, 3)), ) # Modules that require input to have a batch dimension. BATCH_MODULES = ( ModuleDescriptor( name="BatchNorm", create=lambda: Training(hk.BatchNorm(True, True, 0.9)), shape=(BATCH_SIZE, 2, 2, 3)), ModuleDescriptor( name="Bias", create=lambda: hk.Bias(), shape=(BATCH_SIZE, 3, 3, 3)), ModuleDescriptor( name="Flatten", create=lambda: hk.Flatten(), shape=(BATCH_SIZE, 3, 3, 3)), ModuleDescriptor( name="InstanceNorm", create=lambda: hk.InstanceNorm(True, True), shape=(BATCH_SIZE, 3, 2)), ModuleDescriptor( name="LayerNorm", create=lambda: hk.LayerNorm(1, True, True), shape=(BATCH_SIZE, 3, 2)), ModuleDescriptor( name="SpectralNorm",