Exemplo n.º 1
0
def bplayer_test1():
    with tf.variable_scope('b') as b_scope:
        w1 = tf.get_variable('w1', shape=[1], initializer=tf.ones_initializer)
        bplayer_b = BPLayer(w1, b_scope)

    with tf.variable_scope('c') as c_scope:
        w2 = tf.get_variable('w2', shape=[1], initializer=tf.ones_initializer)
        y = w1 + w2
        bplayer = BPLayer(y, c_scope, [bplayer_b])

    trainable_vars_b = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                         b_scope.name)
    trainable_vars_c = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                         c_scope.name)
    print(trainable_vars_b)
    print(trainable_vars_c)

    with tf.variable_scope("bplayer_network") as bp_scope:
        opt_bp = tf.train.GradientDescentOptimizer(0.01)
        with tf.variable_scope("backward_gradients"):
            grads_vals_bp = bplayer.backward_gradients()
        train_op_bp = opt_bp.apply_gradients(grads_vals_bp)
        grads_bp, vals_bp, grad_names_bp, val_names_bp = separate_grads_vals(
            grads_vals_bp)
        print(val_names_bp)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        result_origin = sess.run([train_op_bp] + grads_bp + vals_bp)
        print(result_origin[1:])
Exemplo n.º 2
0
def embedding_lookup(input_ids,
                     vocab_size,
                     embedding_size=128,
                     initializer_range=0.02,
                     word_embedding_name="word_embeddings",
                     use_one_hot_embeddings=False,
                     previor_bplayer=None,
                     name_or_scope=None):
    """
    Looks up words embeddings for id tensor.
    Args:
        input_ids: int32 Tensor of shape [batch_size, seq_length] containing word
          ids.
        vocab_size: int. Size of the embedding vocabulary.
        embedding_size: int. Width of the word embeddings.
        initializer_range: float. Embedding initialization range.
        word_embedding_name: string. Name of the embedding table.
        use_one_hot_embeddings: bool. If True, use one-hot method for word
          embeddings. If False, use `tf.gather()`.
        previor_bplayer: previor_bplayer
        name_or_scope: name_or_scope
    Returns:
        float Tensor of shape [batch_size, seq_length, embedding_size].
    """
    # This function assumes that the input is of shape [batch_size, seq_length,
    # num_inputs].
    #
    # If the input is a 2D tensor of shape [batch_size, seq_length], we
    # reshape to [batch_size, seq_length, 1].
    with tf.variable_scope(name_or_scope, default_name='embedding_lookup'):

        with tf.variable_scope("table") as table_scope:
            embedding_table = tf.get_variable(
                name=word_embedding_name,
                shape=[vocab_size, embedding_size],
                initializer=utils.create_initializer(initializer_range))
            bplayer_table = BPLayer(embedding_table, table_scope,
                                    [previor_bplayer])

        with tf.variable_scope("lookup") as lookup_scope:
            if input_ids.shape.ndims == 2:
                input_ids = tf.expand_dims(input_ids, axis=[-1])
            flat_input_ids = tf.reshape(input_ids, [-1])
            if use_one_hot_embeddings:
                one_hot_input_ids = tf.one_hot(flat_input_ids,
                                               depth=vocab_size)
                output = tf.matmul(one_hot_input_ids, embedding_table)
            else:
                output = tf.gather(embedding_table, flat_input_ids)
            input_shape = utils.get_shape_list(input_ids)
            output = tf.reshape(
                output, input_shape[0:-1] + [input_shape[-1] * embedding_size])
            bplayer_output = BPLayer(output, lookup_scope, [bplayer_table])
        return (output, bplayer_output, embedding_table, bplayer_table)
Exemplo n.º 3
0
 def network(inputs, targets):
     with tf.variable_scope("network") as scope:
         with tf.variable_scope("w") as scope_w:
             w = tf.get_variable("w", [1, 2],
                                 dtype=tf.float32,
                                 initializer=tf.ones_initializer)
             bplayer_w = BPLayer(w, scope_w)
         with tf.variable_scope("add") as scope_calc:
             b = tf.get_variable("b", [1],
                                 dtype=tf.float32,
                                 initializer=tf.zeros_initializer)
             y = tf.matmul(w, tf.expand_dims(inputs, axis=-1)) + b
             y = tf.squeeze(y, axis=-1)
             loss = tf.losses.mean_squared_error(targets, y)
             loss = tf.reduce_mean(loss)
             bplayer = BPLayer(loss, scope_calc, [bplayer_w])
         return loss, bplayer, [w, b], y
