Пример #1
0
    def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple:
        """Compute embeddings for each node in the graphs.

    Args:
      graphs: a set of graphs batched into a single graph.  The nodes and edges
        are represented as feature tensors.

    Returns:
      graphs: new graph with node embeddings updated (shape [n_nodes,
        embed_dim]).
    """
        nodes = hk.Linear(self._embed_dim)(graphs.nodes)
        edges = hk.Linear(self._embed_dim)(graphs.edges)

        nodes = hk.LayerNorm(axis=-1, create_scale=True,
                             create_offset=True)(jax.nn.gelu(nodes))
        edges = hk.LayerNorm(axis=-1, create_scale=True,
                             create_offset=True)(jax.nn.gelu(edges))

        graphs = graphs._replace(nodes=nodes, edges=edges)
        graphs = gn.SimpleGraphNet(
            num_layers=self._num_layers,
            msg_hidden_size_factor=self._msg_hidden_size_factor,
            layer_norm=self._use_layer_norm)(graphs)
        return graphs
Пример #2
0
    def __init__(self, C, attention_fn, name='SlotAttention'):
        super().__init__(name=name)
        he_init = hk.initializers.VarianceScaling(scale=2.0)

        self.num_slots = C['slots']
        self.slot_size = C['slot_size'] 
        self.attn_eps = C['attention_eps']
        self.mlp_hidden_size = C['mlp_hidden_size']
        # Learnable mu and sigma (no covar) (dim slot_dim) to initilialize slots

        # Learnable linear transforms for K,Q,V Attention - Use Glorot init here?
        self.k = hk.Linear(self.slot_size, w_init=he_init, with_bias=False)
        self.q = hk.Linear(self.slot_size, w_init=he_init, with_bias=False)
        self.v = hk.Linear(self.slot_size, w_init=he_init, with_bias=False)

        # Layer norm for slots after and before attention (GRU Omitted)
        self.layer_norm_in = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)
        self.layer_norm_1 = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)
        self.layer_norm_2 = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)

        # Slot update function is learned by GRU - #hidden_states = slot_dim
        self.mlp = hk.Sequential([
            hk.Linear(self.mlp_hidden_size, w_init=he_init), jax.nn.relu,# MLP + Residual Connection improves output
            hk.Linear(self.slot_size, w_init=he_init)
        ])

        self.attention_fn = attention_fn
Пример #3
0
    def __init__(self,
                 layer_sizes: Sequence[int],
                 w_init: hk.initializers.Initializer = uniform_initializer,
                 activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.elu,
                 activate_final: bool = False,
                 name: str = 'feedforward_mlp_torso'):
        """Construct the MLP.

    Args:
      layer_sizes: a sequence of ints specifying the size of each layer.
      w_init: initializer for Linear layers.
      activation: nonlinearity to use in the MLP, defaults to elu.
        Note! The default activation differs from the usual MLP default of ReLU
        for legacy reasons.
      activate_final: whether or not to use the activation function on the final
        layer of the neural network.
      name: a name for the module.
    """
        super().__init__(name=name)

        self._network = hk.Sequential([
            hk.Linear(layer_sizes[0], w_init=w_init),
            hk.LayerNorm(axis=-1, create_scale=True, create_offset=True),
            jax.lax.tanh,
            hk.nets.MLP(layer_sizes[1:],
                        w_init=w_init,
                        activation=activation,
                        activate_final=activate_final),
        ])
Пример #4
0
 def large_critic(x):
     # inspired by the ones used in RL Unplugged
     x = hk.Sequential([
         hk.Linear(400, w_init=rlu_uniform_initializer),
         hk.LayerNorm(axis=-1, create_scale=True, create_offset=True),
         jax.lax.tanh,
     ])(x)
     x = hk.Linear(1024, w_init=rlu_uniform_initializer)(x)
     for i in range(4):
         x = network_utils.ResidualLayerNormBlock(
             [1024, 1024],
             activation=jax.nn.relu,
             w_init=rlu_uniform_initializer,
         )(x)
     h = x
     # v = hk.Linear(1, w_init=rlu_uniform_initializer)(h)
     # v = hk.Linear(critic_output_dim)(h)
     all_vs = []
     for _ in range(critic_output_dim):
         head_v = hk.Linear(256, w_init=rlu_uniform_initializer)(h)
         head_v = jax.nn.relu(head_v)
         head_v = hk.Linear(1, w_init=rlu_uniform_initializer)(head_v)
         all_vs.append(head_v)
     v = jnp.concatenate(all_vs, axis=-1)
     return v, h
