Example #1
0
def universal_transformer_gpt_model(
        max_seq_length: int, vocabulary_size: int,
        word_embedding_size: int, transformer_depth: int,
        num_heads: int, transformer_dropout: float = 0.1,
        embedding_dropout: float = 0.6,
        l2_reg_penalty: float = 1e-6,
        confidence_penalty_weight: float = 0.1):
    """
    A model which is similar to the one described by OpenAI in paper
    "Improving Language Understanding by Generative Pre-Training", except
    that it relies L2 regularization of the word embedding matrix
    (instead of the dropout), and uses Universal Transformer architecture.
    """
    word_ids = Input(shape=(max_seq_length,), dtype='int32', name='word_ids')
    l2_regularizer = (regularizers.l2(l2_reg_penalty) if l2_reg_penalty
                      else None)
    embedding_layer = ReusableEmbedding(
        vocabulary_size, word_embedding_size,
        input_length=max_seq_length,
        name='bpe_embeddings',
        # Regularization is based on paper "A Comparative Study on
        # Regularization Strategies for Embedding-based Neural Networks"
        # https://arxiv.org/pdf/1508.03721.pdf
        embeddings_regularizer=l2_regularizer)
    output_layer = TiedOutputEmbedding(
        projection_regularizer=l2_regularizer,
        projection_dropout=embedding_dropout,
        name='word_prediction_logits')
    coordinate_embedding_layer = TransformerCoordinateEmbedding(
        transformer_depth,
        name='coordinate_embedding')
    transformer_act_layer = TransformerACT(name='adaptive_computation_time')
    transformer_block = TransformerBlock(
        name='transformer', num_heads=num_heads,
        residual_dropout=transformer_dropout,
        attention_dropout=transformer_dropout,
        use_masking=True, vanilla_wiring=False)
    output_softmax_layer = Softmax(name='word_predictions')

    next_step_input, embedding_matrix = embedding_layer(word_ids)
    act_output = next_step_input

    for i in range(transformer_depth):
        next_step_input = coordinate_embedding_layer(next_step_input, step=i)
        next_step_input = transformer_block(next_step_input)
        next_step_input, act_output = transformer_act_layer(next_step_input)

    transformer_act_layer.finalize()
    next_step_input = act_output
    word_predictions = output_softmax_layer(
        output_layer([next_step_input, embedding_matrix]))
    model = Model(inputs=[word_ids], outputs=[word_predictions])
    # Penalty for confidence of the output distribution, as described in
    # "Regularizing Neural Networks by Penalizing Confident
    # Output Distributions" (https://arxiv.org/abs/1701.06548)
    confidence_penalty = K.mean(
        confidence_penalty_weight *
        K.sum(word_predictions * K.log(word_predictions), axis=-1))
    model.add_loss(confidence_penalty)
    return model
Example #2
0
def make_transformer_model(max_seq_length: int, vocabulary_size: int,
                           word_embedding_size: int, transformer_depth: int,
                           num_heads: int, transformer_dropout: float = 0.1,
                           embedding_dropout: float = 0.6,
                           l2_reg_penalty: float = 1e-6,
                           confidence_penalty_weight: float = 0.05):
    word_ids = Input(shape=(max_seq_length,), dtype='int32', name='word_ids')
    l2_regularizer = (regularizers.l2(l2_reg_penalty) if l2_reg_penalty
                      else None)
    embedding_layer = ReusableEmbedding(
        vocabulary_size, word_embedding_size,
        input_length=max_seq_length,
        name='bpe_embeddings',
        # Regularization is based on paper "A Comparative Study on
        # Regularization Strategies for Embedding-based Neural Networks"
        # https://arxiv.org/pdf/1508.03721.pdf
        embeddings_regularizer=l2_regularizer)
    output_layer = TiedOutputEmbedding(
        projection_regularizer=l2_regularizer,
        projection_dropout=embedding_dropout,
        name='word_prediction_logits')
    coordinate_embedding_layer = TransformerCoordinateEmbedding(
        transformer_depth,
        name='coordinate_embedding')
    transformer_act_layer = TransformerACT(name='adaptive_computation_time')
    transformer_block = TransformerBlock(
        name='transformer', num_heads=num_heads,
        residual_dropout=transformer_dropout,
        attention_dropout=transformer_dropout,
        use_masking=True)
    output_softmax_layer = Softmax(name='word_predictions')

    next_step_input, embedding_matrix = embedding_layer(word_ids)
    act_output = next_step_input
    dropout_layer = Dropout(embedding_dropout, name='input_dropout')

    next_step_input = dropout_layer(next_step_input)
    for i in range(transformer_depth):
        next_step_input = coordinate_embedding_layer(next_step_input, step=i)
        next_step_input = transformer_block(next_step_input)
        next_step_input, act_output = transformer_act_layer(next_step_input)

    transformer_act_layer.finalize()
    next_step_input = act_output
    # depth_of_trainable_params(act_output)
    word_predictions = output_softmax_layer(
        output_layer([next_step_input, embedding_matrix]))
    model = Model(inputs=[word_ids], outputs=[word_predictions])
    confidence_penalty = K.mean(
        confidence_penalty_weight *
        K.sum(word_predictions * K.log(word_predictions), axis=-1))
    model.add_loss(confidence_penalty)
    return model
