def test_multi_doc_decoder(self):
        self._config = utils.get_test_params(cls=configs.NHNetConfig)
        seq_length = 10
        num_docs = 5
        encoder_input_ids = tf.keras.layers.Input(shape=(num_docs, seq_length),
                                                  name="encoder_input_ids",
                                                  dtype=tf.int32)
        target_ids = tf.keras.layers.Input(shape=(seq_length, ),
                                           name="target_ids",
                                           dtype=tf.int32)
        encoder_outputs = tf.keras.layers.Input(
            shape=(num_docs, seq_length, self._config.hidden_size),
            name="all_encoder_outputs",
            dtype=tf.float32)
        embedding_lookup = layers.OnDeviceEmbedding(
            vocab_size=self._config.vocab_size,
            embedding_width=self._config.hidden_size,
            initializer=tf.keras.initializers.TruncatedNormal(
                stddev=self._config.initializer_range),
            name="word_embeddings")
        doc_attention_probs = tf.keras.layers.Input(
            shape=(self._config.num_decoder_attn_heads, seq_length, num_docs),
            name="doc_attention_probs",
            dtype=tf.float32)
        cross_attention_bias = decoder.AttentionBias(
            bias_type="multi_cross")(encoder_input_ids)
        self_attention_bias = decoder.AttentionBias(
            bias_type="decoder_self")(target_ids)

        inputs = dict(attention_bias=cross_attention_bias,
                      self_attention_bias=self_attention_bias,
                      target_ids=target_ids,
                      all_encoder_outputs=encoder_outputs,
                      doc_attention_probs=doc_attention_probs)

        decoder_layer = decoder.Decoder(self._config, embedding_lookup)
        outputs = decoder_layer(inputs)
        model_inputs = dict(encoder_input_ids=encoder_input_ids,
                            target_ids=target_ids,
                            all_encoder_outputs=encoder_outputs,
                            doc_attention_probs=doc_attention_probs)
        model = tf.keras.Model(inputs=model_inputs,
                               outputs=outputs,
                               name="test")
        self.assertLen(decoder_layer.trainable_weights, 30)
        # Forward path.
        fake_inputs = {
            "encoder_input_ids":
            np.zeros((2, num_docs, seq_length), dtype=np.int32),
            "target_ids":
            np.zeros((2, seq_length), dtype=np.int32),
            "all_encoder_outputs":
            np.zeros((2, num_docs, seq_length, 16), dtype=np.float32),
            "doc_attention_probs":
            np.zeros(
                (2, self._config.num_decoder_attn_heads, seq_length, num_docs),
                dtype=np.float32)
        }
        output_tensor = model(fake_inputs)
        self.assertEqual(output_tensor.shape, (2, seq_length, 16))
Example #2
0
def get_bert2bert_layers(params: configs.BERT2BERTConfig):
    """Creates a Bert2Bert stem model and returns Bert encoder/decoder.

  We use funtional-style to create stem model because we need to make all layers
  built to restore variables in a customized way. The layers are called with
  placeholder inputs to make them fully built.

  Args:
    params: ParamsDict.

  Returns:
    two keras Layers, bert_model_layer and decoder_layer
  """
    input_ids = tf.keras.layers.Input(shape=(None, ),
                                      name="input_ids",
                                      dtype=tf.int32)
    input_mask = tf.keras.layers.Input(shape=(None, ),
                                       name="input_mask",
                                       dtype=tf.int32)
    segment_ids = tf.keras.layers.Input(shape=(None, ),
                                        name="segment_ids",
                                        dtype=tf.int32)
    target_ids = tf.keras.layers.Input(shape=(None, ),
                                       name="target_ids",
                                       dtype=tf.int32)
    bert_config = utils.get_bert_config_from_params(params)
    bert_model_layer = networks.BertEncoder(
        vocab_size=bert_config.vocab_size,
        hidden_size=bert_config.hidden_size,
        num_layers=bert_config.num_hidden_layers,
        num_attention_heads=bert_config.num_attention_heads,
        intermediate_size=bert_config.intermediate_size,
        activation=tf_utils.get_activation(bert_config.hidden_act),
        dropout_rate=bert_config.hidden_dropout_prob,
        attention_dropout_rate=bert_config.attention_probs_dropout_prob,
        max_sequence_length=bert_config.max_position_embeddings,
        type_vocab_size=bert_config.type_vocab_size,
        initializer=tf.keras.initializers.TruncatedNormal(
            stddev=bert_config.initializer_range),
        return_all_encoder_outputs=True,
        name="bert_encoder")
    all_encoder_outputs, _ = bert_model_layer(
        [input_ids, input_mask, segment_ids])
    # pylint: disable=protected-access
    decoder_layer = decoder.Decoder(params, bert_model_layer._embedding_layer)
    # pylint: enable=protected-access
    cross_attention_bias = decoder.AttentionBias(
        bias_type="single_cross")(input_ids)
    self_attention_bias = decoder.AttentionBias(
        bias_type="decoder_self")(target_ids)
    decoder_inputs = dict(attention_bias=cross_attention_bias,
                          self_attention_bias=self_attention_bias,
                          target_ids=target_ids,
                          all_encoder_outputs=all_encoder_outputs)
    _ = decoder_layer(decoder_inputs)

    return bert_model_layer, decoder_layer
