def _create_global_visual_feature_embeddings(model_config, features,
                                             flags) -> tf.Tensor:
    """Creates global embeddings based on visual features."""
    initializer_range = 0.02

    indicator_cross_emb_lookup = etc_layers.EmbeddingLookup(
        vocab_size=2**len(flags.indicators_to_cross),
        embedding_size=model_config.hidden_size,
        initializer_range=initializer_range,
        use_one_hot_lookup=flags.use_tpu,
        name='indicator_cross_emb_lookup')
    global_embedding_adder = indicator_cross_emb_lookup(
        features['global_indicator_cross'])

    font_id_emb_lookup = etc_layers.EmbeddingLookup(
        vocab_size=generate_examples_lib.FONT_ID_VOCAB_SIZE,
        embedding_size=model_config.hidden_size,
        initializer_range=initializer_range,
        use_one_hot_lookup=flags.use_tpu,
        name='font_id_emb_lookup')
    global_embedding_adder += font_id_emb_lookup(features['global_font_ids'])

    parent_font_id_emb_lookup = etc_layers.EmbeddingLookup(
        vocab_size=generate_examples_lib.FONT_ID_VOCAB_SIZE,
        embedding_size=model_config.hidden_size,
        initializer_range=initializer_range,
        use_one_hot_lookup=flags.use_tpu,
        name='parent_font_id_emb_lookup')
    global_embedding_adder += parent_font_id_emb_lookup(
        features['global_parent_font_ids'])

    # Add transformation of dense features
    dense_feature_projection = tf.keras.layers.Dense(
        units=model_config.hidden_size,
        activation=tensor_utils.get_activation('gelu'),
        kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
        name='dense_feature_projection')
    dense_feature_embeddings = dense_feature_projection(
        features['global_dense_features'])
    if flags.extra_dense_feature_layers > 1:
        raise NotImplementedError(
            '`extra_dense_feature_layers` must be at most 1.')
    elif flags.extra_dense_feature_layers == 1:
        dense_feature_layer2 = tf.keras.layers.Dense(
            units=model_config.hidden_size,
            activation=tensor_utils.get_activation('gelu'),
            kernel_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=0.02),
            name='dense_feature_layer2')
        dense_feature_embeddings = dense_feature_layer2(
            dense_feature_embeddings)
    global_embedding_adder += dense_feature_embeddings

    return global_embedding_adder
