Ejemplo n.º 1
0
def gpt2_featurizer(
    X,
    encoder,
    config,
    train=False,
    reuse=None,
    **kwargs
):
    initial_shape = tf.shape(X)
    X = tf.reshape(X, shape=tf.concat(([-1], initial_shape[-2:]), 0))
    X.set_shape([None, None, None])

    with tf.variable_scope("model/featurizer", reuse=reuse):
        embed_weights = tf.get_variable(
            name="we",
            shape=[encoder.vocab_size + config.max_length, config.n_embed],
            initializer=tf.random_normal_initializer(stddev=config.weight_stddev),
        )
        if config.train_embeddings:
            embed_weights = dropout(embed_weights, config.embed_p_drop, train)
        else:
            embed_weights = tf.stop_gradient(embed_weights)

        X = tf.reshape(X, [-1, config.max_length, 2])
        h = embed(X, embed_weights)

        # Transformer
        pasts = [None] * config.n_layer
        for layer, past in enumerate(pasts):
            if (
                (config.n_layer - layer) == config.num_layers_trained
                and config.num_layers_trained != config.n_layer
                and config.adapter_size is None
            ):
                h = tf.stop_gradient(h)
                train_layer = False
            else:
                train_layer = train

            with tf.variable_scope("h%d" % layer):
                block_fn = functools.partial(
                    block, past=past, hparams=config, train=train
                )
                if config.low_memory_mode and train_layer:
                    block_fn = recompute_grad(block_fn, use_entire_scope=True)
                h = block_fn(h)

        h = norm(h, "ln_f")

        # Use hidden state at classifier token as input to final proj. + softmax
        clf_h = tf.reshape(h, [-1, config.n_embed])  # [batch * seq_len, embed]
        clf_token = encoder["_classify_"]
        pool_idx = tf.cast(
            tf.argmax(tf.cast(tf.equal(X[:, :, 0], clf_token), tf.float32), 1), tf.int32
        )
        clf_h = tf.gather(
            clf_h,
            tf.range(shape_list(X)[0], dtype=tf.int32) * config.max_length + pool_idx,
        )
        clf_h = tf.reshape(
            clf_h, shape=tf.concat((initial_shape[:-2], [config.n_embed]), 0)
        )
        seq_feats = tf.reshape(
            h, shape=tf.concat((initial_shape[:-1], [config.n_embed]), 0)
        )

        lengths = lengths_from_eos_idx(eos_idx=pool_idx, max_length=shape_list(X)[0])

        return {
            "embed_weights": embed_weights,
            "features": clf_h,
            "sequence_features": seq_feats,
            "eos_idx": pool_idx,
            "lengths": lengths
        }