Пример #5
0
 def base_network(x, layers):
   x = hk.nets.MLP(layers[:-1], activate_final=True)(x)
   x = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(x)
   x = hk.Linear(layers[-1])(x)
   x = jax.nn.relu(x)
   x = hk.Linear(self._num_actions)(x)
   return x
Пример #6
0
    def __init__(self, C, position_enc_fn, name=None):
        super().__init__(name=name)
        he_init = hk.initializers.VarianceScaling(scale=2.0)

        channels = C['encoder_cnn_channels']
        kernels  = C['encoder_cnn_kernels']
        strides  = C['encoder_cnn_strides']

        hidden_size = channels[-1]
        self.cnn_layers = hk.Sequential([
            hk.Conv2D(channels[0], kernels[0], stride=strides[0], padding='SAME', w_init=he_init, with_bias=True), jax.nn.relu,
            hk.Conv2D(channels[1], kernels[1], stride=strides[1], padding='SAME', w_init=he_init, with_bias=True), jax.nn.relu,
            hk.Conv2D(channels[2], kernels[2], stride=strides[2], padding='SAME', w_init=he_init, with_bias=True), jax.nn.relu,
            hk.Conv2D(hidden_size, kernels[3], stride=strides[3], padding='SAME', w_init=he_init, with_bias=True), jax.nn.relu,
        
        ])

        self.pos_embed = SoftPositionEmbed(hidden_size, C['hidden_res'], position_enc_fn)

        self.linears = hk.Sequential([ # i.e. 1x1 convolution (shared 32 neurons across all locations)
            hk.Reshape((-1, hidden_size)), # Flatten spatial dim (works with batch)
            hk.LayerNorm(axis=-1, create_scale=True, create_offset=True),
            hk.Linear(32, w_init=he_init), jax.nn.relu,
            hk.Linear(32, w_init=he_init),
        ])
Пример #7
0
 def single_mlp(inner_name: str):
     """Creates a single MLP performing the update."""
     mlp = hk.nets.MLP(output_sizes=output_sizes,
                       name=inner_name,
                       activation=activation)
     mlp = jraph.concatenated_args(mlp)
     if normalization_type == 'layer_norm':
         norm = hk.LayerNorm(axis=-1,
                             create_scale=True,
                             create_offset=True,
                             name=name + '_layer_norm')
     elif normalization_type == 'batch_norm':
         batch_norm = hk.BatchNorm(
             create_scale=True,
             create_offset=True,
             decay_rate=0.9,
             name=f'{inner_name}_batch_norm',
             cross_replica_axis=None if hk.running_init() else 'i',
         )
         norm = lambda x: batch_norm(x, is_training)
     elif normalization_type == 'none':
         return mlp
     else:
         raise ValueError(
             f'Unknown normalization type {normalization_type}')
     return jraph.concatenated_args(hk.Sequential([mlp, norm]))
Пример #8
0
    def __init__(self, config, name=None):
        super().__init__(name=name)
        out_dim = config["n_vocab"]

        self.dim = out_dim
        self.norm = hk.LayerNorm(-1, True, True)

        self.proj = hk.Linear(self.dim)
Пример #9
0
    def __init__(self,
                 make_inner_op: MakeInnerOp,
                 non_linearity: NonLinearity = jax.nn.relu,
                 use_layer_norm: bool = False,
                 name: str = 'residual_block'):
        super().__init__(name=name)
        self.inner_op1 = make_inner_op()
        self.inner_op2 = make_inner_op()
        self.non_linearity = non_linearity
        self.use_layer_norm = use_layer_norm

        if use_layer_norm:
            self.layernorm1 = hk.LayerNorm(axis=(1, 2, 3),
                                           create_scale=True,
                                           create_offset=True,
                                           eps=1e-6)
            self.layernorm2 = hk.LayerNorm(axis=(1, 2, 3),
                                           create_scale=True,
                                           create_offset=True,
                                           eps=1e-6)
Пример #10
0
            def norm(name):
                layer_norm = hk.LayerNorm(axis=-1,
                                          name=name,
                                          create_scale=True,
                                          create_offset=True)

                def norm_apply(
                    x, **kwargs
                ):  # So that this code works with the is_training kwarg
                    return layer_norm(x)

                return norm_apply
