Ejemplo n.º 1
0
 def _get_fake_data(self, inputs, mlm_logits):
     """Sample from the generator to create corrupted input."""
     inputs = pretrain_helpers.unmask(inputs)
     disallow = tf.one_hot(
         inputs.masked_lm_ids,
         depth=self._bert_config.vocab_size,
         dtype=tf.float32) if self._config.disallow_correct else None
     sampled_tokens = tf.stop_gradient(
         pretrain_helpers.sample_from_softmax(mlm_logits /
                                              self._config.temperature,
                                              disallow=disallow))
     sampled_tokids = tf.argmax(sampled_tokens, -1, output_type=tf.int32)
     updated_input_ids, masked = pretrain_helpers.scatter_update(
         inputs.input_ids, sampled_tokids, inputs.masked_lm_positions)
     if self._config.electric_objective:
         labels = masked
     else:
         labels = masked * (1 - tf.cast(
             tf.equal(updated_input_ids, inputs.input_ids), tf.int32))
     updated_inputs = pretrain_data.get_updated_inputs(
         inputs, input_ids=updated_input_ids)
     FakedData = collections.namedtuple(
         "FakedData", ["inputs", "is_fake_tokens", "sampled_tokens"])
     return FakedData(inputs=updated_inputs,
                      is_fake_tokens=labels,
                      sampled_tokens=sampled_tokens)
Ejemplo n.º 2
0
def _get_fake_data(inputs, mlm_logits):
    """Sample from the generator to create corrupted input."""
    masked_lm_weights = inputs.masked_lm_weights
    inputs = pretrain_helpers.unmask(inputs)
    disallow = None
    sampled_tokens = tf.stop_gradient(
        pretrain_helpers.sample_from_softmax(mlm_logits / 1.0,
                                             disallow=disallow))

    # sampled_tokens: [batch_size, n_pos, n_vocab]
    # mlm_logits: [batch_size, n_pos, n_vocab]
    sampled_tokens_fp32 = tf.cast(sampled_tokens, dtype=tf.float32)
    print(sampled_tokens_fp32, "===sampled_tokens_fp32===")
    # [batch_size, n_pos]
    # mlm_logprobs: [batch_size, n_pos. n_vocab]
    mlm_logprobs = tf.nn.log_softmax(mlm_logits, axis=-1)
    pseudo_logprob = tf.reduce_sum(mlm_logprobs * sampled_tokens_fp32, axis=-1)
    pseudo_logprob *= tf.cast(masked_lm_weights, dtype=tf.float32)
    # [batch_size]
    pseudo_logprob = tf.reduce_sum(pseudo_logprob, axis=-1)
    # [batch_size]
    # pseudo_logprob /= (1e-10+tf.reduce_sum(tf.cast(masked_lm_weights, dtype=tf.float32), axis=-1))
    print("== _get_fake_data pseudo_logprob ==", pseudo_logprob)
    sampled_tokids = tf.argmax(sampled_tokens, -1, output_type=tf.int32)
    updated_input_ids, masked = pretrain_helpers.scatter_update(
        inputs.input_ids, sampled_tokids, inputs.masked_lm_positions)

    labels = masked * (
        1 - tf.cast(tf.equal(updated_input_ids, inputs.input_ids), tf.int32))
    updated_inputs = pretrain_data.get_updated_inputs(
        inputs, input_ids=updated_input_ids)
    FakedData = collections.namedtuple(
        "FakedData",
        ["inputs", "is_fake_tokens", "sampled_tokens", "pseudo_logprob"])
    return FakedData(inputs=updated_inputs,
                     is_fake_tokens=labels,
                     sampled_tokens=sampled_tokens,
                     pseudo_logprob=pseudo_logprob)