def _get_fake_data(self, inputs, mlm_logits): """Sample from the generator to create corrupted input.""" inputs = pretrain_helpers.unmask(inputs) disallow = tf.one_hot( inputs.masked_lm_ids, depth=self._bert_config.vocab_size, dtype=tf.float32) if self._config.disallow_correct else None sampled_tokens = tf.stop_gradient( pretrain_helpers.sample_from_softmax(mlm_logits / self._config.temperature, disallow=disallow)) sampled_tokids = tf.argmax(sampled_tokens, -1, output_type=tf.int32) updated_input_ids, masked = pretrain_helpers.scatter_update( inputs.input_ids, sampled_tokids, inputs.masked_lm_positions) if self._config.electric_objective: labels = masked else: labels = masked * (1 - tf.cast( tf.equal(updated_input_ids, inputs.input_ids), tf.int32)) updated_inputs = pretrain_data.get_updated_inputs( inputs, input_ids=updated_input_ids) FakedData = collections.namedtuple( "FakedData", ["inputs", "is_fake_tokens", "sampled_tokens"]) return FakedData(inputs=updated_inputs, is_fake_tokens=labels, sampled_tokens=sampled_tokens)
def _get_fake_data(inputs, mlm_logits): """Sample from the generator to create corrupted input.""" masked_lm_weights = inputs.masked_lm_weights inputs = pretrain_helpers.unmask(inputs) disallow = None sampled_tokens = tf.stop_gradient( pretrain_helpers.sample_from_softmax(mlm_logits / 1.0, disallow=disallow)) # sampled_tokens: [batch_size, n_pos, n_vocab] # mlm_logits: [batch_size, n_pos, n_vocab] sampled_tokens_fp32 = tf.cast(sampled_tokens, dtype=tf.float32) print(sampled_tokens_fp32, "===sampled_tokens_fp32===") # [batch_size, n_pos] # mlm_logprobs: [batch_size, n_pos. n_vocab] mlm_logprobs = tf.nn.log_softmax(mlm_logits, axis=-1) pseudo_logprob = tf.reduce_sum(mlm_logprobs * sampled_tokens_fp32, axis=-1) pseudo_logprob *= tf.cast(masked_lm_weights, dtype=tf.float32) # [batch_size] pseudo_logprob = tf.reduce_sum(pseudo_logprob, axis=-1) # [batch_size] # pseudo_logprob /= (1e-10+tf.reduce_sum(tf.cast(masked_lm_weights, dtype=tf.float32), axis=-1)) print("== _get_fake_data pseudo_logprob ==", pseudo_logprob) sampled_tokids = tf.argmax(sampled_tokens, -1, output_type=tf.int32) updated_input_ids, masked = pretrain_helpers.scatter_update( inputs.input_ids, sampled_tokids, inputs.masked_lm_positions) labels = masked * ( 1 - tf.cast(tf.equal(updated_input_ids, inputs.input_ids), tf.int32)) updated_inputs = pretrain_data.get_updated_inputs( inputs, input_ids=updated_input_ids) FakedData = collections.namedtuple( "FakedData", ["inputs", "is_fake_tokens", "sampled_tokens", "pseudo_logprob"]) return FakedData(inputs=updated_inputs, is_fake_tokens=labels, sampled_tokens=sampled_tokens, pseudo_logprob=pseudo_logprob)
def __init__(self, config: configure_pretraining.PretrainingConfig, features, is_training): # Set up model config self._config = config self._bert_config = training_utils.get_bert_config(config) if config.debug: self._bert_config.num_hidden_layers = 3 self._bert_config.hidden_size = 144 self._bert_config.intermediate_size = 144 * 4 self._bert_config.num_attention_heads = 4 # Mask the input # masked_inputs = pretrain_helpers.mask( # config, pretrain_data.features_to_inputs(features), config.mask_prob) # tf.logging.error(f"features to inputs: {pretrain_data.features_to_inputs(features)}") # tf.logging.error(f"features: {features}") masked_inputs = pretrain_helpers.mask( config, pretrain_helpers.unmask( pretrain_data.features_to_inputs(features)), config.mask_prob) # Generator embedding_size = (self._bert_config.hidden_size if config.embedding_size is None else config.embedding_size) if config.uniform_generator: mlm_output = self._get_masked_lm_output(masked_inputs, None) elif config.electra_objective and config.untied_generator: generator = self._build_transformer( masked_inputs, is_training, bert_config=get_generator_config(config, self._bert_config), embedding_size=(None if config.untied_generator_embeddings else embedding_size), untied_embeddings=config.untied_generator_embeddings, name="generator") mlm_output = self._get_masked_lm_output(masked_inputs, generator) else: generator = self._build_transformer(masked_inputs, is_training, embedding_size=embedding_size) mlm_output = self._get_masked_lm_output(masked_inputs, generator) fake_data = self._get_fake_data(masked_inputs, mlm_output.logits) self.mlm_output = mlm_output self.total_loss = config.gen_weight * mlm_output.loss # Discriminator disc_output = None if config.electra_objective: discriminator = self._build_transformer( fake_data.inputs, is_training, reuse=not config.untied_generator, embedding_size=embedding_size) disc_output = self._get_discriminator_output( fake_data.inputs, discriminator, fake_data.is_fake_tokens) self.total_loss += config.disc_weight * disc_output.loss tvars = tf.trainable_variables() initialized_variable_names = {} self.scaffold_fn = None if config.init_checkpoint: (assignment_map, initialized_variable_names ) = modeling.get_assignment_map_from_checkpoint( tvars, config.init_checkpoint) if config.use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(config.init_checkpoint, assignment_map) return tf.train.Scaffold() self.scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(config.init_checkpoint, assignment_map) # Evaluation eval_fn_inputs = { "input_ids": masked_inputs.input_ids, "masked_lm_preds": mlm_output.preds, "mlm_loss": mlm_output.per_example_loss, "masked_lm_ids": masked_inputs.masked_lm_ids, "masked_lm_weights": masked_inputs.masked_lm_weights, "input_mask": masked_inputs.input_mask } if config.electra_objective: eval_fn_inputs.update({ "disc_loss": disc_output.per_example_loss, "disc_labels": disc_output.labels, "disc_probs": disc_output.probs, "disc_preds": disc_output.preds, "sampled_tokids": tf.argmax(fake_data.sampled_tokens, -1, output_type=tf.int32) }) eval_fn_keys = eval_fn_inputs.keys() eval_fn_values = [eval_fn_inputs[k] for k in eval_fn_keys] def metric_fn(*args): """Computes the loss and accuracy of the model.""" d = {k: arg for k, arg in zip(eval_fn_keys, args)} metrics = dict() metrics["masked_lm_accuracy"] = tf.metrics.accuracy( labels=tf.reshape(d["masked_lm_ids"], [-1]), predictions=tf.reshape(d["masked_lm_preds"], [-1]), weights=tf.reshape(d["masked_lm_weights"], [-1])) metrics["masked_lm_loss"] = tf.metrics.mean( values=tf.reshape(d["mlm_loss"], [-1]), weights=tf.reshape(d["masked_lm_weights"], [-1])) if config.electra_objective: metrics["sampled_masked_lm_accuracy"] = tf.metrics.accuracy( labels=tf.reshape(d["masked_lm_ids"], [-1]), predictions=tf.reshape(d["sampled_tokids"], [-1]), weights=tf.reshape(d["masked_lm_weights"], [-1])) if config.disc_weight > 0: metrics["disc_loss"] = tf.metrics.mean(d["disc_loss"]) metrics["disc_auc"] = tf.metrics.auc( d["disc_labels"] * d["input_mask"], d["disc_probs"] * tf.cast(d["input_mask"], tf.float32)) metrics["disc_accuracy"] = tf.metrics.accuracy( labels=d["disc_labels"], predictions=d["disc_preds"], weights=d["input_mask"]) metrics["disc_precision"] = tf.metrics.accuracy( labels=d["disc_labels"], predictions=d["disc_preds"], weights=d["disc_preds"] * d["input_mask"]) metrics["disc_recall"] = tf.metrics.accuracy( labels=d["disc_labels"], predictions=d["disc_preds"], weights=d["disc_labels"] * d["input_mask"]) return metrics, scaffold_fn self.eval_metrics = (metric_fn, eval_fn_values)