def create_model(batch_size: int,
                 len_seqs: int,
                 num_motifs: int,
                 len_motifs: int,
                 num_denses: int,
                 num_classes: int = 10,
                 embed_size: int = 4,
                 one_hot: bool = True,
                 l2_weight: float = 0.0,
                 dropout_rate: float = 0.1,
                 before_conv_dropout: bool = False,
                 use_mc_dropout: bool = False,
                 spec_norm_hparams: Dict[str, Any] = None,
                 gp_layer_hparams: Dict[str, Any] = None,
                 **unused_kwargs: Dict[str, Any]) -> tf.keras.models.Model:

  """Builds Genomics CNN model.

  Args:
    batch_size: (int) Value of the static per_replica batch size.
    len_seqs: (int) Sequence length.
    num_motifs: (int) Number of motifs (= number of filters) to apply to input.
    len_motifs: (int) Length of the motifs (= size of convolutional filters).
    num_denses: (int) Number of nodes in the dense layer.
    num_classes: (int) Number of output classes.
    embed_size: (int) Static size of hidden dimension of the embedding output.
    one_hot: (bool) If using one hot encoding to encode input sequences.
    l2_weight: (float) L2 regularization coefficient.
    dropout_rate: (float) Fraction of the convolutional output units and dense.
      layer output units to drop.
    before_conv_dropout: (bool) Whether to use filter wise dropout before the
      convolutional layer.
    use_mc_dropout: (bool) Whether to apply Monte Carlo dropout.
    spec_norm_hparams: (dict) Hyperparameters for spectral normalization.
    gp_layer_hparams: (dict) Hyperparameters for Gaussian Process output layer.
    **unused_kwargs: (dict) Unused keyword arguments that will be ignored by the
      model.

  Returns:
    (tf.keras.Model) The 1D convolutional model for genomic sequences.
  """
  # define layers
  if spec_norm_hparams:
    spec_norm_bound = spec_norm_hparams['spec_norm_bound']
    spec_norm_iteration = spec_norm_hparams['spec_norm_iteration']
  else:
    spec_norm_bound = None
    spec_norm_iteration = None

  conv_layer = models_util.make_conv2d_layer(
      use_spec_norm=(spec_norm_hparams is not None),
      spec_norm_bound=spec_norm_bound,
      spec_norm_iteration=spec_norm_iteration)

  dense_layer = models_util.make_dense_layer(
      use_spec_norm=(spec_norm_hparams is not None),
      spec_norm_bound=spec_norm_bound,
      spec_norm_iteration=spec_norm_iteration)

  output_layer = models_util.make_output_layer(
      gp_layer_hparams=gp_layer_hparams)

  # compute outputs given inputs
  inputs = tf.keras.Input(
      shape=[len_seqs], batch_size=batch_size, dtype=tf.int32)
  x = _input_embedding(
      inputs, VOCAB_SIZE, one_hot=one_hot, embed_size=embed_size)

  # filter-wise dropout before conv,
  # x.shape=[batch_size, len_seqs, vocab_size/embed_size]
  if before_conv_dropout:
    x = models_util.apply_dropout(
        x,
        dropout_rate,
        use_mc_dropout,
        filter_wise_dropout=True,
        name='conv_dropout')

  x = _conv_pooled_block(
      x,
      conv_layer=conv_layer(
          filters=num_motifs,
          kernel_size=(len_motifs, embed_size),
          strides=(1, 1),
          kernel_regularizer=tf.keras.regularizers.l2(l2_weight),
          name='conv'))
  x = models_util.apply_dropout(
      x, dropout_rate, use_mc_dropout, name='dropout1')
  x = dense_layer(
      units=num_denses,
      activation=tf.keras.activations.relu,
      kernel_regularizer=tf.keras.regularizers.l2(l2_weight),
      name='dense')(
          x)
  x = models_util.apply_dropout(
      x, dropout_rate, use_mc_dropout, name='dropout2')
  if gp_layer_hparams and gp_layer_hparams['gp_input_dim'] > 0:
    # Uses random projection to reduce the input dimension of the GP layer.
    x = tf.keras.layers.Dense(
        gp_layer_hparams['gp_input_dim'],
        kernel_initializer='random_normal',
        use_bias=False,
        trainable=False,
        name='gp_random_projection')(
            x)
  outputs = output_layer(num_classes, name='logits')(x)
  return tf.keras.Model(inputs=inputs, outputs=outputs)