Exemplo n.º 4
0
 def network(input, labels, num_layers):
     with tf.variable_scope("network") as network_scope:
         with tf.variable_scope("prepare") as prepare_scope:
             prev_output = input
             prev_bplayer = BPLayer(prev_output, prepare_scope)
         for i in range(num_layers):
             with tf.variable_scope("layer_%s" % i) as layernum_scope:
                 prev_output, bplayer = layer_revbp(prev_output)
                 bplayer.add_backward_layer(prev_bplayer)
                 prev_bplayer = bplayer
         with tf.variable_scope("loss") as loss_scope:
             y = tf.layers.dense(prev_output,
                                 units=1,
                                 kernel_initializer=tf.ones_initializer)
             y = tf.squeeze(y, axis=-1)
             loss = tf.losses.mean_squared_error(labels, y)
             bplayer = BPLayer(loss, loss_scope)
             bplayer.add_backward_layer(prev_bplayer)
         var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                      network_scope.name)
     return loss, bplayer, var_list, y
Exemplo n.º 5
0
def get_masked_lm_output(bert_config, input_tensor, output_weights, positions,
                         label_ids, label_weights, prev_bplayers=None):
    """Get loss and log probs for the masked LM."""
    input_tensor = gather_indexes(input_tensor, positions)

    with tf.variable_scope("cls/predictions") as prediction_scope:
        # We apply one more non-linear transformation before the output layer.
        # This matrix is not used after pre-training.
        with tf.variable_scope("transform"):
            input_tensor = tf.layers.dense(
                input_tensor,
                units=bert_config.hidden_size,
                activation=utils.get_activation(bert_config.hidden_act),
                kernel_initializer=utils.create_initializer(
                    bert_config.initializer_range))
            input_tensor = utils.layer_norm(input_tensor)

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        output_bias = tf.get_variable(
            "output_bias",
            shape=[bert_config.vocab_size],
            initializer=tf.zeros_initializer())
        logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
        logits = tf.nn.bias_add(logits, output_bias)
        log_probs = tf.nn.log_softmax(logits, axis=-1)

        label_ids = tf.reshape(label_ids, [-1])
        label_weights = tf.reshape(label_weights, [-1])

        one_hot_labels = tf.one_hot(
            label_ids, depth=bert_config.vocab_size, dtype=tf.float32)

        # The `positions` tensor might be zero-padded (if the sequence is too
        # short to have the maximum number of predictions). The `label_weights`
        # tensor has a value of 1.0 for every real prediction and 0.0 for the
        # padding predictions.
        per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])
        numerator = tf.reduce_sum(label_weights * per_example_loss)
        denominator = tf.reduce_sum(label_weights) + 1e-5
        loss = numerator / denominator
        loss_bplayer = BPLayer(loss, prediction_scope, backward_layers=prev_bplayers)

    return (loss, per_example_loss, log_probs, loss_bplayer)
Exemplo n.º 6
0
def get_next_sentence_output(bert_config, input_tensor, labels, prev_bplayers=None):
    """Get loss and log probs for the next sentence prediction."""

    # Simple binary classification. Note that 0 is "next sentence" and 1 is
    # "random sentence". This weight matrix is not used after pre-training.
    with tf.variable_scope("cls/seq_relationship") as scope:
        output_weights = tf.get_variable(
            "output_weights",
            shape=[2, bert_config.hidden_size],
            initializer=utils.create_initializer(bert_config.initializer_range))
        output_bias = tf.get_variable(
            "output_bias", shape=[2], initializer=tf.zeros_initializer())

        logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
        logits = tf.nn.bias_add(logits, output_bias)
        log_probs = tf.nn.log_softmax(logits, axis=-1)
        labels = tf.reshape(labels, [-1])
        one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32)
        per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
        loss = tf.reduce_mean(per_example_loss)
        loss_bplayer = BPLayer(loss, scope, backward_layers=prev_bplayers)
        return (loss, per_example_loss, log_probs, loss_bplayer)