Пример #11
0
def quantile_net(x, quantile_fractions):
    x_size = x.shape[-1]
    x_tiled = jnp.tile(x[:, None, :], [num_bins, 1])
    quantiles_emb = quantile_cos_embedding(quantile_fractions)
    quantiles_emb = hk.Linear(x_size)(quantiles_emb)
    quantiles_emb = hk.LayerNorm(axis=-1,
                                 create_scale=True,
                                 create_offset=True)(quantiles_emb)
    quantiles_emb = jax.nn.sigmoid(quantiles_emb)
    x = x_tiled * quantiles_emb
    x = hk.Linear(x_size)(x)
    x = jax.nn.relu(x)
    return x
Пример #12
0
    def __init__(self, config, name=None, init_scale=1.):
        super().__init__(name=name)
        self.dim = config["d_model"]
        self.n_head = config["n_heads"]
        self.d_head = config["d_head"]
        self.d_rotary = config["pe_rotary_dims"]
        self.mp_num = thread_resources.env.shape['mp']

        self.norm = hk.LayerNorm(-1, True, True)
        self.input_proj = hk.Linear(self.d_head * self.n_head * 3 +
                                    self.dim * 4)
        self.output_proj = hk.Linear(
            self.dim,
            w_init=hk.initializers.TruncatedNormal(stddev=init_scale /
                                                   jnp.sqrt(self.dim)))
Пример #13
0
def mlp(
    x: tp.Union[jnp.ndarray, JAXSparse],
    is_training: bool,
    ids=None,
    num_classes: tp.Optional[int] = None,
    hidden_filters: tp.Union[int, tp.Iterable[int]] = 64,
    dropout_rate: float = 0.8,
    use_batch_norm: bool = False,
    use_layer_norm: bool = False,
    use_renormalize: bool = False,
    use_gathered_batch_norm: bool = False,
    activation: Activation = jax.nn.relu,
    final_activation: Activation = lambda x: x,
    input_dropout_rate: tp.Optional[float] = None,
    batch_norm_decay: float = 0.9,
    renorm_scale: bool = True,
    w_init=None,
):
    assert (sum((use_batch_norm, use_layer_norm, use_renormalize,
                 use_gathered_batch_norm)) <= 1)
    if input_dropout_rate is None:
        input_dropout_rate = dropout_rate
    if isinstance(hidden_filters, int):
        hidden_filters = (hidden_filters, )

    x = dropout(x, input_dropout_rate, is_training=is_training)
    for filters in hidden_filters:
        x = Linear(filters, w_init=w_init)(x)
        if use_batch_norm:
            x = hk.BatchNorm(renorm_scale, True, batch_norm_decay)(x,
                                                                   is_training)
        if use_layer_norm:
            x = hk.LayerNorm(0, renorm_scale, True)(x)
        if use_renormalize:
            x = Renormalize(renorm_scale, True)(x)
        if use_gathered_batch_norm:
            assert ids is not None
            x = GatheredBatchNorm(True, True,
                                  batch_norm_decay)(x,
                                                    is_training=is_training,
                                                    ids=ids)
        x = activation(x)
        x = dropout(x, dropout_rate, is_training=is_training)
    if num_classes is not None:
        x = hk.Linear(num_classes, w_init=w_init)(x)
    return final_activation(x)
Пример #14
0
def _build_mlp(
    name: str,
    output_sizes: Sequence[int],
    use_layer_norm=False,
    activation=jax.nn.relu,
):
    """Builds an MLP, optionally with layernorm."""
    net = hk.nets.MLP(output_sizes=output_sizes,
                      name=name + "_mlp",
                      activation=activation)
    if use_layer_norm:
        layer_norm = hk.LayerNorm(axis=-1,
                                  create_scale=True,
                                  create_offset=True,
                                  name=name + "_layer_norm")
        net = hk.Sequential([net, layer_norm])
    return jraph.concatenated_args(net)