Exemplo n.º 2
0
  def __init__(self,
               config: EtcConfig,
               is_training: Optional[bool] = None,
               use_one_hot_embeddings=False,
               use_one_hot_relative_embeddings=False,
               name: Text = "etc_document_bert",
               **kwargs):
    """Constructor for `EtcModel`.

    Args:
      config: `EtcConfig` instance.
      is_training: Optional bool. True for training model, False for eval model.
        The None default will defer to the typical Keras `training` argument in
        `call` instead. When `is_training` is specified here, the `training`
        argument from `call` must not be used.
      use_one_hot_embeddings: (optional) bool. Whether to use one-hot word
        embeddings or tf.nn.embedding_lookup() for the word embeddings.
      use_one_hot_relative_embeddings: (optional) bool. Whether to use one-hot
        word embeddings or tf.nn.embedding_lookup() for the relative position
        embeddings.
      name: (Optional) name of the layer.
      **kwargs: Forwarded to super.

    Raises:
      ValueError: The config is invalid.
    """
    super(EtcModel, self).__init__(name=name, **kwargs)

    config = copy.deepcopy(config)
    if is_training is not None and not is_training:
      config.hidden_dropout_prob = 0.0
      config.attention_probs_dropout_prob = 0.0

    self.config = config
    self.is_training = is_training
    self.use_one_hot_embeddings = use_one_hot_embeddings
    self.use_one_hot_relative_embeddings = use_one_hot_relative_embeddings

    if config.relative_vocab_size is None:
      if config.relative_pos_max_distance != 0:
        raise ValueError(
            "`relative_pos_max_distance` must be 0 when `relative_vocab_size` "
            "is None.")
    elif config.relative_vocab_size < (feature_utils.RelativePositionGenerator(
        config.relative_pos_max_distance).relative_vocab_size +
                                       _NUM_OTHER_RELATIVE_IDS):
      raise ValueError("`relative_vocab_size` ({}) too small for "
                       "`relative_pos_max_distance` ({})".format(
                           config.relative_vocab_size,
                           config.relative_pos_max_distance))
    if config.embedding_size is None:
      config.embedding_size = config.hidden_size

    self.token_embedding = etc_layers.EmbeddingLookup(
        vocab_size=config.vocab_size,
        embedding_size=config.embedding_size,
        projection_size=config.hidden_size,
        initializer_range=config.initializer_range,
        use_one_hot_lookup=use_one_hot_embeddings,
        name="token_emb_lookup")

    self.token_embedding_norm = tf.keras.layers.LayerNormalization(
        axis=-1, epsilon=1e-12, name="long_emb_layer_norm")
    self.token_embedding_dropout = tf.keras.layers.Dropout(
        rate=config.hidden_dropout_prob)

    self.segment_embedding = etc_layers.EmbeddingLookup(
        vocab_size=config.segment_vocab_size,
        embedding_size=config.hidden_size,
        initializer_range=config.initializer_range,
        use_one_hot_lookup=True,
        name="segment_emb_lookup")

    if config.max_absolute_position_embeddings != 0:
      self.position_embedding = etc_layers.EmbeddingLookup(
          vocab_size=config.max_absolute_position_embeddings,
          embedding_size=config.hidden_size,
          initializer_range=config.initializer_range,
          use_one_hot_lookup=use_one_hot_embeddings,
          name="position_emb_lookup_long")
      # We use `max_absolute_position_embeddings` for the maximum global input
      # length even though it's larger than we need. This makes it easier to
      # initialize both long and global position embedding tables with the same
      # values if desired.
      self.global_position_embedding = etc_layers.EmbeddingLookup(
          vocab_size=config.max_absolute_position_embeddings,
          embedding_size=config.hidden_size,
          initializer_range=config.initializer_range,
          use_one_hot_lookup=use_one_hot_embeddings,
          name="position_emb_lookup_global")
      # Call layers to force variable initialization.
      self.position_embedding(tf.ones([1, 1], tf.int32))
      self.global_position_embedding(tf.ones([1, 1], tf.int32))
    else:
      self.position_embedding = None
      self.global_position_embedding = None

    # We use the same embedding table for global tokens to make it easy to place
    # WordPieces in the global memory for finetuning tasks downstream.
    self.global_token_embedding = self.token_embedding
    self.global_token_embedding_norm = tf.keras.layers.LayerNormalization(
        axis=-1, epsilon=1e-12, name="global_emb_layer_norm")
    self.global_token_embedding_dropout = tf.keras.layers.Dropout(
        rate=config.hidden_dropout_prob)

    self.global_local_transformer = etc_layers.GlobalLocalTransformerLayers(
        long_hidden_size=config.hidden_size,
        global_hidden_size=config.hidden_size,
        num_hidden_layers=config.num_hidden_layers,
        num_attention_heads=config.num_attention_heads,
        local_radius=config.local_radius,
        att_size_per_head=config.att_size_per_head,
        long_intermediate_size=config.intermediate_size,
        global_intermediate_size=config.intermediate_size,
        hidden_act=tensor_utils.get_activation(config.hidden_act),
        hidden_dropout_prob=config.hidden_dropout_prob,
        attention_probs_dropout_prob=config.attention_probs_dropout_prob,
        initializer_range=config.initializer_range,
        relative_vocab_size=config.relative_vocab_size,
        share_feed_forward_params=config.share_feed_forward_params,
        share_kv_projections=config.share_kv_projections,
        share_qkv_projections=config.share_qkv_projections,
        share_att_output_projection=config.share_att_output_projection,
        use_pre_activation_order=config.use_pre_activation_order,
        use_one_hot_lookup=use_one_hot_relative_embeddings,
        grad_checkpointing_period=config.grad_checkpointing_period)