Exemplo n.º 7
0
def embedding_postprocessor(input_tensor,
                            use_token_type=False,
                            token_type_ids=None,
                            token_type_vocab_size=16,
                            token_type_embedding_name="token_type_embeddings",
                            use_position_embeddings=True,
                            position_embedding_name="position_embeddings",
                            initializer_range=0.02,
                            max_position_embeddings=512,
                            dropout_prob=0.1,
                            previor_bplayer=None,
                            name_or_scope=None):
    """
    Performs various post-processing on a word embedding tensor.
    Args:
      input_tensor: float Tensor of shape [batch_size, seq_length,
        embedding_size].
      use_token_type: bool. Whether to add embeddings for `token_type_ids`.
      token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
        Must be specified if `use_token_type` is True.
      token_type_vocab_size: int. The vocabulary size of `token_type_ids`.
      token_type_embedding_name: string. The name of the embedding table variable
        for token type ids.
      use_position_embeddings: bool. Whether to add position embeddings for the
        position of each token in the sequence.
      position_embedding_name: string. The name of the embedding table variable
        for positional embeddings.
      initializer_range: float. Range of the weight initialization.
      max_position_embeddings: int. Maximum sequence length that might ever be
        used with this model. This can be longer than the sequence length of
        input_tensor, but cannot be shorter.
      dropout_prob: float. Dropout probability applied to the final output tensor.
      previor_bplayer: bplayer.
      name_or_scope: string or scope.
    Returns:
      float tensor with same shape as `input_tensor`.
    Raises:
      ValueError: One of the tensor shapes or input values is invalid.
    """
    with tf.variable_scope(name_or_scope,
                           default_name='embedding_postprocessor') as scope:
        input_shape = utils.get_shape_list(input_tensor, expected_rank=3)
        batch_size = input_shape[0]
        seq_length = input_shape[1]
        width = input_shape[2]

        output = input_tensor

        if use_token_type:
            if token_type_ids is None:
                raise ValueError("`token_type_ids` must be specified if"
                                 "`use_token_type` is True.")
            token_type_table = tf.get_variable(
                name=token_type_embedding_name,
                shape=[token_type_vocab_size, width],
                initializer=utils.create_initializer(initializer_range))
            # This vocab will be small so we always do one-hot here, since it is always
            # faster for a small vocabulary.
            flat_token_type_ids = tf.reshape(token_type_ids, [-1])
            one_hot_ids = tf.one_hot(flat_token_type_ids,
                                     depth=token_type_vocab_size)
            token_type_embeddings = tf.matmul(one_hot_ids, token_type_table)
            token_type_embeddings = tf.reshape(token_type_embeddings,
                                               [batch_size, seq_length, width])
            output += token_type_embeddings

        if use_position_embeddings:
            assert_op = tf.assert_less_equal(seq_length,
                                             max_position_embeddings)
            with tf.control_dependencies([assert_op]):
                full_position_embeddings = tf.get_variable(
                    name=position_embedding_name,
                    shape=[max_position_embeddings, width],
                    initializer=utils.create_initializer(initializer_range))
                # Since the position embedding table is a learned variable, we create it
                # using a (long) sequence length `max_position_embeddings`. The actual
                # sequence length might be shorter than this, for faster training of
                # tasks that do not have long sequences.
                #
                # So `full_position_embeddings` is effectively an embedding table
                # for position [0, 1, 2, ..., max_position_embeddings-1], and the current
                # sequence has positions [0, 1, 2, ... seq_length-1], so we can just
                # perform a slice.
                position_embeddings = tf.slice(full_position_embeddings,
                                               [0, 0], [seq_length, -1])
                num_dims = len(output.shape.as_list())

                # Only the last two dimensions are relevant (`seq_length` and `width`), so
                # we broadcast among the first dimensions, which is typically just
                # the batch size.
                position_broadcast_shape = []
                for _ in range(num_dims - 2):
                    position_broadcast_shape.append(1)
                position_broadcast_shape.extend([seq_length, width])
                position_embeddings = tf.reshape(position_embeddings,
                                                 position_broadcast_shape)
                output += position_embeddings

        output = utils.layer_norm_and_dropout(output, dropout_prob)
        bplayer_output = BPLayer(output, scope, [previor_bplayer])
    return output, bplayer_output
