예제 #1
0
    def _get_masked_lm_output(self, inputs: pretrain_data.Inputs, model):
        """Masked language modeling softmax layer."""
        masked_lm_weights = inputs.masked_lm_weights
        with tf.variable_scope("generator_predictions"):
            if self._config.uniform_generator:
                logits = tf.zeros(self._bert_config.vocab_size)
                logits_tiled = tf.zeros(
                    modeling.get_shape_list(inputs.masked_lm_ids) +
                    [self._bert_config.vocab_size])
                logits_tiled += tf.reshape(
                    logits, [1, 1, self._bert_config.vocab_size])
                logits = logits_tiled
            else:
                relevant_hidden = pretrain_helpers.gather_positions(
                    model.get_sequence_output(), inputs.masked_lm_positions)
                hidden = tf.layers.dense(
                    relevant_hidden,
                    units=modeling.get_shape_list(
                        model.get_embedding_table())[-1],
                    activation=modeling.get_activation(
                        self._bert_config.hidden_act),
                    kernel_initializer=modeling.create_initializer(
                        self._bert_config.initializer_range))
                hidden = modeling.layer_norm(hidden)
                output_bias = tf.get_variable(
                    "output_bias",
                    shape=[self._bert_config.vocab_size],
                    initializer=tf.zeros_initializer())
                logits = tf.matmul(hidden,
                                   model.get_embedding_table(),
                                   transpose_b=True)
                logits = tf.nn.bias_add(logits, output_bias)

            oh_labels = tf.one_hot(inputs.masked_lm_ids,
                                   depth=self._bert_config.vocab_size,
                                   dtype=tf.float32)

            probs = tf.nn.softmax(logits)
            log_probs = tf.nn.log_softmax(logits)
            label_log_probs = -tf.reduce_sum(log_probs * oh_labels, axis=-1)

            numerator = tf.reduce_sum(inputs.masked_lm_weights *
                                      label_log_probs)
            denominator = tf.reduce_sum(masked_lm_weights) + 1e-6
            loss = numerator / denominator
            preds = tf.argmax(log_probs, axis=-1, output_type=tf.int32)

            MLMOutput = collections.namedtuple(
                "MLMOutput",
                ["logits", "probs", "loss", "per_example_loss", "preds"])
            return MLMOutput(logits=logits,
                             probs=probs,
                             per_example_loss=label_log_probs,
                             loss=loss,
                             preds=preds)
예제 #2
0
def scatter_update(sequence, updates, positions):
    """Scatter-update a sequence.

  Args:
    sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth] tensor
    updates: A tensor of size batch_size*seq_len(*depth)
    positions: A [batch_size, n_positions] tensor

  Returns: A tuple of two tensors. First is a [batch_size, seq_len] or
    [batch_size, seq_len, depth] tensor of "sequence" with elements at
    "positions" replaced by the values at "updates." Updates to index 0 are
    ignored. If there are duplicated positions the update is only applied once.
    Second is a [batch_size, seq_len] mask tensor of which inputs were updated.
  """
    shape = modeling.get_shape_list(sequence, expected_rank=[2, 3])
    depth_dimension = (len(shape) == 3)
    if depth_dimension:
        B, L, D = shape
    else:
        B, L = shape
        D = 1
        sequence = tf.expand_dims(sequence, -1)
    N = modeling.get_shape_list(positions)[1]

    shift = tf.expand_dims(L * tf.range(B), -1)
    flat_positions = tf.reshape(positions + shift, [-1, 1])
    flat_updates = tf.reshape(updates, [-1, D])
    updates = tf.scatter_nd(flat_positions, flat_updates, [B * L, D])
    updates = tf.reshape(updates, [B, L, D])

    flat_updates_mask = tf.ones([B * N], tf.int32)
    updates_mask = tf.scatter_nd(flat_positions, flat_updates_mask, [B * L])
    updates_mask = tf.reshape(updates_mask, [B, L])
    not_first_token = tf.concat(
        [tf.zeros((B, 1), tf.int32),
         tf.ones((B, L - 1), tf.int32)], -1)
    updates_mask *= not_first_token
    updates_mask_3d = tf.expand_dims(updates_mask, -1)

    # account for duplicate positions
    if sequence.dtype == tf.float32:
        updates_mask_3d = tf.cast(updates_mask_3d, tf.float32)
        updates /= tf.maximum(1.0, updates_mask_3d)
    else:
        assert sequence.dtype == tf.int32
        updates = tf.math.floordiv(updates, tf.maximum(1, updates_mask_3d))
    updates_mask = tf.minimum(updates_mask, 1)
    updates_mask_3d = tf.minimum(updates_mask_3d, 1)

    updated_sequence = (((1 - updates_mask_3d) * sequence) +
                        (updates_mask_3d * updates))
    if not depth_dimension:
        updated_sequence = tf.squeeze(updated_sequence, -1)

    return updated_sequence, updates_mask
예제 #3
0
    def _get_entropy_output(self, inputs: pretrain_data.Inputs, model):
        """Masked language modeling softmax layer."""
        with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE):
            hidden = tf.layers.dense(
                model.get_sequence_output(),
                units=modeling.get_shape_list(model.get_embedding_table())[-1],
                activation=modeling.get_activation(
                    self._bert_config.hidden_act),
                kernel_initializer=modeling.create_initializer(
                    self._bert_config.initializer_range))
            hidden = modeling.layer_norm(hidden)
            output_bias = tf.get_variable("output_bias",
                                          shape=[self._bert_config.vocab_size],
                                          initializer=tf.zeros_initializer())
            logits = tf.matmul(hidden,
                               model.get_embedding_table(),
                               transpose_b=True)
            logits = tf.nn.bias_add(logits, output_bias)

            probs = tf.nn.softmax(logits)
            log_probs = tf.nn.log_softmax(logits)
            entropy = -tf.reduce_sum(log_probs * probs, axis=[2])

            EntropyOutput = collections.namedtuple(
                "EntropyOutput", ["logits", "probs", "log_probs", "entropy"])
            return EntropyOutput(logits=logits,
                                 probs=probs,
                                 log_probs=log_probs,
                                 entropy=entropy)
예제 #4
0
def gather_positions(sequence, positions):
  """Gathers the vectors at the specific positions over a minibatch.

  Args:
    sequence: A [batch_size, seq_length] or
        [batch_size, seq_length, depth] tensor of values
    positions: A [batch_size, n_positions] tensor of indices

  Returns: A [batch_size, n_positions] or
    [batch_size, n_positions, depth] tensor of the values at the indices
  """
  shape = modeling.get_shape_list(sequence, expected_rank=[2, 3])
  depth_dimension = (len(shape) == 3)
  if depth_dimension:
    B, L, D = shape
  else:
    B, L = shape
    D = 1
    sequence = tf.expand_dims(sequence, -1)
  position_shift = tf.expand_dims(L * tf.range(B), -1)
  flat_positions = tf.reshape(positions + position_shift, [-1])
  flat_sequence = tf.reshape(sequence, [B * L, D])
  gathered = tf.gather(flat_sequence, flat_positions)
  if depth_dimension:
    return tf.reshape(gathered, [B, -1, D])
  else:
    return tf.reshape(gathered, [B, -1])
예제 #5
0
def sample_from_softmax(logits, disallow=None):
  if disallow is not None:
    logits -= 1000.0 * disallow
  uniform_noise = tf.random.uniform(
      modeling.get_shape_list(logits), minval=0, maxval=1)
  gumbel_noise = -tf.log(-tf.log(uniform_noise + 1e-9) + 1e-9)
  return tf.one_hot(tf.argmax(tf.nn.softmax(logits + gumbel_noise), -1,
                              output_type=tf.int32), logits.shape[-1])
예제 #6
0
def sample_from_top_k(logits,
                      temperature,
                      disallow=None,
                      straight_through=False,
                      k=20):
    print(logits, '===========')
    logits_shape = modeling.get_shape_list(logits, expected_rank=[2, 3])
    depth_dimension = (len(logits_shape) == 3)
    if depth_dimension:
        reshape_logits = tf.reshape(logits, [-1, logits_shape[-1]])
    else:
        reshape_logits = logits
    print(reshape_logits, '======')
    reshape_logits_shape = modeling.get_shape_list(reshape_logits,
                                                   expected_rank=[2])
    batch = reshape_logits_shape[0]

    values, _ = tf.nn.top_k(reshape_logits, k=k)
    min_values = values[:, -1, tf.newaxis]

    reshape_topk_logits = tf.where(
        reshape_logits < min_values,
        tf.ones_like(reshape_logits, dtype=logits.dtype) * -1e10,
        reshape_logits,
    )
    topk_logits = tf.reshape(reshape_topk_logits, logits_shape)
    if disallow is not None:
        topk_logits -= 1e10 * disallow
    uniform_noise = tf.random.uniform(modeling.get_shape_list(topk_logits),
                                      minval=0,
                                      maxval=1)
    gumbel_noise = -tf.log(-tf.log(uniform_noise + 1e-9) + 1e-9)

    gumbel_logits = (topk_logits + gumbel_noise) / temperature
    gumbel_probs = tf.nn.softmax(gumbel_logits)
    hard_token_ids = tf.one_hot(
        tf.argmax(gumbel_probs, axis=-1, output_type=tf.int32),
        topk_logits.shape[-1])

    if straight_through:
        gumbel_dense = tf.stop_gradient(hard_token_ids -
                                        gumbel_probs) + gumbel_probs
    else:
        gumbel_dense = gumbel_probs
    return gumbel_dense