Exemplo n.º 3
0
def process_model_output(model_config,
                         mode,
                         global_output_tensor,
                         global_token_type_ids_tensor,
                         labels,
                         is_real_example,
                         add_final_layer=True,
                         label_smoothing=0.0):
    """Process model output embeddings and computes loss, logits etc."""

    global_output_tensor_shape = tensor_utils.get_shape_list(
        global_output_tensor, expected_rank=3)
    batch_size = global_output_tensor_shape[0]
    global_seq_len = global_output_tensor_shape[1]
    hidden_size = global_output_tensor_shape[2]

    global_output_tensor = tf.reshape(
        global_output_tensor, [batch_size * global_seq_len, hidden_size])

    if add_final_layer:
        with tf.variable_scope("global_output_layer/transform"):
            is_training = True if mode == tf.estimator.ModeKeys.TRAIN else False
            final_layer = wrappers.ResidualBlock(
                inner_intermediate_size=model_config.intermediate_size,
                inner_activation=tensor_utils.get_activation(
                    model_config.hidden_act),
                use_pre_activation_order=False,
                dropout_probability=model_config.hidden_dropout_prob)
            global_output_tensor = final_layer(global_output_tensor,
                                               training=is_training)

    output_weights = tf.get_variable(
        "output_weights", [1, model_config.hidden_size],
        initializer=tf.truncated_normal_initializer(
            stddev=model_config.initializer_range))

    output_bias = tf.get_variable("output_bias", [1],
                                  initializer=tf.zeros_initializer())

    with tf.variable_scope("loss"):
        logits = tf.matmul(global_output_tensor,
                           output_weights,
                           transpose_b=True)
        logits = tf.nn.bias_add(logits, output_bias)

        tf.logging.info("*** logits initial are {} *** ".format(logits))
        logits = tf.reshape(logits, [batch_size, global_seq_len])
        tf.logging.info("*** logits after reshape are {} *** ".format(logits))

        # Consider only candidate global tokens in the global output.
        multiplier_mask = tf.cast(tf.equal(
            global_token_type_ids_tensor,
            multihop_utils.CANDIDATE_GLOBAL_TOKEN_TYPE_ID),
                                  dtype=logits.dtype)

        adder_mask = -10000.0 * (1.0 - multiplier_mask)

        logits = (logits * multiplier_mask + adder_mask)

        tf.logging.info("*** global_token_type_ids_tensor is {} *** ".format(
            global_token_type_ids_tensor))
        tf.logging.info("*** adder_mask is {} *** ".format(adder_mask))
        tf.logging.info(
            "*** multiplier_mask is {} *** ".format(multiplier_mask))
        tf.logging.info("*** logits computed are {} *** ".format(logits))

        # probabilities = tf.nn.softmax(logits, axis=-1)
        log_probs = tf.nn.log_softmax(logits, axis=-1)
        one_hot_labels = tf.one_hot(labels,
                                    depth=global_seq_len,
                                    dtype=tf.float32)
        if label_smoothing > 0:
            num_classes = tf.reduce_sum(multiplier_mask, axis=-1)
            num_classes = tf.expand_dims(num_classes, -1)
            one_hot_labels = (1 - label_smoothing) * one_hot_labels
            one_hot_labels += (label_smoothing / num_classes)
            # Ensure smoothing of labels only for applicable global (candidate)
            # tokens.
            one_hot_labels *= multiplier_mask

        per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)

        numerator = tf.reduce_sum(per_example_loss * is_real_example)
        denominator = tf.reduce_sum(is_real_example) + 1e-5
        loss = numerator / denominator

        return (loss, per_example_loss, logits)
