Example #1
0
def get_nhnet_layers(params: configs.NHNetConfig):
  """Creates a Mult-doc encoder/decoder.

  Args:
    params: ParamsDict.

  Returns:
    two keras Layers, bert_model_layer and decoder_layer
  """
  input_ids = tf.keras.layers.Input(
      shape=(None,), name="input_ids", dtype=tf.int32)
  input_mask = tf.keras.layers.Input(
      shape=(None,), name="input_mask", dtype=tf.int32)
  segment_ids = tf.keras.layers.Input(
      shape=(None,), name="segment_ids", dtype=tf.int32)
  bert_config = utils.get_bert_config_from_params(params)
  bert_model_layer = networks.TransformerEncoder(
      vocab_size=bert_config.vocab_size,
      hidden_size=bert_config.hidden_size,
      num_layers=bert_config.num_hidden_layers,
      num_attention_heads=bert_config.num_attention_heads,
      intermediate_size=bert_config.intermediate_size,
      activation=tf_utils.get_activation(bert_config.hidden_act),
      dropout_rate=bert_config.hidden_dropout_prob,
      attention_dropout_rate=bert_config.attention_probs_dropout_prob,
      sequence_length=None,
      max_sequence_length=bert_config.max_position_embeddings,
      type_vocab_size=bert_config.type_vocab_size,
      initializer=tf.keras.initializers.TruncatedNormal(
          stddev=bert_config.initializer_range),
      return_all_encoder_outputs=True,
      name="bert_encoder")
  bert_model_layer([input_ids, input_mask, segment_ids])

  input_ids = tf.keras.layers.Input(
      shape=(None, None), name="input_ids", dtype=tf.int32)
  all_encoder_outputs = tf.keras.layers.Input((None, None, params.hidden_size),
                                              dtype=tf.float32)
  target_ids = tf.keras.layers.Input(
      shape=(None,), name="target_ids", dtype=tf.int32)
  doc_attention_probs = tf.keras.layers.Input(
      (params.num_decoder_attn_heads, None, None), dtype=tf.float32)
  # pylint: disable=protected-access
  decoder_layer = decoder.Decoder(params, bert_model_layer._embedding_layer)
  # pylint: enable=protected-access
  cross_attention_bias = decoder.AttentionBias(bias_type="multi_cross")(
      input_ids)
  self_attention_bias = decoder.AttentionBias(bias_type="decoder_self")(
      target_ids)
  decoder_inputs = dict(
      attention_bias=cross_attention_bias,
      self_attention_bias=self_attention_bias,
      target_ids=target_ids,
      all_encoder_outputs=all_encoder_outputs,
      doc_attention_probs=doc_attention_probs)
  _ = decoder_layer(decoder_inputs)

  return bert_model_layer, decoder_layer
Example #2
0
def get_bert2bert_layers(params: configs.BERT2BERTConfig):
  """Creates a Bert2Bert stem model and returns Bert encoder/decoder.

  We use funtional-style to create stem model because we need to make all layers
  built to restore variables in a customized way. The layers are called with
  placeholder inputs to make them fully built.

  Args:
    params: ParamsDict.

  Returns:
    two keras Layers, bert_model_layer and decoder_layer
  """
  input_ids = tf.keras.layers.Input(
      shape=(None,), name="input_ids", dtype=tf.int32)
  input_mask = tf.keras.layers.Input(
      shape=(None,), name="input_mask", dtype=tf.int32)
  segment_ids = tf.keras.layers.Input(
      shape=(None,), name="segment_ids", dtype=tf.int32)
  target_ids = tf.keras.layers.Input(
      shape=(None,), name="target_ids", dtype=tf.int32)
  bert_config = utils.get_bert_config_from_params(params)
  bert_model_layer = networks.TransformerEncoder(
      vocab_size=bert_config.vocab_size,
      hidden_size=bert_config.hidden_size,
      num_layers=bert_config.num_hidden_layers,
      num_attention_heads=bert_config.num_attention_heads,
      intermediate_size=bert_config.intermediate_size,
      activation=tf_utils.get_activation(bert_config.hidden_act),
      dropout_rate=bert_config.hidden_dropout_prob,
      attention_dropout_rate=bert_config.attention_probs_dropout_prob,
<<<<<<< HEAD
      sequence_length=None,
=======
>>>>>>> a811a3b7e640722318ad868c99feddf3f3063e36
      max_sequence_length=bert_config.max_position_embeddings,
      type_vocab_size=bert_config.type_vocab_size,
      initializer=tf.keras.initializers.TruncatedNormal(
          stddev=bert_config.initializer_range),
      return_all_encoder_outputs=True,
      name="bert_encoder")
  all_encoder_outputs, _ = bert_model_layer(
      [input_ids, input_mask, segment_ids])
  # pylint: disable=protected-access
  decoder_layer = decoder.Decoder(params, bert_model_layer._embedding_layer)
  # pylint: enable=protected-access
  cross_attention_bias = decoder.AttentionBias(bias_type="single_cross")(
      input_ids)
  self_attention_bias = decoder.AttentionBias(bias_type="decoder_self")(
      target_ids)
  decoder_inputs = dict(
      attention_bias=cross_attention_bias,
      self_attention_bias=self_attention_bias,
      target_ids=target_ids,
      all_encoder_outputs=all_encoder_outputs)
  _ = decoder_layer(decoder_inputs)

  return bert_model_layer, decoder_layer