Ejemplo n.º 2
0
def transformer_model(input_tensor,
                      attention_mask=None,
                      hidden_size=768,
                      num_hidden_layers=12,
                      num_attention_heads=12,
                      intermediate_size=3072,
                      intermediate_act_fn=gelu,
                      hidden_dropout_prob=0.1,
                      attention_probs_dropout_prob=0.1,
                      initializer_range=0.02,
                      do_return_all_layers=False,
                      adapter_size=0,
                      low_memory_mode=False):
    """Multi-headed, multi-layer Transformer from "Attention is All You Need".

    This is almost an exact implementation of the original Transformer encoder.

    See the original paper:
    https://arxiv.org/abs/1706.03762

    Also see:
        https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py

    Args:
        input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size].
        attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length,
            seq_length], with 1 for positions that can be attended to and 0 in
            positions that should not be.
        hidden_size: int. Hidden size of the Transformer.
        num_hidden_layers: int. Number of layers (blocks) in the Transformer.
        num_attention_heads: int. Number of attention heads in the Transformer.
        intermediate_size: int. The size of the "intermediate" (a.k.a., feed
            forward) layer.
        intermediate_act_fn: function. The non-linear activation function to apply
            to the output of the intermediate/feed-forward layer.
        hidden_dropout_prob: float. Dropout probability for the hidden layers.
        attention_probs_dropout_prob: float. Dropout probability of the attention
            probabilities.
        initializer_range: float. Range of the initializer (stddev of truncated
            normal).
        do_return_all_layers: Whether to also return all layers or just the final
            layer.
        adapter_size: The size of adaptor modules to use. None to disable.
        low_memory_mode: Whether to use gradient checkpointing.

    Returns:
        float Tensor of shape [batch_size, seq_length, hidden_size], the final
        hidden layer of the Transformer.

    Raises:
        ValueError: A Tensor shape or parameter is invalid.
    """
    if hidden_size % num_attention_heads != 0:
        raise ValueError(
            "The hidden size (%d) is not a multiple of the number of attention "
            "heads (%d)" % (hidden_size, num_attention_heads))

    attention_head_size = int(hidden_size / num_attention_heads)
    input_shape = get_shape_list(input_tensor, expected_rank=3)
    batch_size = input_shape[0]
    seq_length = input_shape[1]
    input_width = input_shape[2]

    # The Transformer performs sum residuals on all layers so the input needs
    # to be the same as the hidden size.
    if input_width != hidden_size:
        raise ValueError(
            "The width of the input tensor (%d) != hidden size (%d)" %
            (input_width, hidden_size))

    # We keep the representation as a 2D tensor to avoid re-shaping it back and
    # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on
    # the GPU/CPU but may not be free on the TPU, so we want to minimize them to
    # help the optimizer.
    prev_output = reshape_to_matrix(input_tensor)

    all_layer_outputs = []
    for layer_idx in range(num_hidden_layers):
        with tf.variable_scope("layer_%d" % layer_idx):
            layer_input = prev_output

            block_fn = functools.partial(
                full_block,
                attention_head_size=attention_head_size,
                batch_size=batch_size,
                seq_length=seq_length,
                attention_mask=attention_mask,
                hidden_size=hidden_size,
                num_attention_heads=num_attention_heads,
                intermediate_size=intermediate_size,
                intermediate_act_fn=intermediate_act_fn,
                hidden_dropout_prob=hidden_dropout_prob,
                attention_probs_dropout_prob=attention_probs_dropout_prob,
                initializer_range=initializer_range,
                adapter_size=adapter_size)

            if low_memory_mode:
                block_fn = recompute_grad(block_fn, use_entire_scope=True)

            layer_output = block_fn(layer_input)
            prev_output = layer_output
            all_layer_outputs.append(layer_output)

    if do_return_all_layers:
        final_outputs = []
        for layer_output in all_layer_outputs:
            final_output = reshape_from_matrix(layer_output, input_shape)
            final_outputs.append(final_output)
        return final_outputs
    else:
        final_output = reshape_from_matrix(prev_output, input_shape)
        return final_output