def _build_model(model_config, features, is_training, flags):
  """Build an ETC model for OpenKP."""

  global_embedding_adder = None
  long_embedding_adder = None

  # Create `global_embedding_adder` if using visual features.
  if flags.use_visual_features_in_global or flags.use_visual_features_in_long:
    global_embedding_adder = _create_global_visual_feature_embeddings(
        model_config, features, flags)

  if flags.use_visual_features_in_long:
    # Create `long_embedding_adder` based on `global_embedding_adder`
    long_embedding_adder = gather_global_embeddings_to_long(
        global_embedding_adder, features['long_vdom_idx'])

  if not flags.use_visual_features_in_global:
    global_embedding_adder = None

  model = modeling.EtcModel(
      config=model_config,
      is_training=is_training,
      use_one_hot_relative_embeddings=flags.use_tpu)

  model_inputs = dict(
      token_ids=features['long_token_ids'],
      global_token_ids=features['global_token_ids'],
      long_embedding_adder=long_embedding_adder,
      global_embedding_adder=global_embedding_adder)
  for field in attr.fields(input_utils.GlobalLocalTransformerSideInputs):
    model_inputs[field.name] = features[field.name]

  long_output, _ = model(**model_inputs)

  word_embeddings_unnormalized = batch_segment_sum_embeddings(
      long_embeddings=long_output,
      long_word_idx=features['long_word_idx'],
      long_input_mask=features['long_input_mask'])
  word_emb_layer_norm = tf.keras.layers.LayerNormalization(
      axis=-1, epsilon=1e-12, name='word_emb_layer_norm')
  word_embeddings = word_emb_layer_norm(word_embeddings_unnormalized)

  ngram_logit_list = []
  for i in range(flags.kp_max_length):
    conv = tf.keras.layers.Conv1D(
        filters=model_config.hidden_size,
        kernel_size=i + 1,
        padding='valid',
        activation=tensor_utils.get_activation('gelu'),
        kernel_initializer=tf.keras.initializers.TruncatedNormal(
            stddev=0.02 / math.sqrt(i + 1)),
        name=f'{i + 1}gram_conv')
    layer_norm = tf.keras.layers.LayerNormalization(
        axis=-1, epsilon=1e-12, name=f'{i + 1}gram_layer_norm')

    logit_dense = tf.keras.layers.Dense(
        units=1,
        activation=None,
        use_bias=False,
        kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
        name=f'logit_dense{i}')
    # [batch_size, long_max_length - i]
    unpadded_logits = tf.squeeze(
        logit_dense(layer_norm(conv(word_embeddings))), axis=-1)

    # Pad to the right to get back to `long_max_length`.
    padded_logits = tf.pad(unpadded_logits, paddings=[[0, 0], [0, i]])

    # Padding logits should be ignored, so we make a large negative mask adder
    # for them.
    shifted_word_mask = tf.cast(
        tensor_utils.shift_elements_right(
            features['long_word_input_mask'], axis=-1, amount=-i),
        dtype=padded_logits.dtype)
    mask_adder = -10000.0 * (1.0 - shifted_word_mask)

    ngram_logit_list.append(padded_logits * shifted_word_mask + mask_adder)

  # [batch_size, kp_max_length, long_max_length]
  ngram_logits = tf.stack(ngram_logit_list, axis=1)

  extra_model_losses = model.losses

  return ngram_logits, extra_model_losses