Exemplo n.º 8
0
def transformer_model(input_tensor,
                      input_bplayer,
                      attention_mask=None,
                      hidden_size=768,
                      num_hidden_layers=12,
                      num_attention_heads=12,
                      intermediate_size=3072,
                      intermediate_act_fn=utils.gelu,
                      hidden_dropout_prob=0.1,
                      attention_probs_dropout_prob=0.1,
                      initializer_range=0.02,
                      do_return_all_layers=False):
    """
    Multi-headed, multi-layer Transformer from "Attention is All You Need".
    This is almost an exact implementation of the original Transformer encoder.
    Add revnet.
    See the original paper:
    https://arxiv.org/abs/1706.03762
    Also see:
    https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py
    Args:
      input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size].
      input_bplayer: input bplayer
      attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length,
        seq_length], with 1 for positions that can be attended to and 0 in
        positions that should not be.
      hidden_size: int. Hidden size of the Transformer.
      num_hidden_layers: int. Number of layers (blocks) in the Transformer.
      num_attention_heads: int. Number of attention heads in the Transformer.
      intermediate_size: int. The size of the "intermediate" (a.k.a., feed
        forward) layer.
      intermediate_act_fn: function. The non-linear activation function to apply
        to the output of the intermediate/feed-forward layer.
      hidden_dropout_prob: float. Dropout probability for the hidden layers.
      attention_probs_dropout_prob: float. Dropout probability of the attention
        probabilities.
      initializer_range: float. Range of the initializer (stddev of truncated
        normal).
      do_return_all_layers: Whether to also return all layers or just the final
        layer.
    Returns:
        Tuple
        float Tensor of shape [batch_size, seq_length, hidden_size], the final
        hidden layer of the Transformer
        and bplayer
    Raises:
      ValueError: A Tensor shape or parameter is invalid.
    """
    if hidden_size % num_attention_heads != 0:
        raise ValueError(
            "The hidden size (%d) is not a multiple of the number of attention "
            "heads (%d)" % (hidden_size, num_attention_heads))
    with tf.variable_scope("transfomer_model"):
        with tf.variable_scope("prepare") as prepare_scope:
            attention_head_size = int(hidden_size / num_attention_heads)
            input_shape = utils.get_shape_list(input_tensor, expected_rank=3)
            batch_size = input_shape[0]
            seq_length = input_shape[1]
            input_width = input_shape[2]
            # The Transformer performs sum residuals on all layers so the input needs
            # to be the same as the hidden size.
            if input_width != hidden_size:
                raise ValueError(
                    "The width of the input tensor (%d) != hidden size (%d)" %
                    (input_width, hidden_size))
            # We keep the representation as a 2D tensor to avoid re-shaping it back and
            # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on
            # the GPU/CPU but may not be free on the TPU, so we want to minimize them to
            # help the optimizer.
            prev_output = utils.reshape_to_matrix(
                input_tensor)  # [batch_size * seq_length, input_width]
            prev_bplayer = BPLayer(prev_output, prepare_scope, [input_bplayer])
        all_layer_outputs = []
        all_layer_bplayers = []
        for layer_idx in range(num_hidden_layers):
            with tf.variable_scope("layer_%d" % layer_idx):
                layer_input = prev_output
                layer_output, bplayer = rev_transformer_layer(
                    layer_input, [prev_bplayer], batch_size, seq_length,
                    attention_head_size, attention_mask, num_attention_heads,
                    intermediate_size, intermediate_act_fn,
                    hidden_dropout_prob, attention_probs_dropout_prob,
                    initializer_range)
                all_layer_outputs.append(layer_output)
                prev_output = layer_output
                all_layer_bplayers.append(bplayer)
                prev_bplayer = bplayer
        with tf.variable_scope("output") as output_scope:
            if do_return_all_layers:
                if len(all_layer_outputs) != len(all_layer_bplayers):
                    raise Exception(
                        "transformer model: the num of all layer outputs is not equal to"
                        "the num of all layer bplayers")
                final_outputs = []
                for i in range(len(all_layer_outputs)):
                    final_output = utils.reshape_from_matrix(
                        all_layer_outputs[i], input_shape)
                    final_bplayer = BPLayer(final_output, output_scope,
                                            [all_layer_bplayers[i]])
                    final_outputs.append((final_output, final_bplayer))
                return final_outputs
            else:
                final_output = utils.reshape_from_matrix(
                    prev_output, input_shape)
                final_bplayer = BPLayer(final_output, output_scope,
                                        [prev_bplayer])
                return (final_output, final_bplayer)