Ejemplo n.º 3
0
def featurizer(X, encoder, config, train=False, reuse=None, encoder_state=None, context=None, context_dim=None, **kwargs):
    """
    The main element of the OSCAR model. Maps from tokens ids to a dense, embedding of the sequence.

    :param X: A tensor of token indexes with shape [batch_size, sequence_length, token_idx]
    :param encoder: A TextEncoder object.
    :param config: A config object, containing all parameters for the featurizer.
    :param train: If this flag is true, dropout and losses are added to the graph.
    :param reuse: Should reuse be set within this scope.
    :return: A dict containing;
        embed_weights: the word embedding matrix.
        features: The output of the featurizer_final state.
        sequence_features: The output of the featurizer at each timestep.
    """
    initial_shape = [a or -1 for a in X.get_shape().as_list()]
    if len(initial_shape) != 3:
        X = tf.reshape(X, shape=[-1] + initial_shape[-2:])

    x_shape = tf.shape(X)
    with tf.variable_scope('model/featurizer', reuse=reuse):
        encoder._lazy_init()
        clf_token = encoder.end_token
        pool_idx = tf.cast(tf.argmax(tf.cast(tf.equal(X[:, :, 0], clf_token), tf.float32), 1), tf.int32)
        if encoder_state is None:
            embed_weights = tf.get_variable("we", [encoder.vocab_size + config.max_length, config.n_embed],
                                            initializer=tf.random_normal_initializer(stddev=config.weight_stddev))
        else:
            embed_weights = encoder_state["embed_weights"]

        if config.oscar_use_fp16:
            embed_weights = tf.cast(embed_weights, tf.float16)

        if config.train_embeddings:
            embed_weights = dropout(embed_weights, config.embed_p_drop, train)
        else:
            embed_weights = tf.stop_gradient(embed_weights)

        X = tf.reshape(X, [-1, x_shape[1], 2])

        if config.oscar_use_timing:
            h = embed(X, embed_weights)
        else:
            h = embed_no_timing(X, embed_weights)

        for layer in range(config.n_layer):
            with tf.variable_scope('h%d_' % layer):
                if (
                        (config.n_layer - layer) == config.num_layers_trained and
                        config.num_layers_trained != config.n_layer
                ):
                    h = tf.stop_gradient(h)

                block_fn_fwd = functools.partial(
                    block, block_name='block%d_' % layer, use_fp16=config.oscar_use_fp16,
                    pool_idx=None, encoder_state=encoder_state, train=train,
                    pdrop=config.resid_p_drop, use_fused_kernel=config.oscar_use_fused_kernel,
                )

                if config.low_memory_mode and train:
                    block_fn_fwd = recompute_grad(block_fn_fwd, use_entire_scope=True)
                h = block_fn_fwd(h)

        h = normal_1d_conv_block(h, 1, "output", config.oscar_use_fp16, dilation=1)

        mask = tf.expand_dims(tf.sequence_mask(pool_idx, maxlen=tf.shape(h)[1], dtype=h.dtype), -1)

        if config.oscar_feat_mode == "clf_tok":
            clf_h = tf.gather_nd(h, tf.stack([tf.range(shape_list(h)[0]), pool_idx], 1))
        elif config.oscar_feat_mode == "mean_tok":
            clf_h = tf.reduce_sum(h * mask, 1) / tf.reduce_sum(h)
        elif config.oscar_feat_mode == "max_tok":
            clf_h = tf.reduce_max(h - (1e5 * (1.0 - mask)), 1)
        else:
            raise ValueError("config.feat_mode should be one of clf_tok, mean_tok or max_tok")

        if len(initial_shape) != 3:
            seq_feats = tf.reshape(h, shape=initial_shape[:-1] + [config.n_embed])
        else:
            seq_feats = h

        return {
            'embed_weights': embed_weights,
            'features': cast_maybe(clf_h, tf.float32),
            'sequence_features': seq_feats,
            'eos_idx': pool_idx,
            'encoded_input': X[:, :tf.reduce_min(pool_idx), 0],
            'lengths': lengths_from_eos_idx(eos_idx=pool_idx, max_length=shape_list(X)[0])
        }
Ejemplo n.º 4
0
def gpt_featurizer(X,
                   encoder,
                   config,
                   train=False,
                   reuse=None,
                   explain=False,
                   **kwargs):
    """
    The transformer element of the finetuning model. Maps from tokens ids to a dense, embedding of the sequence.

    :param X: A tensor of token indexes with shape [batch_size, sequence_length, token_idx]
    :param encoder: A TextEncoder object.
    :param config: A config object, containing all parameters for the featurizer.
    :param train: If this flag is true, dropout and losses are added to the graph.
    :param reuse: Should reuse be set within this scope.
    :return: A dict containing;
        embed_weights: the word embedding matrix.
        features: The output of the featurizer_final state.
        sequence_features: The output of the featurizer at each timestep.
    """
    initial_shape = tf.shape(X)
    X = tf.reshape(X, shape=tf.concat(([-1], initial_shape[-2:]), 0))
    sequence_length = tf.shape(X)[1]

    with tf.variable_scope("model/featurizer", reuse=reuse):
        embed_weights = tf.get_variable(
            name="we",
            shape=[encoder.vocab_size + config.max_length, config.n_embed],
            initializer=tf.random_normal_initializer(
                stddev=config.weight_stddev),
        )
        if config.train_embeddings:
            embed_weights = dropout(embed_weights, config.embed_p_drop, train)
        else:
            embed_weights = tf.stop_gradient(embed_weights)