Example #3
0
def transformer_bert_model(max_seq_length: int,
                           vocabulary_size: int,
                           word_embedding_size: int,
                           use_universal_transformer: bool,
                           transformer_depth: int,
                           num_heads: int,
                           transformer_dropout: float = 0.1,
                           embedding_dropout: float = 0.6,
                           l2_reg_penalty: float = 1e-4):
    """
    Builds a BERT-based model (Bidirectional Encoder Representations
    from Transformers) following paper "BERT: Pre-training of Deep
    Bidirectional Transformers for Language Understanding"
    (https://arxiv.org/abs/1810.04805)

    Depending on the value passed with `use_universal_transformer` argument,
    this function applies either an Adaptive Universal Transformer (2018)
    or a vanilla Transformer (2017) to do the job (the original paper uses
    vanilla Transformer).
    """
    word_ids = Input(shape=(max_seq_length, ), dtype='int32', name='word_ids')
    segment_ids = Input(shape=(max_seq_length, ),
                        dtype='int32',
                        name='segment_ids')
    l2_regularizer = (regularizers.l2(l2_reg_penalty)
                      if l2_reg_penalty else None)
    embedding_layer = ReusableEmbedding(
        vocabulary_size,
        word_embedding_size,
        input_length=max_seq_length,
        name='bpe_embeddings',
        # Regularization is based on paper "A Comparative Study on
        # Regularization Strategies for Embedding-based Neural Networks"
        # https://arxiv.org/pdf/1508.03721.pdf
        embeddings_regularizer=l2_regularizer)
    segment_embedding_layer = Embedding(
        2,  # "Segment A" and "Segment B" embeddings
        word_embedding_size,
        name='segment_embeddings')
    add_segment_layer = Add(name='add_segment')
    output_layer = TiedOutputEmbedding(projection_regularizer=l2_regularizer,
                                       projection_dropout=embedding_dropout,
                                       name='word_prediction_logits')
    output_softmax_layer = Softmax(name='word_predictions')
    coordinate_embedding_layer = TransformerCoordinateEmbedding(
        transformer_depth if use_universal_transformer else 1,
        name='coordinate_embedding')

    next_step_input, embedding_matrix = embedding_layer(word_ids)
    segment_embeddings = segment_embedding_layer(segment_ids)

    if use_universal_transformer:
        # Building a Universal Transformer (2018)
        act_layer = TransformerACT(name='adaptive_computation_time')
        transformer_block = TransformerBlock(
            name='transformer',
            num_heads=num_heads,
            residual_dropout=transformer_dropout,
            attention_dropout=transformer_dropout,
            # Allow bi-directional attention
            use_masking=False)

        act_output = next_step_input
        for i in range(transformer_depth):
            next_step_input = coordinate_embedding_layer(next_step_input,
                                                         step=i)
            next_step_input = add_segment_layer(
                [next_step_input, segment_embeddings])
            next_step_input = transformer_block(next_step_input)
            next_step_input, act_output = act_layer(next_step_input)

        act_layer.finalize()
        next_step_input = act_output
    else:
        # Building a Vanilla Transformer (described in
        # "Attention is all you need", 2017)
        next_step_input = coordinate_embedding_layer(next_step_input, step=0)
        next_step_input = add_segment_layer(
            [next_step_input, segment_embeddings])
        for i in range(transformer_depth):
            next_step_input = (
                TransformerBlock(
                    name='transformer' + str(i),
                    num_heads=num_heads,
                    residual_dropout=transformer_dropout,
                    attention_dropout=transformer_dropout,
                    use_masking=False,  # Allow bi-directional attention
                    vanilla_wiring=True)(next_step_input))

    word_predictions = output_softmax_layer(
        output_layer([next_step_input, embedding_matrix]))
    cls_node_slice = (
        # selecting the first output position in each sequence
        # (responsible for classification)
        Lambda(lambda x: x[:, 0], name='cls_node_slicer')(next_step_input))
    class_prediction = (Dense(1, name='class_prediction',
                              activation='sigmoid')(cls_node_slice))
    model = Model(inputs=[word_ids, segment_ids],
                  outputs=[word_predictions, class_prediction])
    return model