예제 #7
0
    def __init__(self, config: configure_finetuning.FinetuningConfig, tasks,
                 is_training, features, num_train_steps):
        # Create a shared transformer encoder
        bert_config = training_utils.get_bert_config(config)
        self.bert_config = bert_config
        if config.debug:
            bert_config.num_hidden_layers = 3
            bert_config.hidden_size = 144
            bert_config.intermediate_size = 144 * 4
            bert_config.num_attention_heads = 4

        # multi-choice mrc
        if any([isinstance(x, qa_tasks.MQATask) for x in tasks]):
            seq_len = config.max_seq_length
            assert seq_len <= bert_config.max_position_embeddings
            bs, total_len = modeling.get_shape_list(features["input_ids"],
                                                    expected_rank=2)
            to_shape = [
                bs * config.max_options_num * config.evidences_top_k, seq_len
            ]
            bert_model = modeling.BertModel(
                bert_config=bert_config,
                is_training=is_training,
                input_ids=tf.reshape(features["input_ids"], to_shape),
                input_mask=tf.reshape(features["input_mask"], to_shape),
                token_type_ids=tf.reshape(features["segment_ids"], to_shape),
                use_one_hot_embeddings=config.use_tpu,
                embedding_size=config.embedding_size)
        else:
            assert config.max_seq_length <= bert_config.max_position_embeddings
            bert_model = modeling.BertModel(
                bert_config=bert_config,
                is_training=is_training,
                input_ids=features["input_ids"],
                input_mask=features["input_mask"],
                token_type_ids=features["segment_ids"],
                use_one_hot_embeddings=config.use_tpu,
                embedding_size=config.embedding_size)
        percent_done = (
            tf.cast(tf.train.get_or_create_global_step(), tf.float32) /
            tf.cast(num_train_steps, tf.float32))

        # Add specific tasks
        self.outputs = {"task_id": features["task_id"]}
        losses = []
        for task in tasks:
            with tf.variable_scope("task_specific/" + task.name):
                task_losses, task_outputs = task.get_prediction_module(
                    bert_model, features, is_training, percent_done)
                losses.append(task_losses)
                self.outputs[task.name] = task_outputs
        self.loss = tf.reduce_sum(
            tf.stack(losses, -1) *
            tf.one_hot(features["task_id"], len(config.task_names)))
예제 #8
0
def _get_masked_lm_output(inputs, model):
    """Masked language modeling softmax layer."""
    with tf.variable_scope("generator_predictions"):

        logits = tf.zeros(21228)
        logits_tiled = tf.zeros(
            modeling.get_shape_list(inputs.masked_lm_ids) + [21228])
        logits_tiled += tf.reshape(logits, [1, 1, 21228])
        logits = logits_tiled

        return get_softmax_output(logits, inputs.masked_lm_ids,
                                  inputs.masked_lm_weights, 21228)
예제 #9
0
    def get_prediction_module(self, bert_model, features, is_training,
                              percent_done):
        final_hidden = bert_model.get_pooled_output()
        # bs * options_num * top_k, hidden_dim
        final_hidden_shape = modeling.get_shape_list(final_hidden, expected_rank=2)
        # bs, options_num * top_k * seq_len
        input_ids_shape = modeling.get_shape_list(features["input_ids"], expected_rank=2)
        batch_size = input_ids_shape[0]
        hidden_dim = final_hidden_shape[1]
        final_hidden_reshape = tf.reshape(final_hidden, [batch_size, self.config.max_options_num,
                                                         self.config.evidences_top_k * hidden_dim])

        # def single(hidden, mask, y):
        #     logits = tf.squeeze(tf.layers.dense(hidden, 1), -1)
        #     mask = mask[:self.config.max_options_num]
        #     y = y[:self.config.max_options_num]
        #     logits_masked = logits + 1e8 * (mask - 1)
        #     loss = tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits_masked)
        #     return logits_masked, loss
        #
        # def combination_single(hidden, mask, y):
        #     logits = tf.squeeze(tf.layers.dense(hidden, 1), -1)  # todo: share or not ?
        #     logits = tf.layers.dense(logits, 2 ** self.config.max_options_num)
        #     logits_masked = logits + 1e8 * (mask - 1)
        #     loss = tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits_masked)
        #     return logits_masked, loss
        # logits, loss = tf.cond()

        logits = tf.squeeze(tf.layers.dense(final_hidden_reshape, 1), -1)
        loss1 = tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.to_float(features[self.name + "_answer_ids_raw"]), logits=logits)
        loss1 = tf.reduce_mean(loss1, axis=-1)
        logits = tf.layers.dense(logits, 2 ** self.config.max_options_num)
        logits_masked = logits + 1e8 * tf.to_float(features[self.name + "_answer_mask"] - 1)
        loss = tf.nn.softmax_cross_entropy_with_logits(labels=features[self.name + "_answer_ids"], logits=logits)
        loss = loss * 0.2 + loss1 * 0.8
        return loss, dict(
            loss=loss,
            logits=logits_masked,
            eid=features[self.name + "_eid"],
        )
예제 #10
0
def get_token_logits(input_reprs, embedding_table, bert_config):
    hidden = tf.layers.dense(
        input_reprs,
        units=modeling.get_shape_list(embedding_table)[-1],
        activation=modeling.get_activation(bert_config.hidden_act),
        kernel_initializer=modeling.create_initializer(
            bert_config.initializer_range))
    hidden = modeling.layer_norm(hidden)
    output_bias = tf.get_variable("output_bias",
                                  shape=[bert_config.vocab_size],
                                  initializer=tf.zeros_initializer())
    logits = tf.matmul(hidden, embedding_table, transpose_b=True)
    logits = tf.nn.bias_add(logits, output_bias)
    return logits
