def forward_pass(batch): x = batch['x'] # [time_steps, batch_size, ...]. x = jnp.transpose(x) # [time_steps, batch_size, embed_dim]. embedding_layer = hk.Embed(full_vocab_size, embed_size) embeddings = embedding_layer(x) lstm_layers = [] for _ in range(lstm_num_layers): lstm_layers.extend([ hk.LSTM(hidden_size=lstm_hidden_size), jnp.tanh, # Projection changes dimension from lstm_hidden_size to embed_size. hk.Linear(embed_size) ]) rnn_core = hk.DeepRNN(lstm_layers) initial_state = rnn_core.initial_state(batch_size=embeddings.shape[1]) # [time_steps, batch_size, hidden_size]. output, _ = hk.static_unroll(rnn_core, embeddings, initial_state) if share_input_output_embeddings: output = jnp.dot(output, jnp.transpose(embedding_layer.embeddings)) output = hk.Bias(bias_dims=[-1])(output) else: output = hk.Linear(full_vocab_size)(output) # [batch_size, time_steps, full_vocab_size]. output = jnp.transpose(output, axes=(1, 0, 2)) return output
def unroll(self, inputs: observation_action_reward.OAR, state: hk.LSTMState) -> Tuple[base.QValues, hk.LSTMState]: """Efficient unroll that applies torso, core, and duelling mlp in one pass.""" embeddings = self._embed(inputs) # [B?, T, D+A+1] core_outputs, new_states = hk.static_unroll(self._core, embeddings, state) q_values = self._duelling_head(core_outputs) # [B?, T, A] return q_values, new_states
def unroll(self, inputs: observation_action_reward.OAR, state: hk.LSTMState) -> base.LSTMOutputs: """Efficient unroll that applies embeddings, MLP, & convnet in one pass.""" embeddings = self._embed(inputs) embeddings, new_states = hk.static_unroll(self._core, embeddings, state) logits, values = self._head(embeddings) return (logits, values), new_states
def loss(params: hk.Params, sample: reverb.ReplaySample) -> jnp.ndarray: """Entropy-regularised actor-critic loss.""" # Extract the data. observations, actions, rewards, discounts, extra = sample.data initial_state = tree.map_structure(lambda s: s[0], extra['core_state']) behaviour_logits = extra['logits'] # actions = actions[:-1] # [T-1] rewards = rewards[:-1] # [T-1] discounts = discounts[:-1] # [T-1] rewards = jnp.clip(rewards, -max_abs_reward, max_abs_reward) # Unroll current policy over observations. net = functools.partial(network.apply, params) (logits, values), _ = hk.static_unroll(net, observations, initial_state) # Compute importance sampling weights: current policy / behavior policy. rhos = rlax.categorical_importance_sampling_ratios( logits[:-1], behaviour_logits[:-1], actions) # Critic loss. vtrace_returns = rlax.vtrace_td_error_and_advantage( v_tm1=values[:-1], v_t=values[1:], r_t=rewards, discount_t=discounts * discount, rho_t=rhos) critic_loss = jnp.square(vtrace_returns.errors) # Policy gradient loss. policy_gradient_loss = rlax.policy_gradient_loss( logits_t=logits[:-1], a_t=actions, adv_t=vtrace_returns.pg_advantage, w_t=jnp.ones_like(rewards)) # Entropy regulariser. entropy_loss = rlax.entropy_loss(logits[:-1], jnp.ones_like(rewards)) # Combine weighted sum of actor & critic losses. mean_loss = jnp.mean(policy_gradient_loss + baseline_cost * critic_loss + entropy_cost * entropy_loss) return mean_loss
def forward_pass(batch): x = batch['x'] # [time_steps, batch_size, ...]. x = jnp.transpose(x) # [time_steps, batch_size, embed_dim]. embedding_layer = hk.Embed(full_vocab_size, embed_size) embeddings = embedding_layer(x) lstm_layers = [] for _ in range(lstm_num_layers): lstm_layers.extend([hk.LSTM(hidden_size=lstm_hidden_size), jnp.tanh]) rnn_core = hk.DeepRNN(lstm_layers) initial_state = rnn_core.initial_state(batch_size=embeddings.shape[1]) # [time_steps, batch_size, hidden_size]. output, _ = hk.static_unroll(rnn_core, embeddings, initial_state) output = hk.Linear(full_vocab_size)(output) # [batch_size, time_steps, full_vocab_size]. output = jnp.transpose(output, axes=(1, 0, 2)) return output
def forward(batch, is_training): x, _ = batch batch_size = x.shape[0] x = hk.Embed(vocab_size=max_features, embed_dim=embedding_size)(x) x = hk.Conv1D(output_channels=num_filters, kernel_shape=kernel_size, padding="VALID")(x) if use_swish: x = jax.nn.swish(x) else: x = jax.nn.relu(x) if use_maxpool: x = hk.MaxPool( window_shape=pool_size, strides=pool_size, padding='VALID', channel_axis=2)(x) x = jnp.moveaxis(x, 1, 0)[:, :] #[T, B, F] lstm_layer = hk.LSTM(hidden_size=cell_size) init_state = lstm_layer.initial_state(batch_size) x, state = hk.static_unroll(lstm_layer, x, init_state) x = x[-1] logits = hk.Linear(num_classes)(x) return logits
def unroll_fn(inputs, state): model = MyNetwork(spec.actions.num_values) return hk.static_unroll(model, inputs, state)
def lstm_model(x, vocab_size=10_000, seq_len=256, args=None, **_): embed_init = hk.initializers.TruncatedNormal(stddev=0.02) token_embedding_map = hk.Embed(vocab_size + 4, 100, w_init=embed_init) o2 = token_embedding_map(x) o2 = jnp.reshape(o2, (o2.shape[1], o2.shape[0], o2.shape[2])) # LSTM Part of Network core = hk.LSTM(100) if args and args.dynamic_unroll: outs, state = hk.dynamic_unroll(core, o2, core.initial_state(x.shape[0])) else: outs, state = hk.static_unroll(core, o2, core.initial_state(x.shape[0])) outs = outs.reshape(outs.shape[1], outs.shape[0], outs.shape[2]) # Avg Pool -> Linear red_dim_outs = hk.avg_pool(outs, seq_len, seq_len, "SAME").squeeze() final_layer = hk.Linear(2) ret = final_layer(red_dim_outs) return ret def embedding_model(arr, vocab_size=10_000, seq_len=256, **_): x = arr embed_init = hk.initializers.TruncatedNormal(stddev=0.02) token_embedding_map = hk.Embed(vocab_size + 4, 16, w_init=embed_init)