Example #4
0
def build_model(max_length: int,
                embedding_matrix: Union[np.ndarray, Tuple[int]],
                transformer_depth: int,
                transformer_heads: int,
                filters: List[int],
                kernel_size: List[int],
                pool_size: List[int],
                conv_padding: str,
                pool_padding: str,
                dense_size: List[int],
                loaded_model: Optional[str] = None,
                fine_tune_model: bool = False,
                l2_penalty: Optional[float] = None,
                embedding_dropout: float = 0.6,
                transformer_dropout: float = 0.1,
                conv_dropout: float = 0.1,
                dense_dropout: Union[float, List[float]] = 0.3,
                classifier_dropout: float = 0.1,
                train_lm=True) -> Model:

    if not (len(filters) > 0 and len(kernel_size) > 0 and len(pool_size) > 0):
        logger.error(
            "There are no filters, kernel sizes or pool sizes specified for the CNN."
        )
        raise ValueError(
            "There are no filters, kernel sizes or pool sizes specified for the CNN."
        )

    if type(dense_dropout) != list:
        dense_dropout = [dense_dropout]

    if len(dense_size) > 0 and len(dense_size) != len(dense_dropout):
        max_list_length = max([len(dense_size), len(dense_dropout)])
        new_dense_size = []
        new_dense_dropout = []
        for i in range(max_list_length):
            new_dense_size.append(
                dense_size[i] if i < len(dense_size) else dense_size[-1])
            new_dense_dropout.append(dense_dropout[i] if i < len(dense_dropout)
                                     else dense_dropout[-1])
        dense_size = new_dense_size
        dense_dropout = new_dense_dropout
        logger.warning(
            "Lists given for dense layer sizes and dense layer dropout rates are not the same length. "
            "The shorter lists are padded using the last value to match the length of the longest."
        )

    if len(filters) != len(kernel_size) or len(filters) != len(
            pool_size) or len(kernel_size) != len(pool_size):
        max_list_length = max([len(filters), len(kernel_size), len(pool_size)])
        new_filters = []
        new_kernel_size = []
        new_pool_size = []
        for i in range(max_list_length):
            new_filters.append(filters[i] if i < len(filters) else filters[-1])
            new_kernel_size.append(
                kernel_size[i] if i < len(kernel_size) else kernel_size[-1])
            new_pool_size.append(
                pool_size[i] if i < len(pool_size) else pool_size[-1])
        filters = new_filters
        kernel_size = new_kernel_size
        pool_size = new_pool_size
        logger.warning(
            "Lists given for convolutional filters, kernel sizes and pooling sizes had different lengths. "
            "The shorter lists are padded using the last value to match the length of the longest."
        )

    original_model = None
    if loaded_model:
        # load the specified model
        original_model = load_model(loaded_model,
                                    custom_objects={
                                        "perplexity": perplexity,
                                        "lm_accuracy": lm_accuracy
                                    })

    # regularizer for embedding layer
    l2_regularizer = l2(l2_penalty) if l2_penalty else None

    # input encoded as integers
    raw_input = Input(shape=(max_length, ), name="input")

    # embedding layer, initialised with embedding matrix weights for now
    embedding_weights = [
        original_model.get_layer(name="word_embedding").get_weights()[0]
        if loaded_model else embedding_matrix
    ]
    embedding_layer = ReusableEmbedding(
        input_dim=(embedding_matrix[0] if type(embedding_matrix) == tuple else
                   embedding_matrix.shape[0]),
        output_dim=(embedding_matrix[1] if type(embedding_matrix) == tuple else
                    embedding_matrix.shape[1]),
        input_length=max_length,
        name="word_embedding",
        weights=(None if type(embedding_matrix) == tuple and not loaded_model
                 else embedding_weights),
        embeddings_regularizer=l2_regularizer)

    # "transpose" of embedding matrix to map back to vocabulary
    if loaded_model:
        output_weights = original_model.get_layer(
            name="word_prediction_logits").get_weights()
        output_layer = TiedOutputEmbedding(
            projection_regularizer=l2_regularizer,
            projection_dropout=embedding_dropout,
            name="word_prediction_logits",
            weights=output_weights)
    else:
        output_layer = TiedOutputEmbedding(
            projection_regularizer=l2_regularizer,
            projection_dropout=embedding_dropout,
            name="word_prediction_logits")

    # transformer as taken from here: https://github.com/kpot/keras-transformer/blob/master/example/models.py
    if loaded_model:
        position_weights = original_model.get_layer(
            name="position_embedding").get_weights()
        position_embedding = TransformerCoordinateEmbedding(
            max_transformer_depth=1,
            name="position_embedding",
            weights=position_weights)
    else:
        position_embedding = TransformerCoordinateEmbedding(
            max_transformer_depth=1, name="position_embedding")

    transformer_input, embedding_matrix = embedding_layer(raw_input)
    transformer_output = position_embedding(transformer_input, step=0)
    for i in range(transformer_depth):
        block_name = "transformer" + str(i)

        # define transformer block
        transformer_block = TransformerBlock(
            name=block_name,
            num_heads=transformer_heads,
            residual_dropout=transformer_dropout,
            attention_dropout=transformer_dropout,
            use_masking=True,
            vanilla_wiring=True)

        # build the layers in the block because apparently you have to do that
        if loaded_model:
            if i == 0:
                transformer_block.attention_layer.build(
                    original_model.get_layer(
                        "position_embedding").output_shape)
            else:
                transformer_block.attention_layer.build(
                    original_model.get_layer(
                        "transformer{}_normalization2".format(i -
                                                              1)).output_shape)
            transformer_block.norm1_layer.build(
                original_model.get_layer(block_name +
                                         "_self_attention").output_shape)
            transformer_block.norm2_layer.build(
                original_model.get_layer(block_name +
                                         "_normalization1").output_shape)
            transformer_block.transition_layer.build(
                original_model.get_layer(block_name +
                                         "_normalization1").output_shape)

            # set weights for all the contained layers manually
            transformer_block.attention_layer.set_weights(
                original_model.get_layer(
                    name=(block_name + "_self_attention")).get_weights())
            transformer_block.norm1_layer.set_weights(
                original_model.get_layer(
                    name=(block_name + "_normalization1")).get_weights())
            transformer_block.norm2_layer.set_weights(
                original_model.get_layer(
                    name=(block_name + "_normalization2")).get_weights())
            transformer_block.transition_layer.set_weights(
                original_model.get_layer(name=(block_name +
                                               "_transition")).get_weights())

        # pass output of last layer through transformer
        transformer_output = transformer_block(transformer_output)

    # nothing special to load for softmax
    softmax_layer = Softmax(name="word_predictions")
    lm_output_logits = output_layer([transformer_output, embedding_matrix])
    lm_output = softmax_layer(lm_output_logits)

    if not fine_tune_model:
        m = Model(inputs=raw_input, outputs=lm_output)
        return m

    loaded_layer_names = []
    if loaded_model:
        loaded_layer_names = [layer.name for layer in original_model.layers]

    # convolution layer(s)
    conv_dropout = Dropout(conv_dropout, name="conv_dropout")
    conv_output = transformer_output
    for i in range(len(filters)):
        # construct and possibly load convolutional layer
        conv_layer_name = "conv_{}".format(i)
        convolution = Conv1D(filters[i],
                             kernel_size[i],
                             padding=conv_padding,
                             activation="relu",
                             name=conv_layer_name)
        if loaded_model and conv_layer_name in loaded_layer_names:
            layer = original_model.get_layer(name=conv_layer_name)
            convolution.build(layer.input_shape)
            convolution.set_weights(layer.get_weights())

        # construct max pooling, no weights to load
        pooling = MaxPooling1D(pool_size[i],
                               padding=pool_padding,
                               name="max_pool_{}".format(i))

        # get output/input of next layer
        conv_output = pooling(convolution(conv_dropout(conv_output)))

    # dense layer(s)
    flatten = Flatten(name="flatten")
    dense_output = flatten(conv_output)
    for i in range(len(dense_size)):
        # construct and possibly load dense layer
        dense_layer_name = "dense_{}".format(i)
        dense = Dense(dense_size[i], name=dense_layer_name)
        if loaded_model and dense_layer_name in loaded_layer_names:
            layer = original_model.get_layer(name=dense_layer_name)
            dense.build(layer.input_shape)
            dense.set_weights(layer.get_weights())

        # nothing to load for dropout
        dropout = Dropout(rate=dense_dropout[i],
                          name="dense_dropout_{}".format(i))

        # get output
        dense_output = dense(dropout(dense_output))

    # classification layer
    classifier_dropout = Dropout(classifier_dropout, name="classifier_dropout")
    classifier = Dense(1, name="classifier")
    classifier_prediction = Activation("sigmoid", name="classifier_prediction")
    classifier_output = classifier_prediction(
        classifier(classifier_dropout(dense_output)))

    if train_lm:
        m = Model(inputs=raw_input, outputs=[lm_output, classifier_output])
    else:
        m = Model(inputs=raw_input, outputs=classifier_output)
    return m