Example #3
0
def get_nhnet_layers(params: configs.NHNetConfig):
  """Creates a Mult-doc encoder/decoder.

  Args:
    params: ParamsDict.

  Returns:
    two keras Layers, bert_model_layer and decoder_layer
  """
  input_ids = tf.keras.layers.Input(
      shape=(None,), name="input_ids", dtype=tf.int32)
  input_mask = tf.keras.layers.Input(
      shape=(None,), name="input_mask", dtype=tf.int32)
  segment_ids = tf.keras.layers.Input(
      shape=(None,), name="segment_ids", dtype=tf.int32)
  bert_config = utils.get_bert_config_from_params(params)
  bert_model_layer = networks.BertEncoder(
      vocab_size=bert_config.vocab_size,
      hidden_size=bert_config.hidden_size,
      num_layers=bert_config.num_hidden_layers,
      num_attention_heads=bert_config.num_attention_heads,
      intermediate_size=bert_config.intermediate_size,
      activation=tf_utils.get_activation(bert_config.hidden_act),
      dropout_rate=bert_config.hidden_dropout_prob,
      attention_dropout_rate=bert_config.attention_probs_dropout_prob,
      max_sequence_length=bert_config.max_position_embeddings,
      type_vocab_size=bert_config.type_vocab_size,
      initializer=tf.keras.initializers.TruncatedNormal(
          stddev=bert_config.initializer_range),
      return_all_encoder_outputs=True,
      name="bert_encoder")
  bert_model_layer([input_ids, input_mask, segment_ids])

  input_ids = tf.keras.layers.Input(
      shape=(None, None), name="input_ids", dtype=tf.int32)
  all_encoder_outputs = tf.keras.layers.Input((None, None, params.hidden_size),
                                              dtype=tf.float32)
  target_ids = tf.keras.layers.Input(
      shape=(None,), name="target_ids", dtype=tf.int32)
  doc_attention_probs = tf.keras.layers.Input(
      (params.num_decoder_attn_heads, None, None), dtype=tf.float32)
  # pylint: disable=protected-access
  decoder_layer = decoder.Decoder(params, bert_model_layer._embedding_layer)
  # pylint: enable=protected-access
  cross_attention_bias = decoder.AttentionBias(bias_type="multi_cross")(
      input_ids)
  self_attention_bias = decoder.AttentionBias(bias_type="decoder_self")(
      target_ids)
  decoder_inputs = dict(
      attention_bias=cross_attention_bias,
      self_attention_bias=self_attention_bias,
      target_ids=target_ids,
      all_encoder_outputs=all_encoder_outputs,
      doc_attention_probs=doc_attention_probs)
  _ = decoder_layer(decoder_inputs)

  return bert_model_layer, decoder_layer
 def test_bert_decoder(self):
     seq_length = 10
     encoder_input_ids = tf.keras.layers.Input(shape=(seq_length, ),
                                               name="encoder_input_ids",
                                               dtype=tf.int32)
     target_ids = tf.keras.layers.Input(shape=(seq_length, ),
                                        name="target_ids",
                                        dtype=tf.int32)
     encoder_outputs = tf.keras.layers.Input(
         shape=(seq_length, self._config.hidden_size),
         name="all_encoder_outputs",
         dtype=tf.float32)
     embedding_lookup = layers.OnDeviceEmbedding(
         vocab_size=self._config.vocab_size,
         embedding_width=self._config.hidden_size,
         initializer=tf.keras.initializers.TruncatedNormal(
             stddev=self._config.initializer_range),
         name="word_embeddings")
     cross_attention_bias = decoder.AttentionBias(
         bias_type="single_cross")(encoder_input_ids)
     self_attention_bias = decoder.AttentionBias(
         bias_type="decoder_self")(target_ids)
     inputs = dict(attention_bias=cross_attention_bias,
                   self_attention_bias=self_attention_bias,
                   target_ids=target_ids,
                   all_encoder_outputs=encoder_outputs)
     decoder_layer = decoder.Decoder(self._config, embedding_lookup)
     outputs = decoder_layer(inputs)
     model_inputs = dict(encoder_input_ids=encoder_input_ids,
                         target_ids=target_ids,
                         all_encoder_outputs=encoder_outputs)
     model = tf.keras.Model(inputs=model_inputs,
                            outputs=outputs,
                            name="test")
     self.assertLen(decoder_layer.trainable_weights, 30)
     # Forward path.
     fake_inputs = {
         "encoder_input_ids": np.zeros((2, 10), dtype=np.int32),
         "target_ids": np.zeros((2, 10), dtype=np.int32),
         "all_encoder_outputs": np.zeros((2, 10, 16), dtype=np.float32),
     }
     output_tensor = model(fake_inputs)
     self.assertEqual(output_tensor.shape, (2, 10, 16))