예제 #1
0
    def test_transformer_xl(self, two_stream, memory_length, reuse_length,
                            tie_attention_biases, state, mask, segment):
        batch_size, num_heads, head_size, seq_length = 2, 12, 64, 8
        hidden_size, num_predictions, inner_size = 24, 8, 12
        num_layers = 3

        data = create_mock_transformer_xl_data(include_biases=False,
                                               num_heads=num_heads,
                                               head_size=head_size,
                                               hidden_size=hidden_size,
                                               seq_length=seq_length,
                                               batch_size=batch_size,
                                               memory_length=memory_length,
                                               num_predictions=num_predictions,
                                               two_stream=two_stream,
                                               num_layers=num_layers,
                                               include_state=state,
                                               include_mask=mask,
                                               include_segment=segment)
        transformer_xl_layer = transformer_xl.TransformerXL(
            vocab_size=32000,
            num_layers=num_layers,
            head_size=head_size,
            hidden_size=hidden_size,
            num_attention_heads=num_heads,
            inner_size=inner_size,
            dropout_rate=0.,
            attention_dropout_rate=0.,
            initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
            two_stream=two_stream,
            tie_attention_biases=tie_attention_biases,
            memory_length=memory_length,
            reuse_length=reuse_length,
            inner_activation="relu")
        attention_output, cached_memory_states = transformer_xl_layer(**data)
        if two_stream:
            self.assertEqual(attention_output.shape,
                             [batch_size, num_predictions, hidden_size])
        else:
            self.assertEqual(attention_output.shape,
                             [batch_size, seq_length, hidden_size])
        self.assertEqual(len(cached_memory_states), num_layers)
예제 #2
0
 def test_get_config(self):
     transformer_xl_layer = transformer_xl.TransformerXL(
         vocab_size=32000,
         num_layers=12,
         hidden_size=36,
         head_size=12,
         num_attention_heads=12,
         inner_size=12,
         dropout_rate=0.,
         attention_dropout_rate=0.,
         initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
         two_stream=False,
         tie_attention_biases=True,
         memory_length=0,
         reuse_length=0,
         inner_activation="relu")
     transformer_xl_config = transformer_xl_layer.get_config()
     new_transformer_xl = transformer_xl.TransformerXL.from_config(
         transformer_xl_config)
     self.assertEqual(transformer_xl_config,
                      new_transformer_xl.get_config())
예제 #3
0
  def __init__(self,
               vocab_size,
               num_layers,
               hidden_size,
               num_attention_heads,
               head_size,
               inner_size,
               dropout_rate,
               attention_dropout_rate,
               attention_type,
               bi_data,
               initializer,
               two_stream=False,
               tie_attention_biases=True,
               memory_length=None,
               clamp_length=-1,
               reuse_length=None,
               inner_activation="relu",
               use_cls_mask=False,
               embedding_width=None,
               **kwargs):
    super(XLNetBase, self).__init__(**kwargs)

    self._vocab_size = vocab_size
    self._initializer = initializer
    self._attention_type = attention_type
    self._num_layers = num_layers
    self._hidden_size = hidden_size
    self._num_attention_heads = num_attention_heads
    self._head_size = head_size
    self._inner_size = inner_size
    self._inner_activation = inner_activation
    self._dropout_rate = dropout_rate
    self._attention_dropout_rate = attention_dropout_rate
    self._tie_attention_biases = tie_attention_biases
    self._two_stream = two_stream

    self._memory_length = memory_length
    self._reuse_length = reuse_length
    self._bi_data = bi_data
    self._clamp_length = clamp_length
    self._use_cls_mask = use_cls_mask

    self._segment_embedding = None
    self._mask_embedding = None
    self._embedding_width = embedding_width

    if embedding_width is None:
      embedding_width = hidden_size

    self._embedding_layer = layers.OnDeviceEmbedding(
        vocab_size=self._vocab_size,
        embedding_width=embedding_width,
        initializer=self._initializer,
        dtype=tf.float32,
        name="word_embedding")
    self._dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)

    self.embedding_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
    self.position_encoding = RelativePositionEncoding(self._hidden_size)

    self._transformer_xl = transformer_xl.TransformerXL(
        vocab_size=vocab_size,
        num_layers=num_layers,
        hidden_size=hidden_size,
        num_attention_heads=num_attention_heads,
        head_size=head_size,
        inner_size=inner_size,
        dropout_rate=dropout_rate,
        attention_dropout_rate=attention_dropout_rate,
        initializer=initializer,
        two_stream=two_stream,
        tie_attention_biases=tie_attention_biases,
        memory_length=memory_length,
        reuse_length=reuse_length,
        inner_activation=inner_activation,
        name="transformer_xl")