Example #5
0
def transformer_bert_model(max_seq_length: int,
                           time_window_size: int,
                           vocabulary_size: int,
                           concept_embedding_size: int,
                           depth: int,
                           num_heads: int,
                           transformer_dropout: float = 0.1,
                           embedding_dropout: float = 0.6,
                           l2_reg_penalty: float = 1e-4):
    """
    Builds a BERT-based model (Bidirectional Encoder Representations
    from Transformers) following paper "BERT: Pre-training of Deep
    Bidirectional Transformers for Language Understanding"
    (https://arxiv.org/abs/1810.04805)

    Depending on the value passed with `use_universal_transformer` argument,
    this function applies either an Adaptive Universal Transformer (2018)
    or a vanilla Transformer (2017) to do the job (the original paper uses
    vanilla Transformer).
    """
    masked_concept_ids = tf.keras.layers.Input(shape=(max_seq_length, ),
                                               dtype='int32',
                                               name='masked_concept_ids')

    concept_ids = tf.keras.layers.Input(shape=(max_seq_length, ),
                                        dtype='int32',
                                        name='concept_ids')

    time_stamps = tf.keras.layers.Input(shape=(max_seq_length, ),
                                        dtype='int32',
                                        name='time_stamps')

    mask = tf.keras.layers.Input(shape=(max_seq_length, ),
                                 dtype='int32',
                                 name='mask')

    concept_mask = tf.expand_dims(tf.expand_dims(mask, axis=1), axis=1)

    l2_regularizer = (tf.keras.regularizers.l2(l2_reg_penalty)
                      if l2_reg_penalty else None)

    embedding_layer = ReusableEmbedding(
        vocabulary_size,
        concept_embedding_size,
        input_length=max_seq_length,
        name='bpe_embeddings',
        # Regularization is based on paper "A Comparative Study on
        # Regularization Strategies for Embedding-based Neural Networks"
        # https://arxiv.org/pdf/1508.03721.pdf
        embeddings_regularizer=l2_regularizer)

    time_embedding_layer = TimeSelfAttention(vocab_size=vocabulary_size,
                                             target_seq_len=max_seq_length,
                                             context_seq_len=max_seq_length,
                                             time_window_size=time_window_size,
                                             return_logits=True)

    encoder_layer = Encoder(name='encoder',
                            num_layers=depth,
                            d_model=concept_embedding_size,
                            num_heads=num_heads,
                            dropout_rate=transformer_dropout)

    output_layer = TiedOutputEmbedding(projection_regularizer=l2_regularizer,
                                       projection_dropout=embedding_dropout,
                                       name='concept_prediction_logits')

    softmax_layer = tf.keras.layers.Softmax(name='concept_predictions')

    coordinate_embedding_layer = TransformerCoordinateEmbedding(
        1, name='coordinate_embedding')

    next_step_input, embedding_matrix = embedding_layer(masked_concept_ids)

    # Building a Vanilla Transformer (described in
    # "Attention is all you need", 2017)
    next_step_input = coordinate_embedding_layer(next_step_input, step=0)
    # shape = (batch_size, seq_len, seq_len)
    time_attention = time_embedding_layer([concept_ids, time_stamps, mask])
    # pad a dimension to accommodate the head split
    time_attention = tf.expand_dims(time_attention, axis=1)

    next_step_input, _ = encoder_layer(next_step_input, concept_mask,
                                       time_attention)

    concept_predictions = softmax_layer(
        output_layer([next_step_input, embedding_matrix]))

    model = tf.keras.Model(
        inputs=[masked_concept_ids, concept_ids, time_stamps, mask],
        outputs=[concept_predictions])

    return model