Exemplo n.º 5
0
    def __init__(self,
                 hidden_size: int,
                 num_hidden_layers: int,
                 num_attention_heads: int,
                 intermediate_size: Optional[int] = None,
                 hidden_act=tensor_utils.get_activation('gelu'),
                 hidden_dropout_prob: float = 0.1,
                 attention_probs_dropout_prob: float = 0.1,
                 initializer_range: float = 0.02,
                 relative_vocab_size: Optional[int] = None,
                 use_pre_activation_order: bool = False,
                 use_one_hot_lookup: bool = False,
                 name: Text = 'relative_transformer_layers',
                 **kwargs):
        """Init.

    Args:
      hidden_size: Size of the output hidden dimension.  Must match the input
        hidden dimension size.
      num_hidden_layers: Number of Transformer layers.  Each layer includes both
        an attention sublayer and a feed-forward sublayer.
      num_attention_heads: Number of attention heads. Must evenly divide
        `hidden_size`.
      intermediate_size: The size of the "intermediate" (i.e. feed-forward)
        layers. Defaults to 4 * hidden_size.
      hidden_act: The non-linear activation function in the intermediate layers.
      hidden_dropout_prob: The dropout probability for the attention and
        feed-forward residual blocks. Must be between 0.0 and 1.0.
      attention_probs_dropout_prob: Dropout probability for attention
        probabilities. Must be between 0.0 and 1.0.
      initializer_range: The standard deviation of the truncated normal
        initializer for initializing weight matrices.
      relative_vocab_size: Size of relative position vocabulary. If left
        unspecified, relative positions will be ignored for attention.
      use_pre_activation_order: If True, use "pre-activation" order for residual
        blocks (see ResidualBlock docstring).
      use_one_hot_lookup: Whether to use tf.one_hot for embedding lookup instead
        of tf.gather. Default is False, but setting to True may be more
        efficient on TPUs for vocab sizes that aren't too large. Currently this
        is only used during lookup of relative position embeddings.
      name: Name of the layer.
      **kwargs: Forwarded to super.
    """
        super(RelativeTransformerLayers, self).__init__(name=name, **kwargs)

        if intermediate_size is None:
            intermediate_size = 4 * hidden_size

        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.hidden_act = hidden_act
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.initializer_range = initializer_range
        self.relative_vocab_size = relative_vocab_size
        self.use_pre_activation_order = use_pre_activation_order
        self.use_one_hot_lookup = use_one_hot_lookup

        # TODO(jainslie): When using pre-activation order, the recommendation
        # from https://arxiv.org/abs/1904.10509 is to scale some of the
        # initialization by 1 / sqrt(2 * num_hidden_layers).  Add logic
        # to do this scaling (maybe within ResidualBlock rather than through
        # initialization).
        self.initializer = tf.keras.initializers.TruncatedNormal(
            stddev=initializer_range)

        self.attention_layers = []
        self.feed_forward_layers = []
        for i in range(num_hidden_layers):
            self.attention_layers.append(
                wrappers.ResidualBlock(
                    inner_layer=attention.RelativeAttention(
                        hidden_size=hidden_size,
                        num_heads=num_attention_heads,
                        relative_vocab_size=relative_vocab_size,
                        att_dropout_prob=attention_probs_dropout_prob,
                        initializer=self.initializer,
                        use_one_hot_lookup=use_one_hot_lookup),
                    dropout_probability=hidden_dropout_prob,
                    use_pre_activation_order=use_pre_activation_order,
                    name='attention_layer_%d' % i))
            self.feed_forward_layers.append(
                wrappers.ResidualBlock(
                    dropout_probability=hidden_dropout_prob,
                    use_pre_activation_order=use_pre_activation_order,
                    inner_intermediate_size=intermediate_size,
                    inner_activation=hidden_act,
                    inner_kernel_initializer=self.initializer,
                    name='feed_forward_layer_%d' % i))