def wide_resnet(batch_size: Optional[int],
                input_shape: Iterable[int],
                depth: int,
                width_multiplier: int,
                num_classes: int,
                l2: float,
                dropout_rate: float,
                use_mc_dropout: bool,
                spec_norm_hparams: Dict[str, Any] = None,
                gp_layer_hparams: Dict[str, Any] = None):
    """Builds Wide ResNet.

  Following Zagoruyko and Komodakis (2016), it accepts a width multiplier on the
  number of filters. Using three groups of residual blocks, the network maps
  spatial features of size 32x32 -> 16x16 -> 8x8.

  Args:
    batch_size: (int) Value of the static per_replica batch size.
    input_shape: (tf.Tensor) shape of input to the model.
    depth: Total number of convolutional layers. "n" in WRN-n-k. It differs from
      He et al. (2015)'s notation which uses the maximum depth of the network
      counting non-conv layers like dense.
    width_multiplier: Integer to multiply the number of typical filters by. "k"
      in WRN-n-k.
    num_classes: Number of output classes.
    l2: L2 regularization coefficient.
    dropout_rate: Dropout rate.
    use_mc_dropout: Whether to apply Monte Carlo dropout.
    spec_norm_hparams: (dict) Hyperparameters for spectral normalization.
    gp_layer_hparams: (dict) Hyperparameters for Gaussian Process output layer.

  Returns:
    tf.keras.Model.
  """
    if (depth - 4) % 6 != 0:
        raise ValueError('depth should be 6n+4 (e.g., 16, 22, 28, 40).')
    num_blocks = (depth - 4) // 6
    inputs = tf.keras.layers.Input(shape=input_shape, batch_size=batch_size)

    # pylint: disable=invalid-name
    if spec_norm_hparams:
        spec_norm_bound = spec_norm_hparams['spec_norm_bound']
        spec_norm_iteration = spec_norm_hparams['spec_norm_iteration']
    else:
        spec_norm_bound = None
        spec_norm_iteration = None
    conv2d = models_util.make_conv2d_layer(
        kernel_size=3,
        use_bias=False,
        kernel_initializer='he_normal',
        activation=None,
        use_spec_norm=(spec_norm_hparams is not None),
        spec_norm_bound=spec_norm_bound,
        spec_norm_iteration=spec_norm_iteration)

    x = conv2d(filters=16,
               strides=1,
               kernel_regularizer=tf.keras.regularizers.l2(l2))(inputs)
    x = models_util.apply_dropout(x,
                                  dropout_rate,
                                  use_mc_dropout,
                                  filter_wise_dropout=True)

    x = group(x,
              filters=16 * width_multiplier,
              strides=1,
              num_blocks=num_blocks,
              l2=l2,
              dropout_rate=dropout_rate,
              use_mc_dropout=use_mc_dropout,
              conv_layer=conv2d)
    x = group(x,
              filters=32 * width_multiplier,
              strides=2,
              num_blocks=num_blocks,
              l2=l2,
              dropout_rate=dropout_rate,
              use_mc_dropout=use_mc_dropout,
              conv_layer=conv2d)
    x = group(x,
              filters=64 * width_multiplier,
              strides=2,
              num_blocks=num_blocks,
              l2=l2,
              dropout_rate=dropout_rate,
              use_mc_dropout=use_mc_dropout,
              conv_layer=conv2d)
    x = BatchNormalization(beta_regularizer=tf.keras.regularizers.l2(l2),
                           gamma_regularizer=tf.keras.regularizers.l2(l2))(x)
    x = tf.keras.layers.Activation('relu')(x)
    x = tf.keras.layers.AveragePooling2D(pool_size=8)(x)
    x = tf.keras.layers.Flatten()(x)

    if gp_layer_hparams:
        # add random projection layer to reduce dimension
        gp_output_layer = functools.partial(
            ed.layers.RandomFeatureGaussianProcess,
            num_inducing=gp_layer_hparams['gp_hidden_dim'],
            gp_kernel_scale=gp_layer_hparams['gp_scale'],
            gp_output_bias=gp_layer_hparams['gp_bias'],
            normalize_input=gp_layer_hparams['gp_input_normalization'],
            gp_cov_momentum=gp_layer_hparams['gp_cov_discount_factor'],
            gp_cov_ridge_penalty=gp_layer_hparams['gp_cov_ridge_penalty'])
        if gp_layer_hparams['gp_input_dim'] > 0:
            x = tf.keras.layers.Dense(gp_layer_hparams['gp_input_dim'],
                                      kernel_initializer='random_normal',
                                      use_bias=False,
                                      trainable=False)(x)
        logits, covmat = gp_output_layer(num_classes)(x)
    else:
        logits = tf.keras.layers.Dense(
            num_classes,
            kernel_initializer='he_normal',
            kernel_regularizer=tf.keras.regularizers.l2(l2),
            bias_regularizer=tf.keras.regularizers.l2(l2))(x)
        covmat = tf.eye(batch_size)

    return tf.keras.Model(inputs=inputs, outputs=[logits, covmat])