#        X = tf.reshape(X, [-1, config.max_length, 2])

        clf_token = encoder.end_token
        pool_idx = tf.cast(
            tf.argmax(tf.cast(tf.equal(X[:, :, 0], clf_token), tf.float32), 1),
            tf.int32)

        if explain:
            X = add_explain_tokens(X, sequence_length, pool_idx)

        h = embed(X, embed_weights)
        for layer in range(config.n_layer):
            if ((config.n_layer - layer) == config.num_layers_trained
                    and config.num_layers_trained != config.n_layer
                    and config.adapter_size is None):
                h = tf.stop_gradient(h)
                train_layer = False
            else:
                train_layer = train

            with tf.variable_scope("h%d_" % layer):
                block_fn = functools.partial(
                    block,
                    n_head=config.n_heads,
                    act_fn=config.act_fn,
                    resid_pdrop=config.resid_p_drop,
                    attn_pdrop=config.attn_p_drop,
                    scope="h%d" % layer,
                    train=train_layer,
                    scale=True,
                    explain=explain,
                    adptr_size=config.adapter_size,
                )
                if config.low_memory_mode and train_layer:
                    block_fn = recompute_grad(block_fn, use_entire_scope=True)
                if layer < config.n_layer - 1:
                    h = block_fn(h)
                else:
                    h_out = block_fn(h)

            # get the attention weights from the last layer
            if layer == config.n_layer - 1:
                with tf.variable_scope("h%d_/h%d/attn" % (layer, layer),
                                       reuse=True):
                    q, k, v = multihead_qkv(h,
                                            n_state=shape_list(h)[-1],
                                            n_head=config.n_heads,
                                            train=train)
                    w = attn_weights(q, k, v, scale=True)

        if explain:
            explain_out = h_out[:, initial_shape[1]:]
            explain_out = tf.reshape(
                explain_out,
                shape=tf.concat((initial_shape[:-1], [config.n_embed]), 0))
            h_out = h_out[:, :initial_shape[1]]

        # Use hidden state at classifier token as input to final proj. + softmax
        clf_h = tf.reshape(h_out,
                           [-1, config.n_embed])  # [batch * seq_len, embed]
        clf_h = tf.gather(
            clf_h,
            tf.range(shape_list(X)[0], dtype=tf.int32) * sequence_length +
            pool_idx,
        )
        clf_h = tf.reshape(clf_h,
                           shape=tf.concat(
                               (initial_shape[:-2], [config.n_embed]), 0))
        seq_feats = tf.reshape(h_out,
                               shape=tf.concat(
                                   (initial_shape[:-1], [config.n_embed]), 0))

        lengths = lengths_from_eos_idx(eos_idx=pool_idx,
                                       max_length=sequence_length)

        out = {
            "embed_weights": embed_weights,
            "features": clf_h,
            "sequence_features": seq_feats,
            "eos_idx": pool_idx,
            "lengths": lengths,
            "attention_weights": w,  # [n_heads, seq_len, seq_len]
        }
        if explain:
            out["explain_out"] = explain_out
        return out