Exemplo n.º 6
0
    def __init__(self,
                 long_hidden_size: int,
                 global_hidden_size: int,
                 num_hidden_layers: int,
                 num_attention_heads: int,
                 local_radius: int,
                 att_size_per_head: Optional[int] = None,
                 long_intermediate_size: Optional[int] = None,
                 global_intermediate_size: Optional[int] = None,
                 hidden_act=tensor_utils.get_activation('gelu'),
                 hidden_dropout_prob: float = 0.1,
                 attention_probs_dropout_prob: float = 0.1,
                 initializer_range: float = 0.02,
                 relative_vocab_size: Optional[int] = None,
                 share_feed_forward_params: bool = True,
                 share_kv_projections: bool = False,
                 share_qkv_projections: bool = True,
                 share_att_output_projection: bool = False,
                 use_pre_activation_order: bool = False,
                 use_one_hot_lookup: bool = False,
                 grad_checkpointing_period: int = 0,
                 name: Text = 'global_local_transformer_layers',
                 **kwargs):
        """Init.

    Args:
      long_hidden_size: Size of the long input hidden dimension.
      global_hidden_size: Size of the global input hidden dimension. If this is
        different from `long_hidden_size`, you must turn off parameter sharing
        between long and global operations. In particular, the following
        sharing options which default to True must be set to False instead:
          `share_feed_forward_params`
          `share_qkv_projections`
      num_hidden_layers: Number of Transformer layers.  Each layer includes both
        an attention sublayer and a feed-forward sublayer.
      num_attention_heads: Number of attention heads for global-local attention.
        Must evenly divide both `global_hidden_size` and `long_hidden_size`
        unless `att_size_per_head` is specified.
      local_radius: How many tokens to the left/right for long input tokens to
        locally self-attend to. For example, a value of 1 would allow each token
        to only attend to 1 token to the left and 1 token to the right of it.
      att_size_per_head: Size of attention query/key/value vectors per head.
        By default this will be `long_hidden_size / num_attention_heads`, so
        `num_attention_heads` must evenly divide `long_hidden_size` in this
        case.
      long_intermediate_size: The size of the "intermediate" (i.e. feed-forward)
        layers for long input. Defaults to 4 * long_hidden_size.
      global_intermediate_size: The size of the "intermediate" (i.e.
        feed-forward) layers for global input. Defaults to 4 *
        global_hidden_size. Must not be different from `long_intermediate_size`
        if `share_feed_forward_params` is True (the default).
      hidden_act: The non-linear activation function in the intermediate layers.
      hidden_dropout_prob: The dropout probability for the attention and
        feed-forward residual blocks. Must be between 0.0 and 1.0.
      attention_probs_dropout_prob: Dropout probability for attention
        probabilities. Must be between 0.0 and 1.0.
      initializer_range: The standard deviation of the truncated normal
        initializer for initializing all weight matrices.
      relative_vocab_size: Size of relative position vocabulary. If left
        unspecified, relative positions will be ignored for attention.
      share_feed_forward_params: If True (the default), we share the same
        fully connected feed-forward parameters for the long and global inputs.
      share_kv_projections: If True, key and value projections will be shared
        between long-to-long and long-to-global components, as well as between
        global-to-global and global-to-long components. This results in 2 key
        projections per layer instead of 4 (and similarly for value
        projections). Note that if `share_qkv_projections` is True, then
        `share_kv_projections` is completely ignored since the former results
        in even more sharing.
      share_qkv_projections: If True (the default), all 4 attention operations
        (long-to-long, global-to-global, long-to-global, and global-to-long)
        will share the same query, key, and value projections. The 3 projections
        will still be different from each other and different per layer.
      share_att_output_projection: If True, all 4 attention operations
        (long-to-long, global-to-global, long-to-global, and global-to-long)
        will share the same output projection per layer.
      use_pre_activation_order: If True, use "pre-activation" order for residual
        blocks (see ResidualBlock docstring).
      use_one_hot_lookup: Whether to use tf.one_hot for embedding lookup instead
        of tf.gather. Default is False, but setting to True may be more
        efficient on TPUs for vocab sizes that aren't too large. Currently this
        is only used during lookup of relative position embeddings.
      grad_checkpointing_period: How often to checkpoint activations. The
        default of 0 stores all activations. If greater than 0, activations are
        recomputed as necessary when calculating gradients to save memory. As an
        optimization, we avoid recomputing the last `grad_checkpointing_period`
        layers, so larger values result in less computational overhead but
        reduced memory savings. Using a value of `1` results in potentially the
        greatest memory savings but with the highest recompute cost.
      name: Name of the layer.
      **kwargs: Forwarded to super.
    """
        super(GlobalLocalTransformerLayers, self).__init__(name=name, **kwargs)

        if long_intermediate_size is None:
            long_intermediate_size = 4 * long_hidden_size
        if global_intermediate_size is None:
            global_intermediate_size = 4 * global_hidden_size

        (att_size_per_head, long_total_att_size,
         global_total_att_size) = self._resolve_att_sizes(
             att_size_per_head=att_size_per_head,
             long_hidden_size=long_hidden_size,
             global_hidden_size=global_hidden_size,
             num_attention_heads=num_attention_heads)

        self.long_hidden_size = long_hidden_size
        self.global_hidden_size = global_hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.local_radius = local_radius
        self.att_size_per_head = att_size_per_head
        self.long_intermediate_size = long_intermediate_size
        self.global_intermediate_size = global_intermediate_size
        self.hidden_act = hidden_act
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.initializer_range = initializer_range
        self.relative_vocab_size = relative_vocab_size
        self.share_feed_forward_params = share_feed_forward_params
        self.share_kv_projections = share_kv_projections
        self.share_qkv_projections = share_qkv_projections
        self.share_att_output_projection = share_att_output_projection
        self.use_pre_activation_order = use_pre_activation_order
        self.use_one_hot_lookup = use_one_hot_lookup
        self.grad_checkpointing_period = grad_checkpointing_period

        self._long_total_att_size = long_total_att_size
        self._global_total_att_size = global_total_att_size

        self._validate_init_parameters()

        # TODO(jainslie): When using pre-activation order, the recommendation
        # from https://arxiv.org/abs/1904.10509 is to scale some of the
        # initialization by 1 / sqrt(2 * num_hidden_layers).  Add logic
        # to do this scaling (maybe within ResidualBlock rather than through
        # initialization).
        self.initializer = tf.keras.initializers.TruncatedNormal(
            stddev=initializer_range)

        self.fused_att_layers = []
        self.long_feed_forward_layers = []
        self.global_feed_forward_layers = []

        for i in range(num_hidden_layers):
            normalization_layers = [
                tf.keras.layers.LayerNormalization(axis=-1,
                                                   epsilon=1e-12,
                                                   name='layer_norm_0'),
                tf.keras.layers.LayerNormalization(axis=-1,
                                                   epsilon=1e-12,
                                                   name='layer_norm_1')
            ]
            self.fused_att_layers.append(
                wrappers.
                ResidualBlock(inner_layer=attention.FusedGlobalLocalAttention(
                    long_hidden_size=long_hidden_size,
                    global_hidden_size=global_hidden_size,
                    num_heads=num_attention_heads,
                    local_radius=local_radius,
                    long_total_att_size=long_total_att_size,
                    global_total_att_size=global_total_att_size,
                    relative_vocab_size=relative_vocab_size,
                    att_dropout_prob=attention_probs_dropout_prob,
                    initializer=self.initializer,
                    share_kv_projections=share_kv_projections,
                    share_qkv_projections=share_qkv_projections,
                    share_att_output_projection=share_att_output_projection,
                    use_one_hot_lookup=use_one_hot_lookup),
                              normalization_layer=normalization_layers,
                              dropout_probability=self.hidden_dropout_prob,
                              use_pre_activation_order=self.
                              use_pre_activation_order,
                              name='fused_att_layer_%d' % i))

            if share_feed_forward_params:
                feed_forward_layer = wrappers.ResidualBlock(
                    dropout_probability=hidden_dropout_prob,
                    use_pre_activation_order=use_pre_activation_order,
                    inner_intermediate_size=long_intermediate_size,
                    inner_activation=hidden_act,
                    inner_kernel_initializer=self.initializer,
                    name='feed_forward_layer_%d' % i)
                feed_forward_layer.build(
                    tf.TensorShape([None, long_hidden_size]))
                self.long_feed_forward_layers.append(feed_forward_layer)
                # Create separate layer to generate a new dropout seed.
                self.global_feed_forward_layers.append(
                    wrappers.ResidualBlock(
                        dropout_probability=hidden_dropout_prob,
                        use_pre_activation_order=use_pre_activation_order,
                        inner_layer=feed_forward_layer.inner_layer,
                        normalization_layer=feed_forward_layer.
                        normalization_layers,
                        name='global_feed_forward_layer_%d' % i))
            else:
                self.long_feed_forward_layers.append(
                    wrappers.ResidualBlock(
                        dropout_probability=hidden_dropout_prob,
                        use_pre_activation_order=use_pre_activation_order,
                        inner_intermediate_size=long_intermediate_size,
                        inner_activation=hidden_act,
                        inner_kernel_initializer=self.initializer,
                        name='long_feed_forward_layer_%d' % i))
                self.global_feed_forward_layers.append(
                    wrappers.ResidualBlock(
                        dropout_probability=hidden_dropout_prob,
                        use_pre_activation_order=use_pre_activation_order,
                        inner_intermediate_size=global_intermediate_size,
                        inner_activation=hidden_act,
                        inner_kernel_initializer=self.initializer,
                        name='global_feed_forward_layer_%d' % i))