Пример #15
0
def getnorm(type):
    if type == "layernorm":
        return ReplicatedLayerNorm()
    if type == "layernorm-desync":
        return hk.LayerNorm(-1, True, True)
    elif type == "layernorm-nobias":
        return ReplicatedLayerNorm(offset=False)
    elif type == "rmsnorm":
        return RMSNorm(False, True)
    elif type == "scalenorm":
        return RMSNorm(False, False)
    elif type == "rmsnorm-bias":
        return RMSNorm(True, True)
    elif type == "scalenorm-bias":
        return RMSNorm(True, False)
    else:
        raise Exception("Not implemented")
Пример #16
0
def make_downsampling_layer(
    strategy: Union[str, DownsamplingStrategy],
    output_channels: int,
) -> hk.SupportsCall:
    """Returns a sequence of modules corresponding to the desired downsampling."""
    strategy = DownsamplingStrategy(strategy)

    if strategy is DownsamplingStrategy.AVG_POOL:
        return hk.AvgPool(window_shape=(3, 3, 1),
                          strides=(2, 2, 1),
                          padding='SAME')

    elif strategy is DownsamplingStrategy.CONV:
        return hk.Sequential([
            hk.Conv2D(output_channels,
                      kernel_shape=3,
                      stride=2,
                      w_init=hk.initializers.TruncatedNormal(1e-2)),
        ])

    elif strategy is DownsamplingStrategy.LAYERNORM_RELU_CONV:
        return hk.Sequential([
            hk.LayerNorm(axis=(1, 2, 3),
                         create_scale=True,
                         create_offset=True,
                         eps=1e-6),
            jax.nn.relu,
            hk.Conv2D(output_channels,
                      kernel_shape=3,
                      stride=2,
                      w_init=hk.initializers.TruncatedNormal(1e-2)),
        ])

    elif strategy is DownsamplingStrategy.CONV_MAX:
        return hk.Sequential([
            hk.Conv2D(output_channels, kernel_shape=3, stride=1),
            hk.MaxPool(window_shape=(3, 3, 1),
                       strides=(2, 2, 1),
                       padding='SAME')
        ])
    else:
        raise ValueError(
            'Unrecognized downsampling strategy. Expected one of'
            f' {[strategy.value for strategy in DownsamplingStrategy]}'
            f' but received {strategy}.')
Пример #17
0
  def _update_nodes(self, graph: jraph.GraphsTuple,
                    messages: ArrayType) -> ArrayType:
    """Compute updated node representations."""
    x = jax.ops.segment_sum(messages, graph.receivers,
                            num_segments=graph.nodes.shape[0])
    x = jnp.concatenate([graph.nodes, x], axis=-1)

    layer_sizes = self._node_hidden_sizes[:]
    if self._residual:
      layer_sizes += [graph.nodes.shape[-1]]

    x = hk.nets.MLP(layer_sizes, activate_final=False)(x)
    if self._layer_norm:
      x = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(x)

    if self._residual:
      return graph.nodes + x
    else:
      return x
Пример #18
0
  def __init__(
      self,
      layer_sizes,
      activation,
      with_bias=True,
      w_init=None,
      b_init=None,
      name = 'ResidualLayerNormBlock'):
    super().__init__(name=name)

    self._mlp = hk.nets.MLP(
        output_sizes=layer_sizes,
        w_init=w_init,
        b_init=b_init,
        with_bias=with_bias,
        activation=activation,
        activate_final=False)
    self._layer_norm = hk.LayerNorm(
        axis=-1, create_scale=True, create_offset=True)
Пример #19
0
    def __init__(self,
                 layer_sizes: Sequence[int],
                 activate_final: bool = False):
        """Construct the MLP.

    Args:
      layer_sizes: a sequence of ints specifying the size of each layer.
      activate_final: whether or not to use the activation function on the final
        layer of the neural network.
    """
        super().__init__(name='feedforward_mlp_torso')

        self._network = hk.Sequential([
            hk.Linear(layer_sizes[0], w_init=uniform_initializer),
            hk.LayerNorm(axis=-1, create_scale=True, create_offset=True),
            jax.lax.tanh,
            hk.nets.MLP(layer_sizes[1:],
                        w_init=uniform_initializer,
                        activation=jax.nn.elu,
                        activate_final=activate_final),
        ])
Пример #20
0
def layer_norm(x):
    return hk.LayerNorm(-1, True, True)(x)
Пример #21
0
def layer_norm(x, name=None):
    return hk.LayerNorm(axis=-1,
                        create_scale=True,
                        create_offset=True,
                        name=name)(x)