Ejemplo n.º 5
0
def association(hidden,
                pool_idx,
                targets,
                n_targets,
                config,
                train=False,
                reuse=None,
                **kwargs):
    """
    An Attention based sequence labeler model with association.

    :param hidden: The output of the featurizer. [batch_size, sequence_length, embed_dim]
    :param pool_idx: the index of the classify tokens along the sequence dimension. [batch_size]
    :param targets: A dict containing:
     'labels': The sequence labeling targets. [batch_size, sequence_length],
     'associations': A matrix of class ids for the associations [batch_size, sequence_length, seqence_length]
    :param n_targets: A python int containing the number of classes that the model should be learning to predict over.
    :param config: A config object, containing all parameters for the featurizer.
    :param train: If this flag is true, dropout and losses are added to the graph.
    :param reuse: Should reuse be set within this scope.
    :param kwargs: Spare arguments.
    :return: dict containing:
        "logits": The un-normalised log probabilities of each class being in each location. For usable predictions,
            sampling from this distrobution is not sufficiant and a viterbi decoding method should be used.
        "losses": The negative log likelihood for the sequence targets.
        "predict_params": A dictionary of params to be fed to the viterbi decode function.
    """
    with tf.variable_scope('sequence-labeler', reuse=reuse):
        nx = config.n_embed
        length = config.max_length
        num_associations = len(config.association_types) + 1

        def seq_lab_internal(hidden):
            attn_fn = functools.partial(attn,
                                        scope="seq_label_attn",
                                        n_state=nx,
                                        n_head=config.seq_num_heads,
                                        resid_pdrop=config.resid_p_drop,
                                        attn_pdrop=config.attn_p_drop,
                                        train=train,
                                        scale=False,
                                        mask=False)
            n = norm(attn_fn(hidden) + hidden, 'seq_label_residual')
            flat_logits = tf.layers.dense(n, n_targets)
            logits = tf.reshape(
                flat_logits, tf.concat([tf.shape(hidden)[:2], [n_targets]], 0))

            association_head = tf.layers.dense(n, nx)
            association_head = tf.reshape(
                association_head, tf.concat([tf.shape(hidden)[:2], [nx]], 0))

            a = tf.expand_dims(association_head, 1)
            b = tf.expand_dims(association_head, 2)

            features = tf.concat(
                [
                    a - b,
                    a * b,
                    tf.tile(a, [1, length, 1, 1]),
                    tf.tile(b, [1, 1, length, 1]),
                    # TODO: Think about using prediction as a feature for associations.
                ],
                axis=-1)
            associations_flat = tf.layers.dense(
                tf.reshape(features, shape=[-1, nx * 4]), num_associations)
            associations = tf.reshape(associations_flat,
                                      [-1, length, length, num_associations])

            return logits, associations_flat, associations

        with tf.variable_scope('seq_lab_attn'):
            if config.low_memory_mode and train:
                seq_lab_internal = recompute_grad(seq_lab_internal,
                                                  use_entire_scope=True)

            logits, associations_flat, associations = seq_lab_internal(hidden)

        log_likelihood = 0.0
        association_loss = 0.0
        class_weights = kwargs.get('class_weights')
        if class_weights is not None:
            logits = class_reweighting(class_weights)(logits)

        transition_params = tf.get_variable("Transition_matrix",
                                            shape=[n_targets, n_targets])
        if train and targets is not None:
            log_likelihood, _ = crf_log_likelihood(
                logits,
                targets["labels"],
                kwargs.get('max_length') *
                tf.ones(tf.shape(targets["labels"])[0]),
                transition_params=transition_params)
            sequence_mask = tf.sequence_mask(pool_idx + 1,
                                             maxlen=length,
                                             dtype=tf.float32)
            mask = tf.expand_dims(sequence_mask, 1) * tf.expand_dims(
                sequence_mask, 2)

            association_loss = tf.losses.sparse_softmax_cross_entropy(
                logits=associations_flat,
                labels=tf.reshape(targets["associations"], shape=[-1]),
                weights=tf.reshape(mask, shape=[-1]))

        return {
            'logits': {
                "sequence": logits,
                "association": associations
            },
            'losses': -log_likelihood + config.assocation_loss_weight *
            association_loss,  # TODO: think about weighting.
            'predict_params': {
                'transition_matrix': transition_params
            }
        }
