def create_layer(self,
                   vocab_size,
                   hidden_size,
                   output='predictions',
                   xformer_stack=None):
    # First, create a transformer stack that we can use to get the LM's
    # vocabulary weight.
    if xformer_stack is None:
      xformer_stack = bert_encoder.BertEncoder(
          vocab_size=vocab_size,
          num_layers=1,
          hidden_size=hidden_size,
          num_attention_heads=4,
      )

    # Create a maskedLM from the transformer stack.
    test_layer = masked_lm.MaskedLM(
        embedding_table=xformer_stack.get_embedding_table(), output=output)
    return test_layer
Esempio n. 2
0
    # vocabulary weight.
    if xformer_stack is None:
      xformer_stack = transformer_encoder.TransformerEncoder(
          vocab_size=vocab_size,
          num_layers=1,
<<<<<<< HEAD
          sequence_length=sequence_length,
=======
>>>>>>> a811a3b7e640722318ad868c99feddf3f3063e36
          hidden_size=hidden_size,
          num_attention_heads=4,
      )

    # Create a maskedLM from the transformer stack.
    test_layer = masked_lm.MaskedLM(
        embedding_table=xformer_stack.get_embedding_table(),
        output=output)
    return test_layer

  def test_layer_creation(self):
    vocab_size = 100
    sequence_length = 32
    hidden_size = 64
    num_predictions = 21
    test_layer = self.create_layer(
        vocab_size=vocab_size,
<<<<<<< HEAD
        sequence_length=sequence_length,
=======
>>>>>>> a811a3b7e640722318ad868c99feddf3f3063e36
        hidden_size=hidden_size)