Exemplo n.º 9
0
    def __init__(self,
                 config,
                 is_training,
                 input_ids,
                 input_mask=None,
                 token_type_ids=None,
                 use_one_hot_embeddings=False,
                 scope=None):
        """Constructor for BertModel.

        Args:
          config: `BertConfig` instance.
          is_training: bool. true for training model, false for eval model. Controls
            whether dropout will be applied.
          input_ids: int32 Tensor of shape [batch_size, seq_length].
          input_mask: (optional) int32 Tensor of shape [batch_size, seq_length].
          token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
          use_one_hot_embeddings: (optional) bool. Whether to use one-hot word
            embeddings or tf.embedding_lookup() for the word embeddings.
          scope: (optional) variable scope. Defaults to "revbert".

        Raises:
          ValueError: The config is invalid or one of the input tensor shapes
            is invalid.
        """
        config = copy.deepcopy(config)
        if not is_training:
            config.hidden_dropout_prob = 0.0
            config.attention_probs_dropout_prob = 0.0

        input_shape = utils.get_shape_list(input_ids, expected_rank=2)
        batch_size = input_shape[0]
        seq_length = input_shape[1]

        if input_mask is None:
            input_mask = tf.ones(shape=[batch_size, seq_length],
                                 dtype=tf.int32)

        if token_type_ids is None:
            token_type_ids = tf.zeros(shape=[batch_size, seq_length],
                                      dtype=tf.int32)

        with tf.variable_scope(scope, default_name="revbert"):
            with tf.variable_scope("embeddings"):
                # Perform embedding lookup on the word ids.
                (self.embedding_output, self.embedding_output_bplayer,
                 self.embedding_table,
                 self.embedding_table_bplayer) = layers.embedding_lookup(
                     input_ids=input_ids,
                     vocab_size=config.vocab_size,
                     embedding_size=config.hidden_size,
                     initializer_range=config.initializer_range,
                     word_embedding_name="word_embeddings",
                     use_one_hot_embeddings=use_one_hot_embeddings)

                # Add positional embeddings and token type embeddings, then layer
                # normalize and perform dropout.
                self.embedding_output, self.embedding_output_bplayer = layers.embedding_postprocessor(
                    input_tensor=self.embedding_output,
                    use_token_type=True,
                    token_type_ids=token_type_ids,
                    token_type_vocab_size=config.type_vocab_size,
                    token_type_embedding_name="token_type_embeddings",
                    use_position_embeddings=True,
                    position_embedding_name="position_embeddings",
                    initializer_range=config.initializer_range,
                    max_position_embeddings=config.max_position_embeddings,
                    dropout_prob=config.hidden_dropout_prob,
                    previor_bplayer=self.embedding_output_bplayer)

            with tf.variable_scope("encoder"):
                # This converts a 2D mask of shape [batch_size, seq_length] to a 3D
                # mask of shape [batch_size, seq_length, seq_length] which is used
                # for the attention scores.
                attention_mask = utils.create_attention_mask_from_input_mask(
                    input_ids, input_mask)

                # Run the stacked transformer.
                # `sequence_output` shape = [batch_size, seq_length, hidden_size].
                self.all_encoder_layers = transformer_model(
                    input_tensor=self.embedding_output,
                    input_bplayer=self.embedding_output_bplayer,
                    attention_mask=attention_mask,
                    hidden_size=config.hidden_size,
                    num_hidden_layers=config.num_hidden_layers,
                    num_attention_heads=config.num_attention_heads,
                    intermediate_size=config.intermediate_size,
                    intermediate_act_fn=utils.get_activation(
                        config.hidden_act),
                    hidden_dropout_prob=config.hidden_dropout_prob,
                    attention_probs_dropout_prob=config.
                    attention_probs_dropout_prob,
                    initializer_range=config.initializer_range,
                    do_return_all_layers=True)

            self.sequence_output, self.sequence_output_bplayer = self.all_encoder_layers[
                -1]
            # The "pooler" converts the encoded sequence tensor of shape
            # [batch_size, seq_length, hidden_size] to a tensor of shape
            # [batch_size, hidden_size]. This is necessary for segment-level
            # (or segment-pair-level) classification tasks where we need a fixed
            # dimensional representation of the segment.
            with tf.variable_scope("pooler") as pooler_scope:
                # We "pool" the model by simply taking the hidden state corresponding
                # to the first token. We assume that this has been pre-trained
                first_token_tensor = tf.squeeze(self.sequence_output[:,
                                                                     0:1, :],
                                                axis=1)
                self.pooled_output = tf.layers.dense(
                    first_token_tensor,
                    config.hidden_size,
                    activation=tf.tanh,
                    kernel_initializer=utils.create_initializer(
                        config.initializer_range))
                self.pooled_output_bplayer = BPLayer(
                    self.pooled_output, pooler_scope,
                    [self.sequence_output_bplayer])