Ejemplo n.º 6
0
def sequence_labeler(hidden,
                     targets,
                     n_targets,
                     config,
                     pad_id,
                     multilabel=False,
                     train=False,
                     reuse=None,
                     pool_idx=None,
                     **kwargs):
    """
    An Attention based sequence labeler model.

    In the case of unidirectional base models such as GPT this model takes the output of the pre-trained model,
    applies an additional randomly initialised multihead attention block, with residuals on top.
    The extra attention is not future masked to allow the model to label sequences based on context in both directions.
    The representations fed into this model are necessarily future masked because a language modelling loss is the
    original objective of the featurizer.

    For bidirectional base models we apply the crf model directly to the output of the base model.

    :param hidden: The output of the featurizer. [batch_size, sequence_length, embed_dim]
    :param targets: The placeholder representing the sequence labeling targets. [batch_size, sequence_length]
    :param n_targets: A python int containing the number of classes that the model should be learning to predict over.
    :param dropout_placeholder:
    :param config: A config object, containing all parameters for the featurizer.
    :param train: If this flag is true, dropout and losses are added to the graph.
    :param reuse: Should reuse be set within this scope.
    :param kwargs: Spare arguments.
    :return: dict containing:
        "logits": The un-normalised log probabilities of each class being in each location. For usable predictions,
            sampling from this distribution is not sufficient and a viterbi decoding method should be used.
        "losses": The negative log likelihood for the sequence targets.
        "predict_params": A dictionary of params to be fed to the viterbi decode function.
    """
    with tf.variable_scope('sequence-labeler', reuse=reuse):
        nx = config.n_embed

        def seq_lab_internal(hidden):
            if config.base_model.is_bidirectional:
                n = hidden
            else:
                attn_fn = functools.partial(attn,
                                            scope="seq_label_attn",
                                            n_state=nx,
                                            n_head=config.seq_num_heads,
                                            resid_pdrop=config.resid_p_drop,
                                            attn_pdrop=config.attn_p_drop,
                                            train=train,
                                            scale=False,
                                            mask=False)
                n = norm(attn_fn(hidden) + hidden, 'seq_label_residual')
            flat_logits = tf.layers.dense(n, n_targets)
            logits = tf.reshape(
                flat_logits, tf.concat([tf.shape(hidden)[:2], [n_targets]], 0))
            return logits

        with tf.variable_scope('seq_lab_attn'):
            if config.low_memory_mode and train:
                seq_lab_internal = recompute_grad(seq_lab_internal,
                                                  use_entire_scope=True)
            logits = seq_lab_internal(hidden)

        class_weights = kwargs.get('class_weights')
        if class_weights is not None and train:
            class_weights = tf.reshape(class_weights, [1, 1, -1])
            one_hot_class_weights = class_weights * tf.one_hot(targets,
                                                               depth=n_targets)
            per_token_weights = tf.reduce_sum(one_hot_class_weights,
                                              axis=-1,
                                              keep_dims=True)
            logits = class_reweighting(per_token_weights)(logits)

        log_likelihood = 0.0

        default_lengths = kwargs.get('max_length') * tf.ones(
            tf.shape(hidden)[0], dtype=tf.int32)
        if pool_idx is None:
            pool_idx = default_lengths
        else:
            pool_idx = tf.where(tf.equal(pool_idx, 0), default_lengths,
                                tf.cast(pool_idx, dtype=tf.int32))

        with tf.device("CPU:0"):
            if multilabel:
                transition_params = []
                logits_individual = tf.unstack(logits, n_targets, axis=-1)
                if targets is not None:
                    targets_individual = tf.unstack(targets,
                                                    n_targets,
                                                    axis=-1)
                logits = []
                for i in range(n_targets):
                    transition_params.append(
                        tf.get_variable("Transition_matrix_{}".format(i),
                                        shape=[2, 2]))
                    logits.append(
                        tf.stack(
                            (logits_individual[pad_id], logits_individual[i]),
                            axis=-1))
                    if targets is not None and train and i != pad_id:
                        log_likelihood += crf_log_likelihood(
                            logits[-1],
                            targets_individual[i],
                            pool_idx,
                            transition_params=transition_params[-1])[0]
                logits = tf.stack(logits, axis=-1)
            else:
                transition_params = tf.get_variable(
                    "Transition_matrix", shape=[n_targets, n_targets])
                if train and targets is not None:
                    log_likelihood, _ = crf_log_likelihood(
                        logits,
                        targets,
                        pool_idx,
                        transition_params=transition_params)

        return {
            'logits': logits,
            'losses': -log_likelihood,
            'predict_params': {
                'transition_matrix': transition_params
            }
        }