Пример #22
0
def layer_norm(x: jnp.ndarray, name: Optional[str] = None) -> jnp.ndarray:
    """Apply a unique LayerNorm to x with default settings."""
    return hk.LayerNorm(axis=-1,
                        create_scale=True,
                        create_offset=True,
                        name=name)(x)
Пример #23
0
def layer_norm(x: jnp.ndarray) -> jnp.ndarray:
    """Apply a unique LayerNorm to x with default settings."""
    return hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(x)
Пример #24
0
def layer_norm(x):
    return hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(x)
Пример #25
0
 def __call__(self, x):
     w_init = hk.initializers.Orthogonal(scale=1.0)
     x = hk.Linear(self.feature_dim, w_init=w_init)(x)
     x = hk.LayerNorm(axis=1, create_scale=True, create_offset=True)(x)
     x = jnp.tanh(x)
     return x
Пример #26
0
    def __call__(self,
                 activations,
                 sequence_mask,
                 update_affine,
                 is_training,
                 initial_act,
                 safe_key=None,
                 static_feat_2d=None,
                 aatype=None):
        c = self.config

        if safe_key is None:
            safe_key = prng.SafeKey(hk.next_rng_key())

        def safe_dropout_fn(tensor, safe_key):
            return prng.safe_dropout(
                tensor=tensor,
                safe_key=safe_key,
                rate=c.dropout,
                is_deterministic=self.global_config.deterministic,
                is_training=is_training)

        affine = quat_affine.QuatAffine.from_tensor(activations['affine'])

        act = activations['act']
        attention_module = InvariantPointAttention(self.config,
                                                   self.global_config)
        # Attention
        attn = attention_module(inputs_1d=act,
                                inputs_2d=static_feat_2d,
                                mask=sequence_mask,
                                affine=affine)
        act += attn
        safe_key, *sub_keys = safe_key.split(3)
        sub_keys = iter(sub_keys)
        act = safe_dropout_fn(act, next(sub_keys))
        act = hk.LayerNorm(axis=[-1],
                           create_scale=True,
                           create_offset=True,
                           name='attention_layer_norm')(act)

        final_init = 'zeros' if self.global_config.zero_init else 'linear'

        # Transition
        input_act = act
        for i in range(c.num_layer_in_transition):
            init = 'relu' if i < c.num_layer_in_transition - 1 else final_init
            act = common_modules.Linear(c.num_channel,
                                        initializer=init,
                                        name='transition')(act)
            if i < c.num_layer_in_transition - 1:
                act = jax.nn.relu(act)
        act += input_act
        act = safe_dropout_fn(act, next(sub_keys))
        act = hk.LayerNorm(axis=[-1],
                           create_scale=True,
                           create_offset=True,
                           name='transition_layer_norm')(act)

        if update_affine:
            # This block corresponds to
            # Jumper et al. (2021) Alg. 23 "Backbone update"
            affine_update_size = 6

            # Affine update
            affine_update = common_modules.Linear(affine_update_size,
                                                  initializer=final_init,
                                                  name='affine_update')(act)

            affine = affine.pre_compose(affine_update)

        sc = MultiRigidSidechain(c.sidechain, self.global_config)(
            affine.scale_translation(c.position_scale), [act, initial_act],
            aatype)

        outputs = {'affine': affine.to_tensor(), 'sc': sc}

        affine = affine.apply_rotation_tensor_fn(jax.lax.stop_gradient)

        new_activations = {'act': act, 'affine': affine.to_tensor()}
        return new_activations, outputs