def vanilla_transformer_gpt_model(
        max_seq_length: int, vocabulary_size: int,
        word_embedding_size: int, transformer_depth: int,
        num_heads: int, transformer_dropout: float = 0.1,
        embedding_dropout: float = 0.6,
        l2_reg_penalty: float = 1e-6,
        confidence_penalty_weight: float = 0.1,
        agglomerative_attention: bool = False,
        use_convolutions: bool = False,
        use_coordinate_embeddings: bool = True,
        convolution_width: int = 0,
        dropout_cls: Type[Layer] = Dropout
        ):
    """
    A model which is almost identical to the one described by OpenAI in paper
    "Improving Language Understanding by Generative Pre-Training", except
    that it uses L2 regularization of the word embedding matrix,
    instead of the dropout.
    """
    word_ids = Input(shape=(max_seq_length,), dtype='int32', name='word_ids')
    l2_regularizer = (keras.regularizers.l2(l2_reg_penalty) if l2_reg_penalty
                      else None)
    embedding_layer = ReusableEmbedding(
        vocabulary_size, word_embedding_size,
        input_length=max_seq_length,
        name='bpe_embeddings',
        # Regularization is based on paper "A Comparative Study on
        # Regularization Strategies for Embedding-based Neural Networks"
        # https://arxiv.org/pdf/1508.03721.pdf
        embeddings_regularizer=l2_regularizer)
    output_layer = TiedOutputEmbedding(
        projection_regularizer=l2_regularizer,
        projection_dropout=embedding_dropout,
        name='word_prediction_logits')
    conv_layer = keras.layers.Conv1D(
        word_embedding_size, convolution_width, padding='causal',
        activation='relu', kernel_initializer='he_uniform', name='convolution')
    coordinate_embedding_layer = TransformerCoordinateEmbedding(
        1,
        name='coordinate_embedding')
    output_softmax_layer = Softmax(name='word_predictions')

    next_step_input, embedding_matrix = embedding_layer(word_ids)

    if use_convolutions:
        next_step_input = conv_layer(next_step_input)
    if use_coordinate_embeddings:
        next_step_input = coordinate_embedding_layer(next_step_input, step=0)
    for i in range(transformer_depth):
        next_step_input = (
            TransformerBlock(
                name='transformer' + str(i), num_heads=num_heads,
                residual_dropout=transformer_dropout,
                attention_dropout=transformer_dropout,
                use_masking=True,
                vanilla_wiring=True,
                agglomerative_attention=agglomerative_attention,
                dropout_cls=dropout_cls,
            )
            (next_step_input))

    word_predictions = output_softmax_layer(
        output_layer([next_step_input, embedding_matrix]))
    model = keras.models.Model(inputs=[word_ids], outputs=[word_predictions])
    # Penalty for confidence of the output distribution, as described in
    # "Regularizing Neural Networks by Penalizing Confident
    # Output Distributions" (https://arxiv.org/abs/1701.06548)
    confidence_penalty = K.mean(
        confidence_penalty_weight *
        K.sum(word_predictions * K.log(word_predictions), axis=-1))
    model.add_loss(confidence_penalty)
    return model
