def _get_fake_data(self, inputs, mlm_logits): """Sample from the generator to create corrupted input.""" inputs = pretrain_helpers.unmask(inputs) disallow = tf.one_hot( inputs.masked_lm_ids, depth=self._bert_config.vocab_size, dtype=tf.float32) if self._config.disallow_correct else None sampled_tokens = tf.stop_gradient( pretrain_helpers.sample_from_softmax(mlm_logits / self._config.temperature, disallow=disallow)) sampled_tokids = tf.argmax(sampled_tokens, -1, output_type=tf.int32) updated_input_ids, masked = pretrain_helpers.scatter_update( inputs.input_ids, sampled_tokids, inputs.masked_lm_positions) if self._config.electric_objective: labels = masked else: labels = masked * (1 - tf.cast( tf.equal(updated_input_ids, inputs.input_ids), tf.int32)) updated_inputs = pretrain_data.get_updated_inputs( inputs, input_ids=updated_input_ids) FakedData = collections.namedtuple( "FakedData", ["inputs", "is_fake_tokens", "sampled_tokens"]) return FakedData(inputs=updated_inputs, is_fake_tokens=labels, sampled_tokens=sampled_tokens)
def _get_fake_data(inputs, mlm_logits): """Sample from the generator to create corrupted input.""" masked_lm_weights = inputs.masked_lm_weights inputs = pretrain_helpers.unmask(inputs) disallow = None sampled_tokens = tf.stop_gradient( pretrain_helpers.sample_from_softmax(mlm_logits / 1.0, disallow=disallow)) # sampled_tokens: [batch_size, n_pos, n_vocab] # mlm_logits: [batch_size, n_pos, n_vocab] sampled_tokens_fp32 = tf.cast(sampled_tokens, dtype=tf.float32) print(sampled_tokens_fp32, "===sampled_tokens_fp32===") # [batch_size, n_pos] # mlm_logprobs: [batch_size, n_pos. n_vocab] mlm_logprobs = tf.nn.log_softmax(mlm_logits, axis=-1) pseudo_logprob = tf.reduce_sum(mlm_logprobs * sampled_tokens_fp32, axis=-1) pseudo_logprob *= tf.cast(masked_lm_weights, dtype=tf.float32) # [batch_size] pseudo_logprob = tf.reduce_sum(pseudo_logprob, axis=-1) # [batch_size] # pseudo_logprob /= (1e-10+tf.reduce_sum(tf.cast(masked_lm_weights, dtype=tf.float32), axis=-1)) print("== _get_fake_data pseudo_logprob ==", pseudo_logprob) sampled_tokids = tf.argmax(sampled_tokens, -1, output_type=tf.int32) updated_input_ids, masked = pretrain_helpers.scatter_update( inputs.input_ids, sampled_tokids, inputs.masked_lm_positions) labels = masked * ( 1 - tf.cast(tf.equal(updated_input_ids, inputs.input_ids), tf.int32)) updated_inputs = pretrain_data.get_updated_inputs( inputs, input_ids=updated_input_ids) FakedData = collections.namedtuple( "FakedData", ["inputs", "is_fake_tokens", "sampled_tokens", "pseudo_logprob"]) return FakedData(inputs=updated_inputs, is_fake_tokens=labels, sampled_tokens=sampled_tokens, pseudo_logprob=pseudo_logprob)