Пример #27
0
def generate_affines(representations, batch, config, global_config,
                     is_training, safe_key):
    """Generate predicted affines for a single chain.

  Jumper et al. (2021) Suppl. Alg. 20 "StructureModule"

  This is the main part of the structure module - it iteratively applies
  folding to produce a set of predicted residue positions.

  Args:
    representations: Representations dictionary.
    batch: Batch dictionary.
    config: Config for the structure module.
    global_config: Global config.
    is_training: Whether the model is being trained.
    safe_key: A prng.SafeKey object that wraps a PRNG key.

  Returns:
    A dictionary containing residue affines and sidechain positions.
  """
    c = config
    sequence_mask = batch['seq_mask'][:, None]

    act = hk.LayerNorm(axis=[-1],
                       create_scale=True,
                       create_offset=True,
                       name='single_layer_norm')(representations['single'])

    initial_act = act
    act = common_modules.Linear(c.num_channel, name='initial_projection')(act)

    affine = generate_new_affine(sequence_mask)

    fold_iteration = FoldIteration(c, global_config, name='fold_iteration')

    assert len(batch['seq_mask'].shape) == 1

    activations = {
        'act': act,
        'affine': affine.to_tensor(),
    }

    act_2d = hk.LayerNorm(axis=[-1],
                          create_scale=True,
                          create_offset=True,
                          name='pair_layer_norm')(representations['pair'])

    outputs = []
    safe_keys = safe_key.split(c.num_layer)
    for sub_key in safe_keys:
        activations, output = fold_iteration(activations,
                                             initial_act=initial_act,
                                             static_feat_2d=act_2d,
                                             safe_key=sub_key,
                                             sequence_mask=sequence_mask,
                                             update_affine=True,
                                             is_training=is_training,
                                             aatype=batch['aatype'])
        outputs.append(output)

    output = jax.tree_map(lambda *x: jnp.stack(x), *outputs)
    # Include the activations in the output dict for use by the LDDT-Head.
    output['act'] = activations['act']

    return output
Пример #28
0
     shape=(BATCH_SIZE, 2, 2, 3)),
 ModuleDescriptor(
     name="Bias",
     create=lambda: hk.Bias(),
     shape=(BATCH_SIZE, 3, 3, 3)),
 ModuleDescriptor(
     name="Flatten",
     create=lambda: hk.Flatten(),
     shape=(BATCH_SIZE, 3, 3, 3)),
 ModuleDescriptor(
     name="InstanceNorm",
     create=lambda: hk.InstanceNorm(True, True),
     shape=(BATCH_SIZE, 3, 2)),
 ModuleDescriptor(
     name="LayerNorm",
     create=lambda: hk.LayerNorm(1, True, True),
     shape=(BATCH_SIZE, 3, 2)),
 ModuleDescriptor(
     name="SpectralNorm",
     create=lambda: hk.SpectralNorm(),
     shape=(BATCH_SIZE, 3, 2)),
 ModuleDescriptor(
     name="nets.ResNet",
     create=lambda: Training(hk.nets.ResNet((3, 4, 6, 3), 1000)),
     shape=(BATCH_SIZE, 3, 3, 2)),
 # pylint: disable=g-long-lambda
 ModuleDescriptor(
     name="nets.MobileNetV1",
     create=lambda: Training(hk.nets.MobileNetV1(num_classes=1000,
                                                 strides=(1, 1, 1),
                                                 channels=(16, 32, 64))),
Пример #29
0
                  create=lambda: Training(hk.BatchNorm(True, True, 0.9)),
                  shape=(BATCH_SIZE, 2, 2, 3)),
 ModuleDescriptor(name="Bias",
                  create=lambda: hk.Bias(),
                  shape=(BATCH_SIZE, 3, 3, 3)),
 ModuleDescriptor(name="Flatten",
                  create=lambda: hk.Flatten(),
                  shape=(BATCH_SIZE, 3, 3, 3)),
 ModuleDescriptor(name="InstanceNorm",
                  create=lambda: hk.InstanceNorm(True, True),
                  shape=(BATCH_SIZE, 3, 2)),
 ModuleDescriptor(name="GroupNorm",
                  create=lambda: hk.GroupNorm(5),
                  shape=(BATCH_SIZE, 4, 4, 10)),
 ModuleDescriptor(name="LayerNorm",
                  create=lambda: hk.LayerNorm(1, True, True, param_axis=-1),
                  shape=(BATCH_SIZE, 3, 2)),
 ModuleDescriptor(
     name="MultiHeadAttention",
     create=lambda: MultiInput(  # pylint: disable=g-long-lambda
         hk.MultiHeadAttention(num_heads=8, key_size=64, w_init_scale=1.0),
         num_inputs=3),
     shape=(BATCH_SIZE, 3, 2)),
 ModuleDescriptor(name="RMSNorm",
                  create=lambda: hk.RMSNorm(1),
                  shape=(BATCH_SIZE, 3, 2)),
 ModuleDescriptor(name="SpectralNorm",
                  create=lambda: hk.SpectralNorm(),
                  shape=(BATCH_SIZE, 3, 2)),
 ModuleDescriptor(name="nets.ResNet",
                  create=lambda: Training(hk.nets.ResNet(