Beispiel #1
0
    def forward_pass(batch):
        x = batch['x']
        # [time_steps, batch_size, ...].
        x = jnp.transpose(x)
        # [time_steps, batch_size, embed_dim].
        embedding_layer = hk.Embed(full_vocab_size, embed_size)
        embeddings = embedding_layer(x)

        lstm_layers = []
        for _ in range(lstm_num_layers):
            lstm_layers.extend([
                hk.LSTM(hidden_size=lstm_hidden_size),
                jnp.tanh,
                # Projection changes dimension from lstm_hidden_size to embed_size.
                hk.Linear(embed_size)
            ])
        rnn_core = hk.DeepRNN(lstm_layers)
        initial_state = rnn_core.initial_state(batch_size=embeddings.shape[1])
        # [time_steps, batch_size, hidden_size].
        output, _ = hk.static_unroll(rnn_core, embeddings, initial_state)

        if share_input_output_embeddings:
            output = jnp.dot(output, jnp.transpose(embedding_layer.embeddings))
            output = hk.Bias(bias_dims=[-1])(output)
        else:
            output = hk.Linear(full_vocab_size)(output)
        # [batch_size, time_steps, full_vocab_size].
        output = jnp.transpose(output, axes=(1, 0, 2))
        return output
Beispiel #2
0
 def unroll(self, inputs: observation_action_reward.OAR,
            state: hk.LSTMState) -> Tuple[base.QValues, hk.LSTMState]:
     """Efficient unroll that applies torso, core, and duelling mlp in one pass."""
     embeddings = self._embed(inputs)  # [B?, T, D+A+1]
     core_outputs, new_states = hk.static_unroll(self._core, embeddings,
                                                 state)
     q_values = self._duelling_head(core_outputs)  # [B?, T, A]
     return q_values, new_states
Beispiel #3
0
  def unroll(self, inputs: observation_action_reward.OAR,
             state: hk.LSTMState) -> base.LSTMOutputs:
    """Efficient unroll that applies embeddings, MLP, & convnet in one pass."""
    embeddings = self._embed(inputs)
    embeddings, new_states = hk.static_unroll(self._core, embeddings, state)
    logits, values = self._head(embeddings)

    return (logits, values), new_states
Beispiel #4
0
        def loss(params: hk.Params,
                 sample: reverb.ReplaySample) -> jnp.ndarray:
            """Entropy-regularised actor-critic loss."""

            # Extract the data.
            observations, actions, rewards, discounts, extra = sample.data
            initial_state = tree.map_structure(lambda s: s[0],
                                               extra['core_state'])
            behaviour_logits = extra['logits']

            #
            actions = actions[:-1]  # [T-1]
            rewards = rewards[:-1]  # [T-1]
            discounts = discounts[:-1]  # [T-1]
            rewards = jnp.clip(rewards, -max_abs_reward, max_abs_reward)

            # Unroll current policy over observations.
            net = functools.partial(network.apply, params)
            (logits, values), _ = hk.static_unroll(net, observations,
                                                   initial_state)

            # Compute importance sampling weights: current policy / behavior policy.
            rhos = rlax.categorical_importance_sampling_ratios(
                logits[:-1], behaviour_logits[:-1], actions)

            # Critic loss.
            vtrace_returns = rlax.vtrace_td_error_and_advantage(
                v_tm1=values[:-1],
                v_t=values[1:],
                r_t=rewards,
                discount_t=discounts * discount,
                rho_t=rhos)
            critic_loss = jnp.square(vtrace_returns.errors)

            # Policy gradient loss.
            policy_gradient_loss = rlax.policy_gradient_loss(
                logits_t=logits[:-1],
                a_t=actions,
                adv_t=vtrace_returns.pg_advantage,
                w_t=jnp.ones_like(rewards))

            # Entropy regulariser.
            entropy_loss = rlax.entropy_loss(logits[:-1],
                                             jnp.ones_like(rewards))

            # Combine weighted sum of actor & critic losses.
            mean_loss = jnp.mean(policy_gradient_loss +
                                 baseline_cost * critic_loss +
                                 entropy_cost * entropy_loss)

            return mean_loss
Beispiel #5
0
  def forward_pass(batch):
    x = batch['x']
    # [time_steps, batch_size, ...].
    x = jnp.transpose(x)
    # [time_steps, batch_size, embed_dim].
    embedding_layer = hk.Embed(full_vocab_size, embed_size)
    embeddings = embedding_layer(x)

    lstm_layers = []
    for _ in range(lstm_num_layers):
      lstm_layers.extend([hk.LSTM(hidden_size=lstm_hidden_size), jnp.tanh])
    rnn_core = hk.DeepRNN(lstm_layers)
    initial_state = rnn_core.initial_state(batch_size=embeddings.shape[1])
    # [time_steps, batch_size, hidden_size].
    output, _ = hk.static_unroll(rnn_core, embeddings, initial_state)

    output = hk.Linear(full_vocab_size)(output)
    # [batch_size, time_steps, full_vocab_size].
    output = jnp.transpose(output, axes=(1, 0, 2))
    return output
Beispiel #6
0
 def forward(batch, is_training):
   x, _ = batch
   batch_size = x.shape[0]
   x = hk.Embed(vocab_size=max_features, embed_dim=embedding_size)(x)
   x = hk.Conv1D(output_channels=num_filters, kernel_shape=kernel_size,
                 padding="VALID")(x)
   if use_swish:
       x = jax.nn.swish(x)
   else:
       x = jax.nn.relu(x)
   if use_maxpool:
       x = hk.MaxPool(
           window_shape=pool_size, strides=pool_size, padding='VALID',
           channel_axis=2)(x)
   x = jnp.moveaxis(x, 1, 0)[:, :] #[T, B, F]
   lstm_layer = hk.LSTM(hidden_size=cell_size)
   init_state = lstm_layer.initial_state(batch_size)
   x, state = hk.static_unroll(lstm_layer, x, init_state)
   x = x[-1]
   logits = hk.Linear(num_classes)(x)
   return logits
Beispiel #7
0
 def unroll_fn(inputs, state):
     model = MyNetwork(spec.actions.num_values)
     return hk.static_unroll(model, inputs, state)

def lstm_model(x, vocab_size=10_000, seq_len=256, args=None, **_):
    embed_init = hk.initializers.TruncatedNormal(stddev=0.02)
    token_embedding_map = hk.Embed(vocab_size + 4, 100, w_init=embed_init)
    o2 = token_embedding_map(x)

    o2 = jnp.reshape(o2, (o2.shape[1], o2.shape[0], o2.shape[2]))

    # LSTM Part of Network
    core = hk.LSTM(100)
    if args and args.dynamic_unroll:
        outs, state = hk.dynamic_unroll(core, o2,
                                        core.initial_state(x.shape[0]))
    else:
        outs, state = hk.static_unroll(core, o2,
                                       core.initial_state(x.shape[0]))
    outs = outs.reshape(outs.shape[1], outs.shape[0], outs.shape[2])

    # Avg Pool -> Linear
    red_dim_outs = hk.avg_pool(outs, seq_len, seq_len, "SAME").squeeze()
    final_layer = hk.Linear(2)
    ret = final_layer(red_dim_outs)

    return ret


def embedding_model(arr, vocab_size=10_000, seq_len=256, **_):
    x = arr
    embed_init = hk.initializers.TruncatedNormal(stddev=0.02)
    token_embedding_map = hk.Embed(vocab_size + 4, 16, w_init=embed_init)