def __init__(
            self,
            vocab_size,
            hidden_size=768,
            num_layers=12,
            num_attention_heads=12,
            sequence_length=512,
            max_sequence_length=None,
            type_vocab_size=16,
            intermediate_size=3072,
            activation=activations.gelu,
            dropout_rate=0.1,
            attention_dropout_rate=0.1,
            initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
            return_all_encoder_outputs=False,
            output_range=None,
            **kwargs):
        activation = tf.keras.activations.get(activation)
        initializer = tf.keras.initializers.get(initializer)

        if not max_sequence_length:
            max_sequence_length = sequence_length
        self._self_setattr_tracking = False
        self._config_dict = {
            'vocab_size': vocab_size,
            'hidden_size': hidden_size,
            'num_layers': num_layers,
            'num_attention_heads': num_attention_heads,
            'sequence_length': sequence_length,
            'max_sequence_length': max_sequence_length,
            'type_vocab_size': type_vocab_size,
            'intermediate_size': intermediate_size,
            'activation': tf.keras.activations.serialize(activation),
            'dropout_rate': dropout_rate,
            'attention_dropout_rate': attention_dropout_rate,
            'initializer': tf.keras.initializers.serialize(initializer),
            'return_all_encoder_outputs': return_all_encoder_outputs,
            'output_range': output_range,
        }

        word_ids = tf.keras.layers.Input(shape=(sequence_length, ),
                                         dtype=tf.int32,
                                         name='input_word_ids')
        mask = tf.keras.layers.Input(shape=(sequence_length, ),
                                     dtype=tf.int32,
                                     name='input_mask')
        type_ids = tf.keras.layers.Input(shape=(sequence_length, ),
                                         dtype=tf.int32,
                                         name='input_type_ids')

        self._embedding_layer = layers.OnDeviceEmbedding(
            vocab_size=vocab_size,
            embedding_width=hidden_size,
            initializer=initializer,
            name='word_embeddings')
        word_embeddings = self._embedding_layer(word_ids)

        # Always uses dynamic slicing for simplicity.
        self._position_embedding_layer = layers.PositionEmbedding(
            initializer=initializer,
            use_dynamic_slicing=True,
            max_sequence_length=max_sequence_length,
            name='position_embedding')
        position_embeddings = self._position_embedding_layer(word_embeddings)

        type_embeddings = (layers.OnDeviceEmbedding(
            vocab_size=type_vocab_size,
            embedding_width=hidden_size,
            initializer=initializer,
            use_one_hot=True,
            name='type_embeddings')(type_ids))

        embeddings = tf.keras.layers.Add()(
            [word_embeddings, position_embeddings, type_embeddings])
        embeddings = (tf.keras.layers.LayerNormalization(
            name='embeddings/layer_norm',
            axis=-1,
            epsilon=1e-12,
            dtype=tf.float32)(embeddings))
        embeddings = (tf.keras.layers.Dropout(rate=dropout_rate)(embeddings))

        self._transformer_layers = []
        data = embeddings
        attention_mask = layers.SelfAttentionMask()([data, mask])
        encoder_outputs = []
        for i in range(num_layers):
            if i == num_layers - 1 and output_range is not None:
                transformer_output_range = output_range
            else:
                transformer_output_range = None
            layer = layers.Transformer(
                num_attention_heads=num_attention_heads,
                intermediate_size=intermediate_size,
                intermediate_activation=activation,
                dropout_rate=dropout_rate,
                attention_dropout_rate=attention_dropout_rate,
                output_range=transformer_output_range,
                kernel_initializer=initializer,
                name='transformer/layer_%d' % i)
            self._transformer_layers.append(layer)
            data = layer([data, attention_mask])
            encoder_outputs.append(data)

        first_token_tensor = (
            tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(
                encoder_outputs[-1]))
        self._pooler_layer = tf.keras.layers.Dense(
            units=hidden_size,
            activation='tanh',
            kernel_initializer=initializer,
            name='pooler_transform')
        cls_output = self._pooler_layer(first_token_tensor)

        if return_all_encoder_outputs:
            outputs = [encoder_outputs, cls_output]
        else:
            outputs = [encoder_outputs[-1], cls_output]

        super(TransformerEncoder,
              self).__init__(inputs=[word_ids, mask, type_ids],
                             outputs=outputs,
                             **kwargs)
    def __init__(
            self,
            vocab_size,
            embedding_width=128,
            hidden_size=768,
            num_layers=12,
            num_attention_heads=12,
            sequence_length=512,
            max_sequence_length=None,
            type_vocab_size=16,
            intermediate_size=3072,
            activation=activations.gelu,
            dropout_rate=0.1,
            attention_dropout_rate=0.1,
            initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
            **kwargs):
        activation = tf.keras.activations.get(activation)
        initializer = tf.keras.initializers.get(initializer)

        if not max_sequence_length:
            max_sequence_length = sequence_length
        self._self_setattr_tracking = False
        self._config_dict = {
            'vocab_size': vocab_size,
            'embedding_width': embedding_width,
            'hidden_size': hidden_size,
            'num_layers': num_layers,
            'num_attention_heads': num_attention_heads,
            'sequence_length': sequence_length,
            'max_sequence_length': max_sequence_length,
            'type_vocab_size': type_vocab_size,
            'intermediate_size': intermediate_size,
            'activation': tf.keras.activations.serialize(activation),
            'dropout_rate': dropout_rate,
            'attention_dropout_rate': attention_dropout_rate,
            'initializer': tf.keras.initializers.serialize(initializer),
        }

        word_ids = tf.keras.layers.Input(shape=(sequence_length, ),
                                         dtype=tf.int32,
                                         name='input_word_ids')
        mask = tf.keras.layers.Input(shape=(sequence_length, ),
                                     dtype=tf.int32,
                                     name='input_mask')
        type_ids = tf.keras.layers.Input(shape=(sequence_length, ),
                                         dtype=tf.int32,
                                         name='input_type_ids')

        self._embedding_layer = layers.OnDeviceEmbedding(
            vocab_size=vocab_size,
            embedding_width=embedding_width,
            initializer=initializer,
            name='word_embeddings')
        word_embeddings = self._embedding_layer(word_ids)

        # Always uses dynamic slicing for simplicity.
        self._position_embedding_layer = layers.PositionEmbedding(
            initializer=initializer,
            use_dynamic_slicing=True,
            max_sequence_length=max_sequence_length,
            name='position_embedding')
        position_embeddings = self._position_embedding_layer(word_embeddings)

        type_embeddings = (layers.OnDeviceEmbedding(
            vocab_size=type_vocab_size,
            embedding_width=embedding_width,
            initializer=initializer,
            use_one_hot=True,
            name='type_embeddings')(type_ids))

        embeddings = tf.keras.layers.Add()(
            [word_embeddings, position_embeddings, type_embeddings])
        embeddings = (tf.keras.layers.LayerNormalization(
            name='embeddings/layer_norm',
            axis=-1,
            epsilon=1e-12,
            dtype=tf.float32)(embeddings))
        embeddings = (tf.keras.layers.Dropout(rate=dropout_rate)(embeddings))
        # We project the 'embedding' output to 'hidden_size' if it is not already
        # 'hidden_size'.
        if embedding_width != hidden_size:
            embeddings = layers.DenseEinsum(
                output_shape=hidden_size,
                kernel_initializer=initializer,
                name='embedding_projection')(embeddings)

        data = embeddings
        attention_mask = layers.SelfAttentionMask()([data, mask])
        shared_layer = layers.Transformer(
            num_attention_heads=num_attention_heads,
            intermediate_size=intermediate_size,
            intermediate_activation=activation,
            dropout_rate=dropout_rate,
            attention_dropout_rate=attention_dropout_rate,
            kernel_initializer=initializer,
            name='transformer')
        for _ in range(num_layers):
            data = shared_layer([data, attention_mask])

        first_token_tensor = (tf.keras.layers.Lambda(
            lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(data))
        cls_output = tf.keras.layers.Dense(
            units=hidden_size,
            activation='tanh',
            kernel_initializer=initializer,
            name='pooler_transform')(first_token_tensor)

        super(AlbertTransformerEncoder,
              self).__init__(inputs=[word_ids, mask, type_ids],
                             outputs=[data, cls_output],
                             **kwargs)
  def test_network_invocation(self):
    hidden_size = 32
    sequence_length = 21
    vocab_size = 57

    # Build an embedding network to swap in for the default network. This one
    # will have 2 inputs (mask and word_ids) instead of 3, and won't use
    # positional embeddings.

    word_ids = tf.keras.layers.Input(
        shape=(sequence_length,), dtype=tf.int32, name="input_word_ids")
    mask = tf.keras.layers.Input(
        shape=(sequence_length,), dtype=tf.int32, name="input_mask")
    embedding_layer = layers.OnDeviceEmbedding(
        vocab_size=vocab_size,
        embedding_width=hidden_size,
        initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
        name="word_embeddings")
    word_embeddings = embedding_layer(word_ids)
    network = tf.keras.Model([word_ids, mask], [word_embeddings, mask])

    hidden_cfg = {
        "num_attention_heads":
            2,
        "intermediate_size":
            3072,
        "intermediate_activation":
            activations.gelu,
        "dropout_rate":
            0.1,
        "attention_dropout_rate":
            0.1,
        "kernel_initializer":
            tf.keras.initializers.TruncatedNormal(stddev=0.02),
    }

    # Create a small EncoderScaffold for testing.
    test_network = encoder_scaffold.EncoderScaffold(
        num_hidden_instances=3,
        pooled_output_dim=hidden_size,
        pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
            stddev=0.02),
        hidden_cfg=hidden_cfg,
        embedding_cls=network,
        embedding_data=embedding_layer.embeddings)

    # Create the inputs (note that the first dimension is implicit).
    word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
    mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
    data, pooled = test_network([word_ids, mask])

    # Create a model based off of this network:
    model = tf.keras.Model([word_ids, mask], [data, pooled])

    # Invoke the model. We can't validate the output data here (the model is too
    # complex) but this will catch structural runtime errors.
    batch_size = 3
    word_id_data = np.random.randint(
        vocab_size, size=(batch_size, sequence_length))
    mask_data = np.random.randint(2, size=(batch_size, sequence_length))
    _ = model.predict([word_id_data, mask_data])

    # Test that we can get the embedding data that we passed to the object. This
    # is necessary to support standard language model training.
    self.assertIs(embedding_layer.embeddings,
                  test_network.get_embedding_table())
  def test_serialize_deserialize(self):
    hidden_size = 32
    sequence_length = 21
    vocab_size = 57

    # Build an embedding network to swap in for the default network. This one
    # will have 2 inputs (mask and word_ids) instead of 3, and won't use
    # positional embeddings.

    word_ids = tf.keras.layers.Input(
        shape=(sequence_length,), dtype=tf.int32, name="input_word_ids")
    mask = tf.keras.layers.Input(
        shape=(sequence_length,), dtype=tf.int32, name="input_mask")
    embedding_layer = layers.OnDeviceEmbedding(
        vocab_size=vocab_size,
        embedding_width=hidden_size,
        initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
        name="word_embeddings")
    word_embeddings = embedding_layer(word_ids)
    network = tf.keras.Model([word_ids, mask], [word_embeddings, mask])

    hidden_cfg = {
        "num_attention_heads":
            2,
        "intermediate_size":
            3072,
        "intermediate_activation":
            activations.gelu,
        "dropout_rate":
            0.1,
        "attention_dropout_rate":
            0.1,
        "kernel_initializer":
            tf.keras.initializers.TruncatedNormal(stddev=0.02),
    }

    # Create a small EncoderScaffold for testing.
    test_network = encoder_scaffold.EncoderScaffold(
        num_hidden_instances=3,
        pooled_output_dim=hidden_size,
        pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
            stddev=0.02),
        hidden_cfg=hidden_cfg,
        embedding_cls=network,
        embedding_data=embedding_layer.embeddings)

    # Create another network object from the first object's config.
    new_network = encoder_scaffold.EncoderScaffold.from_config(
        test_network.get_config())

    # Validate that the config can be forced to JSON.
    _ = new_network.to_json()

    # If the serialization was successful, the new config should match the old.
    self.assertAllEqual(test_network.get_config(), new_network.get_config())

    # Create a model based off of the old and new networks:
    word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
    mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)

    data, pooled = new_network([word_ids, mask])
    new_model = tf.keras.Model([word_ids, mask], [data, pooled])

    data, pooled = test_network([word_ids, mask])
    model = tf.keras.Model([word_ids, mask], [data, pooled])

    # Copy the weights between models.
    new_model.set_weights(model.get_weights())

    # Invoke the models.
    batch_size = 3
    word_id_data = np.random.randint(
        vocab_size, size=(batch_size, sequence_length))
    mask_data = np.random.randint(2, size=(batch_size, sequence_length))
    data, cls = model.predict([word_id_data, mask_data])
    new_data, new_cls = new_model.predict([word_id_data, mask_data])

    # The output should be equal.
    self.assertAllEqual(data, new_data)
    self.assertAllEqual(cls, new_cls)

    # We should not be able to get a reference to the embedding data.
    with self.assertRaisesRegex(RuntimeError, ".*does not have a reference.*"):
      new_network.get_embedding_table()