Exemplo n.º 10
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" % (name, features[name].shape))

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        masked_lm_positions = features["masked_lm_positions"]
        masked_lm_ids = features["masked_lm_ids"]
        masked_lm_weights = features["masked_lm_weights"]
        next_sentence_labels = features["next_sentence_labels"]

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        model = revbert.RevBert(
            config=bert_config,
            is_training=is_training,
            input_ids=input_ids,
            input_mask=input_mask,
            token_type_ids=segment_ids,
            use_one_hot_embeddings=use_one_hot_embeddings)
        sequence_output, sequence_output_bplayer = model.get_sequence_output()
        pooled_output, pooled_output_bplayer = model.get_pooled_output()
        embedding_table, embedding_table_bplayer = model.get_embedding_table()

        (masked_lm_loss, masked_lm_example_loss,
         masked_lm_log_probs, masked_lm_loss_bplayer) = get_masked_lm_output(
            bert_config, sequence_output, embedding_table,
            masked_lm_positions, masked_lm_ids, masked_lm_weights, [sequence_output_bplayer, embedding_table_bplayer])

        (next_sentence_loss, next_sentence_example_loss,
         next_sentence_log_probs, next_sentence_loss_bplayer) = get_next_sentence_output(
            bert_config, pooled_output, next_sentence_labels, [pooled_output_bplayer])

        with tf.variable_scope("total_loss") as total_loss_scope:
            if use_random_next:
                total_loss = masked_lm_loss + next_sentence_loss
                total_loss_bplayer = BPLayer(total_loss, total_loss_scope,
                                             [sequence_output_bplayer, pooled_output_bplayer])
            else:
                total_loss = masked_lm_loss
                total_loss_bplayer = BPLayer(total_loss, total_loss_scope,
                                             [sequence_output_bplayer])

        tvars = tf.trainable_variables()

        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (assignment_map, initialized_variable_names
             ) = utils.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_op = optimization.create_optimizer_bplayer(
                total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu, total_loss_bplayer)

            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
                          masked_lm_weights, next_sentence_example_loss,
                          next_sentence_log_probs, next_sentence_labels):
                """Computes the loss and accuracy of the model."""
                masked_lm_log_probs = tf.reshape(masked_lm_log_probs,
                                                 [-1, masked_lm_log_probs.shape[-1]])
                masked_lm_predictions = tf.argmax(
                    masked_lm_log_probs, axis=-1, output_type=tf.int64)
                masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1])
                masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
                masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
                masked_lm_accuracy = tf.metrics.accuracy(
                    labels=masked_lm_ids,
                    predictions=masked_lm_predictions,
                    weights=masked_lm_weights)
                masked_lm_mean_loss = tf.metrics.mean(
                    values=masked_lm_example_loss, weights=masked_lm_weights)

                next_sentence_log_probs = tf.reshape(
                    next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]])
                next_sentence_predictions = tf.argmax(
                    next_sentence_log_probs, axis=-1, output_type=tf.int64)
                next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
                next_sentence_accuracy = tf.metrics.accuracy(
                    labels=next_sentence_labels, predictions=next_sentence_predictions)
                next_sentence_mean_loss = tf.metrics.mean(
                    values=next_sentence_example_loss)

                return {
                    "masked_lm_accuracy": masked_lm_accuracy,
                    "masked_lm_loss": masked_lm_mean_loss,
                    "next_sentence_accuracy": next_sentence_accuracy,
                    "next_sentence_loss": next_sentence_mean_loss,
                }

            eval_metrics = (metric_fn, [
                masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
                masked_lm_weights, next_sentence_example_loss,
                next_sentence_log_probs, next_sentence_labels
            ])
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)
        else:
            raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode))

        return output_spec