Ejemplo n.º 7
0
def gpt_featurizer(X, encoder, config, train=False, reuse=None):
    """
    The transformer element of the finetuning model. Maps from tokens ids to a dense, embedding of the sequence.

    :param X: A tensor of token indexes with shape [batch_size, sequence_length, token_idx]
    :param encoder: A TextEncoder object.
    :param config: A config object, containing all parameters for the featurizer.
    :param train: If this flag is true, dropout and losses are added to the graph.
    :param reuse: Should reuse be set within this scope.
    :return: A dict containing;
        embed_weights: the word embedding matrix.
        features: The output of the featurizer_final state.
        sequence_features: The output of the featurizer at each timestep.
    """
    initial_shape = tf.shape(X)
    X = tf.reshape(X, shape=tf.concat(([-1], initial_shape[-2:]), 0))

    with tf.variable_scope('model/featurizer', reuse=reuse):
        embed_weights = tf.get_variable(
            name="we",
            shape=[encoder.vocab_size + config.max_length, config.n_embed],
            initializer=tf.random_normal_initializer(stddev=config.weight_stddev)
        )
        if config.train_embeddings:
            embed_weights = dropout(embed_weights, config.embed_p_drop, train)
        else:
            embed_weights = tf.stop_gradient(embed_weights)

        X = tf.reshape(X, [-1, config.max_length, 2])

        h = embed(X, embed_weights)
        for layer in range(config.n_layer):
            if (config.n_layer - layer) == config.num_layers_trained and config.num_layers_trained != config.n_layer:
                h = tf.stop_gradient(h)
                train_layer = False
            else:
                train_layer = train

            with tf.variable_scope('h%d_' % layer):
                block_fn = functools.partial(block, n_head=config.n_heads, act_fn=config.act_fn,
                                             resid_pdrop=config.resid_p_drop, attn_pdrop=config.attn_p_drop,
                                             scope='h%d' % layer, train=train_layer, scale=True)
                if config.low_memory_mode and train_layer:
                    block_fn = recompute_grad(block_fn, use_entire_scope=True)
                if layer < config.n_layer - 1:
                    h = block_fn(h)
                else:
                    h_out = block_fn(h)

            # get the attention weights from the last layer
            if layer == config.n_layer - 1:
                with tf.variable_scope('h%d_/h%d/attn' % (layer, layer), reuse=True):
                    q, k, v = multihead_qkv(h, n_state=shape_list(h)[-1], n_head=config.n_heads, train=train)
                    w = attn_weights(q, k, v, attn_pdrop=config.attn_p_drop, train=train_layer, scale=True)


        # Use hidden state at classifier token as input to final proj. + softmax
        clf_h = tf.reshape(h_out, [-1, config.n_embed])  # [batch * seq_len, embed]
        clf_token = encoder['_classify_']
        pool_idx = tf.cast(tf.argmax(tf.cast(tf.equal(X[:, :, 0], clf_token), tf.float32), 1), tf.int32)
        clf_h = tf.gather(clf_h, tf.range(shape_list(X)[0], dtype=tf.int32) * config.max_length + pool_idx)
        clf_h = tf.reshape(clf_h, shape=tf.concat((initial_shape[: -2], [config.n_embed]), 0))
        seq_feats = tf.reshape(h, shape=tf.concat((initial_shape[:-1], [config.n_embed]), 0))

        return {
            'embed_weights': embed_weights,
            'features': clf_h,
            'sequence_features': seq_feats,
            'pool_idx': pool_idx,
            'attention_weights': w  # [n_heads, seq_len, seq_len]
        }