Example #5
0
    def __init__(
            self,
            pooled_output_dim,
            pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=0.02),
            embedding_cls=None,
            embedding_cfg=None,
            embedding_data=None,
            num_hidden_instances=1,
            hidden_cls=layers.Transformer,
            hidden_cfg=None,
            return_all_layer_outputs=False,
            **kwargs):
        self._self_setattr_tracking = False
        self._hidden_cls = hidden_cls
        self._hidden_cfg = hidden_cfg
        self._num_hidden_instances = num_hidden_instances
        self._pooled_output_dim = pooled_output_dim
        self._pooler_layer_initializer = pooler_layer_initializer
        self._embedding_cls = embedding_cls
        self._embedding_cfg = embedding_cfg
        self._embedding_data = embedding_data
        self._return_all_layer_outputs = return_all_layer_outputs
        self._kwargs = kwargs

        if embedding_cls:
            if inspect.isclass(embedding_cls):
                self._embedding_network = embedding_cls(embedding_cfg)
            else:
                self._embedding_network = embedding_cls
            inputs = self._embedding_network.inputs
            embeddings, mask = self._embedding_network(inputs)
        else:
            self._embedding_network = None
            word_ids = tf.keras.layers.Input(
                shape=(embedding_cfg['seq_length'], ),
                dtype=tf.int32,
                name='input_word_ids')
            mask = tf.keras.layers.Input(shape=(embedding_cfg['seq_length'], ),
                                         dtype=tf.int32,
                                         name='input_mask')
            type_ids = tf.keras.layers.Input(
                shape=(embedding_cfg['seq_length'], ),
                dtype=tf.int32,
                name='input_type_ids')
            inputs = [word_ids, mask, type_ids]

            self._embedding_layer = layers.OnDeviceEmbedding(
                vocab_size=embedding_cfg['vocab_size'],
                embedding_width=embedding_cfg['hidden_size'],
                initializer=embedding_cfg['initializer'],
                name='word_embeddings')

            word_embeddings = self._embedding_layer(word_ids)

            # Always uses dynamic slicing for simplicity.
            self._position_embedding_layer = layers.PositionEmbedding(
                initializer=embedding_cfg['initializer'],
                use_dynamic_slicing=True,
                max_sequence_length=embedding_cfg['max_seq_length'],
                name='position_embedding')
            position_embeddings = self._position_embedding_layer(
                word_embeddings)

            type_embeddings = (layers.OnDeviceEmbedding(
                vocab_size=embedding_cfg['type_vocab_size'],
                embedding_width=embedding_cfg['hidden_size'],
                initializer=embedding_cfg['initializer'],
                use_one_hot=True,
                name='type_embeddings')(type_ids))

            embeddings = tf.keras.layers.Add()(
                [word_embeddings, position_embeddings, type_embeddings])
            embeddings = (tf.keras.layers.LayerNormalization(
                name='embeddings/layer_norm',
                axis=-1,
                epsilon=1e-12,
                dtype=tf.float32)(embeddings))
            embeddings = (tf.keras.layers.Dropout(
                rate=embedding_cfg['dropout_rate'])(embeddings))

        attention_mask = layers.SelfAttentionMask()([embeddings, mask])
        data = embeddings

        layer_output_data = []
        self._hidden_layers = []
        for _ in range(num_hidden_instances):
            if inspect.isclass(hidden_cls):
                layer = hidden_cls(
                    **hidden_cfg) if hidden_cfg else hidden_cls()
            else:
                layer = hidden_cls
            data = layer([data, attention_mask])
            layer_output_data.append(data)
            self._hidden_layers.append(layer)

        first_token_tensor = (
            tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(
                layer_output_data[-1]))
        self._pooler_layer = tf.keras.layers.Dense(
            units=pooled_output_dim,
            activation='tanh',
            kernel_initializer=pooler_layer_initializer,
            name='cls_transform')
        cls_output = self._pooler_layer(first_token_tensor)

        if return_all_layer_outputs:
            outputs = [layer_output_data, cls_output]
        else:
            outputs = [layer_output_data[-1], cls_output]

        super(EncoderScaffold, self).__init__(inputs=inputs,
                                              outputs=outputs,
                                              **kwargs)