Example #7
0
def build_model(max_length: int,
                embedding_matrix: Union[np.ndarray, Tuple[int]],
                transformer_depth: int,
                transformer_heads: int,
                cell_type: str,
                cell_size: int,
                cell_stack_size: int,
                filters: List[int],
                kernel_size: List[int],
                pool_size: List[int],
                conv_padding: str,
                pool_padding: str,
                dense_size: List[int],
                loaded_model: Optional[str] = None,
                l2_penalty: Optional[float] = None,
                embedding_dropout: float = 0.6,
                transformer_dropout: float = 0.1,
                conv_dropout: float = 0.1,
                dense_dropout: Union[float, List[float]] = 0.3,
                classifier_dropout: float = 0.1,
                no_cnn: bool = False,
                train_lm: bool = True) -> Model:

    if type(dense_dropout) != list:
        dense_dropout = [dense_dropout]

    if len(dense_size) > 0 and len(dense_size) != len(dense_dropout):
        max_list_length = max([len(dense_size), len(dense_dropout)])
        new_dense_size = []
        new_dense_dropout = []
        for i in range(max_list_length):
            new_dense_size.append(
                dense_size[i] if i < len(dense_size) else dense_size[-1])
            new_dense_dropout.append(dense_dropout[i] if i < len(dense_dropout)
                                     else dense_dropout[-1])
        dense_size = new_dense_size
        dense_dropout = new_dense_dropout
        logger.warning(
            "Lists given for dense layer sizes and dense layer dropout rates are not the same length. "
            "The shorter lists are padded using the last value to match the length of the longest."
        )

    if not no_cnn and len(filters) != len(kernel_size) or len(filters) != len(
            pool_size) or len(kernel_size) != len(pool_size):
        max_list_length = max([len(filters), len(kernel_size), len(pool_size)])
        new_filters = []
        new_kernel_size = []
        new_pool_size = []
        for i in range(max_list_length):
            new_filters.append(filters[i] if i < len(filters) else filters[-1])
            new_kernel_size.append(
                kernel_size[i] if i < len(kernel_size) else kernel_size[-1])
            new_pool_size.append(
                pool_size[i] if i < len(pool_size) else pool_size[-1])
        filters = new_filters
        kernel_size = new_kernel_size
        pool_size = new_pool_size
        logger.warning(
            "Lists given for convolutional filters, kernel sizes and pooling sizes had different lengths. "
            "The shorter lists are padded using the last value to match the length of the longest."
        )

    cell_type_name = cell_type.lower()
    if cell_type_name == "lstm":
        cell_type = LSTM
    elif cell_type_name == "gru":
        cell_type = GRU

    # regularizer for embedding layer
    l2_regularizer = l2(l2_penalty) if l2_penalty else None

    # input encoded as integers
    raw_input = Input(shape=(max_length, ), name="input")

    # embedding layer, initialised with embedding matrix weights for now
    embedding_layer = ReusableEmbedding(
        input_dim=(embedding_matrix[0] if type(embedding_matrix) == tuple else
                   embedding_matrix.shape[0]),
        output_dim=(embedding_matrix[1] if type(embedding_matrix) == tuple else
                    embedding_matrix.shape[1]),
        input_length=max_length,
        name="word_embedding",
        weights=(None
                 if type(embedding_matrix) == tuple else [embedding_matrix]),
        embeddings_regularizer=l2_regularizer)

    # "transpose" of embedding matrix to map back to vocabulary
    lm_output = TiedOutputEmbedding(projection_regularizer=l2_regularizer,
                                    projection_dropout=embedding_dropout,
                                    name="word_prediction_logits")

    # transformer as taken from here: https://github.com/kpot/keras-transformer/blob/master/example/models.py
    position_embedding = TransformerCoordinateEmbedding(
        max_transformer_depth=1, name="position_embedding")

    transformer_blocks = []
    for i in range(transformer_depth):
        block_name = "transformer" + str(i)

        # define transformer block
        transformer_block = TransformerBlock(
            name=block_name,
            num_heads=transformer_heads,
            residual_dropout=transformer_dropout,
            attention_dropout=transformer_dropout,
            use_masking=True,
            vanilla_wiring=True)

        transformer_blocks.append(transformer_block)

    # nothing special to load for softmax
    lm_prediction = Softmax(name="word_predictions")

    # RNN
    rnn_cells = []
    for i in range(cell_stack_size - 1):
        cell_name = "{}_{}".format(cell_type_name, i)
        cell = Bidirectional(cell_type(cell_size,
                                       return_sequences=True,
                                       name=cell_name),
                             name="bidirectional_{}".format(cell_name))
        rnn_cells.append(cell)

    cell_name = "{}_{}".format(cell_type_name, cell_stack_size - 1)
    cell = Bidirectional(cell_type(cell_size,
                                   return_sequences=(not no_cnn),
                                   name=cell_name),
                         name="bidirectional_{}".format(cell_name))
    rnn_cells.append(cell)

    # flattening for RNN/transformer outputs and CNN outputs
    flatten = Flatten(name="flatten")

    # convolution layer(s)
    conv_dropout = Dropout(conv_dropout, name="conv_dropout")
    conv_layers = []
    for i in range(len(filters)):
        # construct and possibly load convolutional layer
        conv_layer_name = "conv_{}".format(i)
        convolution = Conv1D(filters[i],
                             kernel_size[i],
                             padding=conv_padding,
                             activation="relu",
                             name=conv_layer_name)

        # construct max pooling, no weights to load
        pooling = MaxPooling1D(pool_size[i],
                               padding=pool_padding,
                               name="max_pool_{}".format(i))

        conv_layers.append((convolution, pooling))

    # dense layer(s)
    dense_layers = []
    for i in range(len(dense_size)):
        # nothing to load for dropout
        dropout = Dropout(rate=dense_dropout[i],
                          name="dense_dropout_{}".format(i))

        # construct and possibly load dense layer
        dense_layer_name = "dense_{}".format(i)
        dense = Dense(dense_size[i], name=dense_layer_name)

        # get output
        dense_layers.append((dropout, dense))

    # classification layer
    classifier_dropout = Dropout(classifier_dropout, name="classifier_dropout")
    classifier = Dense(1, name="classifier")
    classifier_prediction = Activation("sigmoid", name="classifier_prediction")

    # embedding used for both parts of the model
    embedded, embedding_matrix = embedding_layer(raw_input)

    # BUILD THE ACTUAL MODEL
    # transformer
    transformer_output = position_embedding(embedded, step=0)
    for tb in transformer_blocks:
        transformer_output = tb(transformer_output)

    # language model loss if it is used
    lm_output = lm_output([transformer_output, embedding_matrix])
    lm_output = lm_prediction(lm_output)

    # RNN
    rnn_output = embedded
    for rc in rnn_cells:
        rnn_output = rc(rnn_output)

    # handle concatenation of both outputs
    if no_cnn:
        # flatten both and concatenate them
        concat = Concatenate(name="concat_flattened")
        classifier_output = concat(
            [flatten(transformer_output),
             flatten(rnn_output)])
    else:
        # concatenate the extracted features
        concat = Concatenate(name="concat_features")
        classifier_output = concat([transformer_output, rnn_output])

        # pass through CNN
        for cl in conv_layers:
            classifier_output = cl[1](cl[0](conv_dropout(classifier_output)))

    # pass through dense layer(s)
    classifier_output = flatten(classifier_output)
    for dl in dense_layers:
        classifier_output = dl[1](dl[0](classifier_output))

    # do classification
    classifier_output = classifier_prediction(
        classifier(classifier_dropout(classifier_output)))

    if train_lm:
        m = Model(inputs=raw_input, outputs=[lm_output, classifier_output])
    else:
        m = Model(inputs=raw_input, outputs=classifier_output)
    return m
def build_model(max_length,
                loaded_model=None,
                fine_tune_model=False,
                embedding_matrix=None,
                transformer_depth=8,
                transformer_heads=8,
                l2_penalty=None,
                embedding_dropout=0.6,
                transformer_dropout=0.1,
                classifier_dropout=0.1,
                transformer_output_handling="flatten",
                print_info=False,
                train_lm=True):

    original_model = None
    if loaded_model:
        # load the specified model
        original_model = load_model(loaded_model,
                                    custom_objects={
                                        "perplexity":
                                        perplexity,
                                        "lm_accuracy":
                                        lm_accuracy,
                                        "SeqSelfAttention":
                                        SeqSelfAttention,
                                        "ScaledDotProductAttention":
                                        ScaledDotProductAttention
                                    })

    # regularizer for embedding layer
    l2_regularizer = l2(l2_penalty) if l2_penalty else None

    # input encoded as integers
    raw_input = Input(shape=(max_length, ), name="input")

    # embedding layer, initialised with embedding matrix weights for now
    embedding_weights = [
        original_model.get_layer(name="word_embedding").get_weights()[0]
        if loaded_model else embedding_matrix
    ]
    embedding_layer = ReusableEmbedding(
        input_dim=(embedding_matrix[0] if type(embedding_matrix) == tuple else
                   embedding_matrix.shape[0]),
        output_dim=(embedding_matrix[1] if type(embedding_matrix) == tuple else
                    embedding_matrix.shape[1]),
        input_length=max_length,
        name="word_embedding",
        weights=(None if type(embedding_matrix) == tuple and not loaded_model
                 else embedding_weights),
        embeddings_regularizer=l2_regularizer)

    # "transpose" of embedding matrix to map back to vocabulary
    if loaded_model:
        output_weights = original_model.get_layer(
            name="word_prediction_logits").get_weights()
        output_layer = TiedOutputEmbedding(
            projection_regularizer=l2_regularizer,
            projection_dropout=embedding_dropout,
            name="word_prediction_logits",
            weights=output_weights)
    else:
        output_layer = TiedOutputEmbedding(
            projection_regularizer=l2_regularizer,
            projection_dropout=embedding_dropout,
            name="word_prediction_logits")

    # transformer as taken from here: https://github.com/kpot/keras-transformer/blob/master/example/models.py
    if loaded_model:
        position_weights = original_model.get_layer(
            name="position_embedding").get_weights()
        position_embedding = TransformerCoordinateEmbedding(
            max_transformer_depth=1,
            name="position_embedding",
            weights=position_weights)
    else:
        position_embedding = TransformerCoordinateEmbedding(
            max_transformer_depth=1, name="position_embedding")

    transformer_input, embedding_matrix = embedding_layer(raw_input)
    transformer_output = position_embedding(transformer_input, step=0)
    for i in range(transformer_depth):
        block_name = "transformer" + str(i)

        # define transformer block
        transformer_block = TransformerBlock(
            name=block_name,
            num_heads=transformer_heads,
            residual_dropout=transformer_dropout,
            attention_dropout=transformer_dropout,
            use_masking=True,
            vanilla_wiring=True)

        # build the layers in the block because apparently you have to do that
        if loaded_model:
            if i == 0:
                transformer_block.attention_layer.build(
                    original_model.get_layer(
                        "position_embedding").output_shape)
            else:
                transformer_block.attention_layer.build(
                    original_model.get_layer(
                        "transformer{}_normalization2".format(i -
                                                              1)).output_shape)
            transformer_block.norm1_layer.build(
                original_model.get_layer(block_name +
                                         "_self_attention").output_shape)
            transformer_block.norm2_layer.build(
                original_model.get_layer(block_name +
                                         "_normalization1").output_shape)
            transformer_block.transition_layer.build(
                original_model.get_layer(block_name +
                                         "_normalization1").output_shape)

            # set weights for all the contained layers manually
            transformer_block.attention_layer.set_weights(
                original_model.get_layer(
                    name=(block_name + "_self_attention")).get_weights())
            transformer_block.norm1_layer.set_weights(
                original_model.get_layer(
                    name=(block_name + "_normalization1")).get_weights())
            transformer_block.norm2_layer.set_weights(
                original_model.get_layer(
                    name=(block_name + "_normalization2")).get_weights())
            transformer_block.transition_layer.set_weights(
                original_model.get_layer(name=(block_name +
                                               "_transition")).get_weights())

        # pass output of last layer through transformer
        transformer_output = transformer_block(transformer_output)

    if print_info:
        logger.debug("transformer_output shape: {}".format(
            K.int_shape(transformer_output[0]
                        if fine_tune_model else transformer_output)))

    # nothing special to load for softmax
    softmax_layer = Softmax(name="word_predictions")
    lm_output_logits = output_layer([transformer_output, embedding_matrix])
    lm_output = softmax_layer(lm_output_logits)
    if print_info:
        logger.debug("lm_output_logits shape: {}".format(
            K.int_shape(lm_output_logits)))
        logger.debug("output shape: {}".format(K.int_shape(lm_output)))

    if not fine_tune_model:
        m = Model(inputs=raw_input, outputs=lm_output)
        return m

    loaded_layer_names = []
    if loaded_model:
        loaded_layer_names = [layer.name for layer in original_model.layers]

    # for concatenation transformer outputs early
    flatten = Flatten(name="flatten_transformer_output")
    max_pooling = Lambda(lambda x: K.max(x, axis=1), name="max_pooling")
    mean_pooling = Lambda(lambda x: K.mean(x, axis=1), name="mean_pooling")
    self_attention = SeqSelfAttention(name="self_attention")
    scaled_dot_attention = ScaledDotProductAttention(
        name="scaled_dot_attention")
    dropout = Dropout(rate=classifier_dropout, name="classifier_dropout")
    options = {
        "flatten": flatten,
        "max_pooling": max_pooling,
        "mean_pooling": mean_pooling,
        "self_attention": self_attention,
        "scaled_dot_attention": scaled_dot_attention
    }

    dense = Dense(2, activation=None, name="dense")
    if loaded_model and "dense" in loaded_layer_names:
        layer = original_model.get_layer(name="dense")
        dense.build(layer.input_shape)
        dense.set_weights(layer.get_weights())

    pooling_layer = options[transformer_output_handling]
    if loaded_model and transformer_output_handling in loaded_layer_names:
        layer = original_model.get_layer(name=transformer_output_handling)
        pooling_layer.build(layer.input_shape)
        pooling_layer.set_weights(layer.get_weights())

    if "attention" in transformer_output_handling:
        handled_output = flatten(pooling_layer(transformer_output))
    else:
        handled_output = pooling_layer(transformer_output)

    classifier_logits = dense(dropout(handled_output))
    classifier_output = Softmax(
        name="classifier_prediction")(classifier_logits)

    if train_lm:
        m = Model(inputs=raw_input, outputs=[lm_output, classifier_output])
    else:
        m = Model(inputs=raw_input, outputs=classifier_output)
    # m = Model(inputs=raw_input, outputs=lm_output)
    return m