def gather_indexes(sequence_tensor, positions):
    """Gathers the vectors at the specific positions over a minibatch."""
    sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3)
    batch_size = sequence_shape[0]
    seq_length = sequence_shape[1]
    width = sequence_shape[2]

    flat_offsets = tf.reshape(
        tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
    flat_positions = tf.reshape(positions + flat_offsets, [-1])
    flat_sequence_tensor = tf.reshape(sequence_tensor,
                                      [batch_size * seq_length, width])
    output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
    return output_tensor
예제 #12
0
    def _build_teacher(self,
                       states,
                       inputs: pretrain_data.Inputs,
                       is_training,
                       name="teacher",
                       reuse=False,
                       **kwargs):
        """Build teacher network to estimate token score."""
        input_shape = get_shape_list(states, expected_rank=3)
        prev_output = states
        hidden_size = self._teacher_config.hidden_size
        num_hidden_layers = self._teacher_config.num_hidden_layers
        if is_training:
            hidden_dropout_prob = self._teacher_config.hidden_dropout_prob
        else:
            hidden_dropout_prob = 0.0
        with tf.variable_scope("teacher", reuse=reuse):
            for layer_idx in range(num_hidden_layers):
                with tf.variable_scope("layer_%d" % layer_idx):
                    layer_input = prev_output
                    layer_output = tf.layers.dense(
                        layer_input,
                        hidden_size,
                        activation=get_activation("gelu"),
                        kernel_initializer=create_initializer(
                            self._teacher_config.initializer_range))
                    layer_output = dropout(layer_output, hidden_dropout_prob)
                    layer_output = layer_norm(layer_output)
                    prev_output = layer_output

            sequence_output = prev_output
            with tf.variable_scope("bernoulli"):
                with tf.variable_scope("transform"):
                    logits = tf.layers.dense(
                        sequence_output,
                        units=1,
                        kernel_initializer=create_initializer(
                            self._teacher_config.initializer_range))
                    action_probs = tf.nn.sigmoid(logits)
                    action_probs = tf.squeeze(action_probs)

            TeacherOutput = collections.namedtuple("TeacherOutput",
                                                   ["action_probs"])
        return TeacherOutput(action_probs=action_probs)
예제 #13
0
 def _get_masked_lm_output(self, inputs: pretrain_data.Inputs, model):
     """Masked language modeling softmax layer."""
     with tf.variable_scope("generator_predictions"):
         if self._config.uniform_generator:
             logits = tf.zeros(self._bert_config.vocab_size)
             logits_tiled = tf.zeros(
                 modeling.get_shape_list(inputs.masked_lm_ids) +
                 [self._bert_config.vocab_size])
             logits_tiled += tf.reshape(
                 logits, [1, 1, self._bert_config.vocab_size])
             logits = logits_tiled
         else:
             relevant_reprs = pretrain_helpers.gather_positions(
                 model.get_sequence_output(), inputs.masked_lm_positions)
             logits = get_token_logits(relevant_reprs,
                                       model.get_embedding_table(),
                                       self._bert_config)
         return get_softmax_output(logits, inputs.masked_lm_ids,
                                   inputs.masked_lm_weights,
                                   self._bert_config.vocab_size)
예제 #14
0
    def _get_autoencoder_output(self, inputs: pretrain_data.Inputs, model):
        """Auto-Encoder softmax layer."""
        with tf.variable_scope("autoencoder_predictions"):
            relevant_hidden = model.get_sequence_output()
            hidden = tf.layers.dense(
                relevant_hidden,
                units=modeling.get_shape_list(model.get_embedding_table())[-1],
                activation=modeling.get_activation(
                    self._bert_config.hidden_act),
                kernel_initializer=modeling.create_initializer(
                    self._bert_config.initializer_range))
            hidden = modeling.layer_norm(hidden)
            output_bias = tf.get_variable("output_bias",
                                          shape=[self._bert_config.vocab_size],
                                          initializer=tf.zeros_initializer())
            logits = tf.matmul(hidden,
                               model.get_embedding_table(),
                               transpose_b=True)
            logits = tf.nn.bias_add(logits, output_bias)

            oh_labels = tf.one_hot(inputs.input_ids,
                                   depth=self._bert_config.vocab_size,
                                   dtype=tf.float32)

            probs = tf.nn.softmax(logits)
            log_probs = tf.nn.log_softmax(logits)
            label_log_probs = -tf.reduce_sum(log_probs * oh_labels, axis=-1)

            numerator = tf.reduce_sum(inputs.input_mask * label_log_probs)
            denominator = tf.reduce_sum(inputs.input_mask) + 1e-6
            loss = numerator / denominator
            preds = tf.argmax(log_probs, axis=-1, output_type=tf.int32)

            AEOutput = collections.namedtuple(
                "AEOutput",
                ["logits", "probs", "loss", "per_example_loss", "preds"])
            return AEOutput(logits=logits,
                            probs=probs,
                            per_example_loss=label_log_probs,
                            loss=loss,
                            preds=preds)
예제 #15
0
def sample_from_softmax(logits,
                        temperature,
                        disallow=None,
                        straight_through=False):
    if disallow is not None:
        logits -= 1000.0 * disallow
    uniform_noise = tf.random.uniform(modeling.get_shape_list(logits),
                                      minval=0,
                                      maxval=1)
    gumbel_noise = -tf.log(-tf.log(uniform_noise + 1e-9) + 1e-9)
    gumbel_logits = (logits + gumbel_noise) / temperature
    gumbel_probs = tf.nn.softmax(gumbel_logits)
    hard_token_ids = tf.one_hot(
        tf.argmax(gumbel_probs, axis=-1, output_type=tf.int32),
        logits.shape[-1])
    if straight_through:
        gumbel_dense = tf.stop_gradient(hard_token_ids -
                                        gumbel_probs) + gumbel_probs
    else:
        gumbel_dense = gumbel_probs
    return gumbel_dense
예제 #16
0
    def _sample_masking_subset(self, inputs: pretrain_data.Inputs,
                               action_probs):
        #calculate shifted action_probs
        input_mask = inputs.input_mask
        segment_ids = inputs.segment_ids
        input_ids = inputs.input_ids

        shape = modeling.get_shape_list(input_ids, expected_rank=2)
        batch_size = shape[0]
        max_seq_len = shape[1]

        def _remove_special_token(elems):
            action_prob = tf.cast(elems[0], tf.float32)
            segment = tf.cast(elems[1], tf.int32)
            input = tf.cast(elems[2], tf.int32)
            mask = tf.cast(elems[3], tf.int32)

            seq_len = tf.reduce_sum(mask)
            seg1_len = seq_len - tf.reduce_sum(segment)
            seq1_idx = tf.range(start=1, limit=seg1_len - 1, dtype=tf.int32)
            seq2_limit = tf.math.maximum(seg1_len, seq_len - 1)
            seq2_idx = tf.range(start=seg1_len,
                                limit=seq2_limit,
                                dtype=tf.int32)
            mask_idx = tf.range(start=seq_len,
                                limit=max_seq_len,
                                dtype=tf.int32)
            index_tensor = tf.concat([seq1_idx, seq2_idx, mask_idx], axis=0)

            seq1_prob = action_prob[1:seg1_len - 1]
            seq2_prob = action_prob[seg1_len:seq2_limit]
            mask_prob = tf.ones_like(mask_idx, dtype=tf.float32) * 1e-20
            cleaned_action_prob = tf.concat([seq1_prob, seq2_prob, mask_prob],
                                            axis=0)
            cleaned_mask = tf.concat([
                mask[1:seg1_len - 1], mask[seg1_len:seq_len - 1],
                mask[seq_len:max_seq_len]
            ],
                                     axis=0)

            cleaned_input = tf.concat([
                input[1:seg1_len - 1], input[seg1_len:seq_len - 1],
                input[seq_len:max_seq_len]
            ],
                                      axis=0)

            cleaned_action_prob = cleaned_action_prob[0:max_seq_len - 3]
            index_tensor = index_tensor[0:max_seq_len - 3]
            cleaned_input = cleaned_input[0:max_seq_len - 3]
            cleaned_mask = cleaned_mask[0:max_seq_len - 3]

            return (cleaned_action_prob, index_tensor, cleaned_input,
                    cleaned_mask)

        # Remove CLS and SEP action probs
        elems = tf.stack([
            action_probs,
            tf.cast(segment_ids, tf.float32),
            tf.cast(input_ids, tf.float32),
            tf.cast(input_mask, tf.float32)
        ], 1)
        cleaned_action_probs, index_tensors, cleaned_inputs, cleaned_input_mask = tf.map_fn(
            _remove_special_token,
            elems,
            dtype=(tf.float32, tf.int32, tf.int32, tf.int32),
            parallel_iterations=1)
        logZ, log_prob = self._calculate_partition_table(
            cleaned_input_mask, cleaned_action_probs,
            self._config.max_predictions_per_seq)

        samples, log_q = self._sampling_a_subset(
            logZ, log_prob, self._config.max_predictions_per_seq)

        # Collect masked_lm_ids and masked_lm_positions
        zero_values = tf.zeros_like(index_tensors, tf.int32)
        selected_position = tf.where(tf.equal(samples, 1), index_tensors,
                                     zero_values)
        masked_lm_positions, _ = tf.nn.top_k(
            selected_position,
            self._config.max_predictions_per_seq,
            sorted=False)

        # Get the ids of the masked-out tokens
        shift = tf.expand_dims(max_seq_len * tf.range(batch_size), -1)
        flat_positions = tf.reshape(masked_lm_positions + shift, [-1, 1])
        masked_lm_ids = tf.gather_nd(tf.reshape(input_ids, [-1]),
                                     flat_positions)
        masked_lm_ids = tf.reshape(masked_lm_ids, [batch_size, -1])

        # Update the input ids
        replaced_prob = tf.random.uniform(
            [batch_size, self._config.max_predictions_per_seq])
        replace_with_mask_positions = masked_lm_positions * tf.cast(
            tf.less(replaced_prob, 0.85), tf.int32)
        inputs_ids, _ = scatter_update(
            inputs.input_ids,
            tf.fill([batch_size, self._config.max_predictions_per_seq],
                    self._vocab["[MASK]"]), replace_with_mask_positions)

        # Replace with random tokens
        replace_with_random_positions = masked_lm_positions * tf.cast(
            tf.greater(replaced_prob, 0.925), tf.int32)
        random_tokens = tf.random.uniform(
            [batch_size, self._config.max_predictions_per_seq],
            minval=0,
            maxval=len(self._vocab),
            dtype=tf.int32)

        inputs_ids, _ = scatter_update(inputs_ids, random_tokens,
                                       replace_with_random_positions)

        masked_lm_weights = tf.ones_like(masked_lm_ids, tf.float32)
        inv_vocab = self._inv_vocab
        # Apply mask on input
        if self._config.debug:

            def pretty_print(inputs_ids, masked_lm_ids, masked_lm_positions,
                             masked_lm_weights, tag_ids):
                debug_inputs = Inputs(input_ids=inputs_ids,
                                      input_mask=None,
                                      segment_ids=None,
                                      masked_lm_positions=masked_lm_positions,
                                      masked_lm_ids=masked_lm_ids,
                                      masked_lm_weights=masked_lm_weights,
                                      tag_ids=tag_ids)
                pretrain_data.print_tokens(debug_inputs, inv_vocab)

                ## TODO: save to the mask choice
                return inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights

            mask_shape = masked_lm_ids.get_shape()
            inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights = \
              tf.py_func(pretty_print, [inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights, inputs.tag_ids],
                         (tf.int32, tf.int32, tf.int32, tf.float32))
            inputs_ids.set_shape(inputs.input_ids.get_shape())
            masked_lm_ids.set_shape(mask_shape)
            masked_lm_positions.set_shape(mask_shape)
            masked_lm_weights.set_shape(mask_shape)

        masked_input = pretrain_data.get_updated_inputs(
            inputs,
            input_ids=tf.stop_gradient(input_ids),
            masked_lm_positions=tf.stop_gradient(masked_lm_positions),
            masked_lm_ids=tf.stop_gradient(masked_lm_ids),
            masked_lm_weights=tf.stop_gradient(masked_lm_weights),
            tag_ids=inputs.tag_ids)

        return log_q, masked_input
예제 #17
0
def mask(config: configure_pretraining.PretrainingConfig,
         inputs: pretrain_data.Inputs, mask_prob, proposal_distribution=1.0,
         disallow_from_mask=None, already_masked=None):
  """Implementation of dynamic masking. The optional arguments aren't needed for
  BERT/ELECTRA and are from early experiments in "strategically" masking out
  tokens instead of uniformly at random.

  Args:
    config: configure_pretraining.PretrainingConfig
    inputs: pretrain_data.Inputs containing input input_ids/input_mask
    mask_prob: percent of tokens to mask
    proposal_distribution: for non-uniform masking can be a [B, L] tensor
                           of scores for masking each position.
    disallow_from_mask: a boolean tensor of [B, L] of positions that should
                        not be masked out
    already_masked: a boolean tensor of [B, N] of already masked-out tokens
                    for multiple rounds of masking
  Returns: a pretrain_data.Inputs with masking added
  """
  # Get the batch size, sequence length, and max masked-out tokens
  N = config.max_predictions_per_seq
  B, L = modeling.get_shape_list(inputs.input_ids)

  # Find indices where masking out a token is allowed
  tokenizer = tokenization.FullTokenizer(
      config.vocab_file, do_lower_case=config.do_lower_case)
  vocab = tokenizer.vocab
  inv_vocab = tokenizer.inv_vocab
  candidates_mask = _get_candidates_mask(inputs, vocab, disallow_from_mask)

  # Set the number of tokens to mask out per example
  num_tokens = tf.cast(tf.reduce_sum(inputs.input_mask, -1), tf.float32)
  num_to_predict = tf.maximum(1, tf.minimum(
      N, tf.cast(tf.round(num_tokens * mask_prob), tf.int32)))
  masked_lm_weights = tf.cast(tf.sequence_mask(num_to_predict, N), tf.float32)
  if already_masked is not None:
    masked_lm_weights *= (1 - already_masked)

  # Get a probability of masking each position in the sequence
  candidate_mask_float = tf.cast(candidates_mask, tf.float32)

  if config.masking_strategy == RAND_STRATEGY or config.masking_strategy == MIX_ADV_STRATEGY:
    sample_prob = (proposal_distribution * candidate_mask_float)
  elif config.masking_strategy == POS_STRATEGY:
    unfavor_pos_mask = _get_unfavor_pos_mask(inputs)
    unfavor_pos_mask_float = tf.cast(unfavor_pos_mask, tf.float32)
    prefer_pos_mask_float = 1 - unfavor_pos_mask_float

    # prefered pos have 80% propabiblity, not preferred ones have 20% probability
    # proposal_distribution = prefer_pos_mask_float
    proposal_distribution = 0.95 * prefer_pos_mask_float + 0.05
    sample_prob = (proposal_distribution * candidate_mask_float)
  elif config.masking_strategy == ENTROPY_STRATEGY:
    sample_prob = (proposal_distribution * candidate_mask_float)
  elif config.masking_strategy == MIX_POS_STRATEGY:
    rand_sample_prob = (proposal_distribution * candidate_mask_float)
    unfavor_pos_mask = _get_unfavor_pos_mask(inputs)
    unfavor_pos_mask_float = tf.cast(unfavor_pos_mask, tf.float32)
    prefer_pos_mask_float = 1 - unfavor_pos_mask_float

    # prefered pos have 80% propabiblity, not preferred ones have 20% probability
    # proposal_distribution = prefer_pos_mask_float
    proposal_distribution = 0.95 * prefer_pos_mask_float + 0.05
    pos_sample_prob = (proposal_distribution * candidate_mask_float)

    strategy_prob = tf.random.uniform([B])
    strategy_prob = tf.expand_dims(tf.cast(tf.greater(strategy_prob,0.5), tf.float32),1)
    strategy_prob = tf.tile(strategy_prob, [1,L])
    sample_prob = rand_sample_prob * strategy_prob + pos_sample_prob * (1 - strategy_prob)
  elif config.masking_strategy == MIX_ENTROPY_STRATEGY:
    rand_sample_prob = (proposal_distribution * candidate_mask_float)
    entropy_sample_prob = (proposal_distribution * candidate_mask_float)
    strategy_prob = tf.random.uniform([B])
    strategy_prob = tf.expand_dims(tf.cast(tf.greater(strategy_prob,0.5), tf.float32),1)
    strategy_prob = tf.tile(strategy_prob, [1, L])
    sample_prob = rand_sample_prob * strategy_prob + entropy_sample_prob * (1 - strategy_prob)
  else:
    raise ValueError("{} strategy is not supported".format(config.masking_strategy))
  sample_prob /= tf.reduce_sum(sample_prob, axis=-1, keepdims=True)

  # Sample the positions to mask out
  sample_prob = tf.stop_gradient(sample_prob)
  sample_logits = tf.log(sample_prob)
  masked_lm_positions = tf.random.categorical(
      sample_logits, N, dtype=tf.int32)
  masked_lm_positions *= tf.cast(masked_lm_weights, tf.int32)

  # Get the ids of the masked-out tokens
  shift = tf.expand_dims(L * tf.range(B), -1)
  flat_positions = tf.reshape(masked_lm_positions + shift, [-1, 1])
  masked_lm_ids = tf.gather_nd(tf.reshape(inputs.input_ids, [-1]),
                               flat_positions)
  masked_lm_ids = tf.reshape(masked_lm_ids, [B, -1])
  masked_lm_ids *= tf.cast(masked_lm_weights, tf.int32)

  # Update the input ids
  replace_prob = tf.random.uniform([B, N])
  replace_with_mask_positions = masked_lm_positions * tf.cast(
      tf.less(replace_prob, 0.85), tf.int32)
  inputs_ids, _ = scatter_update(
      inputs.input_ids, tf.fill([B, N], vocab["[MASK]"]),
      replace_with_mask_positions)

  # Replace with random tokens
  replace_with_random_positions = masked_lm_positions * tf.cast(
    tf.greater(replace_prob, 0.925), tf.int32)
  random_tokens = tf.random.uniform([B,N], minval=0, maxval=len(vocab), dtype=tf.int32)
  inputs_ids, _ = scatter_update(
    inputs_ids, random_tokens,
    replace_with_random_positions)

  if config.debug:
    def pretty_print(inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights, tag_ids):
      debug_inputs = Inputs(
      input_ids=inputs_ids,
      input_mask=None,
      segment_ids=None,
      masked_lm_positions=masked_lm_positions,
      masked_lm_ids=masked_lm_ids,
      masked_lm_weights=masked_lm_weights,
      tag_ids = tag_ids)
      pretrain_data.print_tokens(debug_inputs, inv_vocab)

      ## TODO: save to the mask choice
      return inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights

    mask_shape = masked_lm_ids.get_shape()
    inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights = \
      tf.py_func(pretty_print,[inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights, inputs.tag_ids],
                 (tf.int32, tf.int32, tf.int32, tf.float32))
    inputs_ids.set_shape(inputs.input_ids.get_shape())
    masked_lm_ids.set_shape(mask_shape)
    masked_lm_positions.set_shape(mask_shape)
    masked_lm_weights.set_shape(mask_shape)

  return pretrain_data.get_updated_inputs(
      inputs,
      input_ids=tf.stop_gradient(inputs_ids),
      masked_lm_positions=masked_lm_positions,
      masked_lm_ids=masked_lm_ids,
      masked_lm_weights=masked_lm_weights,
      tag_ids = inputs.tag_ids
    )
예제 #18
0
def mask(config: configure_pretraining.PretrainingConfig,
         inputs: pretrain_data.Inputs, mask_prob, proposal_distribution=1.0,
         disallow_from_mask=None, already_masked=None):
  """Implementation of dynamic masking. The optional arguments aren't needed for
  BERT/ELECTRA and are from early experiments in "strategically" masking out
  tokens instead of uniformly at random.

  Args:
    config: configure_pretraining.PretrainingConfig
    inputs: pretrain_data.Inputs containing input input_ids/input_mask
    mask_prob: percent of tokens to mask
    proposal_distribution: for non-uniform masking can be a [B, L] tensor
                           of scores for masking each position.
    disallow_from_mask: a boolean tensor of [B, L] of positions that should
                        not be masked out
    already_masked: a boolean tensor of [B, N] of already masked-out tokens
                    for multiple rounds of masking
  Returns: a pretrain_data.Inputs with masking added
  """
  # Get the batch size, sequence length, and max masked-out tokens
  N = config.max_predictions_per_seq
  B, L = modeling.get_shape_list(inputs.input_ids)

  # Find indices where masking out a token is allowed
  vocab = tokenization.FullTokenizer(
      config.vocab_file, do_lower_case=config.do_lower_case).vocab
  candidates_mask = _get_candidates_mask(inputs, vocab, disallow_from_mask)

  # Set the number of tokens to mask out per example
  num_tokens = tf.cast(tf.reduce_sum(inputs.input_mask, -1), tf.float32)
  num_to_predict = tf.maximum(1, tf.minimum(
      N, tf.cast(tf.round(num_tokens * mask_prob), tf.int32)))
  masked_lm_weights = tf.cast(tf.sequence_mask(num_to_predict, N), tf.float32)
  if already_masked is not None:
    masked_lm_weights *= (1 - already_masked)

  # Get a probability of masking each position in the sequence
  candidate_mask_float = tf.cast(candidates_mask, tf.float32)
  sample_prob = (proposal_distribution * candidate_mask_float)
  sample_prob /= tf.reduce_sum(sample_prob, axis=-1, keepdims=True)

  # Sample the positions to mask out
  sample_prob = tf.stop_gradient(sample_prob)
  sample_logits = tf.log(sample_prob)
  masked_lm_positions = tf.random.categorical(
      sample_logits, N, dtype=tf.int32)
  masked_lm_positions *= tf.cast(masked_lm_weights, tf.int32)

  # Get the ids of the masked-out tokens
  shift = tf.expand_dims(L * tf.range(B), -1)
  flat_positions = tf.reshape(masked_lm_positions + shift, [-1, 1])
  masked_lm_ids = tf.gather_nd(tf.reshape(inputs.input_ids, [-1]),
                               flat_positions)
  masked_lm_ids = tf.reshape(masked_lm_ids, [B, -1])
  masked_lm_ids *= tf.cast(masked_lm_weights, tf.int32)

  # Update the input ids
  replace_with_mask_positions = masked_lm_positions * tf.cast(
      tf.less(tf.random.uniform([B, N]), 0.85), tf.int32)
  inputs_ids, _ = scatter_update(
      inputs.input_ids, tf.fill([B, N], vocab["[MASK]"]),
      replace_with_mask_positions)

  return pretrain_data.get_updated_inputs(
      inputs,
      input_ids=tf.stop_gradient(inputs_ids),
      masked_lm_positions=masked_lm_positions,
      masked_lm_ids=masked_lm_ids,
      masked_lm_weights=masked_lm_weights
  )
예제 #19
0
    def get_prediction_module(self, bert_model, features, is_training,
                              percent_done):
        final_hidden = bert_model.get_sequence_output()

        final_hidden_shape = modeling.get_shape_list(final_hidden,
                                                     expected_rank=3)
        batch_size = final_hidden_shape[0]
        seq_length = final_hidden_shape[1]

        answer_mask = tf.cast(features["input_mask"], tf.float32)
        answer_mask *= tf.cast(features["segment_ids"], tf.float32)
        answer_mask += tf.one_hot(0, seq_length)

        start_logits = tf.squeeze(tf.layers.dense(final_hidden, 1), -1)

        start_top_log_probs = tf.zeros([batch_size, self.config.beam_size])
        start_top_index = tf.zeros([batch_size, self.config.beam_size],
                                   tf.int32)
        end_top_log_probs = tf.zeros(
            [batch_size, self.config.beam_size, self.config.beam_size])
        end_top_index = tf.zeros(
            [batch_size, self.config.beam_size, self.config.beam_size],
            tf.int32)
        if self.config.joint_prediction:
            start_logits += 1000.0 * (answer_mask - 1)
            start_log_probs = tf.nn.log_softmax(start_logits)
            start_top_log_probs, start_top_index = tf.nn.top_k(
                start_log_probs, k=self.config.beam_size)

            if not is_training:
                # batch, beam, length, hidden
                end_features = tf.tile(tf.expand_dims(final_hidden, 1),
                                       [1, self.config.beam_size, 1, 1])
                # batch, beam, length
                start_index = tf.one_hot(start_top_index,
                                         depth=seq_length,
                                         axis=-1,
                                         dtype=tf.float32)
                # batch, beam, hidden
                start_features = tf.reduce_sum(
                    tf.expand_dims(final_hidden, 1) *
                    tf.expand_dims(start_index, -1),
                    axis=-2)
                # batch, beam, length, hidden
                start_features = tf.tile(tf.expand_dims(start_features, 2),
                                         [1, 1, seq_length, 1])
            else:
                start_index = tf.one_hot(features[self.name +
                                                  "_start_positions"],
                                         depth=seq_length,
                                         axis=-1,
                                         dtype=tf.float32)
                start_features = tf.reduce_sum(
                    tf.expand_dims(start_index, -1) * final_hidden, axis=1)
                start_features = tf.tile(tf.expand_dims(start_features, 1),
                                         [1, seq_length, 1])
                end_features = final_hidden

            final_repr = tf.concat([start_features, end_features], -1)
            final_repr = tf.layers.dense(final_repr,
                                         512,
                                         activation=modeling.gelu,
                                         name="qa_hidden")
            # batch, beam, length (batch, length when training)
            end_logits = tf.squeeze(tf.layers.dense(final_repr, 1),
                                    -1,
                                    name="qa_logits")
            if is_training:
                end_logits += 1000.0 * (answer_mask - 1)
            else:
                end_logits += tf.expand_dims(1000.0 * (answer_mask - 1), 1)

            if not is_training:
                end_log_probs = tf.nn.log_softmax(end_logits)
                end_top_log_probs, end_top_index = tf.nn.top_k(
                    end_log_probs, k=self.config.beam_size)
                end_logits = tf.zeros([batch_size, seq_length])
        else:
            end_logits = tf.squeeze(tf.layers.dense(final_hidden, 1), -1)
            start_logits += 1000.0 * (answer_mask - 1)
            end_logits += 1000.0 * (answer_mask - 1)

        def compute_loss(logits, positions):
            one_hot_positions = tf.one_hot(positions,
                                           depth=seq_length,
                                           dtype=tf.float32)
            log_probs = tf.nn.log_softmax(logits, axis=-1)
            loss = -tf.reduce_sum(one_hot_positions * log_probs, axis=-1)
            return loss

        start_positions = features[self.name + "_start_positions"]
        end_positions = features[self.name + "_end_positions"]

        start_loss = compute_loss(start_logits, start_positions)
        end_loss = compute_loss(end_logits, end_positions)

        losses = (start_loss + end_loss) / 2.0

        answerable_logit = tf.zeros([batch_size])
        if self.config.answerable_classifier:
            final_repr = final_hidden[:, 0]
            if self.config.answerable_uses_start_logits:
                start_p = tf.nn.softmax(start_logits)
                start_feature = tf.reduce_sum(tf.expand_dims(start_p, -1) *
                                              final_hidden,
                                              axis=1)
                final_repr = tf.concat([final_repr, start_feature], -1)
                final_repr = tf.layers.dense(final_repr,
                                             512,
                                             activation=modeling.gelu)
            answerable_logit = tf.squeeze(tf.layers.dense(final_repr, 1), -1)
            answerable_loss = tf.nn.sigmoid_cross_entropy_with_logits(
                labels=tf.cast(features[self.name + "_is_impossible"],
                               tf.float32),
                logits=answerable_logit)
            losses += answerable_loss * self.config.answerable_weight

        return losses, dict(
            loss=losses,
            start_logits=start_logits,
            end_logits=end_logits,
            answerable_logit=answerable_logit,
            start_positions=features[self.name + "_start_positions"],
            end_positions=features[self.name + "_end_positions"],
            start_top_log_probs=start_top_log_probs,
            start_top_index=start_top_index,
            end_top_log_probs=end_top_log_probs,
            end_top_index=end_top_index,
            eid=features[self.name + "_eid"],
        )
예제 #20
0
    def get_prediction_module(self, bert_model, features, is_training,
                              percent_done):
        final_hidden = bert_model.get_sequence_output()

        # sgnet
        # dep_mask_x = features[self.name + "_dep_mask_x"]
        # dep_mask_y = features[self.name + "_dep_mask_y"]
        # dep_mask_len = features[self.name + "_dep_mask_len"]
        #
        # def fn(xyz):
        #     x = xyz[0]
        #     y = xyz[1]
        #     length = xyz[2]
        #     x = x[:length]
        #     y = y[:length]
        #     st = tf.SparseTensor(indices=tf.cast(tf.transpose([x, y]), tf.int64),
        #                          values=tf.ones_like(x, dtype=tf.float32),
        #                          dense_shape=[self.config.max_seq_length, self.config.max_seq_length])
        #     dt = tf.sparse_tensor_to_dense(st)
        #     return dt
        #
        # dep_mask = tf.map_fn(fn, (dep_mask_x, dep_mask_y, dep_mask_len), dtype=tf.float32)
        # dep_mask = features["squad_dep_mask"]
        # dep_mask = tf.reshape(dep_mask, [-1, self.config.max_seq_length, self.config.max_seq_length])
        # with tf.variable_scope("dependence"):
        #     bert_config = bert_model.config
        #     dep_att_output, _ = modeling.transformer_model(
        #         input_tensor=final_hidden,
        #         attention_mask=dep_mask,
        #         hidden_size=bert_config.hidden_size,
        #         num_hidden_layers=1,
        #         num_attention_heads=bert_config.num_attention_heads,
        #         intermediate_size=bert_config.intermediate_size,
        #         intermediate_act_fn=modeling.get_activation(bert_config.hidden_act),
        #         hidden_dropout_prob=bert_config.hidden_dropout_prob,
        #         attention_probs_dropout_prob=bert_config.attention_probs_dropout_prob,
        #         initializer_range=bert_config.initializer_range,
        #         do_return_all_layers=False)
        # weight = tf.get_variable(name="weight", dtype=tf.float32, initializer=tf.zeros_initializer(),
        #                          shape=(), trainable=True)
        # weight = tf.sigmoid(weight)
        # final_hidden = weight * final_hidden + (1 - weight) * dep_att_output

        final_hidden_shape = modeling.get_shape_list(final_hidden,
                                                     expected_rank=3)
        batch_size = final_hidden_shape[0]
        seq_length = final_hidden_shape[1]

        answer_mask = tf.cast(features["input_mask"], tf.float32)
        answer_mask *= tf.cast(features["segment_ids"], tf.float32)
        answer_mask += tf.one_hot(0, seq_length)

        start_logits = tf.squeeze(tf.layers.dense(final_hidden, 1), -1)

        start_top_log_probs = tf.zeros([batch_size, self.config.beam_size])
        start_top_index = tf.zeros([batch_size, self.config.beam_size],
                                   tf.int32)
        end_top_log_probs = tf.zeros(
            [batch_size, self.config.beam_size, self.config.beam_size])
        end_top_index = tf.zeros(
            [batch_size, self.config.beam_size, self.config.beam_size],
            tf.int32)
        if self.config.joint_prediction:
            start_logits += 1000.0 * (answer_mask - 1)
            start_log_probs = tf.nn.log_softmax(start_logits)
            start_top_log_probs, start_top_index = tf.nn.top_k(
                start_log_probs, k=self.config.beam_size)

            if not is_training:
                # batch, beam, length, hidden
                end_features = tf.tile(tf.expand_dims(final_hidden, 1),
                                       [1, self.config.beam_size, 1, 1])
                # batch, beam, length
                start_index = tf.one_hot(start_top_index,
                                         depth=seq_length,
                                         axis=-1,
                                         dtype=tf.float32)
                # batch, beam, hidden
                start_features = tf.reduce_sum(
                    tf.expand_dims(final_hidden, 1) *
                    tf.expand_dims(start_index, -1),
                    axis=-2)
                # batch, beam, length, hidden
                start_features = tf.tile(tf.expand_dims(start_features, 2),
                                         [1, 1, seq_length, 1])
            else:
                start_index = tf.one_hot(features[self.name +
                                                  "_start_positions"],
                                         depth=seq_length,
                                         axis=-1,
                                         dtype=tf.float32)
                start_features = tf.reduce_sum(
                    tf.expand_dims(start_index, -1) * final_hidden, axis=1)
                start_features = tf.tile(tf.expand_dims(start_features, 1),
                                         [1, seq_length, 1])
                end_features = final_hidden

            final_repr = tf.concat([start_features, end_features], -1)
            final_repr = tf.layers.dense(final_repr,
                                         512,
                                         activation=modeling.gelu,
                                         name="qa_hidden")
            # batch, beam, length (batch, length when training)
            end_logits = tf.squeeze(tf.layers.dense(final_repr, 1),
                                    -1,
                                    name="qa_logits")
            if is_training:
                end_logits += 1000.0 * (answer_mask - 1)
            else:
                end_logits += tf.expand_dims(1000.0 * (answer_mask - 1), 1)

            if not is_training:
                end_log_probs = tf.nn.log_softmax(end_logits)
                end_top_log_probs, end_top_index = tf.nn.top_k(
                    end_log_probs, k=self.config.beam_size)
                end_logits = tf.zeros([batch_size, seq_length])
        else:
            end_logits = tf.squeeze(tf.layers.dense(final_hidden, 1), -1)
            start_logits += 1000.0 * (answer_mask - 1)
            end_logits += 1000.0 * (answer_mask - 1)

        def compute_loss(logits, positions):
            one_hot_positions = tf.one_hot(positions,
                                           depth=seq_length,
                                           dtype=tf.float32)
            log_probs = tf.nn.log_softmax(logits, axis=-1)
            loss = -tf.reduce_sum(one_hot_positions * log_probs, axis=-1)
            return loss

        start_positions = features[self.name + "_start_positions"]
        end_positions = features[self.name + "_end_positions"]

        start_loss = compute_loss(start_logits, start_positions)
        end_loss = compute_loss(end_logits, end_positions)

        losses = (start_loss + end_loss) / 2.0

        # plausible answer loss
        plau_logits = tf.layers.dense(final_hidden, 2)
        plau_logits = tf.reshape(plau_logits, [batch_size, seq_length, 2])
        plau_logits = tf.transpose(plau_logits, [2, 0, 1])
        unstacked_logits = tf.unstack(plau_logits, axis=0)
        (plau_start_logits, plau_end_logits) = (unstacked_logits[0],
                                                unstacked_logits[1])
        plau_start_logits += 1000.0 * (answer_mask - 1)
        plau_end_logits += 1000.0 * (answer_mask - 1)
        plau_start_positions = features[self.name + "_plau_answer_start"]
        plau_end_positions = features[self.name + "_plau_answer_end"]
        plau_start_loss = compute_loss(plau_start_logits, plau_start_positions)
        plau_end_loss = compute_loss(plau_end_logits, plau_end_positions)
        losses += (plau_start_loss + plau_end_loss) / 2.0

        # def compute_loss_for_plau(start_logits, end_logits, start_positions, end_positions, start_positions_true,
        #                           alpha=1.0, beta=1.0):
        #     start_probs = tf.nn.softmax(start_logits)
        #     end_probs = tf.nn.softmax(end_logits)
        #     log_neg_start_probs = tf.log(tf.clip_by_value(1 - start_probs, 1e-30, 1))
        #     log_neg_end_probs = tf.log(tf.clip_by_value(1 - end_probs, 1e-30, 1))
        #     start_positions_mask = tf.cast(tf.sequence_mask(start_positions, maxlen=seq_length), tf.float32)
        #     end_positions_mask = tf.cast(tf.sequence_mask(end_positions + 1, maxlen=seq_length), tf.float32)
        #     positions_mask = end_positions_mask - start_positions_mask
        #     one_hot_positions = tf.one_hot(
        #         start_positions_true, depth=seq_length, dtype=tf.float32)
        #     positions_mask = positions_mask * (1 - one_hot_positions)  # 忽略切出来的无答案
        #
        #     # mask_0 = tf.zeros([batch_size, 1])
        #     # mask_1 = tf.ones([batch_size, seq_length - 1])
        #     # zero_mask = tf.concat([mask_0, mask_1], axis=1)
        #     # positions_mask = positions_mask * zero_mask
        #     loss1 = - tf.reduce_sum(positions_mask * log_neg_start_probs, axis=-1)
        #     loss1 = tf.reduce_mean(loss1)
        #     loss2 = - tf.reduce_sum(positions_mask * log_neg_end_probs, axis=-1)
        #     loss2 = tf.reduce_mean(loss2)
        #     return (loss1 * alpha + loss2 * beta) * 0.5
        #
        # plau_loss = compute_loss_for_plau(start_logits, end_logits,
        #                                   features[self.name + "_plau_answer_start"],
        #                                   features[self.name + "_plau_answer_end"],
        #                                   features[self.name + "_start_positions"], 1.0, 1.0)
        # losses += plau_loss

        answerable_logit = tf.zeros([batch_size])
        if self.config.answerable_classifier:
            final_repr = final_hidden[:, 0]
            if self.config.answerable_uses_start_logits:
                start_p = tf.nn.softmax(start_logits)
                start_feature = tf.reduce_sum(tf.expand_dims(start_p, -1) *
                                              final_hidden,
                                              axis=1)
                final_repr = tf.concat([final_repr, start_feature], -1)
                final_repr = tf.layers.dense(final_repr,
                                             512,
                                             activation=modeling.gelu)
            answerable_logit = tf.squeeze(tf.layers.dense(final_repr, 1), -1)
            answerable_loss = tf.nn.sigmoid_cross_entropy_with_logits(
                labels=tf.cast(features[self.name + "_is_impossible"],
                               tf.float32),
                logits=answerable_logit)
            losses += answerable_loss * self.config.answerable_weight

        return losses, dict(
            loss=losses,
            start_logits=start_logits,
            end_logits=end_logits,
            answerable_logit=answerable_logit,
            start_positions=features[self.name + "_start_positions"],
            end_positions=features[self.name + "_end_positions"],
            start_top_log_probs=start_top_log_probs,
            start_top_index=start_top_index,
            end_top_log_probs=end_top_log_probs,
            end_top_index=end_top_index,
            eid=features[self.name + "_eid"],
        )
  def _get_masked_lm_output(self, inputs: pretrain_data.Inputs, model):
    """Masked language modeling softmax layer."""
    masked_lm_weights = inputs.masked_lm_weights
    with tf.variable_scope("generator_predictions"):
      if self._config.uniform_generator or self._config.identity_generator or self._config.heuristic_generator:
        logits = tf.zeros(self._bert_config.vocab_size)
        logits_tiled = tf.zeros(
            modeling.get_shape_list(inputs.masked_lm_ids) +
            [self._bert_config.vocab_size])
        logits_tiled += tf.reshape(logits, [1, 1, self._bert_config.vocab_size])
        logits = logits_tiled
      else:
        relevant_hidden = pretrain_helpers.gather_positions(
            model.get_sequence_output(), inputs.masked_lm_positions)
        hidden = tf.layers.dense(
            relevant_hidden,
            units=modeling.get_shape_list(model.get_embedding_table())[-1],
            activation=modeling.get_activation(self._bert_config.hidden_act),
            kernel_initializer=modeling.create_initializer(
                self._bert_config.initializer_range))
        hidden = modeling.layer_norm(hidden)
        output_bias = tf.get_variable(
            "output_bias",
            shape=[self._bert_config.vocab_size],
            initializer=tf.zeros_initializer())
        logits = tf.matmul(hidden, model.get_embedding_table(),
                           transpose_b=True)
        logits = tf.nn.bias_add(logits, output_bias)

      oh_labels = tf.one_hot(
          inputs.masked_lm_ids, depth=self._bert_config.vocab_size,
          dtype=tf.float32)

      probs = tf.nn.softmax(logits)

      if self._config.identity_generator:
          identity_logits = tf.zeros(self._bert_config.vocab_size)
          identity_logits_tiled = tf.zeros(
              modeling.get_shape_list(inputs.masked_lm_ids) +
              [self._bert_config.vocab_size])
          masked_identity_weights = tf.one_hot(inputs.masked_lm_ids, depth=self._bert_config.vocab_size, dtype=tf.float32)
          identity_logits_tiled += 25.0 * masked_identity_weights
          identity_logits_tiled += tf.reshape(identity_logits, [1, 1, self._bert_config.vocab_size])
          identity_logits = identity_logits_tiled
          identity_probs = tf.nn.softmax(identity_logits)

          identity_weight = (self.global_step / tf.cast(self._config.num_train_steps, tf.float32)) * self._config.max_identity_weight
          probs = probs * (1 - identity_weight) + identity_probs * identity_weight
          logits = tf.math.log(probs)  # softmax(log(probs)) = probs
      elif self._config.heuristic_generator:
          synonym_logits = tf.zeros(self._bert_config.vocab_size)
          synonym_logits_tiled = tf.zeros(
              modeling.get_shape_list(inputs.masked_lm_ids) +
              [self._bert_config.vocab_size])
          masked_synonym_weights = tf.reduce_sum(
              tf.one_hot(inputs.masked_synonym_ids, depth=self._bert_config.vocab_size, dtype=tf.float32), -2)
          padded_synonym_mask = tf.concat([tf.zeros([1]), tf.ones([self._bert_config.vocab_size - 1])], 0)
          masked_synonym_weights *= tf.expand_dims(tf.expand_dims(padded_synonym_mask, 0), 0)
          synonym_logits_tiled += 25.0 * masked_synonym_weights
          synonym_logits_tiled += tf.reshape(synonym_logits, [1, 1, self._bert_config.vocab_size])
          synonym_logits = synonym_logits_tiled
          synonym_probs = tf.nn.softmax(synonym_logits)

          if self._config.synonym_scheduler_type == 'linear':
              synonym_weight = (self.global_step / tf.cast(self._config.num_train_steps, tf.float32)) * self._config.max_synonym_weight
              probs = probs * (1 - synonym_weight) + synonym_probs * synonym_weight
              logits = tf.math.log(probs)  # softmax(log(probs)) = probs

      log_probs = tf.nn.log_softmax(logits)
      label_log_probs = -tf.reduce_sum(log_probs * oh_labels, axis=-1)

      numerator = tf.reduce_sum(inputs.masked_lm_weights * label_log_probs)
      denominator = tf.reduce_sum(masked_lm_weights) + 1e-6
      loss = numerator / denominator
      preds = tf.argmax(log_probs, axis=-1, output_type=tf.int32)

      MLMOutput = collections.namedtuple(
          "MLMOutput", ["logits", "probs", "loss", "per_example_loss", "preds"])
      return MLMOutput(
          logits=logits, probs=probs, per_example_loss=label_log_probs,
          loss=loss, preds=preds)
예제 #22
0
    def _sampling_a_subset(self, logZ, logp, max_predictions_per_seq):
        shape = modeling.get_shape_list(logp, expected_rank=2)
        seq_len = shape[1]

        def gather_z_indexes(sequence_tensor, positions):
            """Gathers the vectors at the specific positions over a minibatch."""
            # set negative indices to zeros
            mask = tf.zeros_like(positions, dtype=tf.int32)
            masked_position = tf.reduce_max(tf.stack([positions, mask]), 0)

            index = tf.reshape(
                tf.cast(tf.where(tf.equal(mask, 0)), dtype=tf.int32), [-1])
            flat_offsets = index * (max_predictions_per_seq + 1)
            flat_positions = masked_position + flat_offsets
            flat_sequence_tensor = tf.reshape(sequence_tensor, [-1])
            output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
            return output_tensor

        def sampling_loop_cond(j, subset, count, left, log_q):
            # j < N and left > 0
            # we want to exclude last tokens, because it's always a special token [SEP]
            return tf.logical_or(tf.less(j, seq_len),
                                 tf.greater(tf.reduce_sum(left), 0))

        def sampling_body(j, subset, count, left, log_q):
            # calculate log_q_yes and log_q_no
            logp_j = logp[:, j]
            log_Z_total = gather_z_indexes(logZ[:, j, :], left)  # b
            log_Z_yes = gather_z_indexes(logZ[:, j + 1, :], left - 1)  # b
            logit_yes = logp_j + log_Z_yes - log_Z_total
            logit_no = tf.log(
                tf.clip_by_value(1 - tf.exp(logit_yes), 1e-20, 1.0))
            # draw 2 Gumbel noise and compute action by argmax
            logits = tf.transpose(tf.stack([logit_no, logit_yes]), [1, 0])
            actions = gumbel.gumbel_softmax(logits)
            action_mask = tf.cast(tf.argmax(actions, 1), dtype=tf.int32)
            no_left_mask = tf.where(tf.greater(left, 0),
                                    tf.ones_like(left, dtype=tf.int32),
                                    tf.zeros_like(left, dtype=tf.int32))
            output = action_mask * no_left_mask
            actions = tf.reduce_max(actions, 1)
            log_actions = tf.log(actions)
            # compute log_q_j and update count and subset
            count = count + output
            left = left - output
            log_q = log_q + log_actions
            subset = subset.write(j, output)

            return [tf.add(j, 1), subset, count, left, log_q]

        with tf.variable_scope("teacher/sampling"):
            # Batch sampling
            subset = tf.TensorArray(tf.int32, size=seq_len)
            count = tf.zeros_like(logp[:, 0], dtype=tf.dtypes.int32)
            left = tf.ones_like(logp[:, 0], dtype=tf.dtypes.int32)
            left = left * max_predictions_per_seq
            log_q = tf.zeros_like(count, dtype=tf.dtypes.float32)

            _, subset, count, left, log_q = tf.while_loop(
                sampling_loop_cond, sampling_body,
                [tf.constant(0), subset, count, left, log_q])

            subset = subset.stack()  # K x b x N
            subset = tf.transpose(subset, [1, 0])
        return subset, log_q
예제 #23
0
    def _calculate_partition_table(self, input_mask, action_prob,
                                   max_predictions_per_seq):
        shape = modeling.get_shape_list(action_prob, expected_rank=2)
        seq_len = shape[1]

        with tf.variable_scope("teacher/dp"):
            '''
      Calculate DP table: aims to calculate logZ[0,K]
      # We add an extra row so that when we calculate log_q_yes, we don't have out of bound error
      # Z[b,N+1,k] = log 0 - we do not allow to choose anything
      # logZ size batch_size x N+1 x K+1
      '''
            initZ = tf.TensorArray(tf.float32,
                                   size=max_predictions_per_seq + 1)
            logZ_0 = tf.zeros_like(input_mask, dtype=tf.float32)  # size b x N
            logZ_0 = tf.pad(logZ_0, [[0, 0], [0, 1]],
                            "CONSTANT")  # size b x N+1
            initZ = initZ.write(tf.constant(0), logZ_0)

            # mask logp
            action_prob = tf.cast(input_mask, dtype=tf.float32) * action_prob
            action_prob = tf.clip_by_value(action_prob, 1e-20, 1.0)
            logp = tf.log(action_prob)
            logp_no = tf.log(1 - action_prob)
            accum_logp = tf.cumsum(logp, axis=1, reverse=True)

            def accum_cond(j, logZ_j, logb, loga):
                return tf.greater(j, -1)

            def accum_body(j, logZ_j, logb, loga):
                # logb: log_yes = logp[j] + logZ[j+1, k-1] -- already compute
                # loga: log_no = log(1-p[j]) + logZ[j+1, k]
                logb_j = tf.squeeze(logb[:, j])
                log_one_minus_p_j = tf.squeeze(logp_no[:, j])
                loga = loga + log_one_minus_p_j
                next_logZ_j = tf.math.reduce_logsumexp(
                    tf.stack([loga, logb_j]), 0)
                logZ_j = logZ_j.write(j, next_logZ_j)
                return [tf.subtract(j, 1), logZ_j, logb, next_logZ_j]

            def dp_loop_cond(k, logZ, lastZ):
                return tf.less(k, max_predictions_per_seq + 1)

            def dp_body(k, logZ, lastZ):
                '''
        case j < N-k + 1:
          logZ[j,k] = log_sum( log(1-pi(j)) + logZ[j+1,k], logp(j) + logZ[j+1,k-1])
        case j = N-k + 1
          logZ[j,k] = accum_logp[j]
        case j > N-k + 1
          logZ[j,k] = 0
        '''

                # shift lastZ one step
                shifted_lastZ = tf.roll(lastZ[:, :-1], shift=1,
                                        axis=1)  #logZ[j+1,k-1]
                log_yes = logp + shifted_lastZ  # b x N
                logZ_j = tf.TensorArray(tf.float32, size=seq_len + 1)
                init_value = accum_logp[:, seq_len - k]
                logZ_j = logZ_j.write(seq_len - k, init_value)
                _, logZ_j, logb, loga = tf.while_loop(
                    accum_cond, accum_body,
                    [seq_len - k - 1, logZ_j, log_yes, init_value])
                logZ_j = logZ_j.stack()  # N x b
                logZ_j = tf.transpose(logZ_j, [1, 0])  # b x N
                logZ = logZ.write(k, logZ_j)
                return [tf.add(k, 1), logZ, logZ_j]

            k = tf.constant(1)
            _, logZ, lastZ = tf.while_loop(dp_loop_cond,
                                           dp_body, [k, initZ, logZ_0],
                                           shape_invariants=[
                                               k.get_shape(),
                                               tf.TensorShape([]),
                                               tf.TensorShape([None, None])
                                           ])
            logZ = logZ.stack()  # N x b x N
            logZ = tf.transpose(logZ, [1, 2, 0])
        return logZ, logp
예제 #24
0
  def _get_richer_data(self, fake_data):
    inputs_tf = fake_data.inputs.input_ids
    labels_tf = fake_data.is_fake_tokens
    lens_tf = tf.reduce_sum(fake_data.inputs.input_mask, 1)
    #retrieve the basic config
    V = self._bert_config.vocab_size
    #sub: 10%, del + ins: 5%
    N = int(self._config.max_predictions_per_seq * self._config.rich_prob)
    B, L = modeling.get_shape_list(inputs_tf)
    nlms = 0
    bilm = None
    if self._config.use_bilm:
      with open(self._config.bilm_file, 'rb') as f:
        bilm = tf.constant(np.load(f), tf.int32)
      _, nlms = modeling.get_shape_list(bilm)
    #make multiple partitions for edit op
    splits_list = []
    for i in range(B):
      one = tf.random.uniform([N * 4], 1, lens_tf[i], tf.int32)
      one, _ = tf.unique(one)
      one = tf.cond(tf.less(tf.shape(one)[0], N * 2 + 1),
                    lambda: tf.expand_dims(tf.range(1, N * 2 + 2), 0),
                    lambda: tf.sort(tf.reshape(one[: N * 2 + 1], [1, N * 2 + 1])))
      splits_list.append(one[:, 2::2])
    splits_tf = tf.concat(splits_list, 0)
    splits_up = tf.concat([splits_tf, tf.expand_dims(tf.constant([L] * B, tf.int32), 1)], 1)
    splits_lo = tf.concat([tf.expand_dims(tf.constant([0] * B, tf.int32), 1), splits_tf], 1)
    size_splits = splits_up - splits_lo
    #update the inputs and labels giving random insertion and deletion
    new_labels_list = []
    new_inputs_list = []
    for i in range(B):
      inputs_splits = tf.split(inputs_tf[i, :], size_splits[i, :])
      labels_splits = tf.split(labels_tf[i, :], size_splits[i, :])
      one_inputs = []
      one_labels = []
      size_split = len(inputs_splits)
      inputs_end = inputs_splits[-1]
      labels_end = labels_splits[-1]
      for j in range(size_split-1):
        inputs = inputs_splits[j]
        labels = labels_splits[j] #label 1 for substistution
        rand_op = random.randint(2, self._config.num_preds - 1) 
        if rand_op == 2: #label 2 for insertion
          if bilm is None: #noise
            insert_tok = tf.random.uniform([1], 1, V, tf.int32)
          else: #2-gram prediction
            insert_tok = tf.expand_dims(bilm[inputs[-1], random.randint(0, nlms-1)], 0)
          is_end_valid = tf.less_equal(2, tf.shape(inputs_end)[0])
          inputs = tf.cond(is_end_valid, lambda: tf.concat([inputs, insert_tok], 0), lambda: inputs)
          labels = tf.cond(is_end_valid, lambda: tf.concat([labels, tf.constant([2])], 0), lambda: labels)
          inputs_end = tf.cond(is_end_valid, lambda: inputs_end[:-1], lambda: inputs_end)
          labels_end = tf.cond(is_end_valid, lambda: labels_end[:-1], lambda: labels_end)
        elif rand_op == 3: #label 3 for deletion
          labels = tf.concat([labels[:-2], tf.constant([3])], 0)
          inputs = inputs[:-1]
          inputs_end = tf.concat([inputs_end, tf.constant([0])], 0)
          labels_end = tf.concat([labels_end, tf.constant([0])], 0)
        elif rand_op == 4: #label 4 for swapping
          labels = tf.concat([labels[:-1], tf.constant([4])], 0)
          inputs = tf.concat([inputs[:-2], [inputs[-1]], [inputs[-2]]], 0)
        one_labels.append(labels)
        one_inputs.append(inputs)
      one_inputs.append(inputs_end)
      one_labels.append(labels_end)
      one_inputs_tf = tf.concat(one_inputs, 0)
      one_labels_tf = tf.concat(one_labels, 0)
      one_inputs_tf = tf.cond(tf.less(lens_tf[i], N * 2 + 1), lambda: inputs_tf[i, :], lambda: one_inputs_tf)
      one_labels_tf = tf.cond(tf.less(lens_tf[i], N * 2 + 1), lambda: labels_tf[i, :], lambda: one_labels_tf)
      new_inputs_list.append(tf.expand_dims(one_inputs_tf, 0))
      new_labels_list.append(tf.expand_dims(one_labels_tf, 0))

    new_inputs_tf = tf.concat(new_inputs_list, 0)
    new_labels_tf = tf.concat(new_labels_list, 0)
    new_input_mask = tf.cast(tf.not_equal(new_inputs_tf, 0), tf.int32)
    updated_inputs = pretrain_data.get_updated_inputs(
        fake_data.inputs, input_ids=new_inputs_tf, input_mask=new_input_mask)
    RicherData = collections.namedtuple("RicherData", [
        "inputs", "is_fake_tokens", "sampled_tokens"])
    return RicherData(inputs=updated_inputs, is_fake_tokens=new_labels_tf,
                     sampled_tokens=fake_data.sampled_tokens)
예제 #25
0
    def __init__(self, config: configure_pretraining.PretrainingConfig,
                 features, is_training):
        # Set up model config
        self._config = config
        self._bert_config = training_utils.get_bert_config(config)
        self._teacher_config = training_utils.get_teacher_config(config)

        embedding_size = (self._bert_config.hidden_size
                          if config.embedding_size is None else
                          config.embedding_size)

        tokenizer = tokenization.FullTokenizer(
            config.vocab_file, do_lower_case=config.do_lower_case)
        self._vocab = tokenizer.vocab
        self._inv_vocab = tokenizer.inv_vocab

        # Mask the input
        inputs = pretrain_data.features_to_inputs(features)
        old_model = self._build_transformer(inputs,
                                            is_training,
                                            embedding_size=embedding_size)
        input_states = old_model.get_sequence_output()
        input_states = tf.stop_gradient(input_states)

        teacher_output = self._build_teacher(input_states,
                                             inputs,
                                             is_training,
                                             embedding_size=embedding_size)
        # calculate the proposal distribution

        action_prob = teacher_output.action_probs  #pi(x_i)

        coin_toss = tf.random.uniform([])
        log_q, masked_inputs = self._sample_masking_subset(inputs, action_prob)
        if config.masking_strategy == pretrain_helpers.MIX_ADV_STRATEGY:
            random_masked_input = pretrain_helpers.mask(
                config, pretrain_data.features_to_inputs(features),
                config.mask_prob)
            B, L = modeling.get_shape_list(inputs.input_ids)
            N = config.max_predictions_per_seq
            strategy_prob = tf.random.uniform([B])
            strategy_prob = tf.expand_dims(
                tf.cast(tf.greater(strategy_prob, 0.5), tf.int32), 1)
            l_strategy_prob = tf.tile(strategy_prob, [1, L])
            n_strategy_prob = tf.tile(strategy_prob, [1, N])
            mix_input_ids = masked_inputs.input_ids * l_strategy_prob + random_masked_input.input_ids * (
                1 - l_strategy_prob)
            mix_masked_lm_positions = masked_inputs.masked_lm_positions * n_strategy_prob + random_masked_input.masked_lm_positions * (
                1 - n_strategy_prob)
            mix_masked_lm_ids = masked_inputs.masked_lm_ids * n_strategy_prob + random_masked_input.masked_lm_ids * (
                1 - n_strategy_prob)
            n_strategy_prob = tf.cast(n_strategy_prob, tf.float32)
            mix_masked_lm_weights = masked_inputs.masked_lm_weights * n_strategy_prob + random_masked_input.masked_lm_weights * (
                1 - n_strategy_prob)
            mix_masked_inputs = pretrain_data.get_updated_inputs(
                inputs,
                input_ids=tf.stop_gradient(mix_input_ids),
                masked_lm_positions=mix_masked_lm_positions,
                masked_lm_ids=mix_masked_lm_ids,
                masked_lm_weights=mix_masked_lm_weights,
                tag_ids=inputs.tag_ids)
            masked_inputs = mix_masked_inputs

        # BERT model
        model = self._build_transformer(masked_inputs,
                                        is_training,
                                        reuse=tf.AUTO_REUSE,
                                        embedding_size=embedding_size)
        mlm_output = self._get_masked_lm_output(masked_inputs, model)
        self.total_loss = mlm_output.loss

        # Teacher reward is the -log p(x_S|x;B)
        reward = tf.stop_gradient(
            tf.reduce_mean(mlm_output.per_example_loss, 1))
        self._baseline = tf.reduce_mean(reward, -1)
        self._std = tf.math.reduce_std(reward, -1)

        # Calculate teacher loss
        def compute_teacher_loss(log_q, reward, baseline, std):
            advantage = tf.abs((reward - baseline) / std)
            advantage = tf.stop_gradient(advantage)
            log_q = tf.Print(log_q, [log_q], "log_q: ")
            teacher_loss = tf.reduce_mean(-log_q * advantage)
            return teacher_loss

        teacher_loss = tf.cond(
            coin_toss < 0.1, lambda: compute_teacher_loss(
                log_q, reward, self._baseline, self._std),
            lambda: tf.constant(0.0))
        self.total_loss = mlm_output.loss + teacher_loss
        self.teacher_loss = teacher_loss
        self.mlm_loss = mlm_output.loss

        # Evaluation`
        eval_fn_inputs = {
            "input_ids": masked_inputs.input_ids,
            "masked_lm_preds": mlm_output.preds,
            "mlm_loss": mlm_output.per_example_loss,
            "masked_lm_ids": masked_inputs.masked_lm_ids,
            "masked_lm_weights": masked_inputs.masked_lm_weights,
            "input_mask": masked_inputs.input_mask
        }
        eval_fn_keys = eval_fn_inputs.keys()
        eval_fn_values = [eval_fn_inputs[k] for k in eval_fn_keys]
        """Computes the loss and accuracy of the model."""
        d = {k: arg for k, arg in zip(eval_fn_keys, eval_fn_values)}
        metrics = dict()
        metrics["masked_lm_accuracy"] = tf.metrics.accuracy(
            labels=tf.reshape(d["masked_lm_ids"], [-1]),
            predictions=tf.reshape(d["masked_lm_preds"], [-1]),
            weights=tf.reshape(d["masked_lm_weights"], [-1]))
        metrics["masked_lm_loss"] = tf.metrics.mean(
            values=tf.reshape(d["mlm_loss"], [-1]),
            weights=tf.reshape(d["masked_lm_weights"], [-1]))
        self